Compare commits

...

66 commits
0.4.11 ... main

Author SHA1 Message Date
Musa
473ec70b5c
Bump version to 0.4.21 (#911)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-24 14:27:32 -07:00
Syed A. Hashmi
dafd245332
signals: restore the pre-port flag marker emoji (🚩) (#913)
* signals: restore the pre-port flag marker emoji

#903 inadvertently replaced the legacy FLAG_MARKER (U+1F6A9, '🚩') with
'[!]', which broke any downstream dashboard / alert that searches span
names for the flag emoji. Restores the original marker and updates the
#910 docs pass to match.

- crates/brightstaff/src/signals/analyzer.rs: FLAG_MARKER back to
  "\\u{1F6A9}" with a comment noting the backwards-compatibility
  reason so it doesn't drift again.
- docs/source/concepts/signals.rst and docs/source/guides/observability/
  tracing.rst: swap every '[!]' reference (subheading text, example
  span name, tip box, dashboard query hint) back to 🚩.

Verified: cargo test -p brightstaff --lib (162 passed, 1 ignored);
sphinx-build clean on both files; rendered HTML shows 🚩 in all
flag-marker references.

Made-with: Cursor

* fix: silence manual_checked_ops clippy lint (rustc 1.95)

Pre-existing warning in router/stress_tests.rs that becomes an error
under CI's -D warnings with rustc 1.95. Replace the manual if/else
with growth.checked_div(num_iterations).unwrap_or(0) as clippy
suggests.

Made-with: Cursor
2026-04-24 13:54:53 -07:00
Musa
897fda2deb
fix(routing): auto-migrate v0.3.0 inline routing_preferences to v0.4.0 top-level (#912)
* fix(routing): auto-migrate v0.3.0 inline routing_preferences to v0.4.0 top-level

Lift inline routing_preferences under each model_provider into the
top-level routing_preferences list with merged models[] and bump
version to v0.4.0, with a deprecation warning. Existing v0.3.0
demo configs (Claude Code, Codex, preference_based_routing, etc.)
keep working unchanged. Schema flags the inline shape as deprecated
but still accepts it. Docs and skills updated to canonical top-level
multi-model form.

* test(common): bump reference config assertion to v0.4.0

The rendered reference config was bumped to v0.4.0 when its inline
routing_preferences were lifted to the top level; align the
configuration deserialization test with that change.

* fix(config_generator): bump version to v0.4.0 up front in migration

Move the v0.3.0 -> v0.4.0 version bump to the top of
migrate_inline_routing_preferences so it runs unconditionally,
including for configs that already declare top-level
routing_preferences at v0.3.0. Previously the bump only fired
when inline migration produced entries, leaving top-level v0.3.0
configs rejected by brightstaff's v0.4.0 gate. Tests updated to
cover the new behavior and to confirm we never downgrade newer
versions.

* fix(config_generator): gate routing_preferences migration on version < v0.4.0

Short-circuit the migration when the config already declares v0.4.0
or newer. Anything at v0.4.0+ is assumed to be on the canonical
top-level shape and is passed through untouched, including stray
inline preferences (which are the author's bug to fix). Only v0.3.0
and older configs are rewritten and bumped.
2026-04-24 12:31:44 -07:00
Syed A. Hashmi
5a652eb666
docs: align signals page with paper taxonomy (#910)
* docs: align signals page with paper taxonomy

Updates docs/source/concepts/signals.rst and the tracing guide's signals
subsection to reflect the three-layer taxonomy shipped in #903:

- Introduces the paper reference (arXiv:2604.00356) and the three layers
  (interaction, execution, environment) with all 20 leaf signal types in
  three reference tables
- Documents the new layered OTel attribute set
  (signals.interaction.*, signals.execution.*, signals.environment.*)
  and marks the legacy aggregate keys (signals.follow_up.repair.*,
  signals.frustration.*, signals.repetition.count,
  signals.escalation.requested, signals.positive_feedback.count) as
  deprecated-but-still-emitted
- Adds a Span Events section describing the per-instance signal.<type>
  events with confidence / snippet / metadata attributes
- Fixes the flag marker reference ([!] in the code vs 🚩 in the old docs)
- Updates all example attributes, dashboard queries, and alert rules to
  use the layered keys
- Updates the tracing guide's behavioral-signals subsection to match
- Notes that the triage sampler is a planned follow-up and today sampling
  is consumer-side via observability-platform filters

Build verified locally: sphinx-build produces no warnings on these files.

Made-with: Cursor

* docs: reframe signals intro around the improvement flywheel

Addresses review feedback on #910:

- Replace the triage-only framing at the top with an instrument -> sample
  & triage -> construct data -> optimize -> deploy flywheel that explains
  why signals matter, not just what they surface. Paper's 82% / 1.52x
  numbers move into step 2 of the flywheel where they belong.
- Remove the 'Signals vs Response Quality' section. Per review, signals
  and response quality overlap rather than complement each other, so the
  comparison is misleading.
- Borrow the per-category summaries and leaf-type descriptions verbatim
  from the katanemo/signals reference implementation (module docstrings)
  so the documentation and the detector contract stay in sync. Drops the
  hand-crafted examples that were not strictly accurate (e.g. 'semantic
  overlap is high' for rephrase, 'user explicitly corrects the agent'
  for correction).

Made-with: Cursor

* docs: address signals flywheel review feedback

Addresses review comments on #910:

- Shorten the paper citation to (Chen et al., 2026) per common cite
  practice (replacing the full author list form).
- Replace the Why Signals Matter section with the review-suggested
  rewrite verbatim: more formal intro framing, renumbered steps to
  Instrument / Sample & triage / Data Construction / Model Optimization
  / Deploy, removes 'routing decisions' from the data-construction
  step, and adds DPO/RLHF/SFT as model-optimization examples.
- Renders tau and O(messages) as proper math glyphs via the sphinx
  built-in :math: role (enabled by adding sphinx.ext.mathjax to
  conf.py). Using the RST role form rather than raw $...$ inline so
  sphinx only injects MathJax on pages that actually have math,
  instead of loading ~1MB of JS on every page.

Build verified locally: sphinx-build produces no warnings on the
changed files and the rendered HTML wraps tau and O(messages) in
MathJax-ready <span class="math">\(\tau\)</span> containers.

Made-with: Cursor
2026-04-24 12:31:14 -07:00
Musa
b81eb7266c
feat(providers): add Vercel AI Gateway and OpenRouter support (#902)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
* add Vercel and OpenRouter as OpenAI-compatible LLM providers

* fix(fmt): fix cargo fmt line length issues in provider id tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style(hermesllm): fix rustfmt formatting in provider id tests

* Add Vercel and OpenRouter to zero-config planoai up defaults

Wires `vercel/*` and `openrouter/*` into the synthesized default config so
`planoai up` with no user config exposes both providers out of the box
(env-keyed via AI_GATEWAY_API_KEY / OPENROUTER_API_KEY, pass-through
otherwise). Registers both in SUPPORTED_PROVIDERS_WITHOUT_BASE_URL so
wildcard model entries validate without an explicit provider_interface.

---------

Co-authored-by: Musa Malik <musam@uw.edu>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-23 15:54:39 -07:00
Musa
78dc4edad9
Add first-class ChatGPT subscription provider support (#881)
* Add first-class ChatGPT subscription provider support

* Address PR feedback: move uuid import to top, reuse parsed config in up()

* Add ChatGPT token watchdog for seamless long-lived sessions

* Address PR feedback: error on stream=false for ChatGPT, fix auth file permissions

* Replace ChatGPT watchdog/restart with passthrough_auth

---------

Co-authored-by: Musa Malik <musam@uw.edu>
2026-04-23 15:34:44 -07:00
Adil Hafeez
aa726b1bba
add jemalloc and /debug/memstats endpoint for OOM diagnosis (#885)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-23 13:59:12 -07:00
Syed A. Hashmi
c8079ac971
signals: feature parity with the latest Signals paper. Porting logic from python repo (#903)
* signals: port to layered taxonomy with dual-emit OTel

Made-with: Cursor

* fix: silence collapsible_match clippy lint (rustc 1.95)

Made-with: Cursor

* test: parity harness for rust vs python signals analyzer

Validates the brightstaff signals port against the katanemo/signals Python
reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python-
compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle
bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch
CI job.

Made-with: Cursor

* Remove signals test from the gitops flow

* style: format parity harness with black

Made-with: Cursor

* signals: group summary by taxonomy, factor misalignment_ratio

Addresses #903 review feedback from @nehcgs:

- generate_summary() now renders explicit Interaction / Execution /
  Environment headers so the paper taxonomy is visible at a glance,
  even when no signals fired in a given layer. Quality-driving callouts
  (high misalignment rate, looping detected, escalation requested) are
  appended after the layer summary as an alerts tail.

- repair_ratio (legacy taxonomy name) renamed to misalignment_ratio
  and factored into a single InteractionSignals::misalignment_ratio()
  helper so assess_quality and generate_summary share one source of
  truth instead of recomputing the same divide twice.

Two new unit tests pin the layer headers and the (sev N) severity
suffix. Parity with the python reference is preserved at the Tier-A
level (per-type counts + overall_quality); only the human-readable
summary string diverges, which the parity comparator already classifies
as Tier-C.

Made-with: Cursor
2026-04-23 12:02:30 -07:00
Adil Hafeez
6701195a5d
add overrides.disable_signals to skip CPU-heavy signal analysis (#906) 2026-04-23 11:38:29 -07:00
Adil Hafeez
22f332f62d
Add Prometheus metrics endpoint and Grafana dashboard for brightstaff (#904)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-22 11:19:10 -07:00
Adil Hafeez
9812540602
Improve obs model name matching, latency metrics, and error reporting (#900)
Some checks failed
CI / pre-commit (push) Has been cancelled
CI / plano-tools-tests (push) Has been cancelled
CI / native-smoke-test (push) Has been cancelled
CI / docker-build (push) Has been cancelled
CI / validate-config (push) Has been cancelled
Publish docker image (latest) / build-arm64 (push) Has been cancelled
Publish docker image (latest) / build-amd64 (push) Has been cancelled
Build and Deploy Documentation / build (push) Has been cancelled
CI / security-scan (push) Has been cancelled
CI / test-prompt-gateway (push) Has been cancelled
CI / test-model-alias-routing (push) Has been cancelled
CI / test-responses-api-with-state (push) Has been cancelled
CI / e2e-plano-tests (3.10) (push) Has been cancelled
CI / e2e-plano-tests (3.11) (push) Has been cancelled
CI / e2e-plano-tests (3.12) (push) Has been cancelled
CI / e2e-plano-tests (3.13) (push) Has been cancelled
CI / e2e-plano-tests (3.14) (push) Has been cancelled
CI / e2e-demo-preference (push) Has been cancelled
CI / e2e-demo-currency (push) Has been cancelled
Publish docker image (latest) / create-manifest (push) Has been cancelled
2026-04-18 21:21:15 -07:00
Adil Hafeez
c3c213b2fd
Fix request closures during long-running streaming (#899) 2026-04-18 21:20:34 -07:00
Adil Hafeez
78d8c90184
Add claude-opus-4-7 to anthropic provider models (#901)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-18 19:10:57 -07:00
Adil Hafeez
ffea891dba
fix: prevent index-out-of-bounds panic in signal analyzer follow-up (#896) 2026-04-18 16:24:02 -07:00
Adil Hafeez
e7464b817a
fix(anthropic-stream): avoid bare/duplicate message_stop on OpenAI upstream (#898) 2026-04-18 15:57:34 -07:00
Adil Hafeez
254d2b03bc
release: bump version to 0.4.20 (#897)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-17 21:16:12 -07:00
Adil Hafeez
95a7beaab3
fix: truncate oversized user messages in orchestrator routing prompt (#895) 2026-04-17 21:01:30 -07:00
Adil Hafeez
37600fd07a
fix: passthrough_auth accepts Anthropic x-api-key and normalizes to upstream format (#892)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-17 17:23:05 -07:00
Adil Hafeez
0f67b2c806
planoai obs: live LLM observability TUI (#891) 2026-04-17 14:03:47 -07:00
Adil Hafeez
1f701258cb
Zero-config planoai up: pass-through proxy with auto-detected providers (#890) 2026-04-17 13:11:12 -07:00
Adil Hafeez
711e4dd07d
Add DigitalOcean as a first-class LLM provider (#889) 2026-04-17 12:25:34 -07:00
Musa
743d074184
add Plano agent skills framework and rule set (#797)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
* feat: add initial documentation for Plano Agent Skills

* feat: readme with examples

* feat: add detailed skills documentation and examples for Plano

---------

Co-authored-by: Adil Hafeez <adil.hafeez@gmail.com>
2026-04-16 13:16:51 -07:00
Adil Hafeez
d39d7ddd1c
release 0.4.19 (#887)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-15 16:49:50 -07:00
Adil Hafeez
90b926c2ce
use plano-orchestrator for LLM routing, remove arch-router (#886) 2026-04-15 16:41:42 -07:00
Musa
980faef6be
Redis-backed session cache for cross-replica model affinity (#879)
Some checks failed
CI / pre-commit (push) Has been cancelled
CI / plano-tools-tests (push) Has been cancelled
CI / native-smoke-test (push) Has been cancelled
CI / docker-build (push) Has been cancelled
CI / validate-config (push) Has been cancelled
Publish docker image (latest) / build-arm64 (push) Has been cancelled
Publish docker image (latest) / build-amd64 (push) Has been cancelled
Build and Deploy Documentation / build (push) Has been cancelled
CI / security-scan (push) Has been cancelled
CI / test-prompt-gateway (push) Has been cancelled
CI / test-model-alias-routing (push) Has been cancelled
CI / test-responses-api-with-state (push) Has been cancelled
CI / e2e-plano-tests (3.10) (push) Has been cancelled
CI / e2e-plano-tests (3.11) (push) Has been cancelled
CI / e2e-plano-tests (3.12) (push) Has been cancelled
CI / e2e-plano-tests (3.13) (push) Has been cancelled
CI / e2e-plano-tests (3.14) (push) Has been cancelled
CI / e2e-demo-preference (push) Has been cancelled
CI / e2e-demo-currency (push) Has been cancelled
Publish docker image (latest) / create-manifest (push) Has been cancelled
* add pluggable session cache with Redis backend

* add Redis session affinity demos (Docker Compose and Kubernetes)

* address PR review feedback on session cache

* document Redis session cache backend for model affinity

* sync rendered config reference with session_cache addition

* add tenant-scoped Redis session cache keys and remove dead log_affinity_hit

- Add tenant_header to SessionCacheConfig; when set, cache keys are scoped
  as plano:affinity:{tenant_id}:{session_id} for multi-tenant isolation
- Thread tenant_id through RouterService, routing_service, and llm handlers
- Use Cow<'_, str> in session_key to avoid allocation when no tenant is set
- Remove unused log_affinity_hit (logging was already inlined at call sites)

* remove session_affinity_redis and session_affinity_redis_k8s demos
2026-04-13 19:30:47 -07:00
Musa
128059e7c1
release 0.4.18 (#878)
Some checks failed
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
CI / pre-commit (push) Has been cancelled
CI / plano-tools-tests (push) Has been cancelled
CI / native-smoke-test (push) Has been cancelled
CI / docker-build (push) Has been cancelled
CI / validate-config (push) Has been cancelled
Publish docker image (latest) / build-arm64 (push) Has been cancelled
Publish docker image (latest) / build-amd64 (push) Has been cancelled
Publish docker image (latest) / create-manifest (push) Has been cancelled
Build and Deploy Documentation / build (push) Has been cancelled
2026-04-09 13:12:45 -07:00
Adil Hafeez
8dedf0bec1
Model affinity for consistent model selection in agentic loops (#827)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-08 17:32:02 -07:00
Musa
978b1ea722
Add first-class Xiaomi provider support (#863)
Some checks failed
CI / pre-commit (push) Has been cancelled
CI / plano-tools-tests (push) Has been cancelled
CI / native-smoke-test (push) Has been cancelled
CI / docker-build (push) Has been cancelled
CI / validate-config (push) Has been cancelled
CI / security-scan (push) Has been cancelled
CI / test-prompt-gateway (push) Has been cancelled
CI / test-model-alias-routing (push) Has been cancelled
CI / test-responses-api-with-state (push) Has been cancelled
CI / e2e-plano-tests (3.10) (push) Has been cancelled
CI / e2e-plano-tests (3.11) (push) Has been cancelled
CI / e2e-plano-tests (3.12) (push) Has been cancelled
CI / e2e-plano-tests (3.13) (push) Has been cancelled
CI / e2e-plano-tests (3.14) (push) Has been cancelled
CI / e2e-demo-preference (push) Has been cancelled
CI / e2e-demo-currency (push) Has been cancelled
Publish docker image (latest) / build-arm64 (push) Has been cancelled
Publish docker image (latest) / build-amd64 (push) Has been cancelled
Publish docker image (latest) / create-manifest (push) Has been cancelled
Build and Deploy Documentation / build (push) Has been cancelled
* feat(provider): add xiaomi as first-class provider

* feat(demos): add xiaomi mimo integration demo

* refactor(demos): remove Xiaomi MiMo integration demo and update documentation

* updating model list and adding the xiamoi models

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-389.local>
2026-04-04 09:58:36 -07:00
Adil Hafeez
9406af3a09
release 0.4.17 (#869)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run
2026-04-03 10:05:33 -07:00
Musa
aa16a6dc4b
ci(e2e): stabilize preference demo test execution (#865) 2026-04-02 21:32:20 -04:00
Adil Hafeez
7606c55b4b
support developer role in chat completions API (#867) 2026-04-02 18:10:32 -07:00
Adil Hafeez
1d3f4d6c05
Publish docker images to DigitalOcean Container Registry (#868) 2026-04-02 18:08:49 -07:00
Adil Hafeez
5d79e7a7d4
fix: resolve all open Dependabot security alerts (#866) 2026-04-02 18:00:28 -07:00
Musa
76ff353c1e
fix(web): refresh blog content and featured post selection (#862) 2026-04-02 06:18:19 -07:00
Musa
39b430d74b
feat(web): merge DigitalOcean release announcement updates (#860)
* feat(web): announce DigitalOcean acquisition across sites

* fix(web): make blog routes resilient without Sanity config

* fix(web): add mobile arrow cue to announcement banner

* fix(web): point acquisition links to announcement post
2026-04-02 06:03:52 -07:00
Musa
0857cfafbf
release 0.4.16 (#859) 2026-03-31 17:45:28 -07:00
Musa
f68c21f8df
Handle null prefer in inline routing policy (#856)
* Handle null prefer in inline routing policy

* Use serde defaulting for null selection preference

* Add tests for default selection policy behavior in routing preferences
2026-03-31 17:41:25 -07:00
Musa
3dbda9741e
fix: route Perplexity OpenAI endpoints without /v1 (#854)
* fix: route Perplexity OpenAI paths without /v1

* add tests for Perplexity provider handling in LLM module

* refactor: use constant for Perplexity provider prefix in LLM module

* moving const to top of file
2026-03-31 17:40:42 -07:00
Adil Hafeez
d8f4fd76e3
replace production panics with graceful error handling in common crate (#844) 2026-03-31 14:28:11 -07:00
Musa
36fa42b364
Improve planoai up/down CLI progress output (#858) 2026-03-31 14:26:32 -07:00
Musa
82f34f82f2
Update black hook for Python 3.14 (#857)
* Update pre-commit black to latest release

* Reformat Python files for new black version
2026-03-31 13:18:45 -07:00
Adil Hafeez
f019f05738
release 0.4.15 (#853) 2026-03-30 17:33:40 -07:00
Adil Hafeez
af98c11a6d
restructure model_metrics_sources to type + provider (#855) 2026-03-30 17:12:20 -07:00
Adil Hafeez
e5751d6b13
model routing: cost/latency ranking with ranked fallback list (#849) 2026-03-30 13:46:52 -07:00
Musa
3a531ce22a
expand configuration reference with missing fields (#851) 2026-03-30 12:25:05 -07:00
Adil Hafeez
406fa92802
release 0.4.14 (#840) 2026-03-20 00:51:37 -07:00
Salman Paracha
69df124c47
the orchestrator had a bug where it was setting the wrong headers for archfc.katanemo.dev (#839)
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-389.local>
2026-03-20 00:40:47 -07:00
Adil Hafeez
180a9cb748
separate config generation from process startup in supervisord (#838) 2026-03-19 22:37:56 -07:00
Adil Hafeez
cdad02c5ee
release 0.4.13 (#837) 2026-03-19 19:51:58 -07:00
Adil Hafeez
1ad3e0f64e
refactor brightstaff (#736) 2026-03-19 17:58:33 -07:00
Adil Hafeez
1f23c573bf
add output filter chain (#822) 2026-03-18 17:58:20 -07:00
Adil Hafeez
de2d8847f3
fix CVE-2026-0861: upgrade glibc via apt-get upgrade in Dockerfile (#832) 2026-03-16 13:27:37 -07:00
Adil Hafeez
5388c6777f
add k8s deployment manifests and docs for self-hosted Arch-Router (#831) 2026-03-16 12:05:30 -07:00
Adil Hafeez
f1b8c03e2f
release 0.4.12 (#830) 2026-03-15 13:03:32 -07:00
Salman Paracha
4bb5c6404f
adding new supported models to plano (#829)
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-389.local>
2026-03-15 12:37:20 -07:00
Adil Hafeez
5d85829bc2
Improve config validation error messages (#825)
* improve config validation error messages and update getting started demo

* fix black formatting
2026-03-15 09:36:58 -07:00
Adil Hafeez
bc059aed4d
Unified overrides for custom router and orchestrator models (#820)
* support configurable orchestrator model via orchestration config section

* add self-hosting docs and demo for Plano-Orchestrator

* list all Plano-Orchestrator model variants in docs

* use overrides for custom routing and orchestration model

* update docs

* update orchestrator model name

* rename arch provider to plano, use llm_routing_model and agent_orchestration_model

* regenerate rendered config reference
2026-03-15 09:36:11 -07:00
Adil Hafeez
785bf7e021
add build-cli and build-brightstaff skills (#824) 2026-03-13 00:28:35 -07:00
Adil Hafeez
2f52774c0e
Add Claude Code skills and streamline CLAUDE.md (#823)
* add claude code skills and streamline CLAUDE.md

* remove claude code attribution from PR skill

* update pr skill
2026-03-13 00:18:41 -07:00
Adil Hafeez
5400b0a2fa
add instructions on hosting arch-router locally (#819) 2026-03-11 15:28:50 -07:00
Adil Hafeez
b4313d93a4
Run demos without Docker (#809) 2026-03-11 12:49:36 -07:00
Musa
6610097659
Support for Codex via Plano (#808)
* Add Codex CLI support; xAI response improvements

* Add native Plano running check and update CLI agent error handling

* adding PR suggestions for transformations and code quality

* message extraction logic in ResponsesAPIRequest

* xAI support for Responses API by routing to native endpoint + refactor code
2026-03-10 20:54:14 -07:00
Adil Hafeez
5189f7907a
add k8s deploy guide (#816) 2026-03-10 12:27:31 -07:00
Adil Hafeez
97b7a390ef
support inline routing_policy in request body (#811) (#815) 2026-03-10 12:23:18 -07:00
Adil Hafeez
028a2cd196
add routing service (#814)
fixes https://github.com/katanemo/plano/issues/810
2026-03-09 16:32:16 -07:00
Adil Hafeez
b9f01c8471
support native mode in planoai logs command (#807) 2026-03-05 18:34:06 -08:00
329 changed files with 40769 additions and 17558 deletions

View file

@ -0,0 +1,12 @@
---
name: build-brightstaff
description: Build the brightstaff native binary. Use when brightstaff code changes.
---
Build brightstaff:
```
cd crates && cargo build --release -p brightstaff
```
If the build fails, diagnose and fix the errors.

View file

@ -0,0 +1,10 @@
---
name: build-cli
description: Build and install the Python CLI (planoai). Use after making changes to cli/ code to install locally.
---
1. `cd cli && uv sync` — ensure dependencies are installed
2. `cd cli && uv tool install --editable .` — install the CLI locally
3. Verify the installation: `cd cli && uv run planoai --help`
If the build or install fails, diagnose and fix the issues.

View file

@ -0,0 +1,12 @@
---
name: build-wasm
description: Build the WASM plugins for Envoy. Use when WASM plugin code changes.
---
Build the WASM plugins:
```
cd crates && cargo build --release --target=wasm32-wasip1 -p llm_gateway -p prompt_gateway
```
If the build fails, diagnose and fix the errors.

View file

@ -0,0 +1,12 @@
---
name: check
description: Run Rust fmt, clippy, and unit tests. Use after making Rust code changes.
---
Run all local checks in order:
1. `cd crates && cargo fmt --all -- --check` — if formatting fails, run `cargo fmt --all` to fix it
2. `cd crates && cargo clippy --locked --all-targets --all-features -- -D warnings` — fix any warnings
3. `cd crates && cargo test --lib` — ensure all unit tests pass
Report a summary of what passed/failed.

View file

@ -0,0 +1,17 @@
---
name: new-provider
description: Add a new LLM provider to hermesllm. Use when integrating a new AI provider.
disable-model-invocation: true
user-invocable: true
---
Add a new LLM provider to hermesllm. The user will provide the provider name as $ARGUMENTS.
1. Add a new variant to `ProviderId` enum in `crates/hermesllm/src/providers/id.rs`
2. Implement string parsing in the `TryFrom<&str>` impl for the new provider
3. If the provider uses a non-OpenAI API format, create request/response types in `crates/hermesllm/src/apis/`
4. Add variant to `ProviderRequestType` and `ProviderResponseType` enums and update all match arms
5. Add model list to `crates/hermesllm/src/providers/provider_models.yaml`
6. Update `SupportedUpstreamAPIs` mapping if needed
After making changes, run `cd crates && cargo test --lib` to verify everything compiles and tests pass.

View file

@ -0,0 +1,16 @@
---
name: pr
description: Create a feature branch and open a pull request for the current changes.
disable-model-invocation: true
user-invocable: true
---
Create a pull request for the current changes:
1. Determine the GitHub username via `gh api user --jq .login`. If the login is `adilhafeez`, use `adil` instead.
2. Create a feature branch using format `<username>/<feature_name>` — infer the feature name from the changes
3. Run `cd crates && cargo fmt --all -- --check` and `cd crates && cargo clippy --locked --all-targets --all-features -- -D warnings` to verify Rust code is clean
4. Commit all changes with a short, concise commit message (one line, no Co-Authored-By)
5. Push the branch and create a PR targeting `main`
Keep the PR title short (under 70 chars). Include a brief summary in the body. Never include a "Test plan" section or any "Generated with Claude Code" attribution.

View file

@ -0,0 +1,30 @@
---
name: release
description: Bump the Plano version across all required files. Use when preparing a release.
disable-model-invocation: true
user-invocable: true
---
Prepare a release version bump. The user may provide the new version number as $ARGUMENTS (e.g., `/release 0.4.12`), or a bump type (`major`, `minor`, `patch`).
If no argument is provided, read the current version from `cli/planoai/__init__.py`, auto-increment the patch version (e.g., `0.4.11``0.4.12`), and confirm with the user before proceeding.
Update the version string in ALL of these files:
- `.github/workflows/ci.yml`
- `cli/planoai/__init__.py`
- `cli/planoai/consts.py`
- `cli/pyproject.toml`
- `build_filter_image.sh`
- `config/validate_plano_config.sh`
- `docs/source/conf.py`
- `docs/source/get_started/quickstart.rst`
- `docs/source/resources/deployment.rst`
- `apps/www/src/components/Hero.tsx`
- `demos/llm_routing/preference_based_routing/README.md`
Do NOT change version strings in `*.lock` files or `Cargo.lock`.
After updating all version strings, run `cd cli && uv lock` to update the lock file with the new version.
After making changes, show a summary of all files modified and the old → new version.

View file

@ -0,0 +1,9 @@
---
name: test-python
description: Run Python CLI tests. Use after making changes to cli/ code.
---
1. `cd cli && uv sync` — ensure dependencies are installed
2. `cd cli && uv run pytest -v` — run all tests
If tests fail, diagnose and fix the issues.

View file

@ -133,13 +133,13 @@ jobs:
load: true
tags: |
${{ env.PLANO_DOCKER_IMAGE }}
${{ env.DOCKER_IMAGE }}:0.4.11
${{ env.DOCKER_IMAGE }}:0.4.21
${{ env.DOCKER_IMAGE }}:latest
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Save image as artifact
run: docker save ${{ env.PLANO_DOCKER_IMAGE }} ${{ env.DOCKER_IMAGE }}:0.4.11 ${{ env.DOCKER_IMAGE }}:latest -o /tmp/plano-image.tar
run: docker save ${{ env.PLANO_DOCKER_IMAGE }} ${{ env.DOCKER_IMAGE }}:0.4.21 ${{ env.DOCKER_IMAGE }}:latest -o /tmp/plano-image.tar
- name: Upload image artifact
uses: actions/upload-artifact@v6
@ -477,7 +477,7 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: |
source venv/bin/activate
cd demos/shared/test_runner && sh run_demo_tests.sh llm_routing/preference_based_routing
cd demos/shared/test_runner && bash run_demo_tests.sh llm_routing/preference_based_routing
# ──────────────────────────────────────────────
# E2E: demo — currency conversion
@ -527,4 +527,4 @@ jobs:
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
run: |
source venv/bin/activate
cd demos/shared/test_runner && sh run_demo_tests.sh advanced/currency_exchange
cd demos/shared/test_runner && bash run_demo_tests.sh advanced/currency_exchange

View file

@ -3,6 +3,8 @@ name: Publish docker image (release)
env:
DOCKER_IMAGE: katanemo/plano
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/plano
DOCR_IMAGE: registry.digitalocean.com/genai-prod/plano
DOCR_PREVIEW_IMAGE: registry.digitalocean.com/genai-preview/plano
on:
release:
@ -33,6 +35,14 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Install doctl
uses: digitalocean/action-doctl@v2
with:
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN }}
- name: Log in to DOCR (prod)
run: doctl registry login
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
@ -51,6 +61,21 @@ jobs:
tags: |
${{ steps.meta.outputs.tags }}-arm64
${{ env.GHCR_IMAGE }}:${{ github.event.release.tag_name }}-arm64
${{ env.DOCR_IMAGE }}:${{ github.event.release.tag_name }}-arm64
- name: Switch to DOCR preview
uses: digitalocean/action-doctl@v2
with:
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN_PREVIEW }}
- name: Log in to DOCR (preview)
run: doctl registry login
- name: Push to DOCR Preview
run: |
docker buildx imagetools create \
-t ${{ env.DOCR_PREVIEW_IMAGE }}:${{ github.event.release.tag_name }}-arm64 \
${{ steps.meta.outputs.tags }}-arm64
# Build AMD64 image on GitHub's AMD64 runner — push to both registries
build-amd64:
@ -72,6 +97,14 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Install doctl
uses: digitalocean/action-doctl@v2
with:
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN }}
- name: Log in to DOCR (prod)
run: doctl registry login
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
@ -90,6 +123,21 @@ jobs:
tags: |
${{ steps.meta.outputs.tags }}-amd64
${{ env.GHCR_IMAGE }}:${{ github.event.release.tag_name }}-amd64
${{ env.DOCR_IMAGE }}:${{ github.event.release.tag_name }}-amd64
- name: Switch to DOCR preview
uses: digitalocean/action-doctl@v2
with:
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN_PREVIEW }}
- name: Log in to DOCR (preview)
run: doctl registry login
- name: Push to DOCR Preview
run: |
docker buildx imagetools create \
-t ${{ env.DOCR_PREVIEW_IMAGE }}:${{ github.event.release.tag_name }}-amd64 \
${{ steps.meta.outputs.tags }}-amd64
# Combine ARM64 and AMD64 images into multi-arch manifests for both registries
create-manifest:
@ -109,6 +157,14 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Install doctl
uses: digitalocean/action-doctl@v2
with:
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN }}
- name: Log in to DOCR (prod)
run: doctl registry login
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
@ -131,3 +187,27 @@ jobs:
-t ${{ env.GHCR_IMAGE }}:${TAG} \
${{ env.GHCR_IMAGE }}:${TAG}-arm64 \
${{ env.GHCR_IMAGE }}:${TAG}-amd64
- name: Create DOCR Prod Multi-Arch Manifest
run: |
TAG=${{ github.event.release.tag_name }}
docker buildx imagetools create \
-t ${{ env.DOCR_IMAGE }}:${TAG} \
${{ env.DOCR_IMAGE }}:${TAG}-arm64 \
${{ env.DOCR_IMAGE }}:${TAG}-amd64
- name: Switch to DOCR preview
uses: digitalocean/action-doctl@v2
with:
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN_PREVIEW }}
- name: Log in to DOCR (preview)
run: doctl registry login
- name: Create DOCR Preview Multi-Arch Manifest
run: |
TAG=${{ github.event.release.tag_name }}
docker buildx imagetools create \
-t ${{ env.DOCR_PREVIEW_IMAGE }}:${TAG} \
${{ steps.meta.outputs.tags }}-arm64 \
${{ steps.meta.outputs.tags }}-amd64

1
.gitignore vendored
View file

@ -152,3 +152,4 @@ apps/*/dist/
.cursor/
.agents
docs/do/

View file

@ -4,6 +4,7 @@ repos:
hooks:
- id: check-yaml
exclude: config/envoy.template*
args: [--allow-multiple-documents]
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: local
@ -36,7 +37,7 @@ repos:
- id: gitleaks
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 26.3.1
hooks:
- id: black
language_version: python3

152
CLAUDE.md
View file

@ -1,152 +1,106 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Plano is an AI-native proxy server and data plane for agentic applications, built on Envoy proxy. It centralizes agent orchestration, LLM routing, observability, and safety guardrails as an out-of-process dataplane.
## Build & Test Commands
### Rust (crates/)
```bash
# Build WASM plugins (must target wasm32-wasip1)
# Rust — WASM plugins (must target wasm32-wasip1)
cd crates && cargo build --release --target=wasm32-wasip1 -p llm_gateway -p prompt_gateway
# Build brightstaff binary (native target)
# Rust — brightstaff binary (native target)
cd crates && cargo build --release -p brightstaff
# Run unit tests
# Rust — tests, format, lint
cd crates && cargo test --lib
# Format check
cd crates && cargo fmt --all -- --check
# Lint
cd crates && cargo clippy --locked --all-targets --all-features -- -D warnings
```
### Python CLI (cli/)
# Python CLI
cd cli && uv sync && uv run pytest -v
```bash
cd cli && uv sync # Install dependencies
cd cli && uv run pytest -v # Run tests
cd cli && uv run planoai --help # Run CLI
```
# JS/TS (Turbo monorepo)
npm run build && npm run lint && npm run typecheck
### JavaScript/TypeScript (apps/, packages/)
```bash
npm run build # Build all (via Turbo)
npm run lint # Lint all
npm run dev # Dev servers
npm run typecheck # Type check
```
### Pre-commit (runs fmt, clippy, cargo test, black, yaml checks)
```bash
# Pre-commit (fmt, clippy, cargo test, black, yaml)
pre-commit run --all-files
```
### Docker
```bash
# Docker
docker build -t katanemo/plano:latest .
```
### E2E Tests (tests/e2e/)
E2E tests require a built Docker image and API keys. They run via `tests/e2e/run_e2e_tests.sh` which executes four test suites: `test_prompt_gateway.py`, `test_model_alias_routing.py`, `test_openai_responses_api_client.py`, and `test_openai_responses_api_client_with_state.py`.
E2E tests require a Docker image and API keys: `tests/e2e/run_e2e_tests.sh`
## Architecture
### Core Data Flow
Requests flow through Envoy proxy with two WASM filter plugins, backed by a native Rust binary:
```
Client → Envoy (prompt_gateway.wasm → llm_gateway.wasm) → Agents/LLM Providers
brightstaff (native binary: state, routing, signals, tracing)
```
### Rust Crates (crates/)
### Crates (crates/)
All crates share a Cargo workspace. Two compile to `wasm32-wasip1` for Envoy, the rest are native:
- **prompt_gateway** (WASM) — Proxy-WASM filter for prompt/message processing, guardrails, and filter chains
- **prompt_gateway** (WASM) — Proxy-WASM filter for prompt processing, guardrails, filter chains
- **llm_gateway** (WASM) — Proxy-WASM filter for LLM request/response handling and routing
- **brightstaff** (native binary) — Core application server: handlers, router, signals, state management, tracing
- **common** (library) — Shared across all crates: configuration, LLM provider abstractions, HTTP utilities, routing logic, rate limiting, tokenizer, PII detection, tracing
- **hermesllm** (library) — Translates LLM API formats between providers (OpenAI, Anthropic, Gemini, Mistral, Grok, AWS Bedrock, Azure, together.ai). Key types: `ProviderId`, `ProviderRequest`, `ProviderResponse`, `ProviderStreamResponse`
- **brightstaff** (native) — Core server: handlers, router, signals, state, tracing
- **common** (lib) — Shared: config, HTTP, routing, rate limiting, tokenizer, PII, tracing
- **hermesllm** (lib) — LLM API translation between providers. Key types: `ProviderId`, `ProviderRequest`, `ProviderResponse`, `ProviderStreamResponse`
### Python CLI (cli/planoai/)
The `planoai` CLI manages the Plano lifecycle. Key commands:
- `planoai up <config.yaml>` — Validate config, check API keys, start Docker container
- `planoai down` — Stop container
- `planoai build` — Build Docker image from repo root
- `planoai logs` — Stream access/debug logs
- `planoai trace` — OTEL trace collection and analysis
- `planoai init` — Initialize new project
- `planoai cli_agent` — Start a CLI agent connected to Plano
- `planoai generate_prompt_targets` — Generate prompt_targets from python methods
Entry point: `main.py`. Built with `rich-click`. Commands: `up`, `down`, `build`, `logs`, `trace`, `init`, `cli_agent`, `generate_prompt_targets`.
Entry point: `cli/planoai/main.py`. Container lifecycle in `core.py`. Docker operations in `docker_cli.py`.
### Config (config/)
### Configuration System (config/)
- `plano_config_schema.yaml` — JSON Schema for validating user configs
- `envoy.template.yaml` — Jinja2 template → Envoy config
- `supervisord.conf` — Process supervisor for Envoy + brightstaff
- `plano_config_schema.yaml` — JSON Schema (draft-07) for validating user config files
- `envoy.template.yaml` — Jinja2 template rendered into Envoy proxy config
- `supervisord.conf` — Process supervisor for Envoy + brightstaff in the container
### JS Apps (apps/, packages/)
User configs define: `agents` (id + url), `model_providers` (model + access_key), `listeners` (type: agent/model/prompt, with router strategy), `filters` (filter chains), and `tracing` settings.
Turbo monorepo with Next.js 16 / React 19. Not part of the core proxy.
### JavaScript Apps (apps/, packages/)
## WASM Plugin Rules
Turbo monorepo with Next.js 16 / React 19 applications and shared packages (UI components, Tailwind config, TypeScript config). Not part of the core proxy — these are web applications.
Code in `prompt_gateway` and `llm_gateway` runs in Envoy's WASM sandbox:
- **No std networking/filesystem** — use proxy-wasm host calls only
- **No tokio/async** — synchronous, callback-driven. `Action::Pause` / `Action::Continue` for flow control
- **Lifecycle**: `RootContext``on_configure`, `create_http_context`; `HttpContext``on_http_request/response_headers/body`
- **HTTP callouts**: `dispatch_http_call()` → store context in `callouts: RefCell<HashMap<u32, CallContext>>` → match in `on_http_call_response()`
- **Config**: `Rc`-wrapped, loaded once in `on_configure()` via `serde_yaml::from_slice()`
- **Dependencies must be no_std compatible** (e.g., `governor` with `features = ["no_std"]`)
- **Crate type**: `cdylib` → produces `.wasm`
## Adding a New LLM Provider
1. Add variant to `ProviderId` in `crates/hermesllm/src/providers/id.rs` + `TryFrom<&str>`
2. Create request/response types in `crates/hermesllm/src/apis/` if non-OpenAI format
3. Add variant to `ProviderRequestType`/`ProviderResponseType` enums, update all match arms
4. Add models to `crates/hermesllm/src/providers/provider_models.yaml`
5. Update `SupportedUpstreamAPIs` mapping if needed
## Release Process
To prepare a release (e.g., bumping from `0.4.6` to `0.4.7`), update the version string in all of the following files:
Update version (e.g., `0.4.11``0.4.12`) in all of these files:
**CI Workflow:**
- `.github/workflows/ci.yml` — docker build/save tags
- `.github/workflows/ci.yml`, `build_filter_image.sh`, `config/validate_plano_config.sh`
- `cli/planoai/__init__.py`, `cli/planoai/consts.py`, `cli/pyproject.toml`
- `docs/source/conf.py`, `docs/source/get_started/quickstart.rst`, `docs/source/resources/deployment.rst`
- `apps/www/src/components/Hero.tsx`, `demos/llm_routing/preference_based_routing/README.md`
**CLI:**
- `cli/planoai/__init__.py``__version__`
- `cli/planoai/consts.py``PLANO_DOCKER_IMAGE` default
- `cli/pyproject.toml``version`
**Build & Config:**
- `build_filter_image.sh` — docker build tag
- `config/validate_plano_config.sh` — docker image tag
**Docs:**
- `docs/source/conf.py``release`
- `docs/source/get_started/quickstart.rst` — install commands and example output
- `docs/source/resources/deployment.rst` — docker image tag
**Website & Demos:**
- `apps/www/src/components/Hero.tsx` — version badge
- `demos/llm_routing/preference_based_routing/README.md` — example output
**Important:** Do NOT change `0.4.6` references in `*.lock` files or `Cargo.lock` — those refer to the `colorama` and `http-body` dependency versions, not Plano.
Commit message format: `release X.Y.Z`
Do NOT change version strings in `*.lock` files or `Cargo.lock`. Commit message: `release X.Y.Z`
## Workflow Preferences
- **Git commits:** Do NOT add `Co-Authored-By` lines. Keep commit messages short and concise (one line, no verbose descriptions). NEVER commit and push directly to `main`—always use a feature branch and PR.
- **Git branches:** Use the format `<github_username>/<feature_name>` when creating branches for PRs. Determine the username from `gh api user --jq .login`.
- **GitHub issues:** When a GitHub issue URL is pasted, fetch all requirements and context from the issue first. The end goal is always a PR with all tests passing.
- **Commits:** No `Co-Authored-By`. Short one-line messages. Never push directly to `main` — always feature branch + PR.
- **Branches:** Use `adil/<feature_name>` format.
- **Issues:** When a GitHub issue URL is pasted, fetch all context first. Goal is always a PR with passing tests.
## Key Conventions
- Rust edition 2021, formatted with `cargo fmt`, linted with `cargo clippy -D warnings`
- Python formatted with Black
- WASM plugins must target `wasm32-wasip1` — they run inside Envoy, not as native binaries
- The Docker image bundles Envoy + WASM plugins + brightstaff + Python CLI into a single container managed by supervisord
- API keys come from environment variables or `.env` files, never hardcoded
- Rust edition 2021, `cargo fmt`, `cargo clippy -D warnings`
- Python: Black. Rust errors: `thiserror` with `#[from]`
- API keys from env vars or `.env`, never hardcoded
- Provider dispatch: `ProviderRequestType`/`ProviderResponseType` enums implementing `ProviderRequest`/`ProviderResponse` traits

View file

@ -49,7 +49,8 @@ FROM python:3.14-slim AS arch
RUN set -eux; \
apt-get update; \
apt-get install -y --no-install-recommends gettext-base curl; \
apt-get upgrade -y; \
apt-get install -y --no-install-recommends gettext-base curl procps; \
apt-get clean; rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir supervisor

View file

@ -37,7 +37,7 @@ Plano pulls rote plumbing out of your framework so you can stay focused on what
**Jump to our [docs](https://docs.planoai.dev)** to learn how you can use Plano to improve the speed, safety and obervability of your agentic applications.
> [!IMPORTANT]
> Plano and the Arch family of LLMs (like Plano-Orchestrator-4B, Arch-Router, etc) are hosted free of charge in the US-central region to give you a great first-run developer experience of Plano. To scale and run in production, you can either run these LLMs locally or contact us on [Discord](https://discord.gg/pGZf2gcwEc) for API keys.
> Plano and the Plano family of LLMs (like Plano-Orchestrator) are hosted free of charge in the US-central region to give you a great first-run developer experience of Plano. To scale and run in production, you can either run these LLMs locally or contact us on [Discord](https://discord.gg/pGZf2gcwEc) for API keys.
---

View file

@ -12,6 +12,7 @@
"clean": "rm -rf .next"
},
"dependencies": {
"@heroicons/react": "^2.2.0",
"@katanemo/shared-styles": "*",
"@katanemo/ui": "*",
"next": "^16.1.6",

View file

@ -66,7 +66,9 @@ export default function RootLayout({
}>) {
return (
<html lang="en">
<body className={`${ibmPlexSans.variable} antialiased text-white`}>
<body
className={`${ibmPlexSans.variable} overflow-hidden antialiased text-white`}
>
{/* Google tag (gtag.js) */}
<Script
src="https://www.googletagmanager.com/gtag/js?id=G-RLD5BDNW5N"
@ -80,7 +82,9 @@ export default function RootLayout({
gtag('config', 'G-RLD5BDNW5N');
`}
</Script>
<div className="min-h-screen">{children}</div>
<div className="h-screen overflow-hidden">
{children}
</div>
</body>
</html>
);

View file

@ -1,11 +1,23 @@
import Image from "next/image";
import Link from "next/link";
import LogoSlider from "../components/LogoSlider";
import { ArrowRightIcon } from "@heroicons/react/16/solid";
export default function HomePage() {
return (
<main className="relative flex min-h-screen items-center justify-center overflow-hidden px-6 pt-12 pb-16 font-sans sm:pt-20 lg:items-start lg:justify-start lg:pt-24">
<main className="relative flex h-full items-center justify-center overflow-hidden px-6 pt-12 pb-16 font-sans sm:pt-20 lg:items-start lg:justify-start lg:pt-24">
<div className="relative mx-auto w-full max-w-6xl flex flex-col items-center justify-center text-left lg:items-start lg:justify-start">
<Link
href="https://digitalocean.com/blog/digitalocean-acquires-katanemo-labs-inc"
target="_blank"
rel="noopener noreferrer"
className="mb-7 inline-flex max-w-[20rem] items-center gap-1 self-start rounded-full border border-[#22A875]/30 bg-[#22A875]/30 px-2.5 py-1 text-left text-[12px] leading-tight font-medium text-white transition-opacity hover:opacity-90 lg:hidden"
>
<span>
DigitalOcean acquires Katanemo Labs, Inc.
</span>
<ArrowRightIcon aria-hidden className="h-3 w-3 shrink-0 text-white/90" />
</Link>
<div className="pointer-events-none mb-6 w-full self-start lg:hidden">
<Image
src="/KatanemoLogo.svg"
@ -17,6 +29,20 @@ export default function HomePage() {
/>
</div>
<div className="relative z-10 max-w-xl sm:max-w-2xl lg:max-w-2xl xl:max-w-8xl lg:pr-[26vw] xl:pr-[2vw] sm:right-0 md:right-0 lg:right-0 xl:right-20 2xl:right-50 sm:mt-36 mt-0">
<Link
href="https://digitalocean.com/blog/digitalocean-acquires-katanemo-labs-inc"
target="_blank"
rel="noopener noreferrer"
className="mb-4 hidden max-w-full items-center gap-2 rounded-full border border-[#22A875]/70 bg-[#22A875]/50 px-4 py-1 text-left text-sm font-medium text-white transition-opacity hover:opacity-90 lg:inline-flex"
>
<span>
DigitalOcean acquires Katanemo Labs, Inc.
</span>
<ArrowRightIcon
aria-hidden
className="h-4 w-4 shrink-0 text-white/90"
/>
</Link>
<h1 className="text-3xl sm:text-4xl md:text-5xl lg:text-6xl font-sans font-medium leading-tight tracking-tight text-white">
Forward-deployed AI infrastructure engineers.
</h1>

View file

@ -16,6 +16,7 @@
"migrate:blogs": "tsx scripts/migrate-blogs.ts"
},
"dependencies": {
"@heroicons/react": "^2.2.0",
"@katanemo/shared-styles": "*",
"@katanemo/ui": "*",
"@portabletext/react": "^5.0.0",
@ -32,8 +33,8 @@
"next": "^16.1.6",
"next-sanity": "^11.6.9",
"papaparse": "^5.5.3",
"react": "19.2.0",
"react-dom": "19.2.0",
"react": "19.2.3",
"react-dom": "19.2.3",
"react-markdown": "^10.1.0",
"react-syntax-highlighter": "^16.1.0",
"remark-gfm": "^4.0.1",

View file

@ -39,6 +39,10 @@ function loadFont(fileName: string, baseUrl: string) {
}
async function getBlogPost(slug: string) {
if (!client) {
return null;
}
const query = `*[_type == "blog" && slug.current == $slug && published == true][0] {
_id,
title,
@ -53,8 +57,13 @@ async function getBlogPost(slug: string) {
}
}`;
const post = await client.fetch(query, { slug });
return post;
try {
const post = await client.fetch(query, { slug });
return post;
} catch (error) {
console.error("Error fetching blog post for OG image:", error);
return null;
}
}
function formatDate(dateString: string): string {

View file

@ -17,6 +17,10 @@ interface BlogPost {
}
async function getBlogPost(slug: string): Promise<BlogPost | null> {
if (!client) {
return null;
}
const query = `*[_type == "blog" && slug.current == $slug && published == true][0] {
_id,
title,
@ -26,8 +30,13 @@ async function getBlogPost(slug: string): Promise<BlogPost | null> {
author
}`;
const post = await client.fetch(query, { slug });
return post || null;
try {
const post = await client.fetch(query, { slug });
return post || null;
} catch (error) {
console.error("Error fetching blog post metadata:", error);
return null;
}
}
export async function generateMetadata({

View file

@ -25,6 +25,10 @@ interface BlogPost {
}
async function getBlogPost(slug: string): Promise<BlogPost | null> {
if (!client) {
return null;
}
const query = `*[_type == "blog" && slug.current == $slug && published == true][0] {
_id,
title,
@ -51,17 +55,31 @@ async function getBlogPost(slug: string): Promise<BlogPost | null> {
author
}`;
const post = await client.fetch(query, { slug });
return post || null;
try {
const post = await client.fetch(query, { slug });
return post || null;
} catch (error) {
console.error("Error fetching blog post:", error);
return null;
}
}
async function getAllBlogSlugs(): Promise<string[]> {
if (!client) {
return [];
}
const query = `*[_type == "blog" && published == true] {
"slug": slug.current
}`;
const posts = await client.fetch(query);
return posts.map((post: { slug: string }) => post.slug);
try {
const posts = await client.fetch(query);
return posts.map((post: { slug: string }) => post.slug);
} catch (error) {
console.error("Error fetching blog slugs:", error);
return [];
}
}
export async function generateStaticParams() {

View file

@ -8,6 +8,7 @@ import { BlogSectionHeader } from "@/components/BlogSectionHeader";
import { pageMetadata } from "@/lib/metadata";
export const metadata: Metadata = pageMetadata.blog;
export const dynamic = "force-dynamic";
interface BlogPost {
_id: string;
@ -44,6 +45,10 @@ function formatDate(dateString: string): string {
}
async function getBlogPosts(): Promise<BlogPost[]> {
if (!client) {
return [];
}
const query = `*[_type == "blog" && published == true] | order(publishedAt desc) {
_id,
title,
@ -58,12 +63,48 @@ async function getBlogPosts(): Promise<BlogPost[]> {
featured
}`;
return await client.fetch(query);
try {
return await client.fetch(query);
} catch (error) {
console.error("Error fetching blog posts:", error);
return [];
}
}
async function getFeaturedBlogPost(): Promise<BlogPost | null> {
if (!client) {
return null;
}
const query = `*[_type == "blog" && published == true && featured == true] | order(_updatedAt desc, publishedAt desc)[0] {
_id,
title,
slug,
summary,
publishedAt,
mainImage,
mainImageUrl,
thumbnailImage,
thumbnailImageUrl,
author,
featured
}`;
try {
const post = await client.fetch(query);
return post || null;
} catch (error) {
console.error("Error fetching featured blog post:", error);
return null;
}
}
export default async function BlogPage() {
const posts = await getBlogPosts();
const featuredPost = posts.find((post) => post.featured) || posts[0];
const [posts, featuredCandidate] = await Promise.all([
getBlogPosts(),
getFeaturedBlogPost(),
]);
const featuredPost = featuredCandidate || posts[0];
const recentPosts = posts
.filter((post) => post._id !== featuredPost?._id)
.slice(0, 3);

View file

@ -1,4 +1,6 @@
import type { Metadata } from "next";
import { ArrowRightIcon } from "@heroicons/react/16/solid";
import Link from "next/link";
import Script from "next/script";
import "@katanemo/shared-styles/globals.css";
import { Analytics } from "@vercel/analytics/next";
@ -35,6 +37,27 @@ export default function RootLayout({
gtag('config', 'G-ML7B1X9HY2');
`}
</Script>
<Link
href="https://digitalocean.com/blog/digitalocean-acquires-katanemo-labs-inc"
target="_blank"
rel="noopener noreferrer"
className="block w-full bg-[#7780D9] py-3 text-white transition-opacity"
>
<div className="mx-auto flex max-w-[85rem] items-center justify-center gap-4 px-6 text-center md:justify-between md:text-left lg:px-8">
<span className="w-full text-xs font-medium leading-snug md:w-auto md:text-base flex items-center">
DigitalOcean acquires Katanemo Labs, Inc. to accelerate AI
development
<ArrowRightIcon
aria-hidden
className="ml-1 inline-block h-3 w-3 align-[-1px] text-white/90 md:hidden"
/>
</span>
<span className="hidden shrink-0 items-center gap-1 text-base font-medium tracking-[-0.989px] font-mono leading-snug opacity-70 transition-opacity hover:opacity-100 md:inline-flex">
Read the announcement
<ArrowRightIcon aria-hidden className="h-3.5 w-3.5 text-white/70" />
</span>
</div>
</Link>
<ConditionalLayout>{children}</ConditionalLayout>
<Analytics />
</body>

View file

@ -10,6 +10,10 @@ interface BlogPost {
}
async function getBlogPosts(): Promise<BlogPost[]> {
if (!client) {
return [];
}
const query = `*[_type == "blog" && published == true] | order(publishedAt desc) {
slug,
publishedAt,

View file

@ -24,7 +24,7 @@ export function Hero() {
>
<div className="inline-flex flex-wrap items-center gap-1.5 sm:gap-2 px-3 sm:px-4 py-1 rounded-full bg-[rgba(185,191,255,0.4)] border border-[var(--secondary)] shadow backdrop-blur hover:bg-[rgba(185,191,255,0.6)] transition-colors cursor-pointer">
<span className="text-xs sm:text-sm font-medium text-black/65">
v0.4.11
v0.4.21
</span>
<span className="text-xs sm:text-sm font-medium text-black ">

View file

@ -2,19 +2,33 @@ import { createClient } from "@sanity/client";
import imageUrlBuilder from "@sanity/image-url";
import type { SanityImageSource } from "@sanity/image-url/lib/types/types";
const projectId = process.env.NEXT_PUBLIC_SANITY_PROJECT_ID;
const dataset = process.env.NEXT_PUBLIC_SANITY_DATASET;
const apiVersion = process.env.NEXT_PUBLIC_SANITY_API_VERSION;
const projectId =
process.env.NEXT_PUBLIC_SANITY_PROJECT_ID ||
"71ny25bn";
const dataset =
process.env.NEXT_PUBLIC_SANITY_DATASET ||
"production";
const apiVersion =
process.env.NEXT_PUBLIC_SANITY_API_VERSION ||
"2025-01-01";
export const client = createClient({
projectId,
dataset,
apiVersion,
useCdn: true, // Set to false if statically generating pages, using ISR or using the on-demand revalidation API
});
export const hasSanityConfig = Boolean(projectId && dataset && apiVersion);
const builder = imageUrlBuilder(client);
export const client = hasSanityConfig
? createClient({
projectId,
dataset,
apiVersion,
// Keep blog/admin updates visible immediately after publishing.
useCdn: false,
})
: null;
const builder = client ? imageUrlBuilder(client) : null;
export function urlFor(source: SanityImageSource) {
if (!builder) {
throw new Error("Sanity client is not configured.");
}
return builder.image(source);
}

View file

@ -1 +1 @@
docker build -f Dockerfile . -t katanemo/plano -t katanemo/plano:0.4.11
docker build -f Dockerfile . -t katanemo/plano -t katanemo/plano:0.4.21

2880
bun.lock Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,3 +1,3 @@
"""Plano CLI - Intelligent Prompt Gateway."""
__version__ = "0.4.11"
__version__ = "0.4.21"

290
cli/planoai/chatgpt_auth.py Normal file
View file

@ -0,0 +1,290 @@
"""
ChatGPT subscription OAuth device-flow authentication.
Implements the device code flow used by OpenAI Codex CLI to authenticate
with a ChatGPT Plus/Pro subscription. Tokens are stored locally in
~/.plano/chatgpt/auth.json and auto-refreshed when expired.
"""
import base64
import json
import os
import time
from typing import Any, Dict, Optional, Tuple
import requests
from planoai.consts import PLANO_HOME
# OAuth + API constants (derived from openai/codex)
CHATGPT_AUTH_BASE = "https://auth.openai.com"
CHATGPT_DEVICE_CODE_URL = f"{CHATGPT_AUTH_BASE}/api/accounts/deviceauth/usercode"
CHATGPT_DEVICE_TOKEN_URL = f"{CHATGPT_AUTH_BASE}/api/accounts/deviceauth/token"
CHATGPT_OAUTH_TOKEN_URL = f"{CHATGPT_AUTH_BASE}/oauth/token"
CHATGPT_DEVICE_VERIFY_URL = f"{CHATGPT_AUTH_BASE}/codex/device"
CHATGPT_API_BASE = "https://chatgpt.com/backend-api/codex"
CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
# Local storage
CHATGPT_AUTH_DIR = os.path.join(PLANO_HOME, "chatgpt")
CHATGPT_AUTH_FILE = os.path.join(CHATGPT_AUTH_DIR, "auth.json")
# Timeouts
TOKEN_EXPIRY_SKEW_SECONDS = 60
DEVICE_CODE_TIMEOUT_SECONDS = 15 * 60
DEVICE_CODE_POLL_SECONDS = 5
def _ensure_auth_dir():
os.makedirs(CHATGPT_AUTH_DIR, exist_ok=True)
def load_auth() -> Optional[Dict[str, Any]]:
"""Load auth data from disk."""
try:
with open(CHATGPT_AUTH_FILE, "r") as f:
return json.load(f)
except (IOError, json.JSONDecodeError):
return None
def save_auth(data: Dict[str, Any]):
"""Save auth data to disk."""
_ensure_auth_dir()
fd = os.open(CHATGPT_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w") as f:
json.dump(data, f, indent=2)
def delete_auth():
"""Remove stored credentials."""
try:
os.remove(CHATGPT_AUTH_FILE)
except FileNotFoundError:
pass
def _decode_jwt_claims(token: str) -> Dict[str, Any]:
"""Decode JWT payload without verification."""
try:
parts = token.split(".")
if len(parts) < 2:
return {}
payload_b64 = parts[1]
payload_b64 += "=" * (-len(payload_b64) % 4)
return json.loads(base64.urlsafe_b64decode(payload_b64).decode("utf-8"))
except Exception:
return {}
def _get_expires_at(token: str) -> Optional[int]:
"""Extract expiration time from JWT."""
claims = _decode_jwt_claims(token)
exp = claims.get("exp")
return int(exp) if isinstance(exp, (int, float)) else None
def _extract_account_id(token: Optional[str]) -> Optional[str]:
"""Extract ChatGPT account ID from JWT claims."""
if not token:
return None
claims = _decode_jwt_claims(token)
auth_claims = claims.get("https://api.openai.com/auth")
if isinstance(auth_claims, dict):
account_id = auth_claims.get("chatgpt_account_id")
if isinstance(account_id, str) and account_id:
return account_id
return None
def _is_token_expired(auth_data: Dict[str, Any]) -> bool:
"""Check if the access token is expired."""
expires_at = auth_data.get("expires_at")
if expires_at is None:
access_token = auth_data.get("access_token")
if access_token:
expires_at = _get_expires_at(access_token)
if expires_at:
auth_data["expires_at"] = expires_at
save_auth(auth_data)
if expires_at is None:
return True
return time.time() >= float(expires_at) - TOKEN_EXPIRY_SKEW_SECONDS
def _refresh_tokens(refresh_token: str) -> Dict[str, str]:
"""Refresh the access token using the refresh token."""
resp = requests.post(
CHATGPT_OAUTH_TOKEN_URL,
json={
"client_id": CHATGPT_CLIENT_ID,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"scope": "openid profile email",
},
)
resp.raise_for_status()
data = resp.json()
access_token = data.get("access_token")
id_token = data.get("id_token")
if not access_token or not id_token:
raise RuntimeError(f"Refresh response missing fields: {data}")
return {
"access_token": access_token,
"refresh_token": data.get("refresh_token", refresh_token),
"id_token": id_token,
}
def _build_auth_record(tokens: Dict[str, str]) -> Dict[str, Any]:
"""Build the auth record to persist."""
access_token = tokens.get("access_token")
id_token = tokens.get("id_token")
expires_at = _get_expires_at(access_token) if access_token else None
account_id = _extract_account_id(id_token or access_token)
return {
"access_token": access_token,
"refresh_token": tokens.get("refresh_token"),
"id_token": id_token,
"expires_at": expires_at,
"account_id": account_id,
}
def request_device_code() -> Dict[str, str]:
"""Request a device code from OpenAI's device auth endpoint."""
resp = requests.post(
CHATGPT_DEVICE_CODE_URL,
json={"client_id": CHATGPT_CLIENT_ID},
)
resp.raise_for_status()
data = resp.json()
device_auth_id = data.get("device_auth_id")
user_code = data.get("user_code") or data.get("usercode")
interval = data.get("interval")
if not device_auth_id or not user_code:
raise RuntimeError(f"Device code response missing fields: {data}")
return {
"device_auth_id": device_auth_id,
"user_code": user_code,
"interval": str(interval or "5"),
}
def poll_for_authorization(device_code: Dict[str, str]) -> Dict[str, str]:
"""Poll until the user completes authorization. Returns code_data."""
interval = int(device_code.get("interval", "5"))
start_time = time.time()
while time.time() - start_time < DEVICE_CODE_TIMEOUT_SECONDS:
try:
resp = requests.post(
CHATGPT_DEVICE_TOKEN_URL,
json={
"device_auth_id": device_code["device_auth_id"],
"user_code": device_code["user_code"],
},
)
if resp.status_code == 200:
data = resp.json()
if all(
key in data
for key in ("authorization_code", "code_challenge", "code_verifier")
):
return data
if resp.status_code in (403, 404):
time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS))
continue
resp.raise_for_status()
except requests.HTTPError as exc:
if exc.response is not None and exc.response.status_code in (403, 404):
time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS))
continue
raise RuntimeError(f"Polling failed: {exc}") from exc
time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS))
raise RuntimeError("Timed out waiting for device authorization")
def exchange_code_for_tokens(code_data: Dict[str, str]) -> Dict[str, str]:
"""Exchange the authorization code for access/refresh/id tokens."""
redirect_uri = f"{CHATGPT_AUTH_BASE}/deviceauth/callback"
body = (
"grant_type=authorization_code"
f"&code={code_data['authorization_code']}"
f"&redirect_uri={redirect_uri}"
f"&client_id={CHATGPT_CLIENT_ID}"
f"&code_verifier={code_data['code_verifier']}"
)
resp = requests.post(
CHATGPT_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data=body,
)
resp.raise_for_status()
data = resp.json()
if not all(key in data for key in ("access_token", "refresh_token", "id_token")):
raise RuntimeError(f"Token exchange response missing fields: {data}")
return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
"id_token": data["id_token"],
}
def login() -> Dict[str, Any]:
"""Run the full device code login flow. Returns the auth record."""
device_code = request_device_code()
auth_record = _build_auth_record({})
auth_record["device_code_requested_at"] = time.time()
save_auth(auth_record)
print(
"\nSign in with your ChatGPT account:\n"
f" 1) Visit: {CHATGPT_DEVICE_VERIFY_URL}\n"
f" 2) Enter code: {device_code['user_code']}\n\n"
"Device codes are a common phishing target. Never share this code.\n",
flush=True,
)
code_data = poll_for_authorization(device_code)
tokens = exchange_code_for_tokens(code_data)
auth_record = _build_auth_record(tokens)
save_auth(auth_record)
return auth_record
def get_access_token() -> Tuple[str, Optional[str]]:
"""
Get a valid access token and account ID.
Refreshes automatically if expired. Raises if no auth data exists.
Returns (access_token, account_id).
"""
auth_data = load_auth()
if not auth_data:
raise RuntimeError(
"No ChatGPT credentials found. Run 'planoai chatgpt login' first."
)
access_token = auth_data.get("access_token")
if access_token and not _is_token_expired(auth_data):
return access_token, auth_data.get("account_id")
# Try refresh
refresh_token = auth_data.get("refresh_token")
if refresh_token:
tokens = _refresh_tokens(refresh_token)
auth_record = _build_auth_record(tokens)
save_auth(auth_record)
return auth_record["access_token"], auth_record.get("account_id")
raise RuntimeError(
"ChatGPT token expired and refresh failed. Run 'planoai chatgpt login' again."
)

View file

@ -0,0 +1,86 @@
"""
CLI commands for ChatGPT subscription management.
Usage:
planoai chatgpt login - Authenticate with ChatGPT via device code flow
planoai chatgpt status - Check authentication status
planoai chatgpt logout - Remove stored credentials
"""
import datetime
import click
from rich.console import Console
from planoai import chatgpt_auth
console = Console()
@click.group()
def chatgpt():
"""ChatGPT subscription management."""
pass
@chatgpt.command()
def login():
"""Authenticate with your ChatGPT subscription using device code flow."""
try:
auth_record = chatgpt_auth.login()
account_id = auth_record.get("account_id", "unknown")
console.print(
f"\n[green]Successfully authenticated with ChatGPT![/green]"
f"\nAccount ID: {account_id}"
f"\nCredentials saved to: {chatgpt_auth.CHATGPT_AUTH_FILE}"
)
except Exception as e:
console.print(f"\n[red]Authentication failed:[/red] {e}")
raise SystemExit(1)
@chatgpt.command()
def status():
"""Check ChatGPT authentication status."""
auth_data = chatgpt_auth.load_auth()
if not auth_data or not auth_data.get("access_token"):
console.print(
"[yellow]Not authenticated.[/yellow] Run 'planoai chatgpt login'."
)
return
account_id = auth_data.get("account_id", "unknown")
expires_at = auth_data.get("expires_at")
if expires_at:
expiry_time = datetime.datetime.fromtimestamp(
expires_at, tz=datetime.timezone.utc
)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if expiry_time > now:
remaining = expiry_time - now
console.print(
f"[green]Authenticated[/green]"
f"\n Account ID: {account_id}"
f"\n Token expires: {expiry_time.strftime('%Y-%m-%d %H:%M:%S UTC')}"
f" ({remaining.seconds // 60}m remaining)"
)
else:
console.print(
f"[yellow]Token expired[/yellow]"
f"\n Account ID: {account_id}"
f"\n Expired at: {expiry_time.strftime('%Y-%m-%d %H:%M:%S UTC')}"
f"\n Will auto-refresh on next use, or run 'planoai chatgpt login'."
)
else:
console.print(
f"[green]Authenticated[/green] (no expiry info)"
f"\n Account ID: {account_id}"
)
@chatgpt.command()
def logout():
"""Remove stored ChatGPT credentials."""
chatgpt_auth.delete_auth()
console.print("[green]ChatGPT credentials removed.[/green]")

View file

@ -1,20 +1,20 @@
import json
import os
import uuid
from planoai.utils import convert_legacy_listeners
from jinja2 import Environment, FileSystemLoader
import yaml
from jsonschema import validate
from jsonschema import validate, ValidationError
from urllib.parse import urlparse
from copy import deepcopy
from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
SUPPORTED_PROVIDERS_WITH_BASE_URL = [
"azure_openai",
"ollama",
"qwen",
"amazon_bedrock",
"arch",
"plano",
]
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
@ -22,14 +22,23 @@ SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
"groq",
"mistral",
"openai",
"xiaomi",
"gemini",
"anthropic",
"together_ai",
"xai",
"moonshotai",
"zhipu",
"chatgpt",
"digitalocean",
"vercel",
"openrouter",
]
CHATGPT_API_BASE = "https://chatgpt.com/backend-api/codex"
CHATGPT_DEFAULT_ORIGINATOR = "codex_cli_rs"
CHATGPT_DEFAULT_USER_AGENT = "codex_cli_rs/0.0.0 (Unknown 0; unknown) unknown"
SUPPORTED_PROVIDERS = (
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL
)
@ -49,6 +58,110 @@ def get_endpoint_and_port(endpoint, protocol):
return endpoint, port
def migrate_inline_routing_preferences(config_yaml):
"""Lift v0.3.0-style inline ``routing_preferences`` under each
``model_providers`` entry to the v0.4.0 top-level ``routing_preferences``
list with ``models: [...]``.
This function is a no-op for configs whose ``version`` is already
``v0.4.0`` or newer those are assumed to be on the canonical
top-level shape and are passed through untouched.
For older configs, the version is bumped to ``v0.4.0`` up front so
brightstaff's v0.4.0 gate for top-level ``routing_preferences``
accepts the rendered config, then inline preferences under each
provider are lifted into the top-level list. Preferences with the
same ``name`` across multiple providers are merged into a single
top-level entry whose ``models`` list contains every provider's
full ``<provider>/<model>`` string in declaration order. The first
``description`` encountered wins; conflicts are warned, not errored,
so existing v0.3.0 configs keep compiling. Any top-level preference
already defined by the user is preserved as-is.
"""
current_version = str(config_yaml.get("version", ""))
if _version_tuple(current_version) >= (0, 4, 0):
return
config_yaml["version"] = "v0.4.0"
model_providers = config_yaml.get("model_providers") or []
if not model_providers:
return
migrated = {}
for model_provider in model_providers:
inline_prefs = model_provider.get("routing_preferences")
if not inline_prefs:
continue
full_model_name = model_provider.get("model")
if not full_model_name:
continue
if "/" in full_model_name and full_model_name.split("/")[-1].strip() == "*":
raise Exception(
f"Model {full_model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
)
for pref in inline_prefs:
name = pref.get("name")
description = pref.get("description", "")
if not name:
continue
if name in migrated:
entry = migrated[name]
if description and description != entry["description"]:
print(
f"WARNING: routing preference '{name}' has conflicting descriptions across providers; keeping the first one."
)
if full_model_name not in entry["models"]:
entry["models"].append(full_model_name)
else:
migrated[name] = {
"name": name,
"description": description,
"models": [full_model_name],
}
if not migrated:
return
for model_provider in model_providers:
if "routing_preferences" in model_provider:
del model_provider["routing_preferences"]
existing_top_level = config_yaml.get("routing_preferences") or []
existing_names = {entry.get("name") for entry in existing_top_level}
merged = list(existing_top_level)
for name, entry in migrated.items():
if name in existing_names:
continue
merged.append(entry)
config_yaml["routing_preferences"] = merged
print(
"WARNING: inline routing_preferences under model_providers is deprecated "
"and has been auto-migrated to top-level routing_preferences. Update your "
"config to v0.4.0 top-level form. See docs/routing-api.md"
)
def _version_tuple(version_string):
stripped = version_string.strip().lstrip("vV")
if not stripped:
return (0, 0, 0)
parts = stripped.split("-", 1)[0].split(".")
out = []
for part in parts[:3]:
try:
out.append(int(part))
except ValueError:
out.append(0)
while len(out) < 3:
out.append(0)
return tuple(out)
def validate_and_render_schema():
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
@ -92,6 +205,8 @@ def validate_and_render_schema():
config_yaml["model_providers"] = config_yaml["llm_providers"]
del config_yaml["llm_providers"]
migrate_inline_routing_preferences(config_yaml)
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
config_yaml.get("listeners"), config_yaml.get("model_providers")
)
@ -191,7 +306,16 @@ def validate_and_render_schema():
model_provider_name_set = set()
llms_with_usage = []
model_name_keys = set()
model_usage_name_keys = set()
top_level_preferences = config_yaml.get("routing_preferences") or []
seen_pref_names = set()
for pref in top_level_preferences:
pref_name = pref.get("name")
if pref_name in seen_pref_names:
raise Exception(
f'Duplicate routing preference name "{pref_name}", please provide unique name for each routing preference'
)
seen_pref_names.add(pref_name)
print("listeners: ", listeners)
@ -250,10 +374,6 @@ def validate_and_render_schema():
raise Exception(
f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards."
)
if model_provider.get("routing_preferences"):
raise Exception(
f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
)
# Validate azure_openai and ollama provider requires base_url
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
@ -302,13 +422,6 @@ def validate_and_render_schema():
)
model_name_keys.add(model_id)
for routing_preference in model_provider.get("routing_preferences", []):
if routing_preference.get("name") in model_usage_name_keys:
raise Exception(
f'Duplicate routing preference name "{routing_preference.get("name")}", please provide unique name for each routing preference'
)
model_usage_name_keys.add(routing_preference.get("name"))
# Warn if both passthrough_auth and access_key are configured
if model_provider.get("passthrough_auth") and model_provider.get(
"access_key"
@ -331,6 +444,25 @@ def validate_and_render_schema():
provider = model_provider["provider"]
model_provider["provider_interface"] = provider
del model_provider["provider"]
# Auto-wire ChatGPT provider: inject base_url, passthrough_auth, and extra headers
if provider == "chatgpt":
if not model_provider.get("base_url"):
model_provider["base_url"] = CHATGPT_API_BASE
if not model_provider.get("access_key") and not model_provider.get(
"passthrough_auth"
):
model_provider["passthrough_auth"] = True
headers = model_provider.get("headers", {})
headers.setdefault(
"ChatGPT-Account-Id",
os.environ.get("CHATGPT_ACCOUNT_ID", ""),
)
headers.setdefault("originator", CHATGPT_DEFAULT_ORIGINATOR)
headers.setdefault("user-agent", CHATGPT_DEFAULT_USER_AGENT)
headers.setdefault("session_id", str(uuid.uuid4()))
model_provider["headers"] = headers
updated_model_providers.append(model_provider)
if model_provider.get("base_url", None):
@ -368,47 +500,51 @@ def validate_and_render_schema():
llms_with_endpoint.append(model_provider)
llms_with_endpoint_cluster_names.add(cluster_name)
if len(model_usage_name_keys) > 0:
routing_model_provider = config_yaml.get("routing", {}).get(
"model_provider", None
overrides_config = config_yaml.get("overrides", {})
# Build lookup of model names (already prefix-stripped by config processing)
model_name_set = {mp.get("model") for mp in updated_model_providers}
# Auto-add plano-orchestrator provider if routing preferences exist and no provider matches the routing model
router_model = overrides_config.get("llm_routing_model", "Plano-Orchestrator")
router_model_id = (
router_model.split("/", 1)[1] if "/" in router_model else router_model
)
if len(seen_pref_names) > 0 and router_model_id not in model_name_set:
updated_model_providers.append(
{
"name": "plano-orchestrator",
"provider_interface": "plano",
"model": router_model_id,
"internal": True,
}
)
if (
routing_model_provider
and routing_model_provider not in model_provider_name_set
):
raise Exception(
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
)
if (
routing_model_provider is None
and "arch-router" not in model_provider_name_set
):
updated_model_providers.append(
{
"name": "arch-router",
"provider_interface": "arch",
"model": config_yaml.get("routing", {}).get("model", "Arch-Router"),
"internal": True,
}
)
# Always add arch-function model provider if not already defined
if "arch-function" not in model_provider_name_set:
updated_model_providers.append(
{
"name": "arch-function",
"provider_interface": "arch",
"provider_interface": "plano",
"model": "Arch-Function",
"internal": True,
}
)
if "plano-orchestrator" not in model_provider_name_set:
# Auto-add plano-orchestrator provider if no provider matches the orchestrator model
orchestrator_model = overrides_config.get(
"agent_orchestration_model", "Plano-Orchestrator"
)
orchestrator_model_id = (
orchestrator_model.split("/", 1)[1]
if "/" in orchestrator_model
else orchestrator_model
)
if orchestrator_model_id not in model_name_set:
updated_model_providers.append(
{
"name": "plano-orchestrator",
"provider_interface": "arch",
"model": "Plano-Orchestrator",
"name": "plano/orchestrator",
"provider_interface": "plano",
"model": orchestrator_model_id,
"internal": True,
}
)
@ -426,6 +562,16 @@ def validate_and_render_schema():
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
)
# Validate input_filters IDs on listeners reference valid agent/filter IDs
for listener in listeners:
listener_input_filters = listener.get("input_filters", [])
for fc_id in listener_input_filters:
if fc_id not in agent_id_keys:
raise Exception(
f"Listener '{listener.get('name', 'unknown')}' references input_filters id '{fc_id}' "
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
)
# Validate model aliases if present
if "model_aliases" in config_yaml:
model_aliases = config_yaml["model_aliases"]
@ -503,11 +649,15 @@ def validate_prompt_config(plano_config_file, plano_config_schema_file):
try:
validate(config_yaml, config_schema_yaml)
except Exception as e:
print(
f"Error validating plano_config file: {plano_config_file}, schema file: {plano_config_schema_file}, error: {e}"
except ValidationError as e:
path = (
"".join(str(p) for p in e.absolute_path) if e.absolute_path else "root"
)
raise e
raise ValidationError(
f"{e.message}\n Location: {path}\n Value: {e.instance}"
) from None
except Exception as e:
raise
if __name__ == "__main__":

View file

@ -5,7 +5,7 @@ PLANO_COLOR = "#969FF4"
SERVICE_NAME_ARCHGW = "plano"
PLANO_DOCKER_NAME = "plano"
PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.11")
PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.21")
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT = "http://localhost:4317"
# Native mode constants

View file

@ -10,7 +10,6 @@ from planoai.consts import (
PLANO_DOCKER_IMAGE,
PLANO_DOCKER_NAME,
)
import subprocess
from planoai.docker_cli import (
docker_container_status,
docker_remove_container,
@ -20,7 +19,6 @@ from planoai.docker_cli import (
stream_gateway_logs,
)
log = getLogger(__name__)
@ -147,26 +145,48 @@ def stop_docker_container(service=PLANO_DOCKER_NAME):
log.info(f"Failed to shut down services: {str(e)}")
def start_cli_agent(plano_config_file=None, settings_json="{}"):
"""Start a CLI client connected to Plano."""
with open(plano_config_file, "r") as file:
plano_config = file.read()
plano_config_yaml = yaml.safe_load(plano_config)
# Get egress listener configuration
egress_config = plano_config_yaml.get("listeners", {}).get("egress_traffic", {})
host = egress_config.get("host", "127.0.0.1")
port = egress_config.get("port", 12000)
# Parse additional settings from command line
def _parse_cli_agent_settings(settings_json: str) -> dict:
try:
additional_settings = json.loads(settings_json) if settings_json else {}
return json.loads(settings_json) if settings_json else {}
except json.JSONDecodeError:
log.error("Settings must be valid JSON")
sys.exit(1)
# Set up environment variables
def _resolve_cli_agent_endpoint(plano_config_yaml: dict) -> tuple[str, int]:
listeners = plano_config_yaml.get("listeners")
if isinstance(listeners, dict):
egress_config = listeners.get("egress_traffic", {})
host = egress_config.get("host") or egress_config.get("address") or "0.0.0.0"
port = egress_config.get("port", 12000)
return host, port
if isinstance(listeners, list):
for listener in listeners:
if listener.get("type") == "model":
host = listener.get("host") or listener.get("address") or "0.0.0.0"
port = listener.get("port", 12000)
return host, port
return "0.0.0.0", 12000
def _apply_non_interactive_env(env: dict, additional_settings: dict) -> None:
if additional_settings.get("NON_INTERACTIVE_MODE", False):
env.update(
{
"CI": "true",
"FORCE_COLOR": "0",
"NODE_NO_READLINE": "1",
"TERM": "dumb",
}
)
def _start_claude_cli_agent(
host: str, port: int, plano_config_yaml: dict, additional_settings: dict
) -> None:
env = os.environ.copy()
env.update(
{
@ -186,7 +206,6 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
"ANTHROPIC_SMALL_FAST_MODEL"
]
else:
# Check if arch.claude.code.small.fast alias exists in model_aliases
model_aliases = plano_config_yaml.get("model_aliases", {})
if "arch.claude.code.small.fast" in model_aliases:
env["ANTHROPIC_SMALL_FAST_MODEL"] = "arch.claude.code.small.fast"
@ -196,23 +215,10 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
)
log.info("Or provide ANTHROPIC_SMALL_FAST_MODEL in --settings JSON")
# Non-interactive mode configuration from additional_settings only
if additional_settings.get("NON_INTERACTIVE_MODE", False):
env.update(
{
"CI": "true",
"FORCE_COLOR": "0",
"NODE_NO_READLINE": "1",
"TERM": "dumb",
}
)
_apply_non_interactive_env(env, additional_settings)
# Build claude command arguments
claude_args = []
# Add settings if provided, excluding those already handled as environment variables
if additional_settings:
# Filter out settings that are already processed as environment variables
claude_settings = {
k: v
for k, v in additional_settings.items()
@ -221,10 +227,8 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
if claude_settings:
claude_args.append(f"--settings={json.dumps(claude_settings)}")
# Use claude from PATH
claude_path = "claude"
log.info(f"Connecting Claude Code Agent to Plano at {host}:{port}")
try:
subprocess.run([claude_path] + claude_args, env=env, check=True)
except subprocess.CalledProcessError as e:
@ -235,3 +239,61 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
f"{claude_path} not found. Make sure Claude Code is installed: npm install -g @anthropic-ai/claude-code"
)
sys.exit(1)
def _start_codex_cli_agent(host: str, port: int, additional_settings: dict) -> None:
env = os.environ.copy()
env.update(
{
"OPENAI_API_KEY": "test", # Use test token for plano
"OPENAI_BASE_URL": f"http://{host}:{port}/v1",
"NO_PROXY": host,
"DISABLE_TELEMETRY": "true",
}
)
_apply_non_interactive_env(env, additional_settings)
codex_model = additional_settings.get("CODEX_MODEL", "gpt-5.3-codex")
codex_path = "codex"
codex_args = ["--model", codex_model]
log.info(
f"Connecting Codex CLI Agent to Plano at {host}:{port} (default model: {codex_model})"
)
try:
subprocess.run([codex_path] + codex_args, env=env, check=True)
except subprocess.CalledProcessError as e:
log.error(f"Error starting codex: {e}")
sys.exit(1)
except FileNotFoundError:
log.error(
f"{codex_path} not found. Make sure Codex CLI is installed: npm install -g @openai/codex"
)
sys.exit(1)
def start_cli_agent(
plano_config_file=None, cli_agent_type="claude", settings_json="{}"
):
"""Start a CLI client connected to Plano."""
with open(plano_config_file, "r") as file:
plano_config = file.read()
plano_config_yaml = yaml.safe_load(plano_config)
host, port = _resolve_cli_agent_endpoint(plano_config_yaml)
additional_settings = _parse_cli_agent_settings(settings_json)
if cli_agent_type == "claude":
_start_claude_cli_agent(host, port, plano_config_yaml, additional_settings)
return
if cli_agent_type == "codex":
_start_codex_cli_agent(host, port, additional_settings)
return
log.error(
f"Unsupported cli agent type '{cli_agent_type}'. Supported values: claude, codex"
)
sys.exit(1)

178
cli/planoai/defaults.py Normal file
View file

@ -0,0 +1,178 @@
"""Default config synthesizer for zero-config ``planoai up``.
When the user runs ``planoai up`` in a directory with no ``config.yaml`` /
``plano_config.yaml``, we synthesize a pass-through config that covers the
common LLM providers and auto-wires OTel export to ``localhost:4317`` so
``planoai obs`` works out of the box.
Auth handling:
- If the provider's env var is set, bind ``access_key: $ENV_VAR``.
- Otherwise set ``passthrough_auth: true`` so the client's own Authorization
header is forwarded. No env var is required to start the proxy.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
DEFAULT_LLM_LISTENER_PORT = 12000
# plano_config validation requires an http:// scheme on the OTLP endpoint.
DEFAULT_OTLP_ENDPOINT = "http://localhost:4317"
@dataclass(frozen=True)
class ProviderDefault:
name: str
env_var: str
base_url: str
model_pattern: str
# Only set for providers whose prefix in the model pattern is NOT one of the
# built-in SUPPORTED_PROVIDERS in cli/planoai/config_generator.py. For
# built-ins, the validator infers the interface from the model prefix and
# rejects configs that set this field explicitly.
provider_interface: str | None = None
# Keep ordering stable so synthesized configs diff cleanly across runs.
PROVIDER_DEFAULTS: list[ProviderDefault] = [
ProviderDefault(
name="openai",
env_var="OPENAI_API_KEY",
base_url="https://api.openai.com/v1",
model_pattern="openai/*",
),
ProviderDefault(
name="anthropic",
env_var="ANTHROPIC_API_KEY",
base_url="https://api.anthropic.com/v1",
model_pattern="anthropic/*",
),
ProviderDefault(
name="gemini",
env_var="GEMINI_API_KEY",
base_url="https://generativelanguage.googleapis.com/v1beta",
model_pattern="gemini/*",
),
ProviderDefault(
name="groq",
env_var="GROQ_API_KEY",
base_url="https://api.groq.com/openai/v1",
model_pattern="groq/*",
),
ProviderDefault(
name="deepseek",
env_var="DEEPSEEK_API_KEY",
base_url="https://api.deepseek.com/v1",
model_pattern="deepseek/*",
),
ProviderDefault(
name="mistral",
env_var="MISTRAL_API_KEY",
base_url="https://api.mistral.ai/v1",
model_pattern="mistral/*",
),
# DigitalOcean Gradient is a first-class provider post-#889 — the
# `digitalocean/` model prefix routes to the built-in Envoy cluster, no
# base_url needed at runtime.
ProviderDefault(
name="digitalocean",
env_var="DO_API_KEY",
base_url="https://inference.do-ai.run/v1",
model_pattern="digitalocean/*",
),
ProviderDefault(
name="vercel",
env_var="AI_GATEWAY_API_KEY",
base_url="https://ai-gateway.vercel.sh/v1",
model_pattern="vercel/*",
),
# OpenRouter is a first-class provider — the `openrouter/` model prefix is
# accepted by the schema and brightstaff's ProviderId parser, so no
# provider_interface override is needed.
ProviderDefault(
name="openrouter",
env_var="OPENROUTER_API_KEY",
base_url="https://openrouter.ai/api/v1",
model_pattern="openrouter/*",
),
]
@dataclass
class DetectionResult:
with_keys: list[ProviderDefault]
passthrough: list[ProviderDefault]
@property
def summary(self) -> str:
parts = []
if self.with_keys:
parts.append("env-keyed: " + ", ".join(p.name for p in self.with_keys))
if self.passthrough:
parts.append("pass-through: " + ", ".join(p.name for p in self.passthrough))
return " | ".join(parts) if parts else "no providers"
def detect_providers(env: dict[str, str] | None = None) -> DetectionResult:
env = env if env is not None else dict(os.environ)
with_keys: list[ProviderDefault] = []
passthrough: list[ProviderDefault] = []
for p in PROVIDER_DEFAULTS:
val = env.get(p.env_var)
if val:
with_keys.append(p)
else:
passthrough.append(p)
return DetectionResult(with_keys=with_keys, passthrough=passthrough)
def synthesize_default_config(
env: dict[str, str] | None = None,
*,
listener_port: int = DEFAULT_LLM_LISTENER_PORT,
otel_endpoint: str = DEFAULT_OTLP_ENDPOINT,
) -> dict:
"""Build a pass-through config dict suitable for validation + envoy rendering.
The returned dict can be dumped to YAML and handed to the existing `planoai up`
pipeline unchanged.
"""
detection = detect_providers(env)
def _entry(p: ProviderDefault, base: dict) -> dict:
row: dict = {"name": p.name, "model": p.model_pattern, "base_url": p.base_url}
if p.provider_interface is not None:
row["provider_interface"] = p.provider_interface
row.update(base)
return row
model_providers: list[dict] = []
for p in detection.with_keys:
model_providers.append(_entry(p, {"access_key": f"${p.env_var}"}))
for p in detection.passthrough:
model_providers.append(_entry(p, {"passthrough_auth": True}))
# No explicit `default: true` entry is synthesized: the plano config
# validator rejects wildcard models as defaults, and brightstaff already
# registers bare model names as lookup keys during wildcard expansion
# (crates/common/src/llm_providers.rs), so `{"model": "gpt-4o-mini"}`
# without a prefix resolves via the openai wildcard without needing
# `default: true`. See discussion on #890.
return {
"version": "v0.4.0",
"listeners": [
{
"name": "llm",
"type": "model",
"port": listener_port,
"address": "0.0.0.0",
}
],
"model_providers": model_providers,
"tracing": {
"random_sampling": 100,
"opentracing_grpc_endpoint": otel_endpoint,
},
}

View file

@ -1,9 +1,18 @@
import json
import os
import multiprocessing
import subprocess
import sys
import contextlib
import logging
import rich_click as click
import yaml
from planoai import targets
from planoai.defaults import (
DEFAULT_LLM_LISTENER_PORT,
detect_providers,
synthesize_default_config,
)
# Brand color - Plano purple
PLANO_COLOR = "#969FF4"
@ -28,9 +37,13 @@ from planoai.core import (
)
from planoai.init_cmd import init as init_cmd
from planoai.trace_cmd import trace as trace_cmd, start_trace_listener_background
from planoai.chatgpt_cmd import chatgpt as chatgpt_cmd
from planoai.obs_cmd import obs as obs_cmd
from planoai.consts import (
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT,
DEFAULT_NATIVE_OTEL_TRACING_GRPC_ENDPOINT,
NATIVE_PID_FILE,
PLANO_RUN_DIR,
PLANO_DOCKER_IMAGE,
PLANO_DOCKER_NAME,
)
@ -40,6 +53,30 @@ from planoai.versioning import check_version_status, get_latest_version, get_ver
log = getLogger(__name__)
def _is_native_plano_running() -> bool:
if not os.path.exists(NATIVE_PID_FILE):
return False
try:
with open(NATIVE_PID_FILE, "r") as f:
pids = json.load(f)
except (OSError, json.JSONDecodeError):
return False
envoy_pid = pids.get("envoy_pid")
brightstaff_pid = pids.get("brightstaff_pid")
if not isinstance(envoy_pid, int) or not isinstance(brightstaff_pid, int):
return False
for pid in (envoy_pid, brightstaff_pid):
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
continue
return True
def _is_port_in_use(port: int) -> bool:
"""Check if a TCP port is already bound on localhost."""
import socket
@ -75,6 +112,42 @@ def _print_cli_header(console) -> None:
)
@contextlib.contextmanager
def _temporary_cli_log_level(level: str | None):
if level is None:
yield
return
current_level = logging.getLevelName(logging.getLogger().level).lower()
set_log_level(level)
try:
yield
finally:
set_log_level(current_level)
def _inject_chatgpt_tokens_if_needed(config, env, console):
"""If config uses chatgpt providers, resolve tokens from ~/.plano/chatgpt/auth.json."""
providers = config.get("model_providers") or config.get("llm_providers") or []
has_chatgpt = any(str(p.get("model", "")).startswith("chatgpt/") for p in providers)
if not has_chatgpt:
return
try:
from planoai.chatgpt_auth import get_access_token
access_token, account_id = get_access_token()
env["CHATGPT_ACCESS_TOKEN"] = access_token
if account_id:
env["CHATGPT_ACCOUNT_ID"] = account_id
except Exception as e:
console.print(
f"\n[red]ChatGPT auth error:[/red] {e}\n"
f"[dim]Run 'planoai chatgpt login' to authenticate.[/dim]\n"
)
sys.exit(1)
def _print_missing_keys(console, missing_keys: list[str]) -> None:
console.print(f"\n[red]✗[/red] [red]Missing API keys![/red]\n")
for key in missing_keys:
@ -267,163 +340,239 @@ def build(docker):
help="Run Plano inside Docker instead of natively.",
is_flag=True,
)
def up(file, path, foreground, with_tracing, tracing_port, docker):
@click.option(
"--verbose",
"-v",
default=False,
help="Show detailed startup logs with timestamps.",
is_flag=True,
)
@click.option(
"--listener-port",
default=DEFAULT_LLM_LISTENER_PORT,
type=int,
show_default=True,
help="Override the LLM listener port when running without a config file. Ignored when a config file is present.",
)
def up(
file,
path,
foreground,
with_tracing,
tracing_port,
docker,
verbose,
listener_port,
):
"""Starts Plano."""
from rich.status import Status
console = _console()
_print_cli_header(console)
# Use the utility function to find config file
plano_config_file = find_config_file(path, file)
with _temporary_cli_log_level("warning" if not verbose else None):
# Use the utility function to find config file
plano_config_file = find_config_file(path, file)
# Check if the file exists
if not os.path.exists(plano_config_file):
console.print(
f"[red]✗[/red] Config file not found: [dim]{plano_config_file}[/dim]"
# Zero-config fallback: when no user config is present, synthesize a
# pass-through config that covers the common LLM providers and
# auto-wires OTel export to ``planoai obs``. See cli/planoai/defaults.py.
if not os.path.exists(plano_config_file):
detection = detect_providers()
cfg_dict = synthesize_default_config(listener_port=listener_port)
default_dir = os.path.expanduser("~/.plano")
os.makedirs(default_dir, exist_ok=True)
synthesized_path = os.path.join(default_dir, "default_config.yaml")
with open(synthesized_path, "w") as fh:
yaml.safe_dump(cfg_dict, fh, sort_keys=False)
plano_config_file = synthesized_path
console.print(
f"[dim]No plano config found; using defaults ({detection.summary}). "
f"Listening on :{listener_port}, tracing -> http://localhost:4317.[/dim]"
)
if not docker:
from planoai.native_runner import native_validate_config
with Status(
"[dim]Validating configuration[/dim]",
spinner="dots",
spinner_style="dim",
):
try:
native_validate_config(plano_config_file)
except SystemExit:
console.print(f"[red]✗[/red] Validation failed")
sys.exit(1)
except Exception as e:
console.print(f"[red]✗[/red] Validation failed")
console.print(f" [dim]{str(e).strip()}[/dim]")
sys.exit(1)
else:
with Status(
"[dim]Validating configuration (Docker)[/dim]",
spinner="dots",
spinner_style="dim",
):
(
validation_return_code,
_,
validation_stderr,
) = docker_validate_plano_schema(plano_config_file)
if validation_return_code != 0:
console.print(f"[red]✗[/red] Validation failed")
if validation_stderr:
console.print(f" [dim]{validation_stderr.strip()}[/dim]")
sys.exit(1)
log.info("Configuration valid")
# Set up environment
default_otel = (
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
if docker
else DEFAULT_NATIVE_OTEL_TRACING_GRPC_ENDPOINT
)
sys.exit(1)
env_stage = {
"OTEL_TRACING_GRPC_ENDPOINT": default_otel,
}
env = os.environ.copy()
env.pop("PATH", None)
if not docker:
from planoai.native_runner import native_validate_config
import yaml
with Status(
"[dim]Validating configuration[/dim]",
spinner="dots",
spinner_style="dim",
):
try:
native_validate_config(plano_config_file)
except SystemExit:
console.print(f"[red]✗[/red] Validation failed")
sys.exit(1)
except Exception as e:
console.print(f"[red]✗[/red] Validation failed")
console.print(f" [dim]{str(e).strip()}[/dim]")
sys.exit(1)
else:
with Status(
"[dim]Validating configuration (Docker)[/dim]",
spinner="dots",
spinner_style="dim",
):
(
validation_return_code,
_,
validation_stderr,
) = docker_validate_plano_schema(plano_config_file)
with open(plano_config_file, "r") as f:
plano_config = yaml.safe_load(f)
if validation_return_code != 0:
console.print(f"[red]✗[/red] Validation failed")
if validation_stderr:
console.print(f" [dim]{validation_stderr.strip()}[/dim]")
# Inject ChatGPT tokens from ~/.plano/chatgpt/auth.json if any provider needs them
_inject_chatgpt_tokens_if_needed(plano_config, env, console)
# Check access keys
access_keys = get_llm_provider_access_keys(plano_config_file=plano_config_file)
access_keys = set(access_keys)
access_keys = [
item[1:] if item.startswith("$") else item for item in access_keys
]
missing_keys = []
if access_keys:
if file:
app_env_file = os.path.join(
os.path.dirname(os.path.abspath(file)), ".env"
)
else:
app_env_file = os.path.abspath(os.path.join(path, ".env"))
if not os.path.exists(app_env_file):
for access_key in access_keys:
if env.get(access_key) is None:
missing_keys.append(access_key)
else:
env_stage[access_key] = env.get(access_key)
else:
env_file_dict = load_env_file_to_dict(app_env_file)
for access_key in access_keys:
if env_file_dict.get(access_key) is None:
missing_keys.append(access_key)
else:
env_stage[access_key] = env_file_dict[access_key]
if missing_keys:
_print_missing_keys(console, missing_keys)
sys.exit(1)
log.info("Configuration valid")
# Pass log level to the Docker container — supervisord uses LOG_LEVEL
# to set RUST_LOG (brightstaff) and envoy component log levels
env_stage["LOG_LEVEL"] = os.environ.get("LOG_LEVEL", "info")
# Set up environment
default_otel = (
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
if docker
else DEFAULT_NATIVE_OTEL_TRACING_GRPC_ENDPOINT
)
env_stage = {
"OTEL_TRACING_GRPC_ENDPOINT": default_otel,
}
env = os.environ.copy()
env.pop("PATH", None)
# Start the local OTLP trace collector if --with-tracing is set
trace_server = None
if with_tracing:
if _is_port_in_use(tracing_port):
# A listener is already running (e.g. `planoai trace listen`)
console.print(
f"[green]✓[/green] Trace collector already running on port [cyan]{tracing_port}[/cyan]"
)
else:
try:
trace_server = start_trace_listener_background(
grpc_port=tracing_port
)
console.print(
f"[green]✓[/green] Trace collector listening on [cyan]0.0.0.0:{tracing_port}[/cyan]"
)
except Exception as e:
console.print(
f"[red]✗[/red] Failed to start trace collector on port {tracing_port}: {e}"
)
console.print(
f"\n[dim]Check if another process is using port {tracing_port}:[/dim]"
)
console.print(f" [cyan]lsof -i :{tracing_port}[/cyan]")
console.print(f"\n[dim]Or use a different port:[/dim]")
console.print(
f" [cyan]planoai up --with-tracing --tracing-port 4318[/cyan]\n"
)
sys.exit(1)
# Check access keys
access_keys = get_llm_provider_access_keys(plano_config_file=plano_config_file)
access_keys = set(access_keys)
access_keys = [item[1:] if item.startswith("$") else item for item in access_keys]
# Update the OTEL endpoint so the gateway sends traces to the right port
tracing_host = "host.docker.internal" if docker else "localhost"
otel_endpoint = f"http://{tracing_host}:{tracing_port}"
env_stage["OTEL_TRACING_GRPC_ENDPOINT"] = otel_endpoint
missing_keys = []
if access_keys:
if file:
app_env_file = os.path.join(os.path.dirname(os.path.abspath(file)), ".env")
else:
app_env_file = os.path.abspath(os.path.join(path, ".env"))
env.update(env_stage)
try:
if not docker:
from planoai.native_runner import start_native
if not os.path.exists(app_env_file):
for access_key in access_keys:
if env.get(access_key) is None:
missing_keys.append(access_key)
if not verbose:
with Status(
f"[{PLANO_COLOR}]Starting Plano...[/{PLANO_COLOR}]",
spinner="dots",
) as status:
def _update_progress(message: str) -> None:
console.print(f"[dim]{message}[/dim]")
start_native(
plano_config_file,
env,
foreground=foreground,
with_tracing=with_tracing,
progress_callback=_update_progress,
)
console.print("")
console.print(
"[green]✓[/green] Plano is running! [dim](native mode)[/dim]"
)
log_dir = os.path.join(PLANO_RUN_DIR, "logs")
console.print(f"Logs: {log_dir}")
console.print("Run 'planoai down' to stop.")
else:
env_stage[access_key] = env.get(access_key)
else:
env_file_dict = load_env_file_to_dict(app_env_file)
for access_key in access_keys:
if env_file_dict.get(access_key) is None:
missing_keys.append(access_key)
else:
env_stage[access_key] = env_file_dict[access_key]
start_native(
plano_config_file,
env,
foreground=foreground,
with_tracing=with_tracing,
)
else:
start_plano(plano_config_file, env, foreground=foreground)
if missing_keys:
_print_missing_keys(console, missing_keys)
sys.exit(1)
# Pass log level to the Docker container — supervisord uses LOG_LEVEL
# to set RUST_LOG (brightstaff) and envoy component log levels
env_stage["LOG_LEVEL"] = os.environ.get("LOG_LEVEL", "info")
# Start the local OTLP trace collector if --with-tracing is set
trace_server = None
if with_tracing:
if _is_port_in_use(tracing_port):
# A listener is already running (e.g. `planoai trace listen`)
console.print(
f"[green]✓[/green] Trace collector already running on port [cyan]{tracing_port}[/cyan]"
)
else:
try:
trace_server = start_trace_listener_background(grpc_port=tracing_port)
# When tracing is enabled but --foreground is not, keep the process
# alive so the OTLP collector continues to receive spans.
if trace_server is not None and not foreground:
console.print(
f"[green]✓[/green] Trace collector listening on [cyan]0.0.0.0:{tracing_port}[/cyan]"
f"[dim]Plano is running. Trace collector active on port {tracing_port}. Press Ctrl+C to stop.[/dim]"
)
except Exception as e:
console.print(
f"[red]✗[/red] Failed to start trace collector on port {tracing_port}: {e}"
)
console.print(
f"\n[dim]Check if another process is using port {tracing_port}:[/dim]"
)
console.print(f" [cyan]lsof -i :{tracing_port}[/cyan]")
console.print(f"\n[dim]Or use a different port:[/dim]")
console.print(
f" [cyan]planoai up --with-tracing --tracing-port 4318[/cyan]\n"
)
sys.exit(1)
# Update the OTEL endpoint so the gateway sends traces to the right port
tracing_host = "host.docker.internal" if docker else "localhost"
otel_endpoint = f"http://{tracing_host}:{tracing_port}"
env_stage["OTEL_TRACING_GRPC_ENDPOINT"] = otel_endpoint
env.update(env_stage)
try:
if not docker:
from planoai.native_runner import start_native
start_native(
plano_config_file, env, foreground=foreground, with_tracing=with_tracing
)
else:
start_plano(plano_config_file, env, foreground=foreground)
# When tracing is enabled but --foreground is not, keep the process
# alive so the OTLP collector continues to receive spans.
if trace_server is not None and not foreground:
console.print(
f"[dim]Plano is running. Trace collector active on port {tracing_port}. Press Ctrl+C to stop.[/dim]"
)
trace_server.wait_for_termination()
except KeyboardInterrupt:
if trace_server is not None:
console.print(f"\n[dim]Stopping trace collector...[/dim]")
finally:
if trace_server is not None:
trace_server.stop(grace=2)
trace_server.wait_for_termination()
except KeyboardInterrupt:
if trace_server is not None:
console.print(f"\n[dim]Stopping trace collector...[/dim]")
finally:
if trace_server is not None:
trace_server.stop(grace=2)
@click.command()
@ -433,25 +582,46 @@ def up(file, path, foreground, with_tracing, tracing_port, docker):
help="Stop a Docker-based Plano instance.",
is_flag=True,
)
def down(docker):
@click.option(
"--verbose",
"-v",
default=False,
help="Show detailed shutdown logs with timestamps.",
is_flag=True,
)
def down(docker, verbose):
"""Stops Plano."""
console = _console()
_print_cli_header(console)
if not docker:
from planoai.native_runner import stop_native
with _temporary_cli_log_level("warning" if not verbose else None):
if not docker:
from planoai.native_runner import stop_native
with console.status(
f"[{PLANO_COLOR}]Shutting down Plano...[/{PLANO_COLOR}]",
spinner="dots",
):
stop_native()
else:
with console.status(
f"[{PLANO_COLOR}]Shutting down Plano (Docker)...[/{PLANO_COLOR}]",
spinner="dots",
):
stop_docker_container()
with console.status(
f"[{PLANO_COLOR}]Shutting down Plano...[/{PLANO_COLOR}]",
spinner="dots",
):
stopped = stop_native()
if not verbose:
if stopped:
console.print(
"[green]✓[/green] Plano stopped! [dim](native mode)[/dim]"
)
else:
console.print(
"[dim]No Plano instance was running (native mode)[/dim]"
)
else:
with console.status(
f"[{PLANO_COLOR}]Shutting down Plano (Docker)...[/{PLANO_COLOR}]",
spinner="dots",
):
stop_docker_container()
if not verbose:
console.print(
"[green]✓[/green] Plano stopped! [dim](docker mode)[/dim]"
)
@click.command()
@ -483,9 +653,21 @@ def generate_prompt_targets(file):
is_flag=True,
)
@click.option("--follow", help="Follow the logs", is_flag=True)
def logs(debug, follow):
@click.option(
"--docker",
default=False,
help="Stream logs from a Docker-based Plano instance.",
is_flag=True,
)
def logs(debug, follow, docker):
"""Stream logs from access logs services."""
if not docker:
from planoai.native_runner import native_logs
native_logs(debug=debug, follow=follow)
return
plano_process = None
try:
if debug:
@ -511,7 +693,7 @@ def logs(debug, follow):
@click.command()
@click.argument("type", type=click.Choice(["claude"]), required=True)
@click.argument("type", type=click.Choice(["claude", "codex"]), required=True)
@click.argument("file", required=False) # Optional file argument
@click.option(
"--path", default=".", help="Path to the directory containing plano_config.yaml"
@ -524,14 +706,19 @@ def logs(debug, follow):
def cli_agent(type, file, path, settings):
"""Start a CLI agent connected to Plano.
CLI_AGENT: The type of CLI agent to start (currently only 'claude' is supported)
CLI_AGENT: The type of CLI agent to start ('claude' or 'codex')
"""
# Check if plano docker container is running
plano_status = docker_container_status(PLANO_DOCKER_NAME)
if plano_status != "running":
log.error(f"plano docker container is not running (status: {plano_status})")
log.error("Please start plano using the 'planoai up' command.")
native_running = _is_native_plano_running()
docker_running = False
if not native_running:
docker_running = docker_container_status(PLANO_DOCKER_NAME) == "running"
if not (native_running or docker_running):
log.error("Plano is not running.")
log.error(
"Start Plano first using 'planoai up <config.yaml>' (native or --docker mode)."
)
sys.exit(1)
# Determine plano_config.yaml path
@ -541,7 +728,7 @@ def cli_agent(type, file, path, settings):
sys.exit(1)
try:
start_cli_agent(plano_config_file, settings)
start_cli_agent(plano_config_file, type, settings)
except SystemExit:
# Re-raise SystemExit to preserve exit codes
raise
@ -559,6 +746,8 @@ main.add_command(cli_agent)
main.add_command(generate_prompt_targets)
main.add_command(init_cmd, name="init")
main.add_command(trace_cmd, name="trace")
main.add_command(chatgpt_cmd, name="chatgpt")
main.add_command(obs_cmd, name="obs")
if __name__ == "__main__":
main()

View file

@ -6,6 +6,7 @@ import signal
import subprocess
import sys
import time
from collections.abc import Callable
from planoai.consts import (
NATIVE_PID_FILE,
@ -157,7 +158,13 @@ def render_native_config(plano_config_file, env, with_tracing=False):
return envoy_config_path, plano_config_rendered_path
def start_native(plano_config_file, env, foreground=False, with_tracing=False):
def start_native(
plano_config_file,
env,
foreground=False,
with_tracing=False,
progress_callback: Callable[[str], None] | None = None,
):
"""Start Envoy and brightstaff natively."""
from planoai.core import _get_gateway_ports
@ -174,6 +181,8 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
)
log.info("Configuration rendered")
if progress_callback:
progress_callback("Configuration valid...")
log_dir = os.path.join(PLANO_RUN_DIR, "logs")
os.makedirs(log_dir, exist_ok=True)
@ -194,6 +203,8 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
os.path.join(log_dir, "brightstaff.log"),
)
log.info(f"Started brightstaff (PID {brightstaff_pid})")
if progress_callback:
progress_callback(f"Started brightstaff (PID: {brightstaff_pid})...")
# Start envoy
envoy_pid = _daemon_exec(
@ -210,6 +221,8 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
os.path.join(log_dir, "envoy.log"),
)
log.info(f"Started envoy (PID {envoy_pid})")
if progress_callback:
progress_callback(f"Started envoy (PID: {envoy_pid})...")
# Save PIDs
os.makedirs(PLANO_RUN_DIR, exist_ok=True)
@ -225,6 +238,8 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
# Health check
gateway_ports = _get_gateway_ports(plano_config_file)
log.info("Waiting for listeners to become healthy...")
if progress_callback:
progress_callback("Waiting for listeners to become healthy...")
start_time = time.time()
timeout = 60
@ -238,6 +253,7 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
log.info("Plano is running (native mode)")
for port in gateway_ports:
log.info(f" http://localhost:{port}")
break
# Check if processes are still alive
@ -352,11 +368,19 @@ def _kill_pid(pid):
pass
def stop_native():
"""Stop natively-running Envoy and brightstaff processes."""
def stop_native(skip_pids: set | None = None):
"""Stop natively-running Envoy, brightstaff, and watchdog processes.
Args:
skip_pids: Set of PIDs to skip (used by the watchdog to avoid self-termination).
Returns:
bool: True if at least one process was running and received a stop signal,
False if no running native Plano process was found.
"""
if not os.path.exists(NATIVE_PID_FILE):
log.info("No native Plano instance found (PID file missing).")
return
return False
with open(NATIVE_PID_FILE, "r") as f:
pids = json.load(f)
@ -364,12 +388,19 @@ def stop_native():
envoy_pid = pids.get("envoy_pid")
brightstaff_pid = pids.get("brightstaff_pid")
for name, pid in [("envoy", envoy_pid), ("brightstaff", brightstaff_pid)]:
had_running_process = False
for name, pid in [
("envoy", envoy_pid),
("brightstaff", brightstaff_pid),
]:
if skip_pids and pid in skip_pids:
continue
if pid is None:
continue
try:
os.kill(pid, signal.SIGTERM)
log.info(f"Sent SIGTERM to {name} (PID {pid})")
had_running_process = True
except ProcessLookupError:
log.info(f"{name} (PID {pid}) already stopped")
continue
@ -394,7 +425,11 @@ def stop_native():
pass
os.unlink(NATIVE_PID_FILE)
log.info("Plano stopped (native mode).")
if had_running_process:
log.info("Plano stopped (native mode).")
else:
log.info("No native Plano instance was running.")
return had_running_process
def native_validate_config(plano_config_file):
@ -420,6 +455,51 @@ def native_validate_config(plano_config_file):
with _temporary_env(overrides):
from planoai.config_generator import validate_and_render_schema
# Suppress verbose print output from config_generator
with contextlib.redirect_stdout(io.StringIO()):
validate_and_render_schema()
# Suppress verbose print output from config_generator but capture errors
captured = io.StringIO()
try:
with contextlib.redirect_stdout(captured):
validate_and_render_schema()
except SystemExit:
# validate_and_render_schema calls exit(1) on failure after
# printing to stdout; re-raise so the caller gets a useful message.
output = captured.getvalue().strip()
raise Exception(output) if output else Exception("Config validation failed")
def native_logs(debug=False, follow=False):
"""Stream logs from native-mode Plano."""
import glob as glob_mod
log_dir = os.path.join(PLANO_RUN_DIR, "logs")
if not os.path.isdir(log_dir):
log.error(f"No native log directory found at {log_dir}")
log.error("Is Plano running? Start it with: planoai up <config.yaml>")
sys.exit(1)
log_files = sorted(glob_mod.glob(os.path.join(log_dir, "access_*.log")))
if debug:
log_files.extend(
[
os.path.join(log_dir, "envoy.log"),
os.path.join(log_dir, "brightstaff.log"),
]
)
# Filter to files that exist
log_files = [f for f in log_files if os.path.exists(f)]
if not log_files:
log.error(f"No log files found in {log_dir}")
sys.exit(1)
tail_args = ["tail"]
if follow:
tail_args.append("-f")
tail_args.extend(log_files)
try:
proc = subprocess.Popen(tail_args, stdout=sys.stdout, stderr=sys.stderr)
proc.wait()
except KeyboardInterrupt:
if proc.poll() is None:
proc.terminate()

View file

@ -0,0 +1,6 @@
"""Plano observability console: in-memory live view of LLM traffic."""
from planoai.obs.collector import LLMCall, LLMCallStore, ObsCollector
from planoai.obs.pricing import PricingCatalog
__all__ = ["LLMCall", "LLMCallStore", "ObsCollector", "PricingCatalog"]

View file

@ -0,0 +1,266 @@
"""In-memory collector for LLM calls, fed by OTLP/gRPC spans from brightstaff."""
from __future__ import annotations
import threading
from collections import deque
from concurrent import futures
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Iterable
import grpc
from opentelemetry.proto.collector.trace.v1 import (
trace_service_pb2,
trace_service_pb2_grpc,
)
DEFAULT_GRPC_PORT = 4317
DEFAULT_CAPACITY = 1000
@dataclass
class LLMCall:
"""One LLM call as reconstructed from a brightstaff LLM span.
Fields default to ``None`` when the underlying span attribute was absent.
"""
request_id: str
timestamp: datetime
model: str
provider: str | None = None
request_model: str | None = None
session_id: str | None = None
route_name: str | None = None
is_streaming: bool | None = None
status_code: int | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
total_tokens: int | None = None
cached_input_tokens: int | None = None
cache_creation_tokens: int | None = None
reasoning_tokens: int | None = None
ttft_ms: float | None = None
duration_ms: float | None = None
routing_strategy: str | None = None
routing_reason: str | None = None
cost_usd: float | None = None
@property
def tpt_ms(self) -> float | None:
if self.duration_ms is None or self.completion_tokens in (None, 0):
return None
ttft = self.ttft_ms or 0.0
generate_ms = max(0.0, self.duration_ms - ttft)
if generate_ms <= 0:
return None
return generate_ms / self.completion_tokens
@property
def tokens_per_sec(self) -> float | None:
tpt = self.tpt_ms
if tpt is None or tpt <= 0:
return None
return 1000.0 / tpt
class LLMCallStore:
"""Thread-safe ring buffer of recent LLM calls."""
def __init__(self, capacity: int = DEFAULT_CAPACITY) -> None:
self._capacity = capacity
self._calls: deque[LLMCall] = deque(maxlen=capacity)
self._lock = threading.Lock()
@property
def capacity(self) -> int:
return self._capacity
def add(self, call: LLMCall) -> None:
with self._lock:
self._calls.append(call)
def clear(self) -> None:
with self._lock:
self._calls.clear()
def snapshot(self) -> list[LLMCall]:
with self._lock:
return list(self._calls)
def __len__(self) -> int:
with self._lock:
return len(self._calls)
# Span attribute keys used below are the canonical OTel / Plano keys emitted by
# brightstaff — see crates/brightstaff/src/tracing/constants.rs for the source
# of truth.
def _anyvalue_to_python(value: Any) -> Any: # AnyValue from OTLP
kind = value.WhichOneof("value")
if kind == "string_value":
return value.string_value
if kind == "bool_value":
return value.bool_value
if kind == "int_value":
return value.int_value
if kind == "double_value":
return value.double_value
return None
def _attrs_to_dict(attrs: Iterable[Any]) -> dict[str, Any]:
out: dict[str, Any] = {}
for kv in attrs:
py = _anyvalue_to_python(kv.value)
if py is not None:
out[kv.key] = py
return out
def _maybe_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _maybe_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def span_to_llm_call(
span: Any, service_name: str, pricing: Any | None = None
) -> LLMCall | None:
"""Convert an OTLP span into an LLMCall, or return None if it isn't one.
A span is considered an LLM call iff it carries the ``llm.model`` attribute.
"""
attrs = _attrs_to_dict(span.attributes)
model = attrs.get("llm.model")
if not model:
return None
# Prefer explicit span attributes; fall back to likely aliases.
request_id = next(
(
str(attrs[key])
for key in ("request_id", "http.request_id")
if key in attrs and attrs[key] is not None
),
span.span_id.hex() if span.span_id else "",
)
start_ns = span.start_time_unix_nano or 0
ts = (
datetime.fromtimestamp(start_ns / 1_000_000_000, tz=timezone.utc).astimezone()
if start_ns
else datetime.now().astimezone()
)
call = LLMCall(
request_id=str(request_id),
timestamp=ts,
model=str(model),
provider=(
str(attrs["llm.provider"]) if "llm.provider" in attrs else service_name
),
request_model=(
str(attrs["model.requested"]) if "model.requested" in attrs else None
),
session_id=(
str(attrs["plano.session_id"]) if "plano.session_id" in attrs else None
),
route_name=(
str(attrs["plano.route.name"]) if "plano.route.name" in attrs else None
),
is_streaming=(
bool(attrs["llm.is_streaming"]) if "llm.is_streaming" in attrs else None
),
status_code=_maybe_int(attrs.get("http.status_code")),
prompt_tokens=_maybe_int(attrs.get("llm.usage.prompt_tokens")),
completion_tokens=_maybe_int(attrs.get("llm.usage.completion_tokens")),
total_tokens=_maybe_int(attrs.get("llm.usage.total_tokens")),
cached_input_tokens=_maybe_int(attrs.get("llm.usage.cached_input_tokens")),
cache_creation_tokens=_maybe_int(attrs.get("llm.usage.cache_creation_tokens")),
reasoning_tokens=_maybe_int(attrs.get("llm.usage.reasoning_tokens")),
ttft_ms=_maybe_float(attrs.get("llm.time_to_first_token")),
duration_ms=_maybe_float(attrs.get("llm.duration_ms")),
routing_strategy=(
str(attrs["routing.strategy"]) if "routing.strategy" in attrs else None
),
routing_reason=(
str(attrs["routing.selection_reason"])
if "routing.selection_reason" in attrs
else None
),
)
if pricing is not None:
call.cost_usd = pricing.cost_for_call(call)
return call
class _ObsServicer(trace_service_pb2_grpc.TraceServiceServicer):
def __init__(self, store: LLMCallStore, pricing: Any | None) -> None:
self._store = store
self._pricing = pricing
def Export(self, request, context): # noqa: N802 — gRPC generated name
for resource_spans in request.resource_spans:
service_name = "unknown"
for attr in resource_spans.resource.attributes:
if attr.key == "service.name":
val = _anyvalue_to_python(attr.value)
if val is not None:
service_name = str(val)
break
for scope_spans in resource_spans.scope_spans:
for span in scope_spans.spans:
call = span_to_llm_call(span, service_name, self._pricing)
if call is not None:
self._store.add(call)
return trace_service_pb2.ExportTraceServiceResponse()
@dataclass
class ObsCollector:
"""Owns the OTLP/gRPC server and the in-memory LLMCall ring buffer."""
store: LLMCallStore = field(default_factory=LLMCallStore)
pricing: Any | None = None
host: str = "0.0.0.0"
port: int = DEFAULT_GRPC_PORT
_server: grpc.Server | None = field(default=None, init=False, repr=False)
def start(self) -> None:
if self._server is not None:
return
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
trace_service_pb2_grpc.add_TraceServiceServicer_to_server(
_ObsServicer(self.store, self.pricing), server
)
address = f"{self.host}:{self.port}"
bound = server.add_insecure_port(address)
if bound == 0:
raise OSError(
f"Failed to bind OTLP listener on {address}: port already in use. "
"Stop tracing via `planoai trace down` or pick another port with --port."
)
server.start()
self._server = server
def stop(self, grace: float = 2.0) -> None:
if self._server is not None:
self._server.stop(grace)
self._server = None

321
cli/planoai/obs/pricing.py Normal file
View file

@ -0,0 +1,321 @@
"""DigitalOcean Gradient pricing catalog for the obs console.
Ported loosely from ``crates/brightstaff/src/router/model_metrics.rs::fetch_do_pricing``.
Single-source: one fetch at startup, cached for the life of the process.
"""
from __future__ import annotations
import logging
import re
import threading
from dataclasses import dataclass
from typing import Any
import requests
DEFAULT_PRICING_URL = "https://api.digitalocean.com/v2/gen-ai/models/catalog"
FETCH_TIMEOUT_SECS = 5.0
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ModelPrice:
"""Input/output $/token rates. Token counts are multiplied by these."""
input_per_token_usd: float
output_per_token_usd: float
cached_input_per_token_usd: float | None = None
class PricingCatalog:
"""In-memory pricing lookup keyed by model id.
DO's catalog uses ids like ``openai-gpt-5.4``; Plano's resolved model names
may arrive as ``do/openai-gpt-5.4`` or bare ``openai-gpt-5.4``. We strip the
leading provider prefix when looking up.
"""
def __init__(self, prices: dict[str, ModelPrice] | None = None) -> None:
self._prices: dict[str, ModelPrice] = prices or {}
self._lock = threading.Lock()
def __len__(self) -> int:
with self._lock:
return len(self._prices)
def sample_models(self, n: int = 5) -> list[str]:
with self._lock:
return list(self._prices.keys())[:n]
@classmethod
def fetch(cls, url: str = DEFAULT_PRICING_URL) -> "PricingCatalog":
"""Fetch pricing from DO's catalog endpoint. On failure, returns an
empty catalog (cost column will be blank).
The catalog endpoint is public no auth required, no signup so
``planoai obs`` gets cost data on first run out of the box.
"""
try:
resp = requests.get(url, timeout=FETCH_TIMEOUT_SECS)
resp.raise_for_status()
data = resp.json()
except Exception as exc: # noqa: BLE001 — best-effort; never fatal
logger.warning(
"DO pricing fetch failed: %s; cost column will be blank.",
exc,
)
return cls()
prices = _parse_do_pricing(data)
if not prices:
# Dump the first entry's raw shape so we can see which fields DO
# actually returned — helps when the catalog adds new fields or
# the response doesn't match our parser.
import json as _json
sample_items = _coerce_items(data)
sample = sample_items[0] if sample_items else data
logger.warning(
"DO pricing response had no parseable entries; cost column "
"will be blank. Sample entry: %s",
_json.dumps(sample, default=str)[:400],
)
return cls(prices)
def price_for(self, model_name: str | None) -> ModelPrice | None:
if not model_name:
return None
with self._lock:
# Try the full name first, then stripped prefix, then lowercased variants.
for candidate in _model_key_candidates(model_name):
hit = self._prices.get(candidate)
if hit is not None:
return hit
return None
def cost_for_call(self, call: Any) -> float | None:
"""Compute USD cost for an LLMCall. Returns None when pricing is unknown."""
price = self.price_for(getattr(call, "model", None)) or self.price_for(
getattr(call, "request_model", None)
)
if price is None:
return None
prompt = int(getattr(call, "prompt_tokens", 0) or 0)
completion = int(getattr(call, "completion_tokens", 0) or 0)
cached = int(getattr(call, "cached_input_tokens", 0) or 0)
# Cached input tokens are priced separately at the cached rate when known;
# otherwise they're already counted in prompt tokens at the regular rate.
fresh_prompt = prompt
if price.cached_input_per_token_usd is not None and cached:
fresh_prompt = max(0, prompt - cached)
cost_cached = cached * price.cached_input_per_token_usd
else:
cost_cached = 0.0
cost = (
fresh_prompt * price.input_per_token_usd
+ completion * price.output_per_token_usd
+ cost_cached
)
return round(cost, 6)
_DATE_SUFFIX_RE = re.compile(r"-\d{8}$")
_PROVIDER_PREFIXES = ("anthropic", "openai", "google", "meta", "cohere", "mistral")
_ANTHROPIC_FAMILIES = {"opus", "sonnet", "haiku"}
def _model_key_candidates(model_name: str) -> list[str]:
"""Lookup-side variants of a Plano-emitted model name.
Plano resolves names like ``claude-haiku-4-5-20251001``; the catalog stores
them as ``anthropic-claude-haiku-4.5``. We strip the date suffix and the
``provider/`` prefix here; the catalog itself registers the dash/dot and
family-order aliases at parse time (see :func:`_expand_aliases`).
"""
base = model_name.strip()
out = [base]
if "/" in base:
out.append(base.split("/", 1)[1])
for k in list(out):
stripped = _DATE_SUFFIX_RE.sub("", k)
if stripped != k:
out.append(stripped)
out.extend([v.lower() for v in list(out)])
seen: set[str] = set()
uniq = []
for key in out:
if key not in seen:
seen.add(key)
uniq.append(key)
return uniq
def _expand_aliases(model_id: str) -> set[str]:
"""Catalog-side variants of a DO model id.
DO publishes Anthropic models under ids like ``anthropic-claude-opus-4.7``
or ``anthropic-claude-4.6-sonnet`` while Plano emits ``claude-opus-4-7`` /
``claude-sonnet-4-6``. Generate a set covering provider-prefix stripping,
dashdot in version segments, and familyversion word order so a single
catalog entry matches every name shape we'll see at lookup.
"""
aliases: set[str] = set()
def add(name: str) -> None:
if not name:
return
aliases.add(name)
aliases.add(name.lower())
add(model_id)
base = model_id
head, _, rest = base.partition("-")
if head.lower() in _PROVIDER_PREFIXES and rest:
add(rest)
base = rest
for key in list(aliases):
if "." in key:
add(key.replace(".", "-"))
parts = base.split("-")
if len(parts) >= 3 and parts[0].lower() == "claude":
rest_parts = parts[1:]
for i, p in enumerate(rest_parts):
if p.lower() in _ANTHROPIC_FAMILIES:
others = rest_parts[:i] + rest_parts[i + 1 :]
if not others:
break
family_last = "claude-" + "-".join(others) + "-" + p
family_first = "claude-" + p + "-" + "-".join(others)
add(family_last)
add(family_first)
add(family_last.replace(".", "-"))
add(family_first.replace(".", "-"))
break
return aliases
def _parse_do_pricing(data: Any) -> dict[str, ModelPrice]:
"""Parse DO catalog response into a ModelPrice map keyed by model id.
DO's shape (as of 2026-04):
{
"data": [
{"model_id": "openai-gpt-5.4",
"pricing": {"input_price_per_million": 5.0,
"output_price_per_million": 15.0}},
...
]
}
Older/alternate shapes are also accepted (flat top-level fields, or the
``id``/``model``/``name`` key).
"""
prices: dict[str, ModelPrice] = {}
items = _coerce_items(data)
for item in items:
model_id = (
item.get("model_id")
or item.get("id")
or item.get("model")
or item.get("name")
)
if not model_id:
continue
# DO nests rates under `pricing`; try that first, then fall back to
# top-level fields for alternate response shapes.
sources = [item]
if isinstance(item.get("pricing"), dict):
sources.insert(0, item["pricing"])
input_rate = _extract_rate_from_sources(
sources,
["input_per_token", "input_token_price", "price_input"],
["input_price_per_million", "input_per_million", "input_per_mtok"],
)
output_rate = _extract_rate_from_sources(
sources,
["output_per_token", "output_token_price", "price_output"],
["output_price_per_million", "output_per_million", "output_per_mtok"],
)
cached_rate = _extract_rate_from_sources(
sources,
[
"cached_input_per_token",
"cached_input_token_price",
"prompt_cache_read_per_token",
],
[
"cached_input_price_per_million",
"cached_input_per_million",
"cached_input_per_mtok",
],
)
if input_rate is None or output_rate is None:
continue
# Treat 0-rate entries as "unknown" so cost falls back to `—` rather
# than showing a misleading $0.0000. DO's catalog sometimes omits
# rates for promo/open-weight models.
if input_rate == 0 and output_rate == 0:
continue
price = ModelPrice(
input_per_token_usd=input_rate,
output_per_token_usd=output_rate,
cached_input_per_token_usd=cached_rate,
)
for alias in _expand_aliases(str(model_id)):
prices.setdefault(alias, price)
return prices
def _coerce_items(data: Any) -> list[dict]:
if isinstance(data, list):
return [x for x in data if isinstance(x, dict)]
if isinstance(data, dict):
for key in ("data", "models", "pricing", "items"):
val = data.get(key)
if isinstance(val, list):
return [x for x in val if isinstance(x, dict)]
return []
def _extract_rate_from_sources(
sources: list[dict],
per_token_keys: list[str],
per_million_keys: list[str],
) -> float | None:
"""Return a per-token rate in USD, or None if unknown.
Some DO catalog responses put per-token values under a field whose name
says ``_per_million`` (e.g. ``input_price_per_million: 5E-8`` that's
$5e-8 per token, not per million). Heuristic: values < 1 are already
per-token (real per-million rates are ~0.1 to ~100); values >= 1 are
treated as per-million and divided by 1,000,000.
"""
for src in sources:
for key in per_token_keys:
if key in src and src[key] is not None:
try:
return float(src[key])
except (TypeError, ValueError):
continue
for key in per_million_keys:
if key in src and src[key] is not None:
try:
v = float(src[key])
except (TypeError, ValueError):
continue
if v >= 1:
return v / 1_000_000
return v
return None

634
cli/planoai/obs/render.py Normal file
View file

@ -0,0 +1,634 @@
"""Rich TUI renderer for the observability console."""
from __future__ import annotations
from collections import Counter
from dataclasses import dataclass
from datetime import datetime
from http import HTTPStatus
from rich.align import Align
from rich.box import SIMPLE, SIMPLE_HEAVY
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
MAX_WIDTH = 160
from planoai.obs.collector import LLMCall
@dataclass
class AggregateStats:
count: int
total_cost_usd: float
total_input_tokens: int
total_output_tokens: int
distinct_sessions: int
current_session: str | None
p50_latency_ms: float | None = None
p95_latency_ms: float | None = None
p99_latency_ms: float | None = None
p50_ttft_ms: float | None = None
p95_ttft_ms: float | None = None
p99_ttft_ms: float | None = None
error_count: int = 0
errors_4xx: int = 0
errors_5xx: int = 0
has_cost: bool = False
@dataclass
class ModelRollup:
model: str
requests: int
input_tokens: int
output_tokens: int
cache_write: int
cache_read: int
cost_usd: float
has_cost: bool = False
avg_tokens_per_sec: float | None = None
def _percentile(values: list[float], pct: float) -> float | None:
if not values:
return None
s = sorted(values)
k = max(0, min(len(s) - 1, int(round((pct / 100.0) * (len(s) - 1)))))
return s[k]
def aggregates(calls: list[LLMCall]) -> AggregateStats:
total_cost = sum((c.cost_usd or 0.0) for c in calls)
total_input = sum(int(c.prompt_tokens or 0) for c in calls)
total_output = sum(int(c.completion_tokens or 0) for c in calls)
session_ids = {c.session_id for c in calls if c.session_id}
current = next(
(c.session_id for c in reversed(calls) if c.session_id is not None), None
)
durations = [c.duration_ms for c in calls if c.duration_ms is not None]
ttfts = [c.ttft_ms for c in calls if c.ttft_ms is not None]
errors_4xx = sum(
1 for c in calls if c.status_code is not None and 400 <= c.status_code < 500
)
errors_5xx = sum(
1 for c in calls if c.status_code is not None and c.status_code >= 500
)
has_cost = any(c.cost_usd is not None for c in calls)
return AggregateStats(
count=len(calls),
total_cost_usd=total_cost,
total_input_tokens=total_input,
total_output_tokens=total_output,
distinct_sessions=len(session_ids),
current_session=current,
p50_latency_ms=_percentile(durations, 50),
p95_latency_ms=_percentile(durations, 95),
p99_latency_ms=_percentile(durations, 99),
p50_ttft_ms=_percentile(ttfts, 50),
p95_ttft_ms=_percentile(ttfts, 95),
p99_ttft_ms=_percentile(ttfts, 99),
error_count=errors_4xx + errors_5xx,
errors_4xx=errors_4xx,
errors_5xx=errors_5xx,
has_cost=has_cost,
)
def model_rollups(calls: list[LLMCall]) -> list[ModelRollup]:
buckets: dict[str, dict[str, float | int | bool]] = {}
tps_samples: dict[str, list[float]] = {}
for c in calls:
key = c.model
b = buckets.setdefault(
key,
{
"requests": 0,
"input": 0,
"output": 0,
"cache_write": 0,
"cache_read": 0,
"cost": 0.0,
"has_cost": False,
},
)
b["requests"] = int(b["requests"]) + 1
b["input"] = int(b["input"]) + int(c.prompt_tokens or 0)
b["output"] = int(b["output"]) + int(c.completion_tokens or 0)
b["cache_write"] = int(b["cache_write"]) + int(c.cache_creation_tokens or 0)
b["cache_read"] = int(b["cache_read"]) + int(c.cached_input_tokens or 0)
b["cost"] = float(b["cost"]) + (c.cost_usd or 0.0)
if c.cost_usd is not None:
b["has_cost"] = True
tps = c.tokens_per_sec
if tps is not None:
tps_samples.setdefault(key, []).append(tps)
rollups: list[ModelRollup] = []
for model, b in buckets.items():
samples = tps_samples.get(model)
avg_tps = (sum(samples) / len(samples)) if samples else None
rollups.append(
ModelRollup(
model=model,
requests=int(b["requests"]),
input_tokens=int(b["input"]),
output_tokens=int(b["output"]),
cache_write=int(b["cache_write"]),
cache_read=int(b["cache_read"]),
cost_usd=float(b["cost"]),
has_cost=bool(b["has_cost"]),
avg_tokens_per_sec=avg_tps,
)
)
rollups.sort(key=lambda r: (r.cost_usd, r.requests), reverse=True)
return rollups
@dataclass
class RouteHit:
route: str
hits: int
pct: float
p95_latency_ms: float | None
error_count: int
def route_hits(calls: list[LLMCall]) -> list[RouteHit]:
counts: Counter[str] = Counter()
per_route_latency: dict[str, list[float]] = {}
per_route_errors: dict[str, int] = {}
for c in calls:
if not c.route_name:
continue
counts[c.route_name] += 1
if c.duration_ms is not None:
per_route_latency.setdefault(c.route_name, []).append(c.duration_ms)
if c.status_code is not None and c.status_code >= 400:
per_route_errors[c.route_name] = per_route_errors.get(c.route_name, 0) + 1
total = sum(counts.values())
if total == 0:
return []
return [
RouteHit(
route=r,
hits=n,
pct=(n / total) * 100.0,
p95_latency_ms=_percentile(per_route_latency.get(r, []), 95),
error_count=per_route_errors.get(r, 0),
)
for r, n in counts.most_common()
]
def _fmt_cost(v: float | None, *, zero: str = "") -> str:
if v is None:
return ""
if v == 0:
return zero
if abs(v) < 0.0001:
return f"${v:.8f}".rstrip("0").rstrip(".")
if abs(v) < 0.01:
return f"${v:.6f}".rstrip("0").rstrip(".")
if abs(v) < 1:
return f"${v:.4f}"
return f"${v:,.2f}"
def _fmt_ms(v: float | None) -> str:
if v is None:
return ""
if v >= 1000:
return f"{v / 1000:.1f}s"
return f"{v:.0f}ms"
def _fmt_int(v: int | None) -> str:
if v is None or v == 0:
return ""
return f"{v:,}"
def _fmt_tokens(v: int | None) -> str:
if v is None:
return ""
return f"{v:,}"
def _fmt_tps(v: float | None) -> str:
if v is None or v <= 0:
return ""
if v >= 100:
return f"{v:.0f}/s"
return f"{v:.1f}/s"
def _latency_style(v: float | None) -> str:
if v is None:
return "dim"
if v < 500:
return "green"
if v < 2000:
return "yellow"
return "red"
def _ttft_style(v: float | None) -> str:
if v is None:
return "dim"
if v < 300:
return "green"
if v < 1000:
return "yellow"
return "red"
def _truncate_model(name: str, limit: int = 32) -> str:
if len(name) <= limit:
return name
return name[: limit - 1] + ""
def _status_text(code: int | None) -> Text:
if code is None:
return Text("", style="dim")
if 200 <= code < 300:
return Text("● ok", style="green")
if 300 <= code < 400:
return Text(f"{code}", style="yellow")
if 400 <= code < 500:
return Text(f"{code}", style="yellow bold")
return Text(f"{code}", style="red bold")
def _summary_panel(last: LLMCall | None, stats: AggregateStats) -> Panel:
# Content-sized columns with a fixed gutter keep the two blocks close
# together instead of stretching across the full terminal on wide screens.
grid = Table.grid(padding=(0, 4))
grid.add_column(no_wrap=True)
grid.add_column(no_wrap=True)
# Left: latest request snapshot.
left = Table.grid(padding=(0, 1))
left.add_column(style="dim", no_wrap=True)
left.add_column(no_wrap=True)
if last is None:
left.add_row("latest", Text("waiting for spans…", style="dim italic"))
else:
model_text = Text(_truncate_model(last.model, 48), style="bold cyan")
if last.is_streaming:
model_text.append(" ⟳ stream", style="dim")
left.add_row("model", model_text)
if last.request_model and last.request_model != last.model:
left.add_row(
"requested", Text(_truncate_model(last.request_model, 48), style="cyan")
)
if last.route_name:
left.add_row("route", Text(last.route_name, style="yellow"))
left.add_row("status", _status_text(last.status_code))
tokens = Text()
tokens.append(_fmt_tokens(last.prompt_tokens))
tokens.append(" in", style="dim")
tokens.append(" · ", style="dim")
tokens.append(_fmt_tokens(last.completion_tokens), style="green")
tokens.append(" out", style="dim")
if last.cached_input_tokens:
tokens.append(" · ", style="dim")
tokens.append(_fmt_tokens(last.cached_input_tokens), style="yellow")
tokens.append(" cached", style="dim")
left.add_row("tokens", tokens)
timing = Text()
timing.append("TTFT ", style="dim")
timing.append(_fmt_ms(last.ttft_ms), style=_ttft_style(last.ttft_ms))
timing.append(" · ", style="dim")
timing.append("lat ", style="dim")
timing.append(_fmt_ms(last.duration_ms), style=_latency_style(last.duration_ms))
tps = last.tokens_per_sec
if tps:
timing.append(" · ", style="dim")
timing.append(_fmt_tps(tps), style="green")
left.add_row("timing", timing)
left.add_row("cost", Text(_fmt_cost(last.cost_usd), style="green bold"))
# Right: lifetime totals.
right = Table.grid(padding=(0, 1))
right.add_column(style="dim", no_wrap=True)
right.add_column(no_wrap=True)
right.add_row(
"requests",
Text(f"{stats.count:,}", style="bold"),
)
if stats.error_count:
err_text = Text()
err_text.append(f"{stats.error_count:,}", style="red bold")
parts: list[str] = []
if stats.errors_4xx:
parts.append(f"{stats.errors_4xx} 4xx")
if stats.errors_5xx:
parts.append(f"{stats.errors_5xx} 5xx")
if parts:
err_text.append(f" ({' · '.join(parts)})", style="dim")
right.add_row("errors", err_text)
cost_str = _fmt_cost(stats.total_cost_usd) if stats.has_cost else ""
right.add_row("total cost", Text(cost_str, style="green bold"))
tokens_total = Text()
tokens_total.append(_fmt_tokens(stats.total_input_tokens))
tokens_total.append(" in", style="dim")
tokens_total.append(" · ", style="dim")
tokens_total.append(_fmt_tokens(stats.total_output_tokens), style="green")
tokens_total.append(" out", style="dim")
right.add_row("tokens", tokens_total)
lat_text = Text()
lat_text.append("p50 ", style="dim")
lat_text.append(
_fmt_ms(stats.p50_latency_ms), style=_latency_style(stats.p50_latency_ms)
)
lat_text.append(" · ", style="dim")
lat_text.append("p95 ", style="dim")
lat_text.append(
_fmt_ms(stats.p95_latency_ms), style=_latency_style(stats.p95_latency_ms)
)
lat_text.append(" · ", style="dim")
lat_text.append("p99 ", style="dim")
lat_text.append(
_fmt_ms(stats.p99_latency_ms), style=_latency_style(stats.p99_latency_ms)
)
right.add_row("latency", lat_text)
ttft_text = Text()
ttft_text.append("p50 ", style="dim")
ttft_text.append(_fmt_ms(stats.p50_ttft_ms), style=_ttft_style(stats.p50_ttft_ms))
ttft_text.append(" · ", style="dim")
ttft_text.append("p95 ", style="dim")
ttft_text.append(_fmt_ms(stats.p95_ttft_ms), style=_ttft_style(stats.p95_ttft_ms))
ttft_text.append(" · ", style="dim")
ttft_text.append("p99 ", style="dim")
ttft_text.append(_fmt_ms(stats.p99_ttft_ms), style=_ttft_style(stats.p99_ttft_ms))
right.add_row("TTFT", ttft_text)
sess = Text()
sess.append(f"{stats.distinct_sessions}")
if stats.current_session:
sess.append(" · current ", style="dim")
sess.append(stats.current_session, style="magenta")
right.add_row("sessions", sess)
grid.add_row(left, right)
return Panel(
grid,
title="[bold]live LLM traffic[/]",
border_style="cyan",
box=SIMPLE_HEAVY,
padding=(0, 1),
)
def _model_rollup_table(rollups: list[ModelRollup]) -> Table:
table = Table(
title="by model",
title_justify="left",
title_style="bold dim",
caption="cost via DigitalOcean Gradient catalog",
caption_justify="left",
caption_style="dim italic",
box=SIMPLE,
header_style="bold",
pad_edge=False,
padding=(0, 1),
)
table.add_column("model", style="cyan", no_wrap=True)
table.add_column("req", justify="right")
table.add_column("input", justify="right")
table.add_column("output", justify="right", style="green")
table.add_column("cache wr", justify="right", style="yellow")
table.add_column("cache rd", justify="right", style="yellow")
table.add_column("tok/s", justify="right")
table.add_column("cost", justify="right", style="green")
if not rollups:
table.add_row(
Text("no requests yet", style="dim italic"),
*([""] * 7),
)
return table
for r in rollups:
cost_cell = _fmt_cost(r.cost_usd) if r.has_cost else ""
table.add_row(
_truncate_model(r.model),
f"{r.requests:,}",
_fmt_tokens(r.input_tokens),
_fmt_tokens(r.output_tokens),
_fmt_int(r.cache_write),
_fmt_int(r.cache_read),
_fmt_tps(r.avg_tokens_per_sec),
cost_cell,
)
return table
def _route_hit_table(hits: list[RouteHit]) -> Table:
table = Table(
title="route share",
title_justify="left",
title_style="bold dim",
box=SIMPLE,
header_style="bold",
pad_edge=False,
padding=(0, 1),
)
table.add_column("route", style="cyan")
table.add_column("hits", justify="right")
table.add_column("%", justify="right")
table.add_column("p95", justify="right")
table.add_column("err", justify="right")
for h in hits:
err_cell = (
Text(f"{h.error_count:,}", style="red bold") if h.error_count else ""
)
table.add_row(
h.route,
f"{h.hits:,}",
f"{h.pct:5.1f}%",
Text(_fmt_ms(h.p95_latency_ms), style=_latency_style(h.p95_latency_ms)),
err_cell,
)
return table
def _recent_table(calls: list[LLMCall], limit: int = 15) -> Table:
show_route = any(c.route_name for c in calls)
show_cache = any((c.cached_input_tokens or 0) > 0 for c in calls)
show_rsn = any((c.reasoning_tokens or 0) > 0 for c in calls)
caption_parts = ["in·new = fresh prompt tokens"]
if show_cache:
caption_parts.append("in·cache = cached read")
if show_rsn:
caption_parts.append("rsn = reasoning")
caption_parts.append("lat = total latency")
table = Table(
title=f"recent · last {min(limit, len(calls)) if calls else 0}",
title_justify="left",
title_style="bold dim",
caption=" · ".join(caption_parts),
caption_justify="left",
caption_style="dim italic",
box=SIMPLE,
header_style="bold",
pad_edge=False,
padding=(0, 1),
)
table.add_column("time", no_wrap=True)
table.add_column("model", style="cyan", no_wrap=True)
if show_route:
table.add_column("route", style="yellow", no_wrap=True)
table.add_column("in·new", justify="right")
if show_cache:
table.add_column("in·cache", justify="right", style="yellow")
table.add_column("out", justify="right", style="green")
if show_rsn:
table.add_column("rsn", justify="right")
table.add_column("tok/s", justify="right")
table.add_column("TTFT", justify="right")
table.add_column("lat", justify="right")
table.add_column("cost", justify="right", style="green")
table.add_column("status")
if not calls:
cols = len(table.columns)
table.add_row(
Text("waiting for spans…", style="dim italic"),
*([""] * (cols - 1)),
)
return table
recent = list(reversed(calls))[:limit]
for idx, c in enumerate(recent):
is_newest = idx == 0
time_style = "bold white" if is_newest else None
model_style = "bold cyan" if is_newest else "cyan"
row: list[object] = [
(
Text(c.timestamp.strftime("%H:%M:%S"), style=time_style)
if time_style
else c.timestamp.strftime("%H:%M:%S")
),
Text(_truncate_model(c.model), style=model_style),
]
if show_route:
row.append(c.route_name or "")
row.append(_fmt_tokens(c.prompt_tokens))
if show_cache:
row.append(_fmt_int(c.cached_input_tokens))
row.append(_fmt_tokens(c.completion_tokens))
if show_rsn:
row.append(_fmt_int(c.reasoning_tokens))
row.extend(
[
_fmt_tps(c.tokens_per_sec),
Text(_fmt_ms(c.ttft_ms), style=_ttft_style(c.ttft_ms)),
Text(_fmt_ms(c.duration_ms), style=_latency_style(c.duration_ms)),
_fmt_cost(c.cost_usd),
_status_text(c.status_code),
]
)
table.add_row(*row)
return table
def _last_error(calls: list[LLMCall]) -> LLMCall | None:
for c in reversed(calls):
if c.status_code is not None and c.status_code >= 400:
return c
return None
def _http_reason(code: int) -> str:
try:
return HTTPStatus(code).phrase
except ValueError:
return ""
def _fmt_ago(ts: datetime) -> str:
# `ts` is produced in collector.py via datetime.now(tz=...), but fall back
# gracefully if a naive timestamp ever sneaks in.
now = datetime.now(tz=ts.tzinfo) if ts.tzinfo else datetime.now()
delta = (now - ts).total_seconds()
if delta < 0:
delta = 0
if delta < 60:
return f"{int(delta)}s ago"
if delta < 3600:
return f"{int(delta // 60)}m ago"
return f"{int(delta // 3600)}h ago"
def _error_banner(call: LLMCall) -> Panel:
code = call.status_code or 0
border = "red" if code >= 500 else "yellow"
header = Text()
header.append(f"{code}", style=f"{border} bold")
reason = _http_reason(code)
if reason:
header.append(f" {reason}", style=border)
header.append(" · ", style="dim")
header.append(_truncate_model(call.model, 48), style="cyan")
if call.route_name:
header.append(" · ", style="dim")
header.append(call.route_name, style="yellow")
header.append(" · ", style="dim")
header.append(_fmt_ago(call.timestamp), style="dim")
if call.request_id:
header.append(" · req ", style="dim")
header.append(call.request_id, style="magenta")
return Panel(
header,
title="[bold]last error[/]",
title_align="left",
border_style=border,
box=SIMPLE,
padding=(0, 1),
)
def _footer(stats: AggregateStats) -> Text:
waiting = stats.count == 0
text = Text()
text.append("Ctrl-C ", style="bold")
text.append("exit", style="dim")
text.append(" · OTLP :4317", style="dim")
text.append(" · pricing: DigitalOcean ", style="dim")
if waiting:
text.append("waiting for spans", style="yellow")
text.append(
" — set tracing.opentracing_grpc_endpoint=localhost:4317", style="dim"
)
else:
text.append(f"receiving · {stats.count:,} call(s) buffered", style="green")
return text
def render(calls: list[LLMCall]) -> Align:
last = calls[-1] if calls else None
stats = aggregates(calls)
rollups = model_rollups(calls)
hits = route_hits(calls)
parts: list[object] = [_summary_panel(last, stats)]
err = _last_error(calls)
if err is not None:
parts.append(_error_banner(err))
if hits:
split = Table.grid(padding=(0, 2))
split.add_column(no_wrap=False)
split.add_column(no_wrap=False)
split.add_row(_model_rollup_table(rollups), _route_hit_table(hits))
parts.append(split)
else:
parts.append(_model_rollup_table(rollups))
parts.append(_recent_table(calls))
parts.append(_footer(stats))
# Cap overall width so wide terminals don't stretch the layout into a
# mostly-whitespace gap between columns.
return Align.left(Group(*parts), width=MAX_WIDTH)

99
cli/planoai/obs_cmd.py Normal file
View file

@ -0,0 +1,99 @@
"""`planoai obs` — live observability TUI."""
from __future__ import annotations
import time
import rich_click as click
from rich.console import Console
from rich.live import Live
from planoai.consts import PLANO_COLOR
from planoai.obs.collector import (
DEFAULT_CAPACITY,
DEFAULT_GRPC_PORT,
LLMCallStore,
ObsCollector,
)
from planoai.obs.pricing import PricingCatalog
from planoai.obs.render import render
@click.command(name="obs", help="Live observability console for Plano LLM traffic.")
@click.option(
"--port",
type=int,
default=DEFAULT_GRPC_PORT,
show_default=True,
help="OTLP/gRPC port to listen on. Must match the brightstaff tracing endpoint.",
)
@click.option(
"--host",
type=str,
default="0.0.0.0",
show_default=True,
help="Host to bind the OTLP listener.",
)
@click.option(
"--capacity",
type=int,
default=DEFAULT_CAPACITY,
show_default=True,
help="Max LLM calls kept in memory; older calls evicted FIFO.",
)
@click.option(
"--refresh-ms",
type=int,
default=500,
show_default=True,
help="TUI refresh interval.",
)
def obs(port: int, host: str, capacity: int, refresh_ms: int) -> None:
console = Console()
console.print(
f"[bold {PLANO_COLOR}]planoai obs[/] — loading DO pricing catalog...",
end="",
)
pricing = PricingCatalog.fetch()
if len(pricing):
sample = ", ".join(pricing.sample_models(3))
console.print(
f" [green]{len(pricing)} models loaded[/] [dim]({sample}, ...)[/]"
)
else:
console.print(
" [yellow]no pricing loaded[/] — "
"[dim]cost column will be blank (DO catalog unreachable)[/]"
)
store = LLMCallStore(capacity=capacity)
collector = ObsCollector(store=store, pricing=pricing, host=host, port=port)
try:
collector.start()
except OSError as exc:
console.print(f"[red]{exc}[/]")
raise SystemExit(1)
console.print(
f"Listening for OTLP spans on [bold]{host}:{port}[/]. "
"Ensure plano config has [cyan]tracing.opentracing_grpc_endpoint: http://localhost:4317[/] "
"and [cyan]tracing.random_sampling: 100[/] (or run [bold]planoai up[/] "
"with no config — it wires this automatically)."
)
console.print("Press [bold]Ctrl-C[/] to exit.\n")
refresh = max(0.05, refresh_ms / 1000.0)
try:
with Live(
render(store.snapshot()),
console=console,
refresh_per_second=1.0 / refresh,
screen=False,
) as live:
while True:
time.sleep(refresh)
live.update(render(store.snapshot()))
except KeyboardInterrupt:
console.print("\n[dim]obs stopped[/]")
finally:
collector.stop()

View file

@ -61,7 +61,7 @@ def configure_rich_click(plano_color: str) -> None:
},
{
"name": "Observability",
"commands": ["trace"],
"commands": ["trace", "obs"],
},
{
"name": "Utilities",

View file

@ -189,26 +189,26 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
if isinstance(
arg.annotation.value, ast.Name
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
param_info[
"type"
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
param_info["type"] = (
f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
)
else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
# Default for unknown types
else:
param_info[
"type"
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
param_info["type"] = (
"[UNKNOWN - PLEASE FIX]" # If unable to detect type
)
# Handle default values
if default is not None:
if isinstance(default, ast.Constant) or isinstance(
default, ast.NameConstant
):
param_info[
"default"
] = default.value # Use the default value directly
param_info["default"] = (
default.value
) # Use the default value directly
else:
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
param_info["required"] = False # Optional since it has a default value

View file

@ -6,7 +6,6 @@ import yaml
import logging
from planoai.consts import PLANO_DOCKER_NAME
# Standard env var for log level across all Plano components
LOG_LEVEL_ENV = "LOG_LEVEL"
@ -89,19 +88,24 @@ def convert_legacy_listeners(
) -> tuple[list, dict | None, dict | None]:
llm_gateway_listener = {
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"address": "0.0.0.0",
"timeout": "30s",
# LLM streaming responses routinely exceed 30s (extended thinking,
# long tool reasoning, large completions). Match the 300s ceiling
# used by the direct upstream-provider routes so Envoy doesn't
# abort streams with UT mid-response. Users can override via their
# plano_config.yaml `listeners.timeout` field.
"timeout": "300s",
"model_providers": model_providers or [],
}
prompt_gateway_listener = {
"name": "ingress_traffic",
"type": "prompt_listener",
"type": "prompt",
"port": 10000,
"address": "0.0.0.0",
"timeout": "30s",
"timeout": "300s",
}
# Handle None case

View file

@ -1,6 +1,6 @@
[project]
name = "planoai"
version = "0.4.11"
version = "0.4.21"
description = "Python-based CLI tool to manage Plano."
authors = [{name = "Katanemo Labs, Inc."}]
readme = "README.md"
@ -13,7 +13,7 @@ dependencies = [
"opentelemetry-proto>=1.20.0",
"questionary>=2.1.1,<3.0.0",
"pyyaml>=6.0.2,<7.0.0",
"requests>=2.31.0,<3.0.0",
"requests>=2.33.0,<3.0.0",
"urllib3>=2.6.3",
"rich>=14.2.0",
"rich-click>=1.9.5",

View file

@ -1,7 +1,11 @@
import json
import pytest
import yaml
from unittest import mock
from planoai.config_generator import validate_and_render_schema
from planoai.config_generator import (
validate_and_render_schema,
migrate_inline_routing_preferences,
)
@pytest.fixture(autouse=True)
@ -104,13 +108,13 @@ listeners:
agents:
- id: simple_tmobile_rag_agent
description: t-mobile virtual assistant for device contracts.
filter_chain:
input_filters:
- query_rewriter
- context_builder
- response_generator
- id: research_agent
description: agent to research and gather information from various sources.
filter_chain:
input_filters:
- research_agent
- response_generator
port: 8000
@ -253,38 +257,72 @@ llm_providers:
base_url: "http://custom.com/api/v2"
provider_interface: openai
""",
},
{
"id": "vercel_is_supported_provider",
"expected_error": None,
"plano_config": """
version: v0.4.0
listeners:
- name: llm
type: model
port: 12000
model_providers:
- model: vercel/*
base_url: https://ai-gateway.vercel.sh/v1
passthrough_auth: true
""",
},
{
"id": "openrouter_is_supported_provider",
"expected_error": None,
"plano_config": """
version: v0.4.0
listeners:
- name: llm
type: model
port: 12000
model_providers:
- model: openrouter/*
base_url: https://openrouter.ai/api/v1
passthrough_auth: true
""",
},
{
"id": "duplicate_routeing_preference_name",
"expected_error": "Duplicate routing preference name",
"plano_config": """
version: v0.1.0
version: v0.4.0
listeners:
egress_traffic:
address: 0.0.0.0
- name: llm
type: model
port: 12000
message_format: openai
timeout: 30s
llm_providers:
model_providers:
- model: openai/gpt-4o-mini
access_key: $OPENAI_API_KEY
default: true
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
routing_preferences:
- name: code understanding
description: understand and explain existing code snippets, functions, or libraries
- model: openai/gpt-4.1
access_key: $OPENAI_API_KEY
routing_preferences:
- name: code understanding
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
routing_preferences:
- name: code understanding
description: understand and explain existing code snippets, functions, or libraries
models:
- openai/gpt-4o
- name: code understanding
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
models:
- openai/gpt-4o-mini
tracing:
random_sampling: 100
@ -376,7 +414,7 @@ def test_convert_legacy_llm_providers():
assert updated_providers == [
{
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"address": "0.0.0.0",
"timeout": "30s",
@ -384,7 +422,7 @@ def test_convert_legacy_llm_providers():
},
{
"name": "ingress_traffic",
"type": "prompt_listener",
"type": "prompt",
"port": 10000,
"address": "0.0.0.0",
"timeout": "30s",
@ -400,7 +438,7 @@ def test_convert_legacy_llm_providers():
},
],
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"timeout": "30s",
}
@ -410,7 +448,7 @@ def test_convert_legacy_llm_providers():
"name": "ingress_traffic",
"port": 10000,
"timeout": "30s",
"type": "prompt_listener",
"type": "prompt",
}
@ -449,7 +487,7 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
"name": "egress_traffic",
"port": 12000,
"timeout": "30s",
"type": "model_listener",
"type": "model",
}
]
assert llm_gateway == {
@ -461,7 +499,242 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
},
],
"name": "egress_traffic",
"type": "model_listener",
"type": "model",
"port": 12000,
"timeout": "30s",
}
def test_inline_routing_preferences_migrated_to_top_level():
plano_config = """
version: v0.3.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o-mini
access_key: $OPENAI_API_KEY
default: true
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
routing_preferences:
- name: code understanding
description: understand and explain existing code snippets, functions, or libraries
- model: anthropic/claude-sonnet-4-20250514
access_key: $ANTHROPIC_API_KEY
routing_preferences:
- name: code generation
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
"""
config_yaml = yaml.safe_load(plano_config)
migrate_inline_routing_preferences(config_yaml)
assert config_yaml["version"] == "v0.4.0"
for provider in config_yaml["model_providers"]:
assert "routing_preferences" not in provider
top_level = config_yaml["routing_preferences"]
by_name = {entry["name"]: entry for entry in top_level}
assert set(by_name) == {"code understanding", "code generation"}
assert by_name["code understanding"]["models"] == ["openai/gpt-4o"]
assert by_name["code generation"]["models"] == [
"anthropic/claude-sonnet-4-20250514"
]
assert (
by_name["code understanding"]["description"]
== "understand and explain existing code snippets, functions, or libraries"
)
def test_inline_same_name_across_providers_merges_models():
plano_config = """
version: v0.3.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
routing_preferences:
- name: code generation
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
- model: anthropic/claude-sonnet-4-20250514
access_key: $ANTHROPIC_API_KEY
routing_preferences:
- name: code generation
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
"""
config_yaml = yaml.safe_load(plano_config)
migrate_inline_routing_preferences(config_yaml)
top_level = config_yaml["routing_preferences"]
assert len(top_level) == 1
entry = top_level[0]
assert entry["name"] == "code generation"
assert entry["models"] == [
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
]
assert config_yaml["version"] == "v0.4.0"
def test_existing_top_level_routing_preferences_preserved():
plano_config = """
version: v0.4.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
- model: anthropic/claude-sonnet-4-20250514
access_key: $ANTHROPIC_API_KEY
routing_preferences:
- name: code generation
description: generating new code snippets or boilerplate
models:
- openai/gpt-4o
- anthropic/claude-sonnet-4-20250514
"""
config_yaml = yaml.safe_load(plano_config)
before = yaml.safe_dump(config_yaml, sort_keys=True)
migrate_inline_routing_preferences(config_yaml)
after = yaml.safe_dump(config_yaml, sort_keys=True)
assert before == after
def test_existing_top_level_wins_over_inline_migration():
plano_config = """
version: v0.3.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
routing_preferences:
- name: code generation
description: inline description should lose
routing_preferences:
- name: code generation
description: user-defined top-level description wins
models:
- openai/gpt-4o
"""
config_yaml = yaml.safe_load(plano_config)
migrate_inline_routing_preferences(config_yaml)
top_level = config_yaml["routing_preferences"]
assert len(top_level) == 1
entry = top_level[0]
assert entry["description"] == "user-defined top-level description wins"
assert entry["models"] == ["openai/gpt-4o"]
def test_wildcard_with_inline_routing_preferences_errors():
plano_config = """
version: v0.3.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openrouter/*
base_url: https://openrouter.ai/api/v1
passthrough_auth: true
routing_preferences:
- name: code generation
description: generating code
"""
config_yaml = yaml.safe_load(plano_config)
with pytest.raises(Exception) as excinfo:
migrate_inline_routing_preferences(config_yaml)
assert "wildcard" in str(excinfo.value).lower()
def test_migration_bumps_version_even_without_inline_preferences():
plano_config = """
version: v0.3.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
"""
config_yaml = yaml.safe_load(plano_config)
migrate_inline_routing_preferences(config_yaml)
assert "routing_preferences" not in config_yaml
assert config_yaml["version"] == "v0.4.0"
def test_migration_is_noop_on_v040_config_with_stray_inline_preferences():
# v0.4.0 configs are assumed to be on the canonical top-level shape.
# The migration intentionally does not rescue stray inline preferences
# at v0.4.0+ so that the deprecation boundary is a clean version gate.
plano_config = """
version: v0.4.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
routing_preferences:
- name: code generation
description: generating new code
"""
config_yaml = yaml.safe_load(plano_config)
migrate_inline_routing_preferences(config_yaml)
assert config_yaml["version"] == "v0.4.0"
assert "routing_preferences" not in config_yaml
assert config_yaml["model_providers"][0]["routing_preferences"] == [
{"name": "code generation", "description": "generating new code"}
]
def test_migration_does_not_downgrade_newer_versions():
plano_config = """
version: v0.5.0
listeners:
- type: model
name: model_listener
port: 12000
model_providers:
- model: openai/gpt-4o
access_key: $OPENAI_API_KEY
"""
config_yaml = yaml.safe_load(plano_config)
migrate_inline_routing_preferences(config_yaml)
assert config_yaml["version"] == "v0.5.0"

111
cli/test/test_defaults.py Normal file
View file

@ -0,0 +1,111 @@
from pathlib import Path
import jsonschema
import yaml
from planoai.defaults import (
PROVIDER_DEFAULTS,
detect_providers,
synthesize_default_config,
)
_SCHEMA_PATH = Path(__file__).parents[2] / "config" / "plano_config_schema.yaml"
def _schema() -> dict:
return yaml.safe_load(_SCHEMA_PATH.read_text())
def test_zero_env_vars_produces_pure_passthrough():
cfg = synthesize_default_config(env={})
assert cfg["version"] == "v0.4.0"
assert cfg["listeners"][0]["port"] == 12000
for provider in cfg["model_providers"]:
assert provider.get("passthrough_auth") is True
assert "access_key" not in provider
# No provider should be marked default in pure pass-through mode.
assert provider.get("default") is not True
# All known providers should be listed.
names = {p["name"] for p in cfg["model_providers"]}
assert "digitalocean" in names
assert "vercel" in names
assert "openrouter" in names
assert "openai" in names
assert "anthropic" in names
def test_env_keys_promote_providers_to_env_keyed():
cfg = synthesize_default_config(
env={"OPENAI_API_KEY": "sk-1", "DO_API_KEY": "do-1"}
)
by_name = {p["name"]: p for p in cfg["model_providers"]}
assert by_name["openai"].get("access_key") == "$OPENAI_API_KEY"
assert by_name["openai"].get("passthrough_auth") is None
assert by_name["digitalocean"].get("access_key") == "$DO_API_KEY"
# Unset env keys remain pass-through.
assert by_name["anthropic"].get("passthrough_auth") is True
def test_no_default_is_synthesized():
# Bare model names resolve via brightstaff's wildcard expansion registering
# bare keys, so the synthesizer intentionally never sets `default: true`.
cfg = synthesize_default_config(
env={"OPENAI_API_KEY": "sk-1", "ANTHROPIC_API_KEY": "a-1"}
)
assert not any(p.get("default") is True for p in cfg["model_providers"])
def test_listener_port_is_configurable():
cfg = synthesize_default_config(env={}, listener_port=11000)
assert cfg["listeners"][0]["port"] == 11000
def test_detection_summary_strings():
det = detect_providers(env={"OPENAI_API_KEY": "sk", "DO_API_KEY": "d"})
summary = det.summary
assert "env-keyed" in summary and "openai" in summary and "digitalocean" in summary
assert "pass-through" in summary
def test_tracing_block_points_at_local_console():
cfg = synthesize_default_config(env={})
tracing = cfg["tracing"]
assert tracing["opentracing_grpc_endpoint"] == "http://localhost:4317"
# random_sampling is a percentage in the plano config — 100 = every span.
assert tracing["random_sampling"] == 100
def test_synthesized_config_validates_against_schema():
cfg = synthesize_default_config(env={"OPENAI_API_KEY": "sk"})
jsonschema.validate(cfg, _schema())
def test_provider_defaults_digitalocean_is_configured():
by_name = {p.name: p for p in PROVIDER_DEFAULTS}
assert "digitalocean" in by_name
assert by_name["digitalocean"].env_var == "DO_API_KEY"
assert by_name["digitalocean"].base_url == "https://inference.do-ai.run/v1"
assert by_name["digitalocean"].model_pattern == "digitalocean/*"
def test_provider_defaults_vercel_is_configured():
by_name = {p.name: p for p in PROVIDER_DEFAULTS}
assert "vercel" in by_name
assert by_name["vercel"].env_var == "AI_GATEWAY_API_KEY"
assert by_name["vercel"].base_url == "https://ai-gateway.vercel.sh/v1"
assert by_name["vercel"].model_pattern == "vercel/*"
def test_provider_defaults_openrouter_is_configured():
by_name = {p.name: p for p in PROVIDER_DEFAULTS}
assert "openrouter" in by_name
assert by_name["openrouter"].env_var == "OPENROUTER_API_KEY"
assert by_name["openrouter"].base_url == "https://openrouter.ai/api/v1"
assert by_name["openrouter"].model_pattern == "openrouter/*"
def test_openrouter_env_key_promotes_to_env_keyed():
cfg = synthesize_default_config(env={"OPENROUTER_API_KEY": "or-1"})
by_name = {p["name"]: p for p in cfg["model_providers"]}
assert by_name["openrouter"].get("access_key") == "$OPENROUTER_API_KEY"
assert by_name["openrouter"].get("passthrough_auth") is None

View file

@ -0,0 +1,145 @@
import time
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from planoai.obs.collector import LLMCall, LLMCallStore, span_to_llm_call
def _mk_attr(key: str, value):
v = MagicMock()
if isinstance(value, bool):
v.WhichOneof.return_value = "bool_value"
v.bool_value = value
elif isinstance(value, int):
v.WhichOneof.return_value = "int_value"
v.int_value = value
elif isinstance(value, float):
v.WhichOneof.return_value = "double_value"
v.double_value = value
else:
v.WhichOneof.return_value = "string_value"
v.string_value = str(value)
kv = MagicMock()
kv.key = key
kv.value = v
return kv
def _mk_span(
attrs: dict, start_ns: int | None = None, span_id_hex: str = "ab"
) -> MagicMock:
span = MagicMock()
span.attributes = [_mk_attr(k, v) for k, v in attrs.items()]
span.start_time_unix_nano = start_ns or int(time.time() * 1_000_000_000)
span.span_id.hex.return_value = span_id_hex
return span
def test_span_without_llm_model_is_ignored():
span = _mk_span({"http.method": "POST"})
assert span_to_llm_call(span, "plano(llm)") is None
def test_span_with_full_llm_attrs_produces_call():
span = _mk_span(
{
"llm.model": "openai-gpt-5.4",
"model.requested": "router:software-engineering",
"plano.session_id": "sess-abc",
"plano.route.name": "software-engineering",
"llm.is_streaming": False,
"llm.duration_ms": 1234,
"llm.time_to_first_token": 210,
"llm.usage.prompt_tokens": 100,
"llm.usage.completion_tokens": 50,
"llm.usage.total_tokens": 150,
"llm.usage.cached_input_tokens": 30,
"llm.usage.cache_creation_tokens": 5,
"llm.usage.reasoning_tokens": 200,
"http.status_code": 200,
"request_id": "req-42",
}
)
call = span_to_llm_call(span, "plano(llm)")
assert call is not None
assert call.request_id == "req-42"
assert call.model == "openai-gpt-5.4"
assert call.request_model == "router:software-engineering"
assert call.session_id == "sess-abc"
assert call.route_name == "software-engineering"
assert call.is_streaming is False
assert call.duration_ms == 1234.0
assert call.ttft_ms == 210.0
assert call.prompt_tokens == 100
assert call.completion_tokens == 50
assert call.total_tokens == 150
assert call.cached_input_tokens == 30
assert call.cache_creation_tokens == 5
assert call.reasoning_tokens == 200
assert call.status_code == 200
def test_pricing_lookup_attaches_cost():
class StubPricing:
def cost_for_call(self, call):
# Simple: 2 * prompt + 3 * completion, in cents
return 0.02 * (call.prompt_tokens or 0) + 0.03 * (
call.completion_tokens or 0
)
span = _mk_span(
{
"llm.model": "do/openai-gpt-5.4",
"llm.usage.prompt_tokens": 10,
"llm.usage.completion_tokens": 2,
}
)
call = span_to_llm_call(span, "plano(llm)", pricing=StubPricing())
assert call is not None
assert call.cost_usd == pytest.approx(0.26)
def test_tpt_and_tokens_per_sec_derived():
call = LLMCall(
request_id="x",
timestamp=datetime.now(tz=timezone.utc),
model="m",
duration_ms=1000,
ttft_ms=200,
completion_tokens=80,
)
# (1000 - 200) / 80 = 10ms per token => 100 tokens/sec
assert call.tpt_ms == 10.0
assert call.tokens_per_sec == 100.0
def test_tpt_returns_none_when_no_completion_tokens():
call = LLMCall(
request_id="x",
timestamp=datetime.now(tz=timezone.utc),
model="m",
duration_ms=1000,
ttft_ms=200,
completion_tokens=0,
)
assert call.tpt_ms is None
assert call.tokens_per_sec is None
def test_store_evicts_fifo_at_capacity():
store = LLMCallStore(capacity=3)
now = datetime.now(tz=timezone.utc)
for i in range(5):
store.add(
LLMCall(
request_id=f"r{i}",
timestamp=now,
model="m",
)
)
snap = store.snapshot()
assert len(snap) == 3
assert [c.request_id for c in snap] == ["r2", "r3", "r4"]

View file

@ -0,0 +1,146 @@
from datetime import datetime, timezone
from planoai.obs.collector import LLMCall
from planoai.obs.pricing import ModelPrice, PricingCatalog
def _call(model: str, prompt: int, completion: int, cached: int = 0) -> LLMCall:
return LLMCall(
request_id="r",
timestamp=datetime.now(tz=timezone.utc),
model=model,
prompt_tokens=prompt,
completion_tokens=completion,
cached_input_tokens=cached,
)
def test_lookup_matches_bare_and_prefixed():
prices = {
"openai-gpt-5.4": ModelPrice(
input_per_token_usd=0.000001, output_per_token_usd=0.000002
)
}
catalog = PricingCatalog(prices)
assert catalog.price_for("openai-gpt-5.4") is not None
# do/openai-gpt-5.4 should resolve after stripping the provider prefix.
assert catalog.price_for("do/openai-gpt-5.4") is not None
assert catalog.price_for("unknown-model") is None
def test_cost_computation_without_cache():
prices = {
"m": ModelPrice(input_per_token_usd=0.000001, output_per_token_usd=0.000002)
}
cost = PricingCatalog(prices).cost_for_call(_call("m", 1000, 500))
assert cost == 0.002 # 1000 * 1e-6 + 500 * 2e-6
def test_cost_computation_with_cached_discount():
prices = {
"m": ModelPrice(
input_per_token_usd=0.000001,
output_per_token_usd=0.000002,
cached_input_per_token_usd=0.0000001,
)
}
# 800 fresh @ 1e-6 = 8e-4; 200 cached @ 1e-7 = 2e-5; 500 out @ 2e-6 = 1e-3
cost = PricingCatalog(prices).cost_for_call(_call("m", 1000, 500, cached=200))
assert cost == round(0.0008 + 0.00002 + 0.001, 6)
def test_empty_catalog_returns_none():
assert PricingCatalog().cost_for_call(_call("m", 100, 50)) is None
def test_parse_do_catalog_treats_small_values_as_per_token():
"""DO's real catalog uses per-token values under the `_per_million` key
(e.g. 5E-8 for GPT-oss-20b). We treat values < 1 as already per-token."""
from planoai.obs.pricing import _parse_do_pricing
sample = {
"data": [
{
"model_id": "openai-gpt-oss-20b",
"pricing": {
"input_price_per_million": 5e-8,
"output_price_per_million": 4.5e-7,
},
},
{
"model_id": "openai-gpt-oss-120b",
"pricing": {
"input_price_per_million": 1e-7,
"output_price_per_million": 7e-7,
},
},
]
}
prices = _parse_do_pricing(sample)
# Values < 1 are assumed to already be per-token — no extra division.
assert prices["openai-gpt-oss-20b"].input_per_token_usd == 5e-8
assert prices["openai-gpt-oss-20b"].output_per_token_usd == 4.5e-7
assert prices["openai-gpt-oss-120b"].input_per_token_usd == 1e-7
def test_anthropic_aliases_match_plano_emitted_names():
"""DO publishes 'anthropic-claude-opus-4.7' and 'anthropic-claude-haiku-4.5';
Plano emits 'claude-opus-4-7' and 'claude-haiku-4-5-20251001'. Aliases
registered at parse time should bridge the gap."""
from planoai.obs.pricing import _parse_do_pricing
sample = {
"data": [
{
"model_id": "anthropic-claude-opus-4.7",
"pricing": {
"input_price_per_million": 15.0,
"output_price_per_million": 75.0,
},
},
{
"model_id": "anthropic-claude-haiku-4.5",
"pricing": {
"input_price_per_million": 1.0,
"output_price_per_million": 5.0,
},
},
{
"model_id": "anthropic-claude-4.6-sonnet",
"pricing": {
"input_price_per_million": 3.0,
"output_price_per_million": 15.0,
},
},
]
}
catalog = PricingCatalog(_parse_do_pricing(sample))
# Family-last shapes Plano emits.
assert catalog.price_for("claude-opus-4-7") is not None
assert catalog.price_for("claude-haiku-4-5") is not None
# Date-suffixed name (Anthropic API style).
assert catalog.price_for("claude-haiku-4-5-20251001") is not None
# Word-order swap: DO has 'claude-4.6-sonnet', Plano emits 'claude-sonnet-4-6'.
assert catalog.price_for("claude-sonnet-4-6") is not None
# Original DO ids still resolve.
assert catalog.price_for("anthropic-claude-opus-4.7") is not None
def test_parse_do_catalog_divides_large_values_as_per_million():
"""A provider that genuinely reports $5-per-million in that field gets divided."""
from planoai.obs.pricing import _parse_do_pricing
sample = {
"data": [
{
"model_id": "mystery-model",
"pricing": {
"input_price_per_million": 5.0, # > 1 → treated as per-million
"output_price_per_million": 15.0,
},
},
]
}
prices = _parse_do_pricing(sample)
assert prices["mystery-model"].input_per_token_usd == 5.0 / 1_000_000
assert prices["mystery-model"].output_per_token_usd == 15.0 / 1_000_000

106
cli/test/test_obs_render.py Normal file
View file

@ -0,0 +1,106 @@
from datetime import datetime, timedelta, timezone
from planoai.obs.collector import LLMCall
from planoai.obs.render import aggregates, model_rollups, route_hits
def _call(
model: str,
ts: datetime,
prompt=0,
completion=0,
cost=None,
route=None,
session=None,
cache_read=0,
cache_write=0,
):
return LLMCall(
request_id="r",
timestamp=ts,
model=model,
prompt_tokens=prompt,
completion_tokens=completion,
cached_input_tokens=cache_read,
cache_creation_tokens=cache_write,
cost_usd=cost,
route_name=route,
session_id=session,
)
def test_aggregates_sum_and_session_counts():
now = datetime.now(tz=timezone.utc).astimezone()
calls = [
_call(
"m1",
now - timedelta(seconds=50),
prompt=10,
completion=5,
cost=0.001,
session="s1",
),
_call(
"m2",
now - timedelta(seconds=40),
prompt=20,
completion=10,
cost=0.002,
session="s1",
),
_call(
"m1",
now - timedelta(seconds=30),
prompt=30,
completion=15,
cost=0.003,
session="s2",
),
]
stats = aggregates(calls)
assert stats.count == 3
assert stats.total_cost_usd == 0.006
assert stats.total_input_tokens == 60
assert stats.total_output_tokens == 30
assert stats.distinct_sessions == 2
assert stats.current_session == "s2"
def test_rollups_split_by_model_and_cache():
now = datetime.now(tz=timezone.utc).astimezone()
calls = [
_call(
"m1", now, prompt=10, completion=5, cost=0.001, cache_write=3, cache_read=7
),
_call("m1", now, prompt=20, completion=10, cost=0.002, cache_read=1),
_call("m2", now, prompt=30, completion=15, cost=0.004),
]
rollups = model_rollups(calls)
by_model = {r.model: r for r in rollups}
assert by_model["m1"].requests == 2
assert by_model["m1"].input_tokens == 30
assert by_model["m1"].cache_write == 3
assert by_model["m1"].cache_read == 8
assert by_model["m2"].input_tokens == 30
def test_route_hits_only_for_routed_calls():
now = datetime.now(tz=timezone.utc).astimezone()
calls = [
_call("m", now, route="code"),
_call("m", now, route="code"),
_call("m", now, route="summarization"),
_call("m", now), # no route
]
hits = route_hits(calls)
# Only calls with route names are counted.
assert sum(h.hits for h in hits) == 3
hits_by_name = {h.route: h for h in hits}
assert hits_by_name["code"].hits == 2
assert hits_by_name["summarization"].hits == 1
def test_route_hits_empty_when_no_routes():
now = datetime.now(tz=timezone.utc).astimezone()
calls = [_call("m", now), _call("m", now)]
assert route_hits(calls) == []

16
cli/uv.lock generated
View file

@ -337,7 +337,7 @@ wheels = [
[[package]]
name = "planoai"
version = "0.4.7"
version = "0.4.21"
source = { editable = "." }
dependencies = [
{ name = "click" },
@ -373,7 +373,7 @@ requires-dist = [
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.4.1,<9.0.0" },
{ name = "pyyaml", specifier = ">=6.0.2,<7.0.0" },
{ name = "questionary", specifier = ">=2.1.1,<3.0.0" },
{ name = "requests", specifier = ">=2.31.0,<3.0.0" },
{ name = "requests", specifier = ">=2.33.0,<3.0.0" },
{ name = "rich", specifier = ">=14.2.0" },
{ name = "rich-click", specifier = ">=1.9.5" },
{ name = "urllib3", specifier = ">=2.6.3" },
@ -421,11 +421,11 @@ wheels = [
[[package]]
name = "pygments"
version = "2.19.2"
version = "2.20.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
{ url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" },
]
[[package]]
@ -518,7 +518,7 @@ wheels = [
[[package]]
name = "requests"
version = "2.32.5"
version = "2.33.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
@ -526,9 +526,9 @@ dependencies = [
{ name = "idna" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
sdist = { url = "https://files.pythonhosted.org/packages/5f/a4/98b9c7c6428a668bf7e42ebb7c79d576a1c3c1e3ae2d47e674b468388871/requests-2.33.1.tar.gz", hash = "sha256:18817f8c57c6263968bc123d237e3b8b08ac046f5456bd1e307ee8f4250d3517", size = 134120, upload-time = "2026-03-30T16:09:15.531Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
{ url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" },
]
[[package]]

View file

@ -594,13 +594,13 @@ static_resources:
clusters:
- name: arch
- name: plano
connect_timeout: {{ upstream_connect_timeout | default('5s') }}
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: arch
cluster_name: plano
endpoints:
- lb_endpoints:
- endpoint:
@ -901,6 +901,60 @@ static_resources:
validation_context:
trusted_ca:
filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }}
- name: digitalocean
connect_timeout: {{ upstream_connect_timeout | default('5s') }}
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: digitalocean
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: inference.do-ai.run
port_value: 443
hostname: "inference.do-ai.run"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: inference.do-ai.run
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
validation_context:
trusted_ca:
filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }}
- name: xiaomi
connect_timeout: {{ upstream_connect_timeout | default('5s') }}
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: xiaomi
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.xiaomimimo.com
port_value: 443
hostname: "api.xiaomimimo.com"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.xiaomimimo.com
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
validation_context:
trusted_ca:
filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }}
- name: mistral_7b_instruct
connect_timeout: 0.5s
type: STRICT_DNS

View file

@ -0,0 +1,541 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": "-- Grafana --",
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"description": "RED, LLM upstream, routing service, and process metrics for brightstaff. Pair with Envoy admin metrics from cluster=bright_staff.",
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"id": null,
"links": [],
"liveNow": false,
"panels": [
{
"collapsed": false,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 },
"id": 100,
"panels": [],
"title": "HTTP RED",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisLabel": "req/s",
"drawStyle": "line",
"fillOpacity": 10,
"lineWidth": 1,
"showPoints": "never"
},
"unit": "reqps"
}
},
"gridPos": { "h": 8, "w": 12, "x": 0, "y": 1 },
"id": 1,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (handler) (rate(brightstaff_http_requests_total[1m]))",
"legendFormat": "{{handler}}",
"refId": "A"
}
],
"title": "Rate — brightstaff RPS by handler",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "5xx fraction over 5m. Page-worthy when sustained above ~1%.",
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"thresholds": {
"mode": "absolute",
"steps": [
{ "color": "green", "value": null },
{ "color": "yellow", "value": 0.01 },
{ "color": "red", "value": 0.05 }
]
},
"unit": "percentunit"
}
},
"gridPos": { "h": 8, "w": 12, "x": 12, "y": 1 },
"id": 2,
"options": {
"colorMode": "background",
"graphMode": "area",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum(rate(brightstaff_http_requests_total{status_class=\"5xx\"}[5m])) / clamp_min(sum(rate(brightstaff_http_requests_total[5m])), 1)",
"legendFormat": "5xx rate",
"refId": "A"
}
],
"title": "Errors — brightstaff 5xx rate",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "p50/p95/p99 by handler, computed from histogram buckets over 5m.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
"unit": "s"
}
},
"gridPos": { "h": 9, "w": 24, "x": 0, "y": 9 },
"id": 3,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.50, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))",
"legendFormat": "p50 {{handler}}",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))",
"legendFormat": "p95 {{handler}}",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.99, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))",
"legendFormat": "p99 {{handler}}",
"refId": "C"
}
],
"title": "Duration — p50 / p95 / p99 by handler",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "In-flight requests by handler. Climbs before latency does when brightstaff is saturated.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" },
"unit": "short"
}
},
"gridPos": { "h": 8, "w": 24, "x": 0, "y": 18 },
"id": 4,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (handler) (brightstaff_http_in_flight_requests)",
"legendFormat": "{{handler}}",
"refId": "A"
}
],
"title": "In-flight requests by handler",
"type": "timeseries"
},
{
"collapsed": false,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 26 },
"id": 200,
"panels": [],
"title": "LLM upstream",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
"unit": "s"
}
},
"gridPos": { "h": 9, "w": 12, "x": 0, "y": 27 },
"id": 5,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le, provider, model) (rate(brightstaff_llm_upstream_duration_seconds_bucket[5m])))",
"legendFormat": "p95 {{provider}}/{{model}}",
"refId": "A"
}
],
"title": "LLM upstream p95 by provider/model",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "All non-success error classes. timeout/connect = network, 5xx/429 = provider, parse = body shape mismatch, stream = mid-stream disconnect.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } },
"unit": "reqps"
}
},
"gridPos": { "h": 9, "w": 12, "x": 12, "y": 27 },
"id": 6,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (provider, error_class) (rate(brightstaff_llm_upstream_requests_total{error_class!=\"none\"}[5m]))",
"legendFormat": "{{provider}} / {{error_class}}",
"refId": "A"
}
],
"title": "LLM upstream errors by provider / class",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "Streaming only. Empty if the route never streams.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
"unit": "s"
}
},
"gridPos": { "h": 9, "w": 12, "x": 0, "y": 36 },
"id": 7,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le, provider, model) (rate(brightstaff_llm_time_to_first_token_seconds_bucket[5m])))",
"legendFormat": "p95 {{provider}}/{{model}}",
"refId": "A"
}
],
"title": "Time-to-first-token p95 (streaming)",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "Tokens/sec by provider/model/kind — proxy for cost. Stacked.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } },
"unit": "tokens/s"
}
},
"gridPos": { "h": 9, "w": 12, "x": 12, "y": 36 },
"id": 8,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (provider, model, kind) (rate(brightstaff_llm_tokens_total[5m]))",
"legendFormat": "{{provider}}/{{model}} {{kind}}",
"refId": "A"
}
],
"title": "Token throughput by provider / model / kind",
"type": "timeseries"
},
{
"collapsed": false,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 45 },
"id": 300,
"panels": [],
"title": "Routing service",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "Which models the orchestrator picked over the last 15 minutes.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"unit": "short"
}
},
"gridPos": { "h": 9, "w": 12, "x": 0, "y": 46 },
"id": 9,
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (selected_model) (increase(brightstaff_router_decisions_total[15m]))",
"legendFormat": "{{selected_model}}",
"refId": "A"
}
],
"title": "Model selection distribution (last 15m)",
"type": "bargauge"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "Fraction of decisions that fell back (orchestrator returned `none` or errored). High = router can't classify intent or no candidates configured.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" },
"unit": "percentunit"
}
},
"gridPos": { "h": 9, "w": 12, "x": 12, "y": 46 },
"id": 10,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (route) (rate(brightstaff_router_decisions_total{fallback=\"true\"}[5m])) / clamp_min(sum by (route) (rate(brightstaff_router_decisions_total[5m])), 1)",
"legendFormat": "{{route}}",
"refId": "A"
}
],
"title": "Fallback rate by route",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
"unit": "s"
}
},
"gridPos": { "h": 8, "w": 12, "x": 0, "y": 55 },
"id": 11,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le, route) (rate(brightstaff_router_decision_duration_seconds_bucket[5m])))",
"legendFormat": "p95 {{route}}",
"refId": "A"
}
],
"title": "Router decision p95 latency",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "Hit / (hit + miss). Low ratio = sessions aren't being reused or TTL too short.",
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"thresholds": {
"mode": "absolute",
"steps": [
{ "color": "red", "value": null },
{ "color": "yellow", "value": 0.5 },
{ "color": "green", "value": 0.8 }
]
},
"unit": "percentunit",
"min": 0,
"max": 1
}
},
"gridPos": { "h": 8, "w": 6, "x": 12, "y": 55 },
"id": 12,
"options": {
"colorMode": "background",
"graphMode": "area",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum(rate(brightstaff_session_cache_events_total{outcome=\"hit\"}[5m])) / clamp_min(sum(rate(brightstaff_session_cache_events_total{outcome=~\"hit|miss\"}[5m])), 1)",
"legendFormat": "hit rate",
"refId": "A"
}
],
"title": "Session cache hit rate",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "decision_served = a real model picked. no_candidates = sentinel `none` returned. policy_error = orchestrator failed.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } },
"unit": "reqps"
}
},
"gridPos": { "h": 8, "w": 6, "x": 18, "y": 55 },
"id": 13,
"options": {
"legend": { "displayMode": "list", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (outcome) (rate(brightstaff_routing_service_requests_total[5m]))",
"legendFormat": "{{outcome}}",
"refId": "A"
}
],
"title": "/routing/* outcomes",
"type": "timeseries"
},
{
"collapsed": false,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 63 },
"id": 400,
"panels": [],
"title": "Process & Envoy link",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"description": "Compare to brightstaff RPS (panel 1) — sustained gap = network or Envoy queueing.",
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" },
"unit": "reqps"
}
},
"gridPos": { "h": 8, "w": 12, "x": 0, "y": 64 },
"id": 14,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum(rate(envoy_cluster_upstream_rq_total{envoy_cluster_name=\"bright_staff\"}[1m]))",
"legendFormat": "envoy → bright_staff",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum(rate(brightstaff_http_requests_total[1m]))",
"legendFormat": "brightstaff served",
"refId": "B"
}
],
"title": "Envoy → brightstaff link health",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" }
},
"overrides": [
{
"matcher": { "id": "byName", "options": "RSS" },
"properties": [{ "id": "unit", "value": "bytes" }]
},
{
"matcher": { "id": "byName", "options": "CPU" },
"properties": [{ "id": "unit", "value": "percentunit" }]
}
]
},
"gridPos": { "h": 8, "w": 12, "x": 12, "y": 64 },
"id": 15,
"options": {
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "multi" }
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "process_resident_memory_bytes{job=\"brightstaff\"}",
"legendFormat": "RSS",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "rate(process_cpu_seconds_total{job=\"brightstaff\"}[1m])",
"legendFormat": "CPU",
"refId": "B"
}
],
"title": "Brightstaff process RSS / CPU",
"type": "timeseries"
}
],
"refresh": "30s",
"schemaVersion": 39,
"tags": ["plano", "brightstaff", "llm"],
"templating": {
"list": [
{
"name": "DS_PROMETHEUS",
"label": "Prometheus",
"type": "datasource",
"query": "prometheus",
"current": { "selected": false, "text": "Prometheus", "value": "DS_PROMETHEUS" },
"hide": 0,
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"includeAll": false,
"multi": false
}
]
},
"time": { "from": "now-1h", "to": "now" },
"timepicker": {},
"timezone": "browser",
"title": "Brightstaff (Plano dataplane)",
"uid": "brightstaff",
"version": 1,
"weekStart": ""
}

View file

@ -0,0 +1,43 @@
# One-command Prometheus + Grafana stack for observing a locally-running
# Plano (Envoy admin :9901 + brightstaff :9092 on the host).
#
# cd config/grafana
# docker compose up -d
# open http://localhost:3000 (admin / admin)
#
# Grafana is preloaded with:
# - Prometheus datasource (uid=DS_PROMETHEUS) → http://prometheus:9090
# - Brightstaff dashboard (auto-imported from brightstaff_dashboard.json)
#
# Prometheus scrapes the host's :9092 and :9901 via host.docker.internal.
# On Linux this works because of the `extra_hosts: host-gateway` mapping
# below. On Mac/Win it works natively.
services:
prometheus:
image: prom/prometheus:latest
container_name: plano-prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus_scrape.yaml:/etc/prometheus/prometheus.yml:ro
extra_hosts:
- "host.docker.internal:host-gateway"
restart: unless-stopped
grafana:
image: grafana/grafana:latest
container_name: plano-grafana
ports:
- "3000:3000"
environment:
GF_SECURITY_ADMIN_USER: admin
GF_SECURITY_ADMIN_PASSWORD: admin
GF_AUTH_ANONYMOUS_ENABLED: "true"
GF_AUTH_ANONYMOUS_ORG_ROLE: Viewer
volumes:
- ./provisioning:/etc/grafana/provisioning:ro
- ./brightstaff_dashboard.json:/var/lib/grafana/dashboards/brightstaff_dashboard.json:ro
depends_on:
- prometheus
restart: unless-stopped

View file

@ -0,0 +1,44 @@
# Prometheus config that scrapes Plano (Envoy admin + brightstaff). This is
# a complete Prometheus config — mount it directly at
# /etc/prometheus/prometheus.yml. The included docker-compose.yaml does this
# for you.
#
# Targets:
# - envoy:9901 Envoy admin → envoy_cluster_*, envoy_http_*, envoy_server_*.
# - brightstaff:9092 Native dataplane → brightstaff_http_*, brightstaff_llm_*,
# brightstaff_router_*, process_*.
#
# Hostname `host.docker.internal` works on Docker Desktop (Mac/Win) and on
# Linux when the container is started with `--add-host=host.docker.internal:
# host-gateway` (the included compose does this). If Plano runs *inside*
# Docker on the same network as Prometheus, replace it with the container
# name (e.g. `plano:9092`).
#
# This file is unrelated to demos/llm_routing/model_routing_service/prometheus.yaml,
# which scrapes a fake metrics service to feed the routing engine.
global:
scrape_interval: 15s
scrape_timeout: 10s
evaluation_interval: 15s
scrape_configs:
- job_name: envoy
honor_timestamps: true
metrics_path: /stats
params:
format: ["prometheus"]
static_configs:
- targets:
- host.docker.internal:9901
labels:
service: plano
- job_name: brightstaff
honor_timestamps: true
metrics_path: /metrics
static_configs:
- targets:
- host.docker.internal:9092
labels:
service: plano

View file

@ -0,0 +1,15 @@
# Auto-load the brightstaff dashboard JSON on Grafana startup.
apiVersion: 1
providers:
- name: brightstaff
orgId: 1
folder: Plano
type: file
disableDeletion: false
updateIntervalSeconds: 30
allowUiUpdates: true
options:
path: /var/lib/grafana/dashboards
foldersFromFilesStructure: false

View file

@ -0,0 +1,14 @@
# Auto-provision the Prometheus datasource so the bundled dashboard wires up
# without any clicks. The `uid: DS_PROMETHEUS` matches the templated input in
# brightstaff_dashboard.json.
apiVersion: 1
datasources:
- name: Prometheus
uid: DS_PROMETHEUS
type: prometheus
access: proxy
url: http://prometheus:9090
isDefault: true
editable: true

View file

@ -9,6 +9,7 @@ properties:
- 0.1-beta
- 0.2.0
- v0.3.0
- v0.4.0
agents:
type: array
@ -85,7 +86,7 @@ properties:
type: string
default:
type: boolean
filter_chain:
input_filters:
type: array
items:
type: string
@ -93,6 +94,14 @@ properties:
required:
- id
- description
input_filters:
type: array
items:
type: string
output_filters:
type: array
items:
type: string
additionalProperties: false
required:
- type
@ -173,15 +182,26 @@ properties:
provider_interface:
type: string
enum:
- arch
- plano
- claude
- deepseek
- groq
- mistral
- openai
- xiaomi
- gemini
- chatgpt
- digitalocean
- vercel
- openrouter
headers:
type: object
additionalProperties:
type: string
description: "Additional headers to send with upstream requests (e.g., ChatGPT-Account-Id, originator)."
routing_preferences:
type: array
description: "[DEPRECATED] Inline routing_preferences under a model_provider are auto-migrated to the top-level routing_preferences list by the config generator. New configs should declare routing_preferences at the top level with an explicit models: [...] list. See docs/routing-api.md."
items:
type: object
properties:
@ -220,15 +240,26 @@ properties:
provider_interface:
type: string
enum:
- arch
- plano
- claude
- deepseek
- groq
- mistral
- openai
- xiaomi
- gemini
- chatgpt
- digitalocean
- vercel
- openrouter
headers:
type: object
additionalProperties:
type: string
description: "Additional headers to send with upstream requests (e.g., ChatGPT-Account-Id, originator)."
routing_preferences:
type: array
description: "[DEPRECATED] Inline routing_preferences under an llm_provider are auto-migrated to the top-level routing_preferences list by the config generator. New configs should declare routing_preferences at the top level with an explicit models: [...] list. See docs/routing-api.md."
items:
type: object
properties:
@ -265,12 +296,24 @@ properties:
type: boolean
use_agent_orchestrator:
type: boolean
disable_signals:
type: boolean
description: "Disable agentic signal analysis (frustration, repetition, escalation, etc.) on LLM responses to save CPU. Default false."
upstream_connect_timeout:
type: string
description: "Connect timeout for upstream provider clusters (e.g., '5s', '10s'). Default is '5s'."
upstream_tls_ca_path:
type: string
description: "Path to the trusted CA bundle for upstream TLS verification. Default is '/etc/ssl/certs/ca-certificates.crt'."
llm_routing_model:
type: string
description: "Model name for the LLM router (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
agent_orchestration_model:
type: string
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
orchestrator_model_context_length:
type: integer
description: "Maximum token length for the orchestrator/routing model context window. Default is 8192."
system_prompt:
type: string
prompt_targets:
@ -415,6 +458,34 @@ properties:
type: string
model:
type: string
session_ttl_seconds:
type: integer
minimum: 1
description: TTL in seconds for session-pinned routing cache entries. Default 600 (10 minutes).
session_max_entries:
type: integer
minimum: 1
maximum: 10000
description: Maximum number of session-pinned routing cache entries. Default 10000.
session_cache:
type: object
properties:
type:
type: string
enum:
- memory
- redis
default: memory
description: Session cache backend. "memory" (default) is in-process; "redis" is shared across replicas.
url:
type: string
description: Redis URL, e.g. redis://localhost:6379. Required when type is redis.
tenant_header:
type: string
description: >
Optional HTTP header name whose value is used as a tenant prefix in the cache key.
When set, keys are scoped as plano:affinity:{tenant_id}:{session_id}.
additionalProperties: false
additionalProperties: false
state_storage:
type: object
@ -464,6 +535,88 @@ properties:
additionalProperties: false
required:
- jailbreak
routing_preferences:
type: array
items:
type: object
properties:
name:
type: string
description:
type: string
models:
type: array
items:
type: string
minItems: 1
selection_policy:
type: object
properties:
prefer:
type: string
enum:
- cheapest
- fastest
- none
additionalProperties: false
required:
- prefer
additionalProperties: false
required:
- name
- description
- models
model_metrics_sources:
type: array
items:
oneOf:
- type: object
properties:
type:
type: string
const: cost
provider:
type: string
enum:
- digitalocean
refresh_interval:
type: integer
minimum: 1
description: "Refresh interval in seconds"
model_aliases:
type: object
description: "Map DO catalog keys (lowercase(creator)/model_id) to Plano model names used in routing_preferences. Example: 'openai/openai-gpt-oss-120b: openai/gpt-4o'"
additionalProperties:
type: string
required:
- type
- provider
additionalProperties: false
- type: object
properties:
type:
type: string
const: latency
provider:
type: string
enum:
- prometheus
url:
type: string
query:
type: string
refresh_interval:
type: integer
minimum: 1
description: "Refresh interval in seconds"
required:
- type
- provider
- url
- query
additionalProperties: false
additionalProperties: false
required:
- version

View file

@ -1,14 +1,33 @@
[supervisord]
nodaemon=true
pidfile=/var/run/supervisord.pid
[program:config_generator]
command=/bin/sh -c "\
uv run python -m planoai.config_generator && \
envsubst < /app/plano_config_rendered.yaml > /app/plano_config_rendered.env_sub.yaml && \
envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && \
touch /tmp/config_ready || \
(echo 'Config generation failed, shutting down'; kill -15 $(cat /var/run/supervisord.pid))"
priority=10
autorestart=false
startsecs=0
stdout_logfile=/dev/stdout
redirect_stderr=true
stdout_logfile_maxbytes=0
stderr_logfile_maxbytes=0
[program:brightstaff]
command=sh -c "\
envsubst < /app/plano_config_rendered.yaml > /app/plano_config_rendered.env_sub.yaml && \
while [ ! -f /tmp/config_ready ]; do echo '[brightstaff] Waiting for config generation...'; sleep 0.5; done && \
RUST_LOG=${LOG_LEVEL:-info} \
PLANO_CONFIG_PATH_RENDERED=/app/plano_config_rendered.env_sub.yaml \
/app/brightstaff 2>&1 | \
tee /var/log/brightstaff.log | \
while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done; \
echo '[brightstaff] Process exited, shutting down'; kill -15 $(cat /var/run/supervisord.pid)"
priority=20
autorestart=false
stdout_logfile=/dev/stdout
redirect_stderr=true
stdout_logfile_maxbytes=0
@ -16,13 +35,15 @@ stderr_logfile_maxbytes=0
[program:envoy]
command=/bin/sh -c "\
uv run python -m planoai.config_generator && \
envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && \
envoy -c /etc/envoy.env_sub.yaml \
--component-log-level wasm:${LOG_LEVEL:-info} \
--log-format '[%%Y-%%m-%%d %%T.%%e][%%l] %%v' 2>&1 | \
while [ ! -f /tmp/config_ready ]; do echo '[plano_logs] Waiting for config generation...'; sleep 0.5; done && \
envoy -c /etc/envoy.env_sub.yaml \
--component-log-level wasm:${LOG_LEVEL:-info} \
--log-format '[%%Y-%%m-%%d %%T.%%e][%%l] %%v' 2>&1 | \
tee /var/log/envoy.log | \
while IFS= read -r line; do echo '[plano_logs]' \"$line\"; done"
while IFS= read -r line; do echo '[plano_logs]' \"$line\"; done; \
echo '[plano_logs] Process exited, shutting down'; kill -15 $(cat /var/run/supervisord.pid)"
priority=20
autorestart=false
stdout_logfile=/dev/stdout
redirect_stderr=true
stdout_logfile_maxbytes=0

View file

@ -1,6 +1,16 @@
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
CLI_DIR="$REPO_ROOT/cli"
# Use uv run if available and cli/ has a pyproject.toml, otherwise fall back to bare python
if command -v uv &> /dev/null && [ -f "$CLI_DIR/pyproject.toml" ]; then
PYTHON_CMD="uv run --directory $CLI_DIR python"
else
PYTHON_CMD="python"
fi
failed_files=()
for file in $(find . -name config.yaml -o -name plano_config_full_reference.yaml); do
@ -14,7 +24,7 @@ for file in $(find . -name config.yaml -o -name plano_config_full_reference.yaml
ENVOY_CONFIG_TEMPLATE_FILE="envoy.template.yaml" \
PLANO_CONFIG_FILE_RENDERED="$rendered_file" \
ENVOY_CONFIG_FILE_RENDERED="/dev/null" \
python -m planoai.config_generator 2>&1 > /dev/null
$PYTHON_CMD -m planoai.config_generator 2>&1 > /dev/null
if [ $? -ne 0 ]; then
echo "Validation failed for $file"

2135
crates/Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,18 @@ name = "brightstaff"
version = "0.1.0"
edition = "2021"
[features]
default = ["jemalloc"]
jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"]
[[bin]]
name = "brightstaff"
path = "src/main.rs"
[[bin]]
name = "signals_replay"
path = "src/bin/signals_replay.rs"
[dependencies]
async-openai = "0.30.1"
async-trait = "0.1"
@ -26,6 +38,12 @@ opentelemetry-stdout = "0.31"
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
pretty_assertions = "1.4.1"
rand = "0.9.2"
regex = "1.10"
lru = "0.12"
metrics = "0.23"
metrics-exporter-prometheus = { version = "0.15", default-features = false, features = ["http-listener"] }
metrics-process = "2.1"
redis = { version = "0.27", features = ["tokio-comp"] }
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
@ -33,6 +51,8 @@ serde_with = "3.13.0"
strsim = "0.11"
serde_yaml = "0.9.34"
thiserror = "2.0.12"
tikv-jemallocator = { version = "0.6", optional = true }
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true }
tokio = { version = "1.44.2", features = ["full"] }
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
tokio-stream = "0.1"

View file

@ -0,0 +1,30 @@
use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAttributes};
use common::llm_providers::LlmProviders;
use tokio::sync::RwLock;
use crate::router::orchestrator::OrchestratorService;
use crate::state::StateStorage;
/// Shared application state bundled into a single Arc-wrapped struct.
///
/// Instead of cloning 8+ individual `Arc`s per connection, a single
/// `Arc<AppState>` is cloned once and passed to the request handler.
pub struct AppState {
pub orchestrator_service: Arc<OrchestratorService>,
pub model_aliases: Option<HashMap<String, ModelAlias>>,
pub llm_providers: Arc<RwLock<LlmProviders>>,
pub agents_list: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
pub state_storage: Option<Arc<dyn StateStorage>>,
pub llm_provider_url: String,
pub span_attributes: Option<SpanAttributes>,
/// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive).
pub http_client: reqwest::Client,
pub filter_pipeline: Arc<FilterPipeline>,
/// When false, agentic signal analysis is skipped on LLM responses to save CPU.
/// Controlled by `overrides.disable_signals` in plano config.
pub signals_enabled: bool,
}

View file

@ -0,0 +1,175 @@
//! `signals-replay` — batch driver for the `brightstaff` signal analyzer.
//!
//! Reads JSONL conversations from stdin (one per line) and emits matching
//! JSONL reports on stdout, one per input conversation, in the same order.
//!
//! Input shape (per line):
//! ```json
//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]}
//! ```
//!
//! Output shape (per line, success):
//! ```json
//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }}
//! ```
//!
//! On per-line failure (parse / analyzer error), emits:
//! ```json
//! {"id": "convo-42", "error": "..."}
//! ```
//!
//! The output report dict is shaped to match the Python reference's
//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a
//! direct structural diff.
use std::io::{self, BufRead, BufWriter, Write};
use serde::Deserialize;
use serde_json::{json, Map, Value};
use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport};
#[derive(Debug, Deserialize)]
struct InputLine {
id: Value,
messages: Vec<MessageRow>,
}
#[derive(Debug, Deserialize)]
struct MessageRow {
#[serde(default)]
from: String,
#[serde(default)]
value: String,
}
fn main() {
let stdin = io::stdin();
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
let analyzer = SignalAnalyzer::default();
for line in stdin.lock().lines() {
let line = match line {
Ok(l) => l,
Err(e) => {
eprintln!("read error: {e}");
std::process::exit(1);
}
};
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let result = process_line(&analyzer, trimmed);
// Always emit one line per input line so id ordering stays aligned.
if let Err(e) = writeln!(out, "{result}") {
eprintln!("write error: {e}");
std::process::exit(1);
}
// Flush periodically isn't strictly needed — BufWriter handles it,
// and the parent process reads the whole stream when we're done.
}
let _ = out.flush();
}
fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value {
let parsed: InputLine = match serde_json::from_str(line) {
Ok(p) => p,
Err(e) => {
return json!({
"id": Value::Null,
"error": format!("input parse: {e}"),
});
}
};
let id = parsed.id.clone();
let view: Vec<brightstaff::signals::analyzer::ShareGptMessage<'_>> = parsed
.messages
.iter()
.map(|m| brightstaff::signals::analyzer::ShareGptMessage {
from: m.from.as_str(),
value: m.value.as_str(),
})
.collect();
let report = analyzer.analyze_sharegpt(&view);
let report_dict = report_to_python_dict(&report);
json!({
"id": id,
"report": report_dict,
})
}
/// Convert a `SignalReport` into the Python reference's `to_dict()` shape.
///
/// Ordering of category keys in each layer dict follows the Python source
/// exactly so even string-equality comparisons behave deterministically.
fn report_to_python_dict(r: &SignalReport) -> Value {
let mut interaction = Map::new();
interaction.insert(
"misalignment".to_string(),
signal_group_to_python(&r.interaction.misalignment),
);
interaction.insert(
"stagnation".to_string(),
signal_group_to_python(&r.interaction.stagnation),
);
interaction.insert(
"disengagement".to_string(),
signal_group_to_python(&r.interaction.disengagement),
);
interaction.insert(
"satisfaction".to_string(),
signal_group_to_python(&r.interaction.satisfaction),
);
let mut execution = Map::new();
execution.insert(
"failure".to_string(),
signal_group_to_python(&r.execution.failure),
);
execution.insert(
"loops".to_string(),
signal_group_to_python(&r.execution.loops),
);
let mut environment = Map::new();
environment.insert(
"exhaustion".to_string(),
signal_group_to_python(&r.environment.exhaustion),
);
json!({
"interaction_signals": Value::Object(interaction),
"execution_signals": Value::Object(execution),
"environment_signals": Value::Object(environment),
"overall_quality": r.overall_quality.as_str(),
"summary": r.summary,
})
}
fn signal_group_to_python(g: &SignalGroup) -> Value {
let signals: Vec<Value> = g
.signals
.iter()
.map(|s| {
json!({
"signal_type": s.signal_type.as_str(),
"message_index": s.message_index,
"snippet": s.snippet,
"confidence": s.confidence,
"metadata": s.metadata,
})
})
.collect();
json!({
"category": g.category,
"count": g.count,
"severity": g.severity,
"signals": signals,
})
}

View file

@ -0,0 +1,41 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::Response;
use serde_json::json;
use tracing::{info, warn};
use crate::handlers::response::ResponseHandler;
/// Build a JSON error response from an `AgentFilterChainError`, logging the
/// full error chain along the way.
///
/// Returns `Ok(Response)` so it can be used directly as a handler return value.
pub fn build_error_chain_response<E: std::error::Error>(
err: &E,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut error_chain = Vec::new();
let mut current: &dyn std::error::Error = err;
loop {
error_chain.push(current.to_string());
match current.source() {
Some(source) => current = source,
None => break,
}
}
warn!(error_chain = ?error_chain, "agent chat error chain");
warn!(root_error = ?err, "root error");
let error_json = json!({
"error": {
"type": "AgentFilterChainError",
"message": err.to_string(),
"error_chain": error_chain,
"debug_info": format!("{:?}", err)
}
});
info!(error = %error_json, "structured error info");
Ok(ResponseHandler::create_json_error_response(&error_json))
}

View file

@ -0,0 +1,5 @@
pub mod errors;
pub mod jsonrpc;
pub mod orchestrator;
pub mod pipeline;
pub mod selector;

View file

@ -2,63 +2,56 @@ use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use common::configuration::SpanAttributes;
use common::errors::BrightStaffError;
use common::llm_providers::LlmProviders;
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use hyper::{Request, Response};
use opentelemetry::trace::get_active_span;
use serde::ser::Error as SerError;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use super::errors::build_error_chain_response;
use super::pipeline::{PipelineError, PipelineProcessor};
use super::selector::{AgentSelectionError, AgentSelector};
use crate::app_state::AppState;
use crate::handlers::extract_request_id;
use crate::handlers::response::ResponseHandler;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
pub enum AgentFilterChainError {
#[error("Forwarded error: {0}")]
Brightstaff(#[from] BrightStaffError),
#[error("Agent selection error: {0}")]
Selection(#[from] AgentSelectionError),
#[error("Pipeline processing error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Response handling error: {0}")]
Response(#[from] common::errors::BrightStaffError),
#[error("Request parsing error: {0}")]
RequestParsing(#[from] serde_json::Error),
RequestParsing(String),
#[error("HTTP error: {0}")]
Http(#[from] hyper::Error),
#[error("Unsupported endpoint: {0}")]
UnsupportedEndpoint(String),
#[error("No agents configured")]
NoAgentsConfigured,
#[error("Agent '{0}' not found in configuration")]
AgentNotFound(String),
#[error("No messages in conversation history")]
EmptyHistory,
#[error("Agent chain completed without producing a response")]
IncompleteChain,
}
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
orchestrator_service: Arc<OrchestratorService>,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
span_attributes: Arc<Option<SpanAttributes>>,
llm_providers: Arc<RwLock<LlmProviders>>,
state: Arc<AppState>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_id = extract_request_id(&request);
let custom_attrs =
collect_custom_trace_attributes(request.headers(), span_attributes.as_ref().as_ref());
// Extract request_id from headers or generate a new one
let request_id: String = match request
.headers()
.get(common::consts::REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(id) => id,
None => uuid::Uuid::new_v4().to_string(),
};
collect_custom_trace_attributes(request.headers(), state.span_attributes.as_ref());
// Create a span with request_id that will be included in all log lines
let request_span = info_span!(
@ -74,17 +67,7 @@ pub async fn agent_chat(
// Set service name for orchestrator operations
set_service_name(operation_component::ORCHESTRATOR);
match handle_agent_chat_inner(
request,
orchestrator_service,
agents_list,
listeners,
llm_providers,
request_id,
custom_attrs,
)
.await
{
match handle_agent_chat_inner(request, state, request_id, custom_attrs).await {
Ok(response) => Ok(response),
Err(err) => {
// Check if this is a client error from the pipeline that should be cascaded
@ -101,7 +84,6 @@ pub async fn agent_chat(
"client error from agent"
);
// Create error response with the original status code and body
let error_json = serde_json::json!({
"error": "ClientError",
"agent": agent,
@ -109,52 +91,19 @@ pub async fn agent_chat(
"agent_response": body
});
let status_code = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
let json_string = error_json.to_string();
return Ok(BrightStaffError::ForwardedError {
status_code,
message: json_string,
}
.into_response());
let mut response =
Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
return Ok(response);
}
// Print detailed error information with full error chain for other errors
let mut error_chain = Vec::new();
let mut current_error: &dyn std::error::Error = &err;
// Collect the full error chain
loop {
error_chain.push(current_error.to_string());
match current_error.source() {
Some(source) => current_error = source,
None => break,
}
}
// Log the complete error chain
warn!(error_chain = ?error_chain, "agent chat error chain");
warn!(root_error = ?err, "root error");
// Create structured error response as JSON
let error_json = serde_json::json!({
"error": {
"type": "AgentFilterChainError",
"message": err.to_string(),
"error_chain": error_chain,
"debug_info": format!("{:?}", err)
}
});
// Log the error for debugging
info!(error = %error_json, "structured error info");
Ok(BrightStaffError::ForwardedError {
status_code: StatusCode::BAD_REQUEST,
message: error_json.to_string(),
}
.into_response())
build_error_chain_response(&err)
}
}
}
@ -162,19 +111,22 @@ pub async fn agent_chat(
.await
}
async fn handle_agent_chat_inner(
/// Parsed and validated agent request data.
struct AgentRequest {
client_request: ProviderRequestType,
messages: Vec<OpenAIMessage>,
request_headers: hyper::HeaderMap,
request_id: Option<String>,
}
/// Parse the incoming HTTP request, resolve the listener, and extract messages.
async fn parse_agent_request(
request: Request<hyper::body::Incoming>,
orchestrator_service: Arc<OrchestratorService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(orchestrator_service);
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
state: &AppState,
request_id: &str,
custom_attrs: &std::collections::HashMap<String, String>,
) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> {
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
// Extract listener name from headers
let listener_name = request
@ -183,16 +135,11 @@ async fn handle_agent_chat_inner(
.and_then(|name| name.to_str().ok());
// Find the appropriate listener
let listener: common::configuration::Listener = {
let listeners = listeners.read().await;
agent_selector
.find_listener(listener_name, &listeners)
.await?
};
let listener = agent_selector.find_listener(listener_name, &state.listeners)?;
get_active_span(|span| {
span.update_name(listener.name.to_string());
for (key, value) in &custom_attrs {
for (key, value) in custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
@ -200,24 +147,20 @@ async fn handle_agent_chat_inner(
info!(listener = %listener.name, "handling request");
// Parse request body
let request_path = request
.uri()
.path()
.to_string()
let full_path = request.uri().path().to_string();
let request_path = full_path
.strip_prefix("/agents")
.unwrap()
.unwrap_or(&full_path)
.to_string();
let request_headers = {
let mut headers = request.headers().clone();
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
// Set the request_id in headers if not already present
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
headers.insert(
common::consts::REQUEST_ID_HEADER,
hyper::header::HeaderValue::from_str(&request_id).unwrap(),
);
if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
headers.insert(common::consts::REQUEST_ID_HEADER, val);
}
}
headers
@ -230,63 +173,62 @@ async fn handle_agent_chat_inner(
"received request body"
);
// Determine the API type from the endpoint
let api_type =
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
let err_msg = format!("Unsupported endpoint: {}", request_path);
warn!("{}", err_msg);
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
warn!(path = %request_path, "unsupported endpoint");
AgentFilterChainError::UnsupportedEndpoint(request_path.clone())
})?;
let mut client_request =
match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
Ok(request) => request,
Err(err) => {
warn!("failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
return Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom(err_msg),
));
}
};
let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
.map_err(|err| {
warn!(error = %err, "failed to parse request as ProviderRequestType");
AgentFilterChainError::RequestParsing(format!("Failed to parse request: {}", err))
})?;
// If model is not specified in the request, resolve from default provider
if client_request.model().is_empty() {
match llm_providers.read().await.default() {
Some(default_provider) => {
let default_model = default_provider.name.clone();
info!(default_model = %default_model, "no model specified in request, using default provider");
client_request.set_model(default_model);
}
None => {
let err_msg = "No model specified in request and no default provider configured";
warn!("{}", err_msg);
return Ok(BrightStaffError::NoModelSpecified.into_response());
}
}
}
let message: Vec<OpenAIMessage> = client_request.get_messages();
let messages: Vec<OpenAIMessage> = client_request.get_messages();
let request_id = request_headers
.get(common::consts::REQUEST_ID_HEADER)
.and_then(|val| val.to_str().ok())
.map(|s| s.to_string());
// Create agent map for pipeline processing and agent selection
let agent_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
agent_selector.create_agent_map(agents)
};
Ok((
AgentRequest {
client_request,
messages,
request_headers,
request_id,
},
listener,
agent_selector,
))
}
/// Select agents via the orchestrator model and record selection metrics.
async fn select_and_build_agent_map(
agent_selector: &AgentSelector,
state: &AppState,
messages: &[OpenAIMessage],
listener: &common::configuration::Listener,
request_id: Option<String>,
) -> Result<
(
Vec<common::configuration::AgentFilterChain>,
std::collections::HashMap<String, common::configuration::Agent>,
),
AgentFilterChainError,
> {
let agents = state
.agents_list
.as_ref()
.ok_or(AgentFilterChainError::NoAgentsConfigured)?;
let agent_map = agent_selector.create_agent_map(agents);
// Select appropriate agents using arch orchestrator llm model
let selection_start = Instant::now();
let selected_agents = agent_selector
.select_agents(&message, &listener, request_id.clone())
.select_agents(messages, listener, request_id)
.await?;
// Record selection attributes on the current orchestrator span
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
get_active_span(|span| {
span.set_attribute(opentelemetry::KeyValue::new(
@ -316,12 +258,25 @@ async fn handle_agent_chat_inner(
"selected agents for execution"
);
// Execute agents sequentially, passing output from one to the next
let mut current_messages = message.clone();
Ok((selected_agents, agent_map))
}
/// Execute the agent chain: run each selected agent sequentially, streaming
/// the final agent's response back to the client.
async fn execute_agent_chain(
selected_agents: &[common::configuration::AgentFilterChain],
agent_map: &std::collections::HashMap<String, common::configuration::Agent>,
client_request: ProviderRequestType,
messages: Vec<OpenAIMessage>,
request_headers: &hyper::HeaderMap,
custom_attrs: &std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
let mut current_messages = messages;
let agent_count = selected_agents.len();
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
// Get agent name
let agent_name = selected_agent.id.clone();
let is_last_agent = agent_index == agent_count - 1;
@ -332,18 +287,40 @@ async fn handle_agent_chat_inner(
"processing agent"
);
// Process the filter chain
let chat_history = pipeline_processor
.process_filter_chain(
&current_messages,
selected_agent,
&agent_map,
&request_headers,
)
.await?;
let chat_history = if selected_agent
.input_filters
.as_ref()
.map(|f| !f.is_empty())
.unwrap_or(false)
{
let filter_body = serde_json::json!({
"model": client_request.model(),
"messages": current_messages,
});
let filter_bytes =
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
// Get agent details and invoke
let agent = agent_map.get(&agent_name).unwrap();
let filtered_bytes = pipeline_processor
.process_raw_filter_chain(
&filter_bytes,
selected_agent,
agent_map,
request_headers,
"/v1/chat/completions",
)
.await?;
let filtered_body: serde_json::Value =
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
serde_json::from_value(filtered_body["messages"].clone())
.map_err(PipelineError::ParseError)?
} else {
current_messages.clone()
};
let agent = agent_map
.get(&agent_name)
.ok_or_else(|| AgentFilterChainError::AgentNotFound(agent_name.clone()))?;
debug!(agent = %agent_name, "invoking agent");
@ -357,7 +334,7 @@ async fn handle_agent_chat_inner(
set_service_name(operation_component::AGENT);
get_active_span(|span| {
span.update_name(format!("{} /v1/chat/completions", agent_name));
for (key, value) in &custom_attrs {
for (key, value) in custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
@ -367,28 +344,25 @@ async fn handle_agent_chat_inner(
&chat_history,
client_request.clone(),
agent,
&request_headers,
request_headers,
)
.await
}
.instrument(agent_span.clone())
.await?;
// If this is the last agent, return the streaming response
if is_last_agent {
info!(
agent = %agent_name,
"completed agent chain, returning response"
);
// Capture the orchestrator span (parent of the agent span) so it
// stays open for the full streaming duration alongside the agent span.
let orchestrator_span = tracing::Span::current();
return async {
response_handler
.create_streaming_response(
llm_response,
tracing::Span::current(), // agent span (inner)
orchestrator_span, // orchestrator span (outer)
tracing::Span::current(),
orchestrator_span,
)
.await
.map_err(AgentFilterChainError::from)
@ -397,7 +371,6 @@ async fn handle_agent_chat_inner(
.await;
}
// For intermediate agents, collect the full response and pass to next agent
debug!(agent = %agent_name, "collecting response from intermediate agent");
let response_text = async { response_handler.collect_full_response(llm_response).await }
.instrument(agent_span)
@ -409,11 +382,11 @@ async fn handle_agent_chat_inner(
"agent completed, passing response to next agent"
);
// remove last message and add new one at the end
let last_message = current_messages.pop().unwrap();
let Some(last_message) = current_messages.pop() else {
warn!(agent = %agent_name, "no messages in conversation history");
return Err(AgentFilterChainError::EmptyHistory);
};
// Create a new message with the agent's response as assistant message
// and add it to the conversation history
current_messages.push(OpenAIMessage {
role: hermesllm::apis::openai::Role::Assistant,
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
@ -425,6 +398,34 @@ async fn handle_agent_chat_inner(
current_messages.push(last_message);
}
// This should never be reached since we return in the last agent iteration
unreachable!("Agent execution loop should have returned a response")
Err(AgentFilterChainError::IncompleteChain)
}
async fn handle_agent_chat_inner(
request: Request<hyper::body::Incoming>,
state: Arc<AppState>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
let (agent_req, listener, agent_selector) =
parse_agent_request(request, &state, &request_id, &custom_attrs).await?;
let (selected_agents, agent_map) = select_and_build_agent_map(
&agent_selector,
&state,
&agent_req.messages,
&listener,
agent_req.request_id,
)
.await?;
execute_agent_chain(
&selected_agents,
&agent_map,
agent_req.client_request,
agent_req.messages,
&agent_req.request_headers,
&custom_attrs,
)
.await
}

View file

@ -1,5 +1,6 @@
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
@ -11,7 +12,7 @@ use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use tracing::{debug, info, instrument, warn};
use crate::handlers::jsonrpc::{
use super::jsonrpc::{
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
};
@ -35,8 +36,6 @@ pub enum PipelineError {
NoResultInResponse(String),
#[error("No structured content in response from agent '{0}'")]
NoStructuredContentInResponse(String),
#[error("No messages in response from agent '{0}'")]
NoMessagesInResponse(String),
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
ClientError {
agent: String,
@ -79,74 +78,11 @@ impl PipelineProcessor {
}
}
// /// Process the filter chain of agents (all except the terminal agent)
// #[instrument(
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
// fields(
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
// message_count = chat_history.len()
// )
// )]
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_history_updated = chat_history.to_vec();
// If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated),
};
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
info!(
agent = %agent_name,
tool = %tool_name,
url = %agent.url,
agent_type = %agent.agent_type.as_deref().unwrap_or("mcp"),
conversation_len = chat_history.len(),
"executing filter"
);
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
chat_history_updated = self
.execute_mcp_filter(&chat_history_updated, agent, request_headers)
.await?;
} else {
chat_history_updated = self
.execute_http_filter(&chat_history_updated, agent, request_headers)
.await?;
}
info!(
agent = %agent_name,
updated_len = chat_history_updated.len(),
"filter completed"
);
}
Ok(chat_history_updated)
}
/// Build common MCP headers for requests
fn build_mcp_headers(
&self,
/// Prepare headers shared by all agent/filter requests: removes
/// content-length, injects trace context, sets upstream host and retry.
fn build_agent_headers(
request_headers: &HeaderMap,
agent_id: &str,
session_id: Option<&str>,
) -> Result<HeaderMap, PipelineError> {
let mut headers = request_headers.clone();
headers.remove(hyper::header::CONTENT_LENGTH);
@ -167,24 +103,34 @@ impl PipelineProcessor {
headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
hyper::header::HeaderValue::from_static("3"),
);
Ok(headers)
}
/// Build headers for MCP requests (adds Accept, Content-Type, optional session id).
fn build_mcp_headers(
&self,
request_headers: &HeaderMap,
agent_id: &str,
session_id: Option<&str>,
) -> Result<HeaderMap, PipelineError> {
let mut headers = Self::build_agent_headers(request_headers, agent_id)?;
headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
);
headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
if let Some(sid) = session_id {
headers.insert(
"mcp-session-id",
hyper::header::HeaderValue::from_str(sid).unwrap(),
);
if let Ok(val) = hyper::header::HeaderValue::from_str(sid) {
headers.insert("mcp-session-id", val);
}
}
Ok(headers)
@ -261,14 +207,17 @@ impl PipelineProcessor {
Ok(response)
}
/// Build a tools/call JSON-RPC request
fn build_tool_call_request(
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
fn build_tool_call_request_with_body(
&self,
tool_name: &str,
messages: &[Message],
body: &serde_json::Value,
path: &str,
) -> Result<JsonRpcRequest, PipelineError> {
let mut arguments = HashMap::new();
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
arguments.insert("body".to_string(), serde_json::to_value(body)?);
arguments.insert("path".to_string(), serde_json::to_value(path)?);
let mut params = HashMap::new();
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
@ -282,35 +231,28 @@ impl PipelineProcessor {
})
}
/// Send request to a specific agent and return the response content
#[instrument(
skip(self, messages, agent, request_headers),
fields(
agent_id = %agent.id,
filter_name = %agent.id,
message_count = messages.len()
)
)]
async fn execute_mcp_filter(
/// Like execute_mcp_filter_raw but passes the full raw body dict + path hint as MCP tool arguments.
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
async fn execute_mcp_filter_raw(
&mut self,
messages: &[Message],
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
// Set service name for this filter span
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
// Update current span name to include filter name
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_mcp_filter ({})", agent.id));
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
});
// Get or create MCP session
let body: serde_json::Value =
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
session_id.clone()
} else {
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
let session_id = self.get_new_session_id(&agent.id, request_headers).await?;
self.agent_id_session_map
.insert(agent.id.clone(), session_id.clone());
session_id
@ -321,11 +263,10 @@ impl PipelineProcessor {
mcp_session_id, agent.id
);
// Build JSON-RPC request
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
let json_rpc_request =
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
// Build headers
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
@ -335,7 +276,6 @@ impl PipelineProcessor {
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
@ -353,20 +293,12 @@ impl PipelineProcessor {
});
}
info!(
"Response from agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
// Parse SSE response
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
let response_result = response
.result
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
// Check if error field is set in response result
if response_result
.get("isError")
.and_then(|v| v.as_bool())
@ -388,21 +320,28 @@ impl PipelineProcessor {
});
}
// Extract structured content and parse messages
let response_json = response_result
// FastMCP puts structured Pydantic return values in structuredContent.result,
// but plain dicts land in content[0].text as a JSON string. Try both.
let result = if let Some(structured) = response_result
.get("structuredContent")
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
.and_then(|v| v.get("result"))
.cloned()
{
structured
} else {
let text = response_result
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("text"))
.and_then(|v| v.as_str())
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
serde_json::from_str(text).map_err(PipelineError::ParseError)?
};
let messages: Vec<Message> = response_json
.get("result")
.and_then(|v| v.as_array())
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
.iter()
.map(|msg_value| serde_json::from_value(msg_value.clone()))
.collect::<Result<Vec<Message>, _>>()
.map_err(PipelineError::ParseError)?;
Ok(messages)
Ok(Bytes::from(
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
))
}
/// Build an initialize JSON-RPC request
@ -464,18 +403,19 @@ impl PipelineProcessor {
Ok(())
}
async fn get_new_session_id(&self, agent_id: &str, request_headers: &HeaderMap) -> String {
async fn get_new_session_id(
&self,
agent_id: &str,
request_headers: &HeaderMap,
) -> Result<String, PipelineError> {
info!("initializing MCP session for agent {}", agent_id);
let initialize_request = self.build_initialize_request();
let headers = self
.build_mcp_headers(request_headers, agent_id, None)
.expect("Failed to build headers for initialization");
let headers = self.build_mcp_headers(request_headers, agent_id, None)?;
let response = self
.send_mcp_request(&initialize_request, &headers, agent_id)
.await
.expect("Failed to initialize MCP session");
.await?;
info!("initialize response status: {}", response.status());
@ -483,8 +423,13 @@ impl PipelineProcessor {
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.expect("No mcp-session-id in response")
.to_string();
.map(|s| s.to_string())
.ok_or_else(|| {
PipelineError::NoContentInResponse(format!(
"No mcp-session-id header in initialize response from agent {}",
agent_id
))
})?;
info!(
"created new MCP session for agent {}: {}",
@ -493,88 +438,63 @@ impl PipelineProcessor {
// Send initialized notification
self.send_initialized_notification(agent_id, &session_id, &headers)
.await
.expect("Failed to send initialized notification");
.await?;
session_id
Ok(session_id)
}
/// Execute a HTTP-based filter agent
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
/// Used for input and output filters where the full raw request/response is passed through.
/// No MCP protocol wrapping; agent_type is ignored.
#[instrument(
skip(self, messages, agent, request_headers),
skip(self, raw_bytes, agent, request_headers),
fields(
agent_id = %agent.id,
agent_url = %agent.url,
filter_name = %agent.id,
message_count = messages.len()
bytes_len = raw_bytes.len()
)
)]
async fn execute_http_filter(
async fn execute_raw_filter(
&mut self,
messages: &[Message],
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
// Set service name for this filter span
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
// Update current span name to include filter name
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_http_filter ({})", agent.id));
span.update_name(format!("execute_raw_filter ({})", agent.id));
});
// Build headers
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
);
agent_headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
debug!(
"Sending HTTP request to agent {} at URL: {}",
agent.id, agent.url
);
// Append the original request path so the filter endpoint encodes the API format.
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
// -> POST http://host/anonymize/v1/chat/completions
let url = format!("{}{}", agent.url, request_path);
debug!(agent = %agent.id, url = %url, "sending raw filter request");
// Send messages array directly as request body
let response = self
.client
.post(&agent.url)
.post(&url)
.headers(agent_headers)
.json(&messages)
.body(raw_bytes.to_vec())
.send()
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
@ -592,17 +512,56 @@ impl PipelineProcessor {
});
}
debug!(
"Response from HTTP agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
Ok(response_bytes)
}
// Parse response - expecting array of messages directly
let messages: Vec<Message> =
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
/// Process a chain of raw-bytes filters sequentially.
/// Input: raw request or response bytes. Output: filtered bytes.
/// Each agent receives the output of the previous one.
pub async fn process_raw_filter_chain(
&mut self,
raw_bytes: &[u8],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
};
Ok(messages)
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing raw filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
info!(
agent = %agent_name,
url = %agent.url,
agent_type = %agent_type,
bytes_len = current_bytes.len(),
"executing raw filter"
);
current_bytes = if agent_type == "mcp" {
self.execute_mcp_filter_raw(&current_bytes, agent, request_headers, request_path)
.await?
} else {
self.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?
};
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
}
Ok(current_bytes)
}
/// Send request to terminal agent and return the raw response for streaming
@ -615,36 +574,15 @@ impl PipelineProcessor {
terminal_agent: &Agent,
request_headers: &HeaderMap,
) -> Result<reqwest::Response, PipelineError> {
// let mut request = original_request.clone();
original_request.set_messages(messages);
let request_url = "/v1/chat/completions";
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
// let request_body = serde_json::to_string(&request)?;
let request_body = ProviderRequestType::to_bytes(&original_request)
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
debug!("sending request to terminal agent {}", terminal_agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.id)
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?;
let response = self
.client
@ -661,24 +599,13 @@ impl PipelineProcessor {
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use mockito::Server;
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: Some(MessageContent::Text(content.to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
description: None,
default: None,
}
@ -690,12 +617,19 @@ mod tests {
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let messages = vec![create_test_message(Role::User, "Hello")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers)
.process_raw_filter_chain(
&raw_bytes,
&pipeline,
&agent_map,
&request_headers,
"/v1/chat/completions",
)
.await;
assert!(result.is_err());
@ -725,11 +659,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hello")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {
@ -764,11 +699,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Ping")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {
@ -816,11 +752,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hi")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {

View file

@ -7,7 +7,7 @@ use common::configuration::{
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::plano_orchestrator::OrchestratorService;
use crate::router::orchestrator::OrchestratorService;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]
@ -37,7 +37,7 @@ impl AgentSelector {
}
/// Find listener by name from the request headers
pub async fn find_listener(
pub fn find_listener(
&self,
listener_name: Option<&str>,
listeners: &[common::configuration::Listener],
@ -84,7 +84,7 @@ impl AgentSelector {
}
/// Convert agent descriptions to orchestration preferences
async fn convert_agent_description_to_orchestration_preferences(
fn convert_agent_description_to_orchestration_preferences(
&self,
agents: &[AgentFilterChain],
) -> Vec<AgentUsagePreference> {
@ -121,9 +121,7 @@ impl AgentSelector {
return Ok(vec![agents[0].clone()]);
}
let usage_preferences = self
.convert_agent_description_to_orchestration_preferences(agents)
.await;
let usage_preferences = self.convert_agent_description_to_orchestration_preferences(agents);
debug!(
"Agents usage preferences for orchestration: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
@ -172,12 +170,14 @@ impl AgentSelector {
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
use common::configuration::{AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
"http://localhost:8080".to_string(),
"test-model".to_string(),
"plano-orchestrator".to_string(),
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
))
}
@ -186,14 +186,17 @@ mod tests {
id: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: Some(vec![name.to_string()]),
input_filters: Some(vec![name.to_string()]),
}
}
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
listener_type: ListenerType::Agent,
name: name.to_string(),
agents: Some(agents),
input_filters: None,
output_filters: None,
port: 8080,
router: None,
}
@ -218,9 +221,7 @@ mod tests {
let listener2 = create_test_listener("other-listener", vec![]);
let listeners = vec![listener1.clone(), listener2];
let result = selector
.find_listener(Some("test-listener"), &listeners)
.await;
let result = selector.find_listener(Some("test-listener"), &listeners);
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "test-listener");
@ -233,9 +234,7 @@ mod tests {
let listeners = vec![create_test_listener("other-listener", vec![])];
let result = selector
.find_listener(Some("nonexistent"), &listeners)
.await;
let result = selector.find_listener(Some("nonexistent"), &listeners);
assert!(result.is_err());
matches!(

View file

@ -0,0 +1,53 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::{Response, StatusCode};
use super::full;
#[derive(serde::Serialize)]
struct MemStats {
allocated_bytes: usize,
resident_bytes: usize,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
/// Returns jemalloc memory statistics as JSON.
/// Falls back to a stub when the jemalloc feature is disabled.
pub async fn memstats() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let stats = get_jemalloc_stats();
let json = serde_json::to_string(&stats).unwrap();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(full(json))
.unwrap())
}
#[cfg(feature = "jemalloc")]
fn get_jemalloc_stats() -> MemStats {
use tikv_jemalloc_ctl::{epoch, stats};
if let Err(e) = epoch::advance() {
return MemStats {
allocated_bytes: 0,
resident_bytes: 0,
error: Some(format!("failed to advance jemalloc epoch: {e}")),
};
}
MemStats {
allocated_bytes: stats::allocated::read().unwrap_or(0),
resident_bytes: stats::resident::read().unwrap_or(0),
error: None,
}
}
#[cfg(not(feature = "jemalloc"))]
fn get_jemalloc_stats() -> MemStats {
MemStats {
allocated_bytes: 0,
resident_bytes: 0,
error: Some("jemalloc feature not enabled".to_string()),
}
}

View file

@ -441,10 +441,8 @@ impl ArchFunctionHandler {
}
}
// Handle str/string conversions
"str" | "string" => {
if !value.is_string() {
return Ok(json!(value.to_string()));
}
"str" | "string" if !value.is_string() => {
return Ok(json!(value.to_string()));
}
_ => {}
}
@ -762,7 +760,7 @@ impl ArchFunctionHandler {
// Keep system message if present
if let Some(first) = messages.first() {
if first.role == Role::System {
if first.role == Role::System || first.role == Role::Developer {
if let Some(MessageContent::Text(content)) = &first.content {
num_tokens += content.len() / 4; // Approximate 4 chars per token
}

View file

@ -3,12 +3,11 @@ use std::sync::Arc;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::router::plano_orchestrator::OrchestratorService;
use common::errors::BrightStaffError;
use http_body_util::BodyExt;
use hyper::StatusCode;
use crate::handlers::agents::pipeline::PipelineProcessor;
use crate::handlers::agents::selector::{AgentSelectionError, AgentSelector};
use crate::handlers::response::ResponseHandler;
use crate::router::orchestrator::OrchestratorService;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agents based on orchestration
@ -17,12 +16,14 @@ use hyper::StatusCode;
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
"http://localhost:8080".to_string(),
"test-model".to_string(),
"plano-orchestrator".to_string(),
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
))
}
@ -63,7 +64,7 @@ mod tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![
input_filters: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
@ -72,8 +73,11 @@ mod tests {
};
let listener = Listener {
listener_type: ListenerType::Agent,
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
input_filters: None,
output_filters: None,
port: 8080,
router: None,
};
@ -82,9 +86,7 @@ mod tests {
let messages = vec![create_test_message(Role::User, "Hello world!")];
// Test 1: Agent Selection
let selected_listener = agent_selector
.find_listener(Some("test-listener"), &listeners)
.await;
let selected_listener = agent_selector.find_listener(Some("test-listener"), &listeners);
assert!(selected_listener.is_ok());
let listener = selected_listener.unwrap();
@ -106,58 +108,51 @@ mod tests {
// Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
description: None,
default: None,
};
let headers = HeaderMap::new();
let request_bytes = serde_json::to_vec(&request).expect("failed to serialize request");
let result = pipeline_processor
.process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers)
.process_raw_filter_chain(
&request_bytes,
&test_pipeline,
&agent_map,
&headers,
"/v1/chat/completions",
)
.await;
println!("Pipeline processing result: {:?}", result);
assert!(result.is_ok());
let processed_messages = result.unwrap();
// With empty filter chain, should return the original messages unchanged
assert_eq!(processed_messages.len(), 1);
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
let processed_bytes = result.unwrap();
// With empty filter chain, should return the original bytes unchanged
let processed_request: ChatCompletionsRequest =
serde_json::from_slice(&processed_bytes).expect("failed to deserialize response");
assert_eq!(processed_request.messages.len(), 1);
if let Some(MessageContent::Text(content)) = &processed_request.messages[0].content {
assert_eq!(content, "Hello world!");
} else {
panic!("Expected text content");
}
// Test 4: Error Response Creation
let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string());
let response = err.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
// Helper to extract body as JSON
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"]["code"], "ModelNotFound");
assert_eq!(
body["error"]["details"]["rejected_model_id"],
"gpt-5-secret"
);
assert!(body["error"]["message"]
.as_str()
.unwrap()
.contains("gpt-5-secret"));
let error_response = ResponseHandler::create_bad_request("Test error");
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
println!("✅ All modular components working correctly!");
}
#[tokio::test]
async fn test_error_handling_flow() {
let router_service = create_test_orchestrator_service();
let agent_selector = AgentSelector::new(router_service);
let orchestrator_service = create_test_orchestrator_service();
let agent_selector = AgentSelector::new(orchestrator_service);
// Test listener not found
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
let result = agent_selector.find_listener(Some("nonexistent"), &[]);
assert!(result.is_err());
assert!(matches!(
@ -165,21 +160,12 @@ mod tests {
AgentSelectionError::ListenerNotFound(_)
));
let technical_reason = "Database connection timed out";
let err = BrightStaffError::InternalServerError(technical_reason.to_string());
let response = err.into_response();
// --- 1. EXTRACT BYTES ---
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
// --- 2. DECLARE body_json HERE ---
let body_json: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body");
// --- 3. USE body_json ---
assert_eq!(body_json["error"]["code"], "InternalServerError");
assert_eq!(body_json["error"]["details"]["reason"], technical_reason);
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(
error_response.status(),
hyper::StatusCode::INTERNAL_SERVER_ERROR
);
println!("✅ Error handling working correctly!");
}

View file

@ -1,553 +0,0 @@
use bytes::Bytes;
use common::configuration::{ModelAlias, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::llm_providers::LlmProviders;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::header::{self};
use hyper::{Request, Response};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::{
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
};
use common::errors::BrightStaffError;
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id: String = match request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(id) => id,
None => uuid::Uuid::new_v4().to_string(),
};
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
// Create a span with request_id that will be included in all log lines
let request_span = info_span!(
"llm",
component = "llm",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
llm.model = tracing::field::Empty,
llm.tools = tracing::field::Empty,
llm.user_message_preview = tracing::field::Empty,
llm.temperature = tracing::field::Empty,
);
// Execute the rest of the handler inside the span
llm_chat_inner(
request,
router_service,
full_qualified_llm_provider_url,
model_aliases,
llm_providers,
custom_attrs,
state_storage,
request_id,
request_path,
request_headers,
)
.instrument(request_span)
.await
}
#[allow(clippy::too_many_arguments)]
async fn llm_chat_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
custom_attrs: HashMap<String, String>,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
// Extract or generate traceparent - this establishes the trace context for all spans
let traceparent: String = match request_headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(tp) => tp,
None => {
use uuid::Uuid;
let trace_id = Uuid::new_v4().to_string().replace("-", "");
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %generated_tp,
"TRACE_PARENT header missing, generated new traceparent"
);
generated_tp
}
};
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
body = %String::from_utf8_lossy(&chat_request_bytes),
"request body received"
);
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(
error = %err,
"failed to parse request as ProviderRequestType"
);
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
// If model is not specified in the request, resolve from default provider
let model_from_request = client_request.model().to_string();
let model_from_request = if model_from_request.is_empty() {
match llm_providers.read().await.default() {
Some(default_provider) => {
let default_model = default_provider.name.clone();
info!(default_model = %default_model, "no model specified in request, using default provider");
client_request.set_model(default_model.clone());
default_model
}
None => {
let err_msg = "No model specified in request and no default provider configured";
warn!("{}", err_msg);
return Ok(BrightStaffError::NoModelSpecified.into_response());
}
}
} else {
model_from_request
};
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let temperature = client_request.get_temperature();
let is_streaming_request = client_request.is_streaming();
let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
// Validate that the requested model exists in configuration
// This matches the validation in llm_gateway routing.rs
if llm_providers
.read()
.await
.get(&alias_resolved_model)
.is_none()
{
warn!(model = %alias_resolved_model, "model not found in configured providers");
return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response());
}
// Handle provider/model slug format (e.g., "openai/gpt-4")
// Extract just the model name for upstream (providers don't understand the slug)
let model_name_only = if let Some((_, model)) = alias_resolved_model.split_once('/') {
model.to_string()
} else {
alias_resolved_model.clone()
};
// Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names();
let user_message_preview = client_request
.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50));
let span = tracing::Span::current();
if let Some(temp) = temperature {
span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp));
}
if let Some(tools) = &tool_names {
let formatted_tools = tools
.iter()
.map(|name| format!("{}(...)", name))
.collect::<Vec<_>>()
.join("\n");
span.record(tracing_llm::TOOLS, formatted_tools.as_str());
}
if let Some(preview) = &user_message_preview {
span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str());
}
// Extract messages for signal analysis (clone before moving client_request)
let messages_for_signals = Some(client_request.get_messages());
// Set the model to just the model name (without provider prefix)
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
client_request.set_model(model_name_only.clone());
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client {
if let (
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
Some(ref state_store),
) = (&mut client_request, &state_storage)
{
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
// Get the upstream path and check if it's ResponsesAPI
let upstream_path = get_upstream_path(
&llm_providers,
&alias_resolved_model,
&request_path,
&alias_resolved_model,
is_streaming_request,
)
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_store.clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
.await
{
Ok(combined_input) => {
// Update both the request and original_input_items
responses_req.input = InputParam::Items(combined_input.clone());
original_input_items = combined_input;
info!(
items = original_input_items.len(),
"updated request with conversation history"
);
}
Err(StateStorageError::NotFound(_)) => {
// Return 409 Conflict when previous_response_id not found
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
return Ok(BrightStaffError::ConversationStateNotFound(
prev_resp_id.to_string(),
)
.into_response());
}
Err(e) => {
// Log warning but continue on other storage errors
warn!(
previous_response_id = %prev_resp_id,
error = %e,
"failed to retrieve conversation state"
);
// Restore original_input_items since we passed ownership
original_input_items = extract_input_items(&responses_req.input);
}
}
}
} else {
debug!("upstream supports ResponsesAPI natively");
}
}
}
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Determine routing using the dedicated router_chat module
// This gets its own span for latency and error tracking
let routing_span = info_span!(
"routing",
component = "routing",
http.method = "POST",
http.target = %request_path,
model.requested = %model_from_request,
model.alias_resolved = %alias_resolved_model,
route.selected_model = tracing::field::Empty,
routing.determination_ms = tracing::field::Empty,
);
let routing_result = match async {
set_service_name(operation_component::ROUTING);
router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
&traceparent,
&request_path,
&request_id,
)
.await
}
.instrument(routing_span)
.await
{
Ok(result) => result,
Err(err) => {
return Ok(BrightStaffError::ForwardedError {
status_code: err.status_code,
message: err.message,
}
.into_response());
}
};
// Determine final model to use
// Router returns "none" as a sentinel value when it doesn't select a specific model
let router_selected_model = routing_result.model_name;
let resolved_model = if router_selected_model != "none" {
// Router selected a specific model via routing preferences
router_selected_model
} else {
// Router returned "none" sentinel, use validated resolved_model from request
alias_resolved_model.clone()
};
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
} else {
format!(
"POST {} {} -> {}",
request_path, model_from_request, resolved_model
)
};
get_active_span(|span| {
span.update_name(span_name.clone());
});
debug!(
url = %full_qualified_llm_provider_url,
provider_hint = %resolved_model,
upstream_model = %model_name_only,
"Routing to upstream"
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&resolved_model).unwrap(),
);
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
global::get_text_map_propagator(|propagator| {
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();
let llm_response = match reqwest::Client::new()
.post(&full_qualified_llm_provider_url)
.headers(request_headers)
.body(client_request_bytes_for_upstream)
.send()
.await
{
Ok(res) => res,
Err(err) => {
return Ok(BrightStaffError::InternalServerError(format!(
"Failed to send request: {}",
err
))
.into_response());
}
};
// copy over the headers and status code from the original response
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Build LLM span with actual status code using constants
let byte_stream = llm_response.bytes_stream();
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
operation_component::LLM,
span_name,
request_start_time,
messages_for_signals,
);
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_store,
original_input_items,
alias_resolved_model.clone(),
resolved_model.clone(),
is_streaming_request,
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
request_id,
);
create_streaming_response(byte_stream, state_processor, 16)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => Ok(BrightStaffError::InternalServerError(format!(
"Failed to create response: {}",
err
))
.into_response()),
}
}
/// Resolves model aliases by looking up the requested model in the model_aliases map.
/// Returns the target model if an alias is found, otherwise returns the original model.
fn resolve_model_alias(
model_from_request: &str,
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
) -> String {
if let Some(aliases) = model_aliases.as_ref() {
if let Some(model_alias) = aliases.get(model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To {}'",
model_from_request, model_alias.target
);
return model_alias.target.clone();
}
}
model_from_request.to_string()
}
/// Calculates the upstream path for the provider based on the model name.
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
async fn get_upstream_path(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
request_path: &str,
resolved_model: &str,
is_streaming: bool,
) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
// Calculate the upstream path using the proper API
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
.expect("Should have valid API endpoint");
client_api.target_endpoint_for_provider(
&provider_id,
request_path,
resolved_model,
is_streaming,
base_url_path_prefix.as_deref(),
)
}
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
async fn get_provider_info(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
let providers_lock = llm_providers.read().await;
// Try to find by model name or provider name using LlmProviders::get
// This handles both "gpt-4" and "openai/gpt-4" formats
if let Some(provider) = providers_lock.get(model_name) {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
}
// Fall back to default provider
if let Some(provider) = providers_lock.default() {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,34 @@
use common::configuration::ModelUsagePreference;
use common::configuration::TopLevelRoutingPreference;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hermesllm::ProviderRequestType;
use hyper::StatusCode;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::metrics as bs_metrics;
use crate::metrics::labels as metric_labels;
use crate::router::orchestrator::OrchestratorService;
use crate::streaming::truncate_message;
use crate::tracing::routing;
/// Classify a request path (already stripped of `/agents` or `/routing` by
/// the caller) into the fixed `route` label used on routing metrics.
fn route_label_for_path(request_path: &str) -> &'static str {
if request_path.starts_with("/agents") {
metric_labels::ROUTE_AGENT
} else if request_path.starts_with("/routing") {
metric_labels::ROUTE_ROUTING
} else {
metric_labels::ROUTE_LLM
}
}
pub struct RoutingResult {
/// Primary model to use (first in the ranked list).
pub model_name: String,
/// Full ranked list — use subsequent entries as fallbacks on 429/5xx.
pub models: Vec<String>,
pub route_name: Option<String>,
}
pub struct RoutingError {
@ -32,15 +51,12 @@ impl RoutingError {
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
/// * `Err(RoutingError)` - Contains error details and optional span ID
pub async fn router_chat_get_upstream_model(
router_service: Arc<RouterService>,
orchestrator_service: Arc<OrchestratorService>,
client_request: ProviderRequestType,
traceparent: &str,
request_path: &str,
request_id: &str,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
@ -75,17 +91,6 @@ pub async fn router_chat_get_upstream_model(
"router request"
);
// Extract usage preferences from metadata
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
.messages
@ -96,19 +101,9 @@ pub async fn router_chat_get_upstream_model(
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated)
} else {
latest_message_for_log
};
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
info!(
has_usage_preferences = usage_preferences.is_some(),
path = %request_path,
latest_message = %latest_message_for_log,
"processing router request"
@ -117,39 +112,59 @@ pub async fn router_chat_get_upstream_model(
// Capture start time for routing span
let routing_start_time = std::time::Instant::now();
// Attempt to determine route using the router service
let routing_result = router_service
let routing_result = orchestrator_service
.determine_route(
&chat_request.messages,
traceparent,
usage_preferences,
inline_routing_preferences,
request_id,
)
.await;
let determination_ms = routing_start_time.elapsed().as_millis() as i64;
let determination_elapsed = routing_start_time.elapsed();
let determination_ms = determination_elapsed.as_millis() as i64;
let current_span = tracing::Span::current();
current_span.record(routing::ROUTE_DETERMINATION_MS, determination_ms);
let route_label = route_label_for_path(request_path);
match routing_result {
Ok(route) => match route {
Some((_, model_name)) => {
Some((route_name, ranked_models)) => {
let model_name = ranked_models.first().cloned().unwrap_or_default();
current_span.record("route.selected_model", model_name.as_str());
Ok(RoutingResult { model_name })
bs_metrics::record_router_decision(
route_label,
&model_name,
false,
determination_elapsed,
);
Ok(RoutingResult {
model_name,
models: ranked_models,
route_name: Some(route_name),
})
}
None => {
// No route determined, return sentinel value "none"
// This signals to llm.rs to use the original validated request model
current_span.record("route.selected_model", "none");
info!("no route determined, using default model");
bs_metrics::record_router_decision(
route_label,
"none",
true,
determination_elapsed,
);
Ok(RoutingResult {
model_name: "none".to_string(),
models: vec!["none".to_string()],
route_name: None,
})
}
},
Err(err) => {
current_span.record("route.selected_model", "unknown");
bs_metrics::record_router_decision(route_label, "unknown", true, determination_elapsed);
Err(RoutingError::internal_error(format!(
"Failed to determine route: {}",
err

View file

@ -1,13 +1,58 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod agents;
pub mod debug;
pub mod function_calling;
pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod utils;
pub mod response;
pub mod routing_service;
#[cfg(test)]
mod integration_tests;
use bytes::Bytes;
use common::consts::TRACE_PARENT_HEADER;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::Request;
use tracing::warn;
/// Wrap a chunk into a `BoxBody` for hyper responses.
pub fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
/// An empty HTTP body (used for 404 / OPTIONS responses).
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
/// Extract request ID from incoming request headers, or generate a new UUID v4.
pub fn extract_request_id<T>(request: &Request<T>) -> String {
request
.headers()
.get(common::consts::REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
}
/// Extract or generate a W3C `traceparent` header value.
pub fn extract_or_generate_traceparent(headers: &hyper::HeaderMap) -> String {
headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
let tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %tp,
"TRACE_PARENT header missing, generated new traceparent"
);
tp
})
}

View file

@ -1,10 +1,11 @@
use bytes::Bytes;
use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use http_body_util::combinators::BoxBody;
use hyper::{Response, StatusCode};
use serde_json;
use std::sync::Arc;
use super::full;
pub async fn list_models(
llm_providers: Arc<tokio::sync::RwLock<LlmProviders>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
@ -12,27 +13,15 @@ pub async fn list_models(
let models = prov.to_models();
match serde_json::to_string(&models) {
Ok(json) => {
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Err(_) => {
let body = Full::new(Bytes::from_static(
b"{\"error\":\"Failed to serialize models\"}",
))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Ok(json) => Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(full(json))
.unwrap(),
Err(_) => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(full("{\"error\":\"Failed to serialize models\"}"))
.unwrap(),
}
}

View file

@ -4,7 +4,7 @@ use hermesllm::apis::OpenAIApi;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::SseEvent;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use http_body_util::StreamBody;
use hyper::body::Frame;
use hyper::Response;
use tokio::sync::mpsc;
@ -12,6 +12,8 @@ use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn, Instrument};
use super::full;
/// Service for handling HTTP responses and streaming
pub struct ResponseHandler;
@ -22,9 +24,40 @@ impl ResponseHandler {
/// Create a full response body from bytes
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
full(chunk)
}
/// Create a JSON error response with BAD_REQUEST status
pub fn create_json_error_response(
json: &serde_json::Value,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let body = Self::create_full_body(json.to_string());
let mut response = Response::new(body);
*response.status_mut() = hyper::StatusCode::BAD_REQUEST;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
response
}
/// Create a BAD_REQUEST error response with a message
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
let json = serde_json::json!({"error": message});
Self::create_json_error_response(&json)
}
/// Create an INTERNAL_SERVER_ERROR response with a message
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
let json = serde_json::json!({"error": message});
let body = Self::create_full_body(json.to_string());
let mut response = Response::new(body);
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
response
}
/// Create a streaming response from a reqwest response.
@ -112,7 +145,9 @@ impl ResponseHandler {
let upstream_api =
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap();
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).map_err(|e| {
BrightStaffError::StreamError(format!("Failed to parse SSE stream: {}", e))
})?;
let mut accumulated_text = String::new();
for sse_event in sse_iter {
@ -122,7 +157,13 @@ impl ResponseHandler {
}
let transformed_event =
SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap();
match SseEvent::try_from((sse_event, &client_api, &upstream_api)) {
Ok(event) => event,
Err(e) => {
warn!(error = ?e, "failed to transform SSE event, skipping");
continue;
}
};
// Try to get provider response and extract content delta
match transformed_event.provider_response() {

View file

@ -0,0 +1,436 @@
use bytes::Bytes;
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
use common::consts::{MODEL_AFFINITY_HEADER, REQUEST_ID_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use std::sync::Arc;
use tracing::{debug, info, info_span, warn, Instrument};
use super::extract_or_generate_traceparent;
use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
use crate::metrics as bs_metrics;
use crate::metrics::labels as metric_labels;
use crate::router::orchestrator::OrchestratorService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
/// and the parsed preferences. The field is removed from the JSON before re-serializing
/// so downstream parsers don't see it.
pub fn extract_routing_policy(
raw_bytes: &[u8],
) -> Result<(Bytes, Option<Vec<TopLevelRoutingPreference>>), String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let routing_preferences = json_body
.as_object_mut()
.and_then(|o| o.remove("routing_preferences"))
.and_then(
|value| match serde_json::from_value::<Vec<TopLevelRoutingPreference>>(value) {
Ok(prefs) => {
info!(
num_routes = prefs.len(),
"using inline routing_preferences from request body"
);
Some(prefs)
}
Err(err) => {
warn!(error = %err, "failed to parse routing_preferences");
None
}
},
);
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, routing_preferences))
}
#[derive(serde::Serialize)]
struct RoutingDecisionResponse {
/// Ranked model list — use first, fall back to next on 429/5xx.
models: Vec<String>,
route: Option<String>,
trace_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
pinned: bool,
}
pub async fn routing_decision(
request: Request<hyper::body::Incoming>,
orchestrator_service: Arc<OrchestratorService>,
request_path: String,
span_attributes: &Option<SpanAttributes>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_headers = request.headers().clone();
let request_id: String = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let session_id: Option<String> = request_headers
.get(MODEL_AFFINITY_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let tenant_id: Option<String> = orchestrator_service
.tenant_header()
.and_then(|hdr| request_headers.get(hdr))
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
let request_span = info_span!(
"routing_decision",
component = "routing",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
);
routing_decision_inner(
request,
orchestrator_service,
request_id,
request_path,
request_headers,
custom_attrs,
session_id,
tenant_id,
)
.instrument(request_span)
.await
}
#[allow(clippy::too_many_arguments)]
async fn routing_decision_inner(
request: Request<hyper::body::Incoming>,
orchestrator_service: Arc<OrchestratorService>,
request_id: String,
request_path: String,
request_headers: hyper::HeaderMap,
custom_attrs: std::collections::HashMap<String, String>,
session_id: Option<String>,
tenant_id: Option<String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
set_service_name(operation_component::ROUTING);
opentelemetry::trace::get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
let traceparent = extract_or_generate_traceparent(&request_headers);
// Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags})
let trace_id = traceparent
.split('-')
.nth(1)
.unwrap_or("unknown")
.to_string();
if let Some(ref sid) = session_id {
if let Some(cached) = orchestrator_service
.get_cached_route(sid, tenant_id.as_deref())
.await
{
info!(
session_id = %sid,
model = %cached.model_name,
route = ?cached.route_name,
"returning pinned routing decision from cache"
);
let response = RoutingDecisionResponse {
models: vec![cached.model_name],
route: cached.route_name,
trace_id,
session_id: Some(sid.clone()),
pinned: true,
};
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap());
}
}
// Parse request body
let raw_bytes = request.collect().await?.to_bytes();
debug!(
body = %String::from_utf8_lossy(&raw_bytes),
"routing decision request body received"
);
// Extract routing_preferences from body before parsing as ProviderRequestType
let (chat_request_bytes, inline_routing_preferences) = match extract_routing_policy(&raw_bytes)
{
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
let client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(error = %err, "failed to parse request for routing decision");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
let routing_result = router_chat_get_upstream_model(
Arc::clone(&orchestrator_service),
client_request,
&request_path,
&request_id,
inline_routing_preferences,
)
.await;
match routing_result {
Ok(result) => {
if let Some(ref sid) = session_id {
orchestrator_service
.cache_route(
sid.clone(),
tenant_id.as_deref(),
result.model_name.clone(),
result.route_name.clone(),
)
.await;
}
let response = RoutingDecisionResponse {
models: result.models,
route: result.route_name,
trace_id,
session_id,
pinned: false,
};
// Distinguish "decision served" (a concrete model picked) from
// "no_candidates" (the sentinel "none" returned when nothing
// matched). The handler still responds 200 in both cases, so RED
// metrics alone can't tell them apart.
let outcome = if response.models.first().map(|m| m == "none").unwrap_or(true) {
metric_labels::ROUTING_SVC_NO_CANDIDATES
} else {
metric_labels::ROUTING_SVC_DECISION_SERVED
};
bs_metrics::record_routing_service_outcome(outcome);
info!(
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
total_models = response.models.len(),
route = ?response.route,
"routing decision completed"
);
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap())
}
Err(err) => {
bs_metrics::record_routing_service_outcome(metric_labels::ROUTING_SVC_POLICY_ERROR);
warn!(error = %err.message, "routing decision failed");
Ok(BrightStaffError::InternalServerError(err.message).into_response())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::SelectionPreference;
fn make_chat_body(extra_fields: &str) -> Vec<u8> {
let extra = if extra_fields.is_empty() {
String::new()
} else {
format!(", {}", extra_fields)
};
format!(
r#"{{"model": "gpt-4o-mini", "messages": [{{"role": "user", "content": "hello"}}]{}}}"#,
extra
)
.into_bytes()
}
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
assert!(prefs.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
}
#[test]
fn extract_routing_policy_invalid_json_returns_error() {
let body = b"not valid json";
let result = extract_routing_policy(body);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to parse JSON"));
}
#[test]
fn extract_routing_policy_routing_preferences() {
let policy = r#""routing_preferences": [
{
"name": "code generation",
"description": "generate new code",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": "fastest"}
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs = prefs.expect("should have parsed routing_preferences");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].name, "code generation");
assert_eq!(prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_preferences": [{"name": "test", "description": "test", "models": ["gpt-4o"], "selection_policy": {"prefer": "none"}}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["temperature"], 0.5);
assert_eq!(cleaned_json["max_tokens"], 100);
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn extract_routing_policy_prefer_null_defaults_to_none() {
let policy = r#""routing_preferences": [
{
"name": "coding",
"description": "code generation, writing functions, debugging",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": null}
}
]"#;
let body = make_chat_body(policy);
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs = prefs.expect("should parse routing_preferences when prefer is null");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
}
#[test]
fn extract_routing_policy_selection_policy_missing_defaults_to_none() {
let policy = r#""routing_preferences": [
{
"name": "coding",
"description": "code generation, writing functions, debugging",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"]
}
]"#;
let body = make_chat_body(policy);
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs =
prefs.expect("should parse routing_preferences when selection_policy is missing");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
}
#[test]
fn extract_routing_policy_prefer_empty_string_defaults_to_none() {
let policy = r#""routing_preferences": [
{
"name": "coding",
"description": "code generation, writing functions, debugging",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": ""}
}
]"#;
let body = make_chat_body(policy);
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs =
prefs.expect("should parse routing_preferences when selection_policy.prefer is empty");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
}
#[test]
fn routing_decision_response_serialization() {
let response = RoutingDecisionResponse {
models: vec![
"openai/gpt-4o-mini".to_string(),
"openai/gpt-4o".to_string(),
],
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
session_id: Some("sess-abc".to_string()),
pinned: true,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["models"][0], "openai/gpt-4o-mini");
assert_eq!(parsed["models"][1], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
assert_eq!(parsed["session_id"], "sess-abc");
assert_eq!(parsed["pinned"], true);
}
#[test]
fn routing_decision_response_serialization_no_route() {
let response = RoutingDecisionResponse {
models: vec!["none".to_string()],
route: None,
trace_id: "abc123".to_string(),
session_id: None,
pinned: false,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["models"][0], "none");
assert!(parsed["route"].is_null());
assert!(parsed.get("session_id").is_none());
assert_eq!(parsed["pinned"], false);
}
}

View file

@ -1,288 +0,0 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::KeyValue;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::Message;
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
pub trait StreamProcessor: Send + 'static {
/// Process an incoming chunk of bytes
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String>;
/// Called when the first bytes are received (for time-to-first-token tracking)
fn on_first_bytes(&mut self) {}
/// Called when streaming completes successfully
fn on_complete(&mut self) {}
/// Called when streaming encounters an error
fn on_error(&mut self, _error: &str) {}
}
/// A processor that tracks streaming metrics
pub struct ObservableStreamProcessor {
service_name: String,
operation_name: String,
total_bytes: usize,
chunk_count: usize,
start_time: Instant,
time_to_first_token: Option<u128>,
messages: Option<Vec<Message>>,
}
impl ObservableStreamProcessor {
/// Create a new passthrough processor
///
/// # Arguments
/// * `service_name` - The service name for this span (e.g., "plano(llm)")
/// This will be set as the `service.name.override` attribute on the current span,
/// allowing the ServiceNameOverrideExporter to route spans to different services.
/// * `operation_name` - The current span operation name (e.g., "POST /v1/chat/completions gpt-4")
/// Used to append the flag marker when concerning signals are detected.
/// * `start_time` - When the request started (for duration calculation)
/// * `messages` - Optional conversation messages for signal analysis
pub fn new(
service_name: impl Into<String>,
operation_name: impl Into<String>,
start_time: Instant,
messages: Option<Vec<Message>>,
) -> Self {
let service_name = service_name.into();
// Set the service name override on the current span for OpenTelemetry export
// This allows the ServiceNameOverrideExporter to route this span to the correct service
set_service_name(&service_name);
Self {
service_name,
operation_name: operation_name.into(),
total_bytes: 0,
chunk_count: 0,
start_time,
time_to_first_token: None,
messages,
}
}
}
impl StreamProcessor for ObservableStreamProcessor {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
self.total_bytes += chunk.len();
self.chunk_count += 1;
Ok(Some(chunk))
}
fn on_first_bytes(&mut self) {
// Record time to first token (only for streaming)
if self.time_to_first_token.is_none() {
self.time_to_first_token = Some(self.start_time.elapsed().as_millis());
}
}
fn on_complete(&mut self) {
// Record time-to-first-token as an OTel span attribute + event (streaming only)
if let Some(ttft) = self.time_to_first_token {
let span = tracing::Span::current();
let otel_context = span.context();
let otel_span = otel_context.span();
otel_span.set_attribute(KeyValue::new(llm::TIME_TO_FIRST_TOKEN_MS, ttft as i64));
otel_span.add_event(
llm::TIME_TO_FIRST_TOKEN_MS,
vec![KeyValue::new(llm::TIME_TO_FIRST_TOKEN_MS, ttft as i64)],
);
}
// Analyze signals if messages are available and record as span attributes
if let Some(ref messages) = self.messages {
let analyzer: Box<dyn SignalAnalyzer> = Box::new(TextBasedSignalAnalyzer::new());
let report = analyzer.analyze(messages);
// Get the current OTel span to set signal attributes
let span = tracing::Span::current();
let otel_context = span.context();
let otel_span = otel_context.span();
// Add overall quality
otel_span.set_attribute(KeyValue::new(
signal_constants::QUALITY,
format!("{:?}", report.overall_quality),
));
// Add repair/follow-up metrics if concerning
if report.follow_up.is_concerning || report.follow_up.repair_count > 0 {
otel_span.set_attribute(KeyValue::new(
signal_constants::REPAIR_COUNT,
report.follow_up.repair_count as i64,
));
otel_span.set_attribute(KeyValue::new(
signal_constants::REPAIR_RATIO,
format!("{:.3}", report.follow_up.repair_ratio),
));
}
// Add frustration metrics
if report.frustration.has_frustration {
otel_span.set_attribute(KeyValue::new(
signal_constants::FRUSTRATION_COUNT,
report.frustration.frustration_count as i64,
));
otel_span.set_attribute(KeyValue::new(
signal_constants::FRUSTRATION_SEVERITY,
report.frustration.severity as i64,
));
}
// Add repetition metrics
if report.repetition.has_looping {
otel_span.set_attribute(KeyValue::new(
signal_constants::REPETITION_COUNT,
report.repetition.repetition_count as i64,
));
}
// Add escalation metrics
if report.escalation.escalation_requested {
otel_span
.set_attribute(KeyValue::new(signal_constants::ESCALATION_REQUESTED, true));
}
// Add positive feedback metrics
if report.positive_feedback.has_positive_feedback {
otel_span.set_attribute(KeyValue::new(
signal_constants::POSITIVE_FEEDBACK_COUNT,
report.positive_feedback.positive_count as i64,
));
}
// Flag the span name if any concerning signal is detected
let should_flag = report.frustration.has_frustration
|| report.repetition.has_looping
|| report.escalation.escalation_requested
|| matches!(
report.overall_quality,
InteractionQuality::Poor | InteractionQuality::Severe
);
if should_flag {
otel_span.update_name(format!("{} {}", self.operation_name, FLAG_MARKER));
}
}
info!(
service = %self.service_name,
total_bytes = self.total_bytes,
chunk_count = self.chunk_count,
duration_ms = self.start_time.elapsed().as_millis(),
time_to_first_token_ms = ?self.time_to_first_token,
"streaming completed"
);
}
fn on_error(&mut self, error_msg: &str) {
warn!(
service = %self.service_name,
error = error_msg,
duration_ms = self.start_time.elapsed().as_millis(),
"stream error"
);
}
}
/// Result of creating a streaming response
pub struct StreamingResponse {
pub body: BoxBody<Bytes, hyper::Error>,
pub processor_handle: tokio::task::JoinHandle<()>,
}
pub fn create_streaming_response<S, P>(
mut byte_stream: S,
mut processor: P,
buffer_size: usize,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
// Capture the current span so the spawned task inherits the request context
let current_span = tracing::Span::current();
// Spawn a task to process and forward chunks
let processor_handle = tokio::spawn(
async move {
let mut is_first_chunk = true;
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!(error = %err_msg, "stream error");
processor.on_error(&err_msg);
break;
}
};
// Call on_first_bytes for the first chunk
if is_first_chunk {
processor.on_first_bytes();
is_first_chunk = false;
}
// Process the chunk
match processor.process_chunk(chunk) {
Ok(Some(processed_chunk)) => {
if tx.send(processed_chunk).await.is_err() {
warn!("receiver dropped");
break;
}
}
Ok(None) => {
// Skip this chunk
continue;
}
Err(err) => {
warn!("processor error: {}", err);
processor.on_error(&err);
break;
}
}
}
processor.on_complete();
}
.instrument(current_span),
);
// Convert channel receiver to HTTP stream
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
StreamingResponse {
body: stream_body,
processor_handle,
}
}
/// Truncates a message to the specified maximum length, adding "..." if truncated.
pub fn truncate_message(message: &str, max_length: usize) -> String {
if message.chars().count() > max_length {
let truncated: String = message.chars().take(max_length).collect();
format!("{}...", truncated)
} else {
message.to_string()
}
}

View file

@ -1,6 +1,9 @@
pub mod app_state;
pub mod handlers;
pub mod metrics;
pub mod router;
pub mod session_cache;
pub mod signals;
pub mod state;
pub mod streaming;
pub mod tracing;
pub mod utils;

View file

@ -1,293 +1,590 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
#[cfg(feature = "jemalloc")]
#[global_allocator]
static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
use brightstaff::app_state::AppState;
use brightstaff::handlers::agents::orchestrator::agent_chat;
use brightstaff::handlers::debug;
use brightstaff::handlers::empty;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::metrics as bs_metrics;
use brightstaff::metrics::labels as metric_labels;
use brightstaff::router::model_metrics::ModelMetricsService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::session_cache::init_session_cache;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use brightstaff::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
use common::configuration::{
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use http_body_util::combinators::BoxBody;
use hyper::body::Incoming;
use hyper::header::HeaderValue;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use opentelemetry::global;
use opentelemetry::trace::FutureExt;
use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor;
use std::collections::HashMap;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub mod router;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
// Utility function to extract the context from the incoming request headers
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(req.headers()))
})
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
/// Missing parts default to 0. Non-numeric parts are treated as 0.
fn parse_semver(version: &str) -> (u32, u32, u32) {
let v = version.trim_start_matches('v');
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
let major = parts.next().unwrap_or(0);
let minor = parts.next().unwrap_or(0);
let patch = parts.next().unwrap_or(0);
(major, minor, patch)
}
fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
/// CORS pre-flight response for the models endpoint.
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut response = Response::new(empty());
*response.status_mut() = StatusCode::NO_CONTENT;
let h = response.headers_mut();
h.insert("Allow", HeaderValue::from_static("GET, OPTIONS"));
h.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
h.insert(
"Access-Control-Allow-Headers",
HeaderValue::from_static("Authorization, Content-Type"),
);
h.insert(
"Access-Control-Allow-Methods",
HeaderValue::from_static("GET, POST, OPTIONS"),
);
h.insert("Content-Type", HeaderValue::from_static("application/json"));
Ok(response)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
// ---------------------------------------------------------------------------
// Configuration loading
// ---------------------------------------------------------------------------
// loading plano_config.yaml file (before tracing init so we can read tracing config)
let plano_config_path = env::var("PLANO_CONFIG_PATH_RENDERED")
/// Load and parse the YAML configuration file.
///
/// The path is read from `PLANO_CONFIG_PATH_RENDERED` (env) or falls back to
/// `./plano_config_rendered.yaml`.
fn load_config() -> Result<Configuration, Box<dyn std::error::Error + Send + Sync>> {
let path = env::var("PLANO_CONFIG_PATH_RENDERED")
.unwrap_or_else(|_| "./plano_config_rendered.yaml".to_string());
eprintln!("loading plano_config.yaml from {}", plano_config_path);
eprintln!("loading plano_config.yaml from {}", path);
let config_contents =
fs::read_to_string(&plano_config_path).expect("Failed to read plano_config.yaml");
let contents = fs::read_to_string(&path).map_err(|e| format!("failed to read {path}: {e}"))?;
let config: Configuration =
serde_yaml::from_str(&config_contents).expect("Failed to parse plano_config.yaml");
serde_yaml::from_str(&contents).map_err(|e| format!("failed to parse {path}: {e}"))?;
// Initialize tracing using config.yaml tracing section
let _tracer_provider = init_tracer(config.tracing.as_ref());
info!(path = %plano_config_path, "loaded plano_config.yaml");
Ok(config)
}
let plano_config = Arc::new(config);
// ---------------------------------------------------------------------------
// Application state initialization
// ---------------------------------------------------------------------------
// combine agents and filters into a single list of agents
let all_agents: Vec<Agent> = plano_config
/// Build the shared [`AppState`] from a parsed [`Configuration`].
async fn init_app_state(
config: &Configuration,
) -> Result<AppState, Box<dyn std::error::Error + Send + Sync>> {
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
// Combine agents and filters into a single list
let all_agents: Vec<Agent> = config
.agents
.as_deref()
.unwrap_or_default()
.iter()
.chain(plano_config.filters.as_deref().unwrap_or_default())
.chain(config.filters.as_deref().unwrap_or_default())
.cloned()
.collect();
// Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
let global_agent_map: HashMap<String, Agent> = all_agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect();
let listener = TcpListener::bind(bind_address).await?;
let routing_model_name: String = plano_config
.routing
.as_ref()
.and_then(|r| r.model.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
let llm_providers = LlmProviders::try_from(config.model_providers.clone())
.map_err(|e| format!("failed to create LlmProviders: {e}"))?;
let routing_llm_provider = plano_config
.routing
.as_ref()
.and_then(|r| r.model_provider.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let model_listener_count = config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
if model_listener_count > 1 {
return Err(format!(
"only one model listener is allowed, found {}",
model_listener_count
)
.into());
}
let model_listener = config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
filter_ids.map(|ids| {
let agents = ids
.iter()
.filter_map(|id| {
global_agent_map
.get(id)
.map(|a: &Agent| (id.clone(), a.clone()))
})
.collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
})
};
let filter_pipeline = Arc::new(FilterPipeline {
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
});
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
plano_config.model_providers.clone(),
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,
));
let overrides = config.overrides.clone().unwrap_or_default();
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
));
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
let session_cache = init_session_cache(config).await?;
let model_aliases = Arc::new(plano_config.model_aliases.clone());
let span_attributes = Arc::new(
plano_config
.tracing
.as_ref()
.and_then(|tracing| tracing.span_attributes.clone()),
);
// Validate that top-level routing_preferences requires v0.4.0+.
let config_version = parse_semver(&config.version);
let is_v040_plus = config_version >= (0, 4, 0);
// Initialize trace collector and start background flusher
// Tracing is enabled if the tracing config is present in plano_config.yaml
// Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED
// OpenTelemetry automatic instrumentation is configured in utils/tracing.rs
if !is_v040_plus && config.routing_preferences.is_some() {
return Err(
"top-level routing_preferences requires version v0.4.0 or above. \
Update the version field or remove routing_preferences."
.into(),
);
}
// Initialize conversation state storage for v1/responses
// Configurable via plano_config.yaml state_storage section
// If not configured, state management is disabled
// Environment variables are substituted by envsubst before config is read
let state_storage: Option<Arc<dyn StateStorage>> =
if let Some(storage_config) = &plano_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!(
storage_type = "memory",
"initialized conversation state storage"
);
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
debug!(connection_string = %connection_string, "postgres connection");
info!(
storage_type = "postgres",
"initializing conversation state storage"
);
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
// The CLI renders model_providers with `name` = "openai/gpt-4o" and `model` = "gpt-4o",
// so we accept a match against either field.
if let Some(ref route_prefs) = config.routing_preferences {
let provider_model_names: std::collections::HashSet<&str> = config
.model_providers
.iter()
.flat_map(|p| std::iter::once(p.name.as_str()).chain(p.model.as_deref()))
.collect();
for pref in route_prefs {
for model in &pref.models {
if !provider_model_names.contains(model.as_str()) {
return Err(format!(
"routing_preferences route '{}' references model '{}' \
which is not declared in model_providers",
pref.name, model
)
}
};
Some(storage)
} else {
info!("no state_storage configured, conversation state management disabled");
None
};
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
let model_aliases: Arc<
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
> = Arc::clone(&model_aliases);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let orchestrator_service = Arc::clone(&orchestrator_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
async move {
let path = req.uri().path();
// Check if path starts with /agents
if path.starts_with("/agents") {
// Check if it matches one of the agent API paths
let stripped_path = path.strip_prefix("/agents").unwrap();
if matches!(
stripped_path,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
return agent_chat(
req,
orchestrator_service,
fully_qualified_url,
agents_list,
listeners,
span_attributes,
llm_providers,
)
.with_context(parent_cx)
.await;
}
}
match (req.method(), path) {
(
&Method::POST,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
llm_chat(
req,
router_service,
fully_qualified_url,
model_aliases,
llm_providers,
span_attributes,
state_storage,
)
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, "/v1/chat/completions");
function_calling_chat_handler(req, fully_qualified_url)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(llm_providers).await)
}
// hack for now to get openw-web-ui to work
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
let mut response = Response::new(empty());
*response.status_mut() = StatusCode::NO_CONTENT;
response
.headers_mut()
.insert("Allow", "GET, OPTIONS".parse().unwrap());
response
.headers_mut()
.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
response.headers_mut().insert(
"Access-Control-Allow-Headers",
"Authorization, Content-Type".parse().unwrap(),
);
response.headers_mut().insert(
"Access-Control-Allow-Methods",
"GET, POST, OPTIONS".parse().unwrap(),
);
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
_ => {
debug!(method = %req.method(), path = %req.uri().path(), "no route found");
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
.into());
}
}
});
}
}
tokio::task::spawn(async move {
debug!(peer = ?peer_addr, "accepted connection");
if let Err(err) = http1::Builder::new()
// .serve_connection(io, service_fn(chat_completion))
.serve_connection(io, service)
.await
{
warn!(error = ?err, "error serving connection");
// Validate and initialize ModelMetricsService if model_metrics_sources is configured.
let metrics_service: Option<Arc<ModelMetricsService>> = if let Some(ref sources) =
config.model_metrics_sources
{
use common::configuration::MetricsSource;
let cost_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::Cost(_)))
.count();
let latency_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::Latency(_)))
.count();
if cost_count > 1 {
return Err("model_metrics_sources: only one cost metrics source is allowed".into());
}
if latency_count > 1 {
return Err("model_metrics_sources: only one latency metrics source is allowed".into());
}
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
Some(Arc::new(svc))
} else {
None
};
// Validate that selection_policy.prefer is compatible with the configured metric sources.
if let Some(ref prefs) = config.routing_preferences {
use common::configuration::{MetricsSource, SelectionPreference};
let has_cost_source = config
.model_metrics_sources
.as_deref()
.unwrap_or_default()
.iter()
.any(|s| matches!(s, MetricsSource::Cost(_)));
let has_latency_source = config
.model_metrics_sources
.as_deref()
.unwrap_or_default()
.iter()
.any(|s| matches!(s, MetricsSource::Latency(_)));
for pref in prefs {
if pref.selection_policy.prefer == SelectionPreference::Cheapest && !has_cost_source {
return Err(format!(
"routing_preferences route '{}' uses prefer: cheapest but no cost metrics source is configured — \
add a cost metrics source to model_metrics_sources",
pref.name
)
.into());
}
});
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_latency_source {
return Err(format!(
"routing_preferences route '{}' uses prefer: fastest but no latency metrics source is configured — \
add a latency metrics source to model_metrics_sources",
pref.name
)
.into());
}
}
}
// Warn about models in routing_preferences that have no matching pricing/latency data.
if let (Some(ref prefs), Some(ref svc)) = (&config.routing_preferences, &metrics_service) {
let cost_data = svc.cost_snapshot().await;
let latency_data = svc.latency_snapshot().await;
for pref in prefs {
use common::configuration::SelectionPreference;
for model in &pref.models {
let missing = match pref.selection_policy.prefer {
SelectionPreference::Cheapest => !cost_data.contains_key(model.as_str()),
SelectionPreference::Fastest => !latency_data.contains_key(model.as_str()),
_ => false,
};
if missing {
warn!(
model = %model,
route = %pref.name,
"model has no metric data — will be ranked last"
);
}
}
}
}
let session_tenant_header = config
.routing
.as_ref()
.and_then(|r| r.session_cache.as_ref())
.and_then(|c| c.tenant_header.clone());
// Resolve model name: prefer llm_routing_model override, then agent_orchestration_model, then default.
let orchestrator_model_name: String = overrides
.llm_routing_model
.as_deref()
.or(overrides.agent_orchestration_model.as_deref())
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
.to_string();
let orchestrator_llm_provider: String = config
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
let orchestrator_max_tokens = overrides
.orchestrator_model_context_length
.unwrap_or(brightstaff::router::orchestrator_model_v1::MAX_TOKEN_LEN);
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
orchestrator_model_name,
orchestrator_llm_provider,
config.routing_preferences.clone(),
metrics_service,
session_ttl_seconds,
session_cache,
session_tenant_header,
orchestrator_max_tokens,
));
let state_storage = init_state_storage(config).await?;
let span_attributes = config
.tracing
.as_ref()
.and_then(|tracing| tracing.span_attributes.clone());
let signals_enabled = !overrides.disable_signals.unwrap_or(false);
Ok(AppState {
orchestrator_service,
model_aliases: config.model_aliases.clone(),
llm_providers: Arc::new(RwLock::new(llm_providers)),
agents_list: Some(all_agents),
listeners: config.listeners.clone(),
state_storage,
llm_provider_url,
span_attributes,
http_client: reqwest::Client::new(),
filter_pipeline,
signals_enabled,
})
}
/// Initialize the conversation state storage backend (if configured).
async fn init_state_storage(
config: &Configuration,
) -> Result<Option<Arc<dyn StateStorage>>, Box<dyn std::error::Error + Send + Sync>> {
let Some(storage_config) = &config.state_storage else {
info!("no state_storage configured, conversation state management disabled");
return Ok(None);
};
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!(
storage_type = "memory",
"initialized conversation state storage"
);
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.ok_or("connection_string is required for postgres state_storage")?;
debug!(connection_string = %connection_string, "postgres connection");
info!(
storage_type = "postgres",
"initializing conversation state storage"
);
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.map_err(|e| format!("failed to initialize Postgres state storage: {e}"))?,
)
}
};
Ok(Some(storage))
}
// ---------------------------------------------------------------------------
// Request routing
// ---------------------------------------------------------------------------
/// Normalized method label — limited set so we never emit a free-form string.
fn method_label(method: &Method) -> &'static str {
match *method {
Method::GET => "GET",
Method::POST => "POST",
Method::PUT => "PUT",
Method::DELETE => "DELETE",
Method::PATCH => "PATCH",
Method::HEAD => "HEAD",
Method::OPTIONS => "OPTIONS",
_ => "OTHER",
}
}
/// Compute the fixed `handler` metric label from the request's path+method.
/// Returning `None` for fall-through means `route()` will hand the request to
/// the catch-all 404 branch.
fn handler_label_for(method: &Method, path: &str) -> &'static str {
if let Some(stripped) = path.strip_prefix("/agents") {
if matches!(
stripped,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return metric_labels::HANDLER_AGENT_CHAT;
}
}
if let Some(stripped) = path.strip_prefix("/routing") {
if matches!(
stripped,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return metric_labels::HANDLER_ROUTING_DECISION;
}
}
match (method, path) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
metric_labels::HANDLER_LLM_CHAT
}
(&Method::POST, "/function_calling") => metric_labels::HANDLER_FUNCTION_CALLING,
(&Method::GET, "/v1/models" | "/agents/v1/models") => metric_labels::HANDLER_LIST_MODELS,
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
metric_labels::HANDLER_CORS_PREFLIGHT
}
_ => metric_labels::HANDLER_NOT_FOUND,
}
}
/// Route an incoming HTTP request to the appropriate handler.
async fn route(
req: Request<Incoming>,
state: Arc<AppState>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let handler = handler_label_for(req.method(), req.uri().path());
let method = method_label(req.method());
let started = std::time::Instant::now();
let _in_flight = bs_metrics::InFlightGuard::new(handler);
let result = dispatch(req, state).await;
let status = match &result {
Ok(resp) => resp.status().as_u16(),
// hyper::Error here means the body couldn't be produced; conventionally 500.
Err(_) => 500,
};
bs_metrics::record_http(handler, method, status, started);
result
}
/// Inner dispatcher split out so `route()` can wrap it with metrics without
/// duplicating the match tree.
async fn dispatch(
req: Request<Incoming>,
state: Arc<AppState>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers())));
let path = req.uri().path().to_string();
// --- Agent routes (/agents/...) ---
if let Some(stripped) = path.strip_prefix("/agents") {
if matches!(
stripped,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return agent_chat(req, Arc::clone(&state))
.with_context(parent_cx)
.await;
}
}
// --- Routing decision routes (/routing/...) ---
if let Some(stripped) = path.strip_prefix("/routing") {
let stripped = stripped.to_string();
if matches!(
stripped.as_str(),
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return routing_decision(
req,
Arc::clone(&state.orchestrator_service),
stripped,
&state.span_attributes,
)
.with_context(parent_cx)
.await;
}
}
// --- Standard routes ---
match (req.method(), path.as_str()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
llm_chat(req, Arc::clone(&state))
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let url = format!("{}/v1/chat/completions", state.llm_provider_url);
function_calling_chat_handler(req, url)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(Arc::clone(&state.llm_providers)).await)
}
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
(&Method::GET, "/debug/memstats") => debug::memstats().await,
_ => {
debug!(method = %req.method(), path = %path, "no route found");
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}
// ---------------------------------------------------------------------------
// Server loop
// ---------------------------------------------------------------------------
/// Accept connections and spawn a task per connection.
///
/// Listens for `SIGINT` / `ctrl-c` and shuts down gracefully, allowing
/// in-flight connections to finish.
async fn run_server(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
let listener = TcpListener::bind(&bind_address).await?;
info!(address = %bind_address, "server listening");
let shutdown = tokio::signal::ctrl_c();
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let state = Arc::clone(&state);
tokio::task::spawn(async move {
debug!(peer = ?peer_addr, "accepted connection");
let service = service_fn(move |req| {
let state = Arc::clone(&state);
async move { route(req, state).await }
});
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
warn!(error = ?err, "error serving connection");
}
});
}
_ = &mut shutdown => {
info!("received shutdown signal, stopping server");
break;
}
}
}
Ok(())
}
// ---------------------------------------------------------------------------
// Entry point
// ---------------------------------------------------------------------------
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let config = load_config()?;
let _tracer_provider = init_tracer(config.tracing.as_ref());
bs_metrics::init();
info!("loaded plano_config.yaml");
let state = Arc::new(init_app_state(&config).await?);
run_server(state).await
}

View file

@ -0,0 +1,38 @@
//! Fixed label-value constants so callers never emit free-form strings
//! (which would blow up cardinality).
// Handler enum — derived from the path+method match in `route()`.
pub const HANDLER_AGENT_CHAT: &str = "agent_chat";
pub const HANDLER_ROUTING_DECISION: &str = "routing_decision";
pub const HANDLER_LLM_CHAT: &str = "llm_chat";
pub const HANDLER_FUNCTION_CALLING: &str = "function_calling";
pub const HANDLER_LIST_MODELS: &str = "list_models";
pub const HANDLER_CORS_PREFLIGHT: &str = "cors_preflight";
pub const HANDLER_NOT_FOUND: &str = "not_found";
// Router "route" class — which brightstaff endpoint prompted the decision.
pub const ROUTE_AGENT: &str = "agent";
pub const ROUTE_ROUTING: &str = "routing";
pub const ROUTE_LLM: &str = "llm";
// Token kind for brightstaff_llm_tokens_total.
pub const TOKEN_KIND_PROMPT: &str = "prompt";
pub const TOKEN_KIND_COMPLETION: &str = "completion";
// LLM error_class values (match docstring in metrics/mod.rs).
pub const LLM_ERR_NONE: &str = "none";
pub const LLM_ERR_TIMEOUT: &str = "timeout";
pub const LLM_ERR_CONNECT: &str = "connect";
pub const LLM_ERR_PARSE: &str = "parse";
pub const LLM_ERR_OTHER: &str = "other";
pub const LLM_ERR_STREAM: &str = "stream";
// Routing service outcome values.
pub const ROUTING_SVC_DECISION_SERVED: &str = "decision_served";
pub const ROUTING_SVC_NO_CANDIDATES: &str = "no_candidates";
pub const ROUTING_SVC_POLICY_ERROR: &str = "policy_error";
// Session cache outcome values.
pub const SESSION_CACHE_HIT: &str = "hit";
pub const SESSION_CACHE_MISS: &str = "miss";
pub const SESSION_CACHE_STORE: &str = "store";

View file

@ -0,0 +1,377 @@
//! Prometheus metrics for brightstaff.
//!
//! Installs the `metrics` global recorder backed by
//! `metrics-exporter-prometheus` and exposes a `/metrics` HTTP endpoint on a
//! dedicated admin port (default `0.0.0.0:9092`, overridable via
//! `METRICS_BIND_ADDRESS`).
//!
//! Emitted metric families (see `describe_all` for full list):
//! - HTTP RED: `brightstaff_http_requests_total`,
//! `brightstaff_http_request_duration_seconds`,
//! `brightstaff_http_in_flight_requests`.
//! - LLM upstream: `brightstaff_llm_upstream_requests_total`,
//! `brightstaff_llm_upstream_duration_seconds`,
//! `brightstaff_llm_time_to_first_token_seconds`,
//! `brightstaff_llm_tokens_total`,
//! `brightstaff_llm_tokens_usage_missing_total`.
//! - Routing: `brightstaff_router_decisions_total`,
//! `brightstaff_router_decision_duration_seconds`,
//! `brightstaff_routing_service_requests_total`,
//! `brightstaff_session_cache_events_total`.
//! - Process: via `metrics-process`.
//! - Build: `brightstaff_build_info`.
use std::net::SocketAddr;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
use tracing::{info, warn};
pub mod labels;
/// Guard flag so tests don't re-install the global recorder.
static INIT: OnceLock<()> = OnceLock::new();
const DEFAULT_METRICS_BIND: &str = "0.0.0.0:9092";
/// HTTP request duration buckets (seconds). Capped at 60s.
const HTTP_BUCKETS: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0,
];
/// LLM upstream / TTFT buckets (seconds). Capped at 120s because provider
/// completions routinely run that long.
const LLM_BUCKETS: &[f64] = &[0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 120.0];
/// Router decision buckets (seconds). The orchestrator call itself is usually
/// sub-second but bucketed generously in case of upstream slowness.
const ROUTER_BUCKETS: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0,
];
/// Install the global recorder and spawn the `/metrics` HTTP listener.
///
/// Safe to call more than once; subsequent calls are no-ops so tests that
/// construct their own recorder still work.
pub fn init() {
if INIT.get().is_some() {
return;
}
let bind: SocketAddr = std::env::var("METRICS_BIND_ADDRESS")
.unwrap_or_else(|_| DEFAULT_METRICS_BIND.to_string())
.parse()
.unwrap_or_else(|err| {
warn!(error = %err, default = DEFAULT_METRICS_BIND, "invalid METRICS_BIND_ADDRESS, falling back to default");
DEFAULT_METRICS_BIND.parse().expect("default bind parses")
});
let builder = PrometheusBuilder::new()
.with_http_listener(bind)
.set_buckets_for_metric(
Matcher::Full("brightstaff_http_request_duration_seconds".to_string()),
HTTP_BUCKETS,
)
.and_then(|b| {
b.set_buckets_for_metric(Matcher::Prefix("brightstaff_llm_".to_string()), LLM_BUCKETS)
})
.and_then(|b| {
b.set_buckets_for_metric(
Matcher::Full("brightstaff_router_decision_duration_seconds".to_string()),
ROUTER_BUCKETS,
)
});
let builder = match builder {
Ok(b) => b,
Err(err) => {
warn!(error = %err, "failed to configure metrics buckets, using defaults");
PrometheusBuilder::new().with_http_listener(bind)
}
};
if let Err(err) = builder.install() {
warn!(error = %err, "failed to install Prometheus recorder; metrics disabled");
return;
}
let _ = INIT.set(());
describe_all();
emit_build_info();
// Register process-level collector (RSS, CPU, FDs).
let collector = metrics_process::Collector::default();
collector.describe();
// Prime once at startup; subsequent scrapes refresh via the exporter's
// per-scrape render, so we additionally refresh on a short interval to
// keep gauges moving between scrapes without requiring client pull.
collector.collect();
tokio::spawn(async move {
let mut tick = tokio::time::interval(Duration::from_secs(10));
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tick.tick().await;
collector.collect();
}
});
info!(address = %bind, "metrics listener started");
}
fn describe_all() {
describe_counter!(
"brightstaff_http_requests_total",
"Total HTTP requests served by brightstaff, by handler and status class."
);
describe_histogram!(
"brightstaff_http_request_duration_seconds",
"Wall-clock duration of HTTP requests served by brightstaff, by handler."
);
describe_gauge!(
"brightstaff_http_in_flight_requests",
"Number of HTTP requests currently being served by brightstaff, by handler."
);
describe_counter!(
"brightstaff_llm_upstream_requests_total",
"LLM upstream request outcomes, by provider, model, status class and error class."
);
describe_histogram!(
"brightstaff_llm_upstream_duration_seconds",
"Wall-clock duration of LLM upstream calls (stream close for streaming), by provider and model."
);
describe_histogram!(
"brightstaff_llm_time_to_first_token_seconds",
"Time from request start to first streamed byte, by provider and model (streaming only)."
);
describe_counter!(
"brightstaff_llm_tokens_total",
"Tokens reported in the provider `usage` field, by provider, model and kind (prompt/completion)."
);
describe_counter!(
"brightstaff_llm_tokens_usage_missing_total",
"LLM responses that completed without a usable `usage` block (so token counts are unknown)."
);
describe_counter!(
"brightstaff_router_decisions_total",
"Routing decisions made by the orchestrator, by route, selected model, and whether a fallback was used."
);
describe_histogram!(
"brightstaff_router_decision_duration_seconds",
"Time spent in the orchestrator deciding a route, by route."
);
describe_counter!(
"brightstaff_routing_service_requests_total",
"Outcomes of /routing/* decision requests: decision_served, no_candidates, policy_error."
);
describe_counter!(
"brightstaff_session_cache_events_total",
"Session affinity cache lookups and stores, by outcome."
);
describe_gauge!(
"brightstaff_build_info",
"Build metadata. Always 1; labels carry version and git SHA."
);
}
fn emit_build_info() {
let version = env!("CARGO_PKG_VERSION");
let git_sha = option_env!("GIT_SHA").unwrap_or("unknown");
gauge!(
"brightstaff_build_info",
"version" => version.to_string(),
"git_sha" => git_sha.to_string(),
)
.set(1.0);
}
/// Split a provider-qualified model id like `"openai/gpt-4o"` into
/// `(provider, model)`. Returns `("unknown", raw)` when there is no `/`.
pub fn split_provider_model(full: &str) -> (&str, &str) {
match full.split_once('/') {
Some((p, m)) => (p, m),
None => ("unknown", full),
}
}
/// Bucket an HTTP status code into `"2xx"` / `"4xx"` / `"5xx"` / `"1xx"` / `"3xx"`.
pub fn status_class(status: u16) -> &'static str {
match status {
100..=199 => "1xx",
200..=299 => "2xx",
300..=399 => "3xx",
400..=499 => "4xx",
500..=599 => "5xx",
_ => "other",
}
}
// ---------------------------------------------------------------------------
// HTTP RED helpers
// ---------------------------------------------------------------------------
/// RAII guard that increments the in-flight gauge on construction and
/// decrements on drop. Pair with [`HttpTimer`] in the `route()` wrapper so the
/// gauge drops even on error paths.
pub struct InFlightGuard {
handler: &'static str,
}
impl InFlightGuard {
pub fn new(handler: &'static str) -> Self {
gauge!(
"brightstaff_http_in_flight_requests",
"handler" => handler,
)
.increment(1.0);
Self { handler }
}
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
gauge!(
"brightstaff_http_in_flight_requests",
"handler" => self.handler,
)
.decrement(1.0);
}
}
/// Record the HTTP request counter + duration histogram.
pub fn record_http(handler: &'static str, method: &'static str, status: u16, started: Instant) {
let class = status_class(status);
counter!(
"brightstaff_http_requests_total",
"handler" => handler,
"method" => method,
"status_class" => class,
)
.increment(1);
histogram!(
"brightstaff_http_request_duration_seconds",
"handler" => handler,
)
.record(started.elapsed().as_secs_f64());
}
// ---------------------------------------------------------------------------
// LLM upstream helpers
// ---------------------------------------------------------------------------
/// Classify an outcome of an LLM upstream call for the `error_class` label.
pub fn llm_error_class_from_reqwest(err: &reqwest::Error) -> &'static str {
if err.is_timeout() {
"timeout"
} else if err.is_connect() {
"connect"
} else if err.is_decode() {
"parse"
} else {
"other"
}
}
/// Record the outcome of an LLM upstream call. `status` is the HTTP status
/// the upstream returned (0 if the call never produced one, e.g. send failure).
/// `error_class` is `"none"` on success, or a discriminated error label.
pub fn record_llm_upstream(
provider: &str,
model: &str,
status: u16,
error_class: &str,
duration: Duration,
) {
let class = if status == 0 {
"error"
} else {
status_class(status)
};
counter!(
"brightstaff_llm_upstream_requests_total",
"provider" => provider.to_string(),
"model" => model.to_string(),
"status_class" => class,
"error_class" => error_class.to_string(),
)
.increment(1);
histogram!(
"brightstaff_llm_upstream_duration_seconds",
"provider" => provider.to_string(),
"model" => model.to_string(),
)
.record(duration.as_secs_f64());
}
pub fn record_llm_ttft(provider: &str, model: &str, ttft: Duration) {
histogram!(
"brightstaff_llm_time_to_first_token_seconds",
"provider" => provider.to_string(),
"model" => model.to_string(),
)
.record(ttft.as_secs_f64());
}
pub fn record_llm_tokens(provider: &str, model: &str, kind: &'static str, count: u64) {
counter!(
"brightstaff_llm_tokens_total",
"provider" => provider.to_string(),
"model" => model.to_string(),
"kind" => kind,
)
.increment(count);
}
pub fn record_llm_tokens_usage_missing(provider: &str, model: &str) {
counter!(
"brightstaff_llm_tokens_usage_missing_total",
"provider" => provider.to_string(),
"model" => model.to_string(),
)
.increment(1);
}
// ---------------------------------------------------------------------------
// Router helpers
// ---------------------------------------------------------------------------
pub fn record_router_decision(
route: &'static str,
selected_model: &str,
fallback: bool,
duration: Duration,
) {
counter!(
"brightstaff_router_decisions_total",
"route" => route,
"selected_model" => selected_model.to_string(),
"fallback" => if fallback { "true" } else { "false" },
)
.increment(1);
histogram!(
"brightstaff_router_decision_duration_seconds",
"route" => route,
)
.record(duration.as_secs_f64());
}
pub fn record_routing_service_outcome(outcome: &'static str) {
counter!(
"brightstaff_routing_service_requests_total",
"outcome" => outcome,
)
.increment(1);
}
pub fn record_session_cache_event(outcome: &'static str) {
counter!(
"brightstaff_session_cache_events_total",
"outcome" => outcome,
)
.increment(1);
}

View file

@ -0,0 +1,173 @@
use hermesllm::apis::openai::ChatCompletionsResponse;
use hyper::header;
use serde::Deserialize;
use thiserror::Error;
use tracing::warn;
/// Max bytes of raw upstream body we include in a log message or error text
/// when the body is not a recognizable error envelope. Keeps logs from being
/// flooded by huge HTML error pages.
const RAW_BODY_LOG_LIMIT: usize = 512;
#[derive(Debug, Error)]
pub enum HttpError {
#[error("Failed to send request: {0}")]
Request(#[from] reqwest::Error),
#[error("Failed to parse JSON response: {0}")]
Json(serde_json::Error, String),
#[error("Upstream returned {status}: {message}")]
Upstream { status: u16, message: String },
}
/// Shape of an OpenAI-style error response body, e.g.
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
#[derive(Debug, Deserialize)]
struct UpstreamErrorEnvelope {
error: UpstreamErrorBody,
}
#[derive(Debug, Deserialize)]
struct UpstreamErrorBody {
message: String,
#[serde(default, rename = "type")]
err_type: Option<String>,
#[serde(default)]
param: Option<String>,
}
/// Extract a human-readable error message from an upstream response body.
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
/// body (UTF-8 safe).
fn extract_upstream_error_message(body: &str) -> String {
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
let mut msg = env.error.message;
if let Some(param) = env.error.param {
msg.push_str(&format!(" (param={param})"));
}
if let Some(err_type) = env.error.err_type {
msg.push_str(&format!(" [type={err_type}]"));
}
return msg;
}
truncate_for_log(body).to_string()
}
fn truncate_for_log(s: &str) -> &str {
if s.len() <= RAW_BODY_LOG_LIMIT {
return s;
}
let mut end = RAW_BODY_LOG_LIMIT;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
/// Sends a POST request to the given URL and extracts the text content
/// from the first choice of the `ChatCompletionsResponse`.
///
/// Returns `Some((content, elapsed))` on success, `None` if the response
/// had no choices or the first choice had no content. Returns
/// `HttpError::Upstream` for any non-2xx status, carrying a message
/// extracted from the OpenAI-style error envelope (or a truncated raw body
/// if the body is not in that shape).
pub async fn post_and_extract_content(
client: &reqwest::Client,
url: &str,
headers: header::HeaderMap,
body: String,
) -> Result<Option<(String, std::time::Duration)>, HttpError> {
let start_time = std::time::Instant::now();
let res = client.post(url).headers(headers).body(body).send().await?;
let status = res.status();
let body = res.text().await?;
let elapsed = start_time.elapsed();
if !status.is_success() {
let message = extract_upstream_error_message(&body);
warn!(
status = status.as_u16(),
message = %message,
body_size = body.len(),
"upstream returned error response"
);
return Err(HttpError::Upstream {
status: status.as_u16(),
message,
});
}
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
warn!(
error = %err,
body = %truncate_for_log(&body),
"failed to parse json response",
);
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
})?;
if response.choices.is_empty() {
warn!(body = %truncate_for_log(&body), "no choices in response");
return Ok(None);
}
Ok(response.choices[0]
.message
.content
.as_ref()
.map(|c| (c.clone(), elapsed)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_message_from_openai_style_error_envelope() {
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
let msg = extract_upstream_error_message(body);
assert!(
msg.starts_with("This model's maximum context length is 32768 tokens."),
"unexpected message: {msg}"
);
assert!(msg.contains("(param=input_tokens)"));
assert!(msg.contains("[type=BadRequestError]"));
}
#[test]
fn extracts_message_without_optional_fields() {
let body = r#"{"error":{"message":"something broke"}}"#;
let msg = extract_upstream_error_message(body);
assert_eq!(msg, "something broke");
}
#[test]
fn falls_back_to_raw_body_when_not_error_envelope() {
let body = "<html><body>502 Bad Gateway</body></html>";
let msg = extract_upstream_error_message(body);
assert_eq!(msg, body);
}
#[test]
fn truncates_non_envelope_bodies_in_logs() {
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
let msg = extract_upstream_error_message(&body);
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
}
#[test]
fn truncate_for_log_respects_utf8_boundaries() {
// 2-byte characters; picking a length that would split mid-char.
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
let out = truncate_for_log(&body);
// Should be a valid &str (implicit — would panic if we returned
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
assert!(out.chars().all(|c| c == 'é'));
}
}

View file

@ -1,187 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hyper::header;
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::router::router_model_v1::{self};
use super::router_model::RouterModel;
pub struct RouterService {
router_url: String,
client: reqwest::Client,
router_model: Arc<dyn RouterModel>,
#[allow(dead_code)]
routing_provider_name: String,
llm_usage_defined: bool,
}
#[derive(Debug, Error)]
pub enum RoutingError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}, JSON: {1}")]
JsonError(serde_json::Error, String),
#[error("Router model error: {0}")]
RouterModelError(#[from] super::router_model::RoutingModelError),
}
pub type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(
providers: Vec<LlmProvider>,
router_url: String,
routing_model_name: String,
routing_provider_name: String,
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.routing_preferences.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
.iter()
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
})
.collect();
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
llm_routes,
routing_model_name,
router_model_v1::MAX_TOKEN_LEN,
));
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(),
}
}
pub async fn determine_route(
&self,
messages: &[Message],
traceparent: &str,
usage_preferences: Option<Vec<ModelUsagePreference>>,
request_id: &str,
) -> Result<Option<(String, String)>> {
if messages.is_empty() {
return Ok(None);
}
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
&& !self.llm_usage_defined
{
return Ok(None);
}
let router_request = self
.router_model
.generate_request(messages, &usage_preferences);
debug!(
model = %self.router_model.get_model_name(),
endpoint = %self.router_url,
"sending request to arch-router"
);
debug!(
body = %serde_json::to_string(&router_request).unwrap(),
"arch router request"
);
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
llm_route_request_headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(),
);
llm_route_request_headers.insert(
header::HeaderName::from_static(TRACE_PARENT_HEADER),
header::HeaderValue::from_str(traceparent).unwrap(),
);
llm_route_request_headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(request_id).unwrap(),
);
llm_route_request_headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static("arch-router"),
);
let start_time = std::time::Instant::now();
let res = self
.client
.post(&self.router_url)
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let router_response_time = start_time.elapsed();
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
warn!(
error = %err,
body = %serde_json::to_string(&body).unwrap(),
"failed to parse json response"
);
return Err(RoutingError::JsonError(
err,
format!("Failed to parse JSON: {}", body),
));
}
};
if chat_completion_response.choices.is_empty() {
warn!(body = %body, "no choices in router response");
return Ok(None);
}
if let Some(content) = &chat_completion_response.choices[0].message.content {
let parsed_response = self
.router_model
.parse_response(content, &usage_preferences)?;
info!(
content = %content.replace("\n", "\\n"),
selected_model = ?parsed_response,
response_time_ms = router_response_time.as_millis(),
"arch-router determined route"
);
if let Some(ref parsed_response) = parsed_response {
return Ok(Some(parsed_response.clone()));
}
Ok(None)
} else {
Ok(None)
}
}
}

View file

@ -1,6 +1,7 @@
pub mod llm_router;
pub(crate) mod http;
pub mod model_metrics;
pub mod orchestrator;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;
pub mod plano_orchestrator;
pub mod router_model;
pub mod router_model_v1;
#[cfg(test)]
mod stress_tests;

View file

@ -0,0 +1,388 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{
CostProvider, LatencyProvider, MetricsSource, SelectionPolicy, SelectionPreference,
};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
pub struct ModelMetricsService {
cost: Arc<RwLock<HashMap<String, f64>>>,
latency: Arc<RwLock<HashMap<String, f64>>>,
}
impl ModelMetricsService {
pub async fn new(sources: &[MetricsSource], client: reqwest::Client) -> Self {
let cost_data = Arc::new(RwLock::new(HashMap::new()));
let latency_data = Arc::new(RwLock::new(HashMap::new()));
for source in sources {
match source {
MetricsSource::Cost(cfg) => match cfg.provider {
CostProvider::Digitalocean => {
let aliases = cfg.model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = cfg.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
},
MetricsSource::Latency(cfg) => match cfg.provider {
LatencyProvider::Prometheus => {
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
info!(models = data.len(), url = %cfg.url, "fetched latency metrics");
*latency_data.write().await = data;
if let Some(interval_secs) = cfg.refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = cfg.url.clone();
let query = cfg.query.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed latency metrics");
*latency_clone.write().await = data;
}
});
}
}
},
}
}
ModelMetricsService {
cost: cost_data,
latency: latency_data,
}
}
/// Rank `models` by `policy`, returning them in preference order.
/// Models with no metric data are appended at the end in their original order.
pub async fn rank_models(&self, models: &[String], policy: &SelectionPolicy) -> Vec<String> {
let cost_data = self.cost.read().await;
let latency_data = self.latency.read().await;
debug!(
input_models = ?models,
cost_data = ?cost_data.iter().collect::<Vec<_>>(),
latency_data = ?latency_data.iter().collect::<Vec<_>>(),
prefer = ?policy.prefer,
"rank_models called"
);
match policy.prefer {
SelectionPreference::Cheapest => {
for m in models {
if !cost_data.contains_key(m.as_str()) {
warn!(model = %m, "no cost data for model — ranking last (prefer: cheapest)");
}
}
rank_by_ascending_metric(models, &cost_data)
}
SelectionPreference::Fastest => {
for m in models {
if !latency_data.contains_key(m.as_str()) {
warn!(model = %m, "no latency data for model — ranking last (prefer: fastest)");
}
}
rank_by_ascending_metric(models, &latency_data)
}
SelectionPreference::None => models.to_vec(),
}
}
/// Returns a snapshot of the current cost data. Used at startup to warn about unmatched models.
pub async fn cost_snapshot(&self) -> HashMap<String, f64> {
self.cost.read().await.clone()
}
/// Returns a snapshot of the current latency data. Used at startup to warn about unmatched models.
pub async fn latency_snapshot(&self) -> HashMap<String, f64> {
self.latency.read().await.clone()
}
}
fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> Vec<String> {
let mut with_data: Vec<(&String, f64)> = models
.iter()
.filter_map(|m| {
let v = *data.get(m.as_str())?;
if v.is_nan() {
None
} else {
Some((m, v))
}
})
.collect();
with_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let without_data: Vec<&String> = models
.iter()
.filter(|m| data.get(m.as_str()).is_none_or(|v| v.is_nan()))
.collect();
with_data
.iter()
.map(|(m, _)| (*m).clone())
.chain(without_data.iter().map(|m| (*m).clone()))
.collect()
}
#[derive(serde::Deserialize)]
struct DoModelList {
data: Vec<DoModel>,
}
#[derive(serde::Deserialize)]
struct DoModel {
model_id: String,
pricing: Option<DoPricing>,
}
#[derive(serde::Deserialize)]
struct DoPricing {
input_price_per_million: Option<f64>,
output_price_per_million: Option<f64>,
}
async fn fetch_do_pricing(
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(DO_PRICING_URL).send().await {
Ok(resp) => match resp.json::<DoModelList>().await {
Ok(list) => list
.data
.into_iter()
.filter_map(|m| {
let pricing = m.pricing?;
let raw_key = m.model_id.clone();
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
let cost = pricing.input_price_per_million.unwrap_or(0.0)
+ pricing.output_price_per_million.unwrap_or(0.0);
Some((key, cost))
})
.collect(),
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to parse digitalocean pricing response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to fetch digitalocean pricing");
HashMap::new()
}
}
}
#[derive(serde::Deserialize)]
struct PrometheusResponse {
data: PrometheusData,
}
#[derive(serde::Deserialize)]
struct PrometheusData {
result: Vec<PrometheusResult>,
}
#[derive(serde::Deserialize)]
struct PrometheusResult {
metric: HashMap<String, String>,
value: (f64, String), // (timestamp, value_str)
}
async fn fetch_prometheus_metrics(
url: &str,
query: &str,
client: &reqwest::Client,
) -> HashMap<String, f64> {
let query_url = format!("{}/api/v1/query", url.trim_end_matches('/'));
match client
.get(&query_url)
.query(&[("query", query)])
.send()
.await
{
Ok(resp) => match resp.json::<PrometheusResponse>().await {
Ok(prom) => prom
.data
.result
.into_iter()
.filter_map(|r| {
let model_name = r.metric.get("model_name")?.clone();
let value: f64 = r.value.1.parse().ok()?;
Some((model_name, value))
})
.collect(),
Err(err) => {
warn!(error = %err, url = %query_url, "failed to parse prometheus response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %query_url, "failed to fetch prometheus metrics");
HashMap::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::SelectionPreference;
fn make_policy(prefer: SelectionPreference) -> SelectionPolicy {
SelectionPolicy { prefer }
}
#[test]
fn test_rank_by_ascending_metric_picks_lowest_first() {
let models = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut data = HashMap::new();
data.insert("a".to_string(), 0.01);
data.insert("b".to_string(), 0.005);
data.insert("c".to_string(), 0.02);
assert_eq!(
rank_by_ascending_metric(&models, &data),
vec!["b", "a", "c"]
);
}
#[test]
fn test_rank_by_ascending_metric_no_data_preserves_order() {
let models = vec!["x".to_string(), "y".to_string()];
let data = HashMap::new();
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["x", "y"]);
}
#[test]
fn test_rank_by_ascending_metric_partial_data() {
let models = vec!["a".to_string(), "b".to_string()];
let mut data = HashMap::new();
data.insert("b".to_string(), 100.0);
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["b", "a"]);
}
#[tokio::test]
async fn test_rank_models_cheapest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m.insert("gpt-4o-mini".to_string(), 0.0001);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["gpt-4o-mini", "gpt-4o"]);
}
#[tokio::test]
async fn test_rank_models_fastest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 200.0);
m.insert("claude-sonnet".to_string(), 120.0);
m
})),
};
let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Fastest))
.await;
assert_eq!(result, vec!["claude-sonnet", "gpt-4o"]);
}
#[tokio::test]
async fn test_rank_models_fallback_no_metrics() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["model-a".to_string(), "model-b".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["model-a", "model-b"]);
}
#[tokio::test]
async fn test_rank_models_partial_data_appended_last() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o-mini".to_string(), "gpt-4o".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[tokio::test]
async fn test_rank_models_none_preserves_order() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o-mini".to_string(), 0.0001);
m.insert("gpt-4o".to_string(), 0.005);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::None))
.await;
// none → original order, despite gpt-4o-mini being cheaper
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[test]
fn test_rank_by_ascending_metric_nan_treated_as_missing() {
let models = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let mut data = HashMap::new();
data.insert("a".to_string(), f64::NAN);
data.insert("b".to_string(), 0.5);
data.insert("c".to_string(), 0.1);
// "d" has no entry at all
let result = rank_by_ascending_metric(&models, &data);
// c (0.1) < b (0.5), then NaN "a" and missing "d" appended in original order
assert_eq!(result, vec!["c", "b", "a", "d"]);
}
}

View file

@ -0,0 +1,433 @@
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
use common::{
configuration::{AgentUsagePreference, OrchestrationPreference, TopLevelRoutingPreference},
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
};
use hermesllm::apis::openai::Message;
use hyper::header;
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use thiserror::Error;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
use super::model_metrics::ModelMetricsService;
use super::orchestrator_model::OrchestratorModel;
use crate::metrics as bs_metrics;
use crate::metrics::labels as metric_labels;
use crate::router::orchestrator_model_v1;
use crate::session_cache::SessionCache;
pub use crate::session_cache::CachedRoute;
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
pub struct OrchestratorService {
orchestrator_url: String,
client: reqwest::Client,
orchestrator_model: Arc<dyn OrchestratorModel>,
orchestrator_provider_name: String,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
session_cache: Option<Arc<dyn SessionCache>>,
session_ttl: Duration,
tenant_header: Option<String>,
}
#[derive(Debug, Error)]
pub enum OrchestrationError {
#[error(transparent)]
Http(#[from] http::HttpError),
#[error("Orchestrator model error: {0}")]
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
}
pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(
orchestrator_url: String,
orchestration_model_name: String,
orchestrator_provider_name: String,
max_token_length: usize,
) -> Self {
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
HashMap::new(),
orchestration_model_name,
max_token_length,
));
OrchestratorService {
orchestrator_url,
client: reqwest::Client::new(),
orchestrator_model,
orchestrator_provider_name,
top_level_preferences: HashMap::new(),
metrics_service: None,
session_cache: None,
session_ttl: Duration::from_secs(DEFAULT_SESSION_TTL_SECONDS),
tenant_header: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn with_routing(
orchestrator_url: String,
orchestration_model_name: String,
orchestrator_provider_name: String,
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
metrics_service: Option<Arc<ModelMetricsService>>,
session_ttl_seconds: Option<u64>,
session_cache: Arc<dyn SessionCache>,
tenant_header: Option<String>,
max_token_length: usize,
) -> Self {
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
.map_or_else(HashMap::new, |prefs| {
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
});
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
HashMap::new(),
orchestration_model_name,
max_token_length,
));
let session_ttl =
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
OrchestratorService {
orchestrator_url,
client: reqwest::Client::new(),
orchestrator_model,
orchestrator_provider_name,
top_level_preferences,
metrics_service,
session_cache: Some(session_cache),
session_ttl,
tenant_header,
}
}
// ---- Session cache methods ----
#[must_use]
pub fn tenant_header(&self) -> Option<&str> {
self.tenant_header.as_deref()
}
fn session_key<'a>(tenant_id: Option<&str>, session_id: &'a str) -> Cow<'a, str> {
match tenant_id {
Some(t) => Cow::Owned(format!("{t}:{session_id}")),
None => Cow::Borrowed(session_id),
}
}
pub async fn get_cached_route(
&self,
session_id: &str,
tenant_id: Option<&str>,
) -> Option<CachedRoute> {
let cache = self.session_cache.as_ref()?;
let result = cache.get(&Self::session_key(tenant_id, session_id)).await;
bs_metrics::record_session_cache_event(if result.is_some() {
metric_labels::SESSION_CACHE_HIT
} else {
metric_labels::SESSION_CACHE_MISS
});
result
}
pub async fn cache_route(
&self,
session_id: String,
tenant_id: Option<&str>,
model_name: String,
route_name: Option<String>,
) {
if let Some(ref cache) = self.session_cache {
cache
.put(
&Self::session_key(tenant_id, &session_id),
CachedRoute {
model_name,
route_name,
},
self.session_ttl,
)
.await;
bs_metrics::record_session_cache_event(metric_labels::SESSION_CACHE_STORE);
}
}
// ---- LLM routing ----
pub async fn determine_route(
&self,
messages: &[Message],
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
request_id: &str,
) -> Result<Option<(String, Vec<String>)>> {
if messages.is_empty() {
return Ok(None);
}
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
inline_routing_preferences
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
return Ok(None);
}
let effective_source = inline_top_map
.as_ref()
.unwrap_or(&self.top_level_preferences);
let effective_prefs: Vec<AgentUsagePreference> = effective_source
.values()
.map(|p| AgentUsagePreference {
model: p.models.first().cloned().unwrap_or_default(),
orchestration_preferences: vec![OrchestrationPreference {
name: p.name.clone(),
description: p.description.clone(),
}],
})
.collect();
let orchestration_result = self
.determine_orchestration(
messages,
Some(effective_prefs),
Some(request_id.to_string()),
)
.await?;
let result = if let Some(ref routes) = orchestration_result {
if routes.len() > 1 {
let all_routes: Vec<&str> = routes.iter().map(|(name, _)| name.as_str()).collect();
info!(
routes = ?all_routes,
using = %all_routes.first().unwrap_or(&"none"),
"plano-orchestrator detected multiple intents, using first"
);
}
if let Some((route_name, _)) = routes.first() {
let top_pref = inline_top_map
.as_ref()
.and_then(|m| m.get(route_name))
.or_else(|| self.top_level_preferences.get(route_name));
if let Some(pref) = top_pref {
let ranked = match &self.metrics_service {
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
None => pref.models.clone(),
};
Some((route_name.clone(), ranked))
} else {
None
}
} else {
None
}
} else {
None
};
info!(
selected_model = ?result,
"plano-orchestrator determined route"
);
Ok(result)
}
// ---- Agent orchestration (existing) ----
pub async fn determine_orchestration(
&self,
messages: &[Message],
usage_preferences: Option<Vec<AgentUsagePreference>>,
request_id: Option<String>,
) -> Result<Option<Vec<(String, String)>>> {
if messages.is_empty() {
return Ok(None);
}
if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.is_empty())
{
return Ok(None);
}
let orchestrator_request = self
.orchestrator_model
.generate_request(messages, &usage_preferences);
debug!(
model = %self.orchestrator_model.get_model_name(),
endpoint = %self.orchestrator_url,
"sending request to plano-orchestrator"
);
let body = serde_json::to_string(&orchestrator_request)
.map_err(super::orchestrator_model::OrchestratorModelError::from)?;
debug!(body = %body, "plano-orchestrator request");
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.orchestrator_provider_name)
.unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")),
);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
});
if let Some(ref request_id) = request_id {
if let Ok(val) = header::HeaderValue::from_str(request_id) {
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
}
}
headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static("plano-orchestrator"),
);
let Some((content, elapsed)) =
post_and_extract_content(&self.client, &self.orchestrator_url, headers, body).await?
else {
return Ok(None);
};
let parsed = self
.orchestrator_model
.parse_response(&content, &usage_preferences)?;
info!(
content = %content.replace("\n", "\\n"),
selected_routes = ?parsed,
response_time_ms = elapsed.as_millis(),
"plano-orchestrator determined routes"
);
Ok(parsed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session_cache::memory::MemorySessionCache;
fn make_orchestrator_service(ttl_seconds: u64, max_entries: usize) -> OrchestratorService {
let session_cache = Arc::new(MemorySessionCache::new(max_entries));
OrchestratorService::with_routing(
"http://localhost:12001/v1/chat/completions".to_string(),
"Plano-Orchestrator".to_string(),
"plano-orchestrator".to_string(),
None,
None,
Some(ttl_seconds),
session_cache,
None,
orchestrator_model_v1::MAX_TOKEN_LEN,
)
}
#[tokio::test]
async fn test_cache_miss_returns_none() {
let svc = make_orchestrator_service(600, 100);
assert!(svc
.get_cached_route("unknown-session", None)
.await
.is_none());
}
#[tokio::test]
async fn test_cache_hit_returns_cached_route() {
let svc = make_orchestrator_service(600, 100);
svc.cache_route(
"s1".to_string(),
None,
"gpt-4o".to_string(),
Some("code".to_string()),
)
.await;
let cached = svc.get_cached_route("s1", None).await.unwrap();
assert_eq!(cached.model_name, "gpt-4o");
assert_eq!(cached.route_name, Some("code".to_string()));
}
#[tokio::test]
async fn test_cache_expired_entry_returns_none() {
let svc = make_orchestrator_service(0, 100);
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
.await;
assert!(svc.get_cached_route("s1", None).await.is_none());
}
#[tokio::test]
async fn test_expired_entries_not_returned() {
let svc = make_orchestrator_service(0, 100);
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
.await;
svc.cache_route("s2".to_string(), None, "claude".to_string(), None)
.await;
assert!(svc.get_cached_route("s1", None).await.is_none());
assert!(svc.get_cached_route("s2", None).await.is_none());
}
#[tokio::test]
async fn test_cache_evicts_oldest_when_full() {
let svc = make_orchestrator_service(600, 2);
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
.await;
svc.cache_route("s3".to_string(), None, "model-c".to_string(), None)
.await;
assert!(svc.get_cached_route("s1", None).await.is_none());
assert!(svc.get_cached_route("s2", None).await.is_some());
assert!(svc.get_cached_route("s3", None).await.is_some());
}
#[tokio::test]
async fn test_cache_update_existing_session_does_not_evict() {
let svc = make_orchestrator_service(600, 2);
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
.await;
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
.await;
svc.cache_route(
"s1".to_string(),
None,
"model-a-updated".to_string(),
Some("route".to_string()),
)
.await;
let s1 = svc.get_cached_route("s1", None).await.unwrap();
assert_eq!(s1.model_name, "model-a-updated");
assert!(svc.get_cached_route("s2", None).await.is_some());
}
}

View file

@ -11,8 +11,7 @@ pub enum OrchestratorModelError {
pub type Result<T> = std::result::Result<T, OrchestratorModelError>;
/// OrchestratorModel trait for handling orchestration requests.
/// Unlike RouterModel which returns a single route, OrchestratorModel
/// can return multiple routes as the model output format is:
/// Returns multiple routes as the model output format is:
/// {"route": ["route_name_1", "route_name_2", ...]}
pub trait OrchestratorModel: Send + Sync {
fn generate_request(

View file

@ -8,7 +8,19 @@ use tracing::{debug, warn};
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the orchestration model
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
/// Hard cap on the number of recent messages considered when building the
/// routing prompt. Bounds prompt growth for long-running conversations and
/// acts as an outer guardrail before the token-budget loop runs. The most
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
/// dropped entirely.
pub const MAX_ROUTING_TURNS: usize = 16;
/// Unicode ellipsis used to mark where content was trimmed out of a long
/// message. Helps signal to the downstream router model that the message was
/// truncated.
const TRIM_MARKER: &str = "";
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
struct SpacedJsonFormatter;
@ -176,47 +188,82 @@ impl OrchestratorModel for OrchestratorModelV1 {
messages: &[Message],
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
// when role == tool its tool call response
let messages_vec = messages
// Remove system/developer/tool messages and messages without extractable
// text (tool calls have no text content we can classify against).
let filtered: Vec<&Message> = messages
.iter()
.filter(|m| {
m.role != Role::System
&& m.role != Role::Developer
&& m.role != Role::Tool
&& !m.content.extract_text().is_empty()
})
.collect::<Vec<&Message>>();
.collect();
// Following code is to ensure that the conversation does not exceed max token length
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
// Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
// messages when building the routing prompt. Keeps prompt growth
// predictable for long conversations regardless of per-message size.
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
let messages_vec: &[&Message] = &filtered[start..];
// Ensure the conversation does not exceed the configured token budget.
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
// avoid running a real tokenizer on the hot path.
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
let mut selected_messages_list_reversed: Vec<Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
token_count += message_token_count;
if token_count > self.max_token_length {
let message_text = message.content.extract_text();
let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
if token_count + message_token_count > self.max_token_length {
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
debug!(
token_count = token_count,
attempted_total_tokens = token_count + message_token_count,
max_tokens = self.max_token_length,
remaining_tokens,
selected = selected_messsage_count,
total = messages_vec.len(),
"token count exceeds max, truncating conversation"
);
if message.role == Role::User {
// If message that exceeds max token length is from user, we need to keep it
selected_messages_list_reversed.push(message);
// If the overflow message is from the user we need to keep
// some of it so the orchestrator still sees the latest user
// intent. Use a middle-trim (head + ellipsis + tail): users
// often frame the task at the start AND put the actual ask
// at the end of a long pasted block, so preserving both is
// better than a head-only cut. The ellipsis also signals to
// the router model that content was dropped.
if message.role == Role::User && remaining_tokens > 0 {
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
let truncated = trim_middle_utf8(&message_text, max_bytes);
selected_messages_list_reversed.push(Message {
role: Role::User,
content: Some(MessageContent::Text(truncated)),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
break;
}
// If we are here, it means that the message is within the max token length
selected_messages_list_reversed.push(message);
token_count += message_token_count;
selected_messages_list_reversed.push(Message {
role: message.role.clone(),
content: Some(MessageContent::Text(message_text)),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
if selected_messages_list_reversed.is_empty() {
debug!("no messages selected, using last message");
if let Some(last_message) = messages_vec.last() {
selected_messages_list_reversed.push(last_message);
selected_messages_list_reversed.push(Message {
role: last_message.role.clone(),
content: Some(MessageContent::Text(last_message.content.extract_text())),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
@ -236,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
}
// Reverse the selected messages to maintain the conversation order
let selected_conversation_list = selected_messages_list_reversed
.iter()
.rev()
.map(|message| Message {
role: message.role.clone(),
content: Some(MessageContent::Text(
message
.content
.as_ref()
.map_or(String::new(), |c| c.to_string()),
)),
name: None,
tool_calls: None,
tool_call_id: None,
})
.collect::<Vec<Message>>();
let selected_conversation_list: Vec<Message> =
selected_messages_list_reversed.into_iter().rev().collect();
// Generate the orchestrator request message based on the usage preferences.
// If preferences are passed in request then we use them;
@ -404,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
body.replace("'", "\"").replace("\\n", "")
}
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
/// separating the two. All splits respect UTF-8 character boundaries. When
/// `max_bytes` is too small to fit the marker at all, falls back to a
/// head-only truncation.
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes {
return s.to_string();
}
if max_bytes <= TRIM_MARKER.len() {
// Not enough room even for the marker — just keep the start.
let mut end = max_bytes;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
return s[..end].to_string();
}
let available = max_bytes - TRIM_MARKER.len();
// Bias toward the start (60%) where task framing typically lives, while
// still preserving ~40% of the tail where the user's actual ask often
// appears after a long paste.
let mut start_len = available * 3 / 5;
while start_len > 0 && !s.is_char_boundary(start_len) {
start_len -= 1;
}
let end_len = available - start_len;
let mut end_start = s.len().saturating_sub(end_len);
while end_start < s.len() && !s.is_char_boundary(end_start) {
end_start += 1;
}
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
out.push_str(&s[..start_len]);
out.push_str(TRIM_MARKER);
out.push_str(&s[end_start..]);
out
}
impl std::fmt::Debug for dyn OrchestratorModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OrchestratorModel")
@ -776,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
#[test]
fn test_conversation_trim_upto_user_message() {
// With max_token_length=230, the older user message "given the image
// In style of Andy Warhol" overflows the remaining budget and gets
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
// are kept in full.
let expected_prompt = r#"
You are a helpful assistant that selects the most suitable routes based on user intent.
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
@ -788,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
[
{
"role": "user",
"content": "given the image In style of Andy Warhol"
"content": "givenrhol"
},
{
"role": "assistant",
@ -861,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_huge_single_user_message_is_middle_trimmed() {
// Regression test for the case where a single, extremely large user
// message was being passed to the orchestrator verbatim and blowing
// past the upstream model's context window. The trimmer must now
// middle-trim (head + ellipsis + tail) the oversized message so the
// resulting request stays within the configured budget, and the
// trim marker must be present so the router model knows content
// was dropped.
let orchestrations_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let max_token_length = 2048;
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
"test-model".to_string(),
max_token_length,
);
// ~500KB of content — same scale as the real payload that triggered
// the production upstream 400.
let head = "HEAD_MARKER_START ";
let tail = " TAIL_MARKER_END";
let filler = "A".repeat(500_000);
let huge_user_content = format!("{head}{filler}{tail}");
let conversation = vec![Message {
role: Role::User,
content: Some(MessageContent::Text(huge_user_content.clone())),
name: None,
tool_calls: None,
tool_call_id: None,
}];
let req = orchestrator.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
// scaffolding + slack. Real result should be well under this.
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
+ 1024;
assert!(
prompt.len() < byte_ceiling,
"prompt length {} exceeded ceiling {} — truncation did not apply",
prompt.len(),
byte_ceiling,
);
// Not all 500k filler chars survive.
let a_count = prompt.chars().filter(|c| *c == 'A').count();
assert!(
a_count < filler.len(),
"expected user message to be truncated; all {} 'A's survived",
a_count
);
assert!(
a_count > 0,
"expected some of the user message to survive truncation"
);
// Head and tail of the message must both be preserved (that's the
// whole point of middle-trim over head-only).
assert!(
prompt.contains(head),
"head marker missing — head was not preserved"
);
assert!(
prompt.contains(tail),
"tail marker missing — tail was not preserved"
);
// Trim marker must be present so the router model can see that
// content was omitted.
assert!(
prompt.contains(TRIM_MARKER),
"ellipsis trim marker missing from truncated prompt"
);
// Routing prompt scaffolding remains intact.
assert!(prompt.contains("<conversation>"));
assert!(prompt.contains("<routes>"));
}
#[test]
fn test_turn_cap_limits_routing_history() {
// The outer turn-cap guardrail should keep only the last
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
// conversation is. We build a conversation with alternating
// user/assistant turns tagged with their index and verify that only
// the tail of the conversation makes it into the prompt.
let orchestrations_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
let mut conversation: Vec<Message> = Vec::new();
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
for i in 0..total_turns {
let role = if i % 2 == 0 {
Role::User
} else {
Role::Assistant
};
conversation.push(Message {
role,
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
let req = orchestrator.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
// must all appear.
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
let tag = format!("turn-{i:03}");
assert!(
prompt.contains(&tag),
"expected recent turn tag {tag} to be present"
);
}
// And earlier turns (indexes 0..total-cap) must all be dropped.
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
let tag = format!("turn-{i:03}");
assert!(
!prompt.contains(&tag),
"old turn tag {tag} leaked past turn cap into the prompt"
);
}
}
#[test]
fn test_trim_middle_utf8_helper() {
// No-op when already small enough.
assert_eq!(trim_middle_utf8("hello", 100), "hello");
assert_eq!(trim_middle_utf8("hello", 5), "hello");
// 60/40 split with ellipsis when too long.
let long = "a".repeat(20);
let out = trim_middle_utf8(&long, 10);
assert!(out.len() <= 10);
assert!(out.contains(TRIM_MARKER));
// Exactly one ellipsis, rest are 'a's.
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
// When max_bytes is smaller than the marker, falls back to
// head-only truncation (no marker).
let out = trim_middle_utf8("abcdefgh", 2);
assert_eq!(out, "ab");
// UTF-8 boundary safety: 2-byte chars.
let s = "é".repeat(50); // 100 bytes
let out = trim_middle_utf8(&s, 25);
assert!(out.len() <= 25);
// Must still be valid UTF-8 that only contains 'é' and the marker.
let ok = out.chars().all(|c| c == 'é' || c == '…');
assert!(ok, "unexpected char in trimmed output: {out:?}");
}
#[test]
fn test_non_text_input() {
let expected_prompt = r#"

View file

@ -1,168 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{AgentUsagePreference, OrchestrationPreference},
consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME, REQUEST_ID_HEADER},
};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hyper::header;
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::router::orchestrator_model_v1::{self};
use super::orchestrator_model::OrchestratorModel;
pub struct OrchestratorService {
orchestrator_url: String,
client: reqwest::Client,
orchestrator_model: Arc<dyn OrchestratorModel>,
}
#[derive(Debug, Error)]
pub enum OrchestrationError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}, JSON: {1}")]
JsonError(serde_json::Error, String),
#[error("Orchestrator model error: {0}")]
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
}
pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model_name.clone(),
orchestrator_model_v1::MAX_TOKEN_LEN,
));
OrchestratorService {
orchestrator_url,
client: reqwest::Client::new(),
orchestrator_model,
}
}
pub async fn determine_orchestration(
&self,
messages: &[Message],
usage_preferences: Option<Vec<AgentUsagePreference>>,
request_id: Option<String>,
) -> Result<Option<Vec<(String, String)>>> {
if messages.is_empty() {
return Ok(None);
}
// Require usage_preferences to be provided
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
return Ok(None);
}
let orchestrator_request = self
.orchestrator_model
.generate_request(messages, &usage_preferences);
debug!(
model = %self.orchestrator_model.get_model_name(),
endpoint = %self.orchestrator_url,
"sending request to arch-orchestrator"
);
debug!(
body = %serde_json::to_string(&orchestrator_request).unwrap(),
"arch orchestrator request"
);
let mut orchestration_request_headers = header::HeaderMap::new();
orchestration_request_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
orchestration_request_headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
);
// Inject OpenTelemetry trace context from current span
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut orchestration_request_headers));
});
if let Some(request_id) = request_id {
orchestration_request_headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(&request_id).unwrap(),
);
}
orchestration_request_headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
);
let start_time = std::time::Instant::now();
let res = self
.client
.post(&self.orchestrator_url)
.headers(orchestration_request_headers)
.body(serde_json::to_string(&orchestrator_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let orchestrator_response_time = start_time.elapsed();
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
warn!(
error = %err,
body = %serde_json::to_string(&body).unwrap(),
"failed to parse json response"
);
return Err(OrchestrationError::JsonError(
err,
format!("Failed to parse JSON: {}", body),
));
}
};
if chat_completion_response.choices.is_empty() {
warn!(body = %body, "no choices in orchestrator response");
return Ok(None);
}
if let Some(content) = &chat_completion_response.choices[0].message.content {
let parsed_response = self
.orchestrator_model
.parse_response(content, &usage_preferences)?;
info!(
content = %content.replace("\n", "\\n"),
selected_routes = ?parsed_response,
response_time_ms = orchestrator_response_time.as_millis(),
"arch-orchestrator determined routes"
);
if let Some(ref parsed_response) = parsed_response {
return Ok(Some(parsed_response.clone()));
}
Ok(None)
} else {
Ok(None)
}
}
}

View file

@ -1,25 +0,0 @@
use common::configuration::ModelUsagePreference;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RoutingModelError {
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,841 +0,0 @@
use std::collections::HashMap;
use common::configuration::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hermesllm::transforms::lib::ExtractText;
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use super::router_model::{RouterModel, RoutingModelError};
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the routing model
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
{routes}
</routes>
<conversation>
{conversation}
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub struct RouterModelV1 {
llm_route_json_str: String,
llm_route_to_model_map: HashMap<String, String>,
routing_model: String,
max_token_length: usize,
}
impl RouterModelV1 {
pub fn new(
llm_routes: HashMap<String, Vec<RoutingPreference>>,
routing_model: String,
max_token_length: usize,
) -> Self {
let llm_route_values: Vec<RoutingPreference> =
llm_routes.values().flatten().cloned().collect();
let llm_route_json_str =
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
let llm_route_to_model_map: HashMap<String, String> = llm_routes
.iter()
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
.collect();
RouterModelV1 {
routing_model,
max_token_length,
llm_route_json_str,
llm_route_to_model_map,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmRouterResponse {
pub route: Option<String>,
}
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
impl RouterModel for RouterModelV1 {
fn generate_request(
&self,
messages: &[Message],
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
// when role == tool its tool call response
let messages_vec = messages
.iter()
.filter(|m| {
m.role != Role::System
&& m.role != Role::Tool
&& !m.content.extract_text().is_empty()
})
.collect::<Vec<&Message>>();
// Following code is to ensure that the conversation does not exceed max token length
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
token_count += message_token_count;
if token_count > self.max_token_length {
debug!(
token_count = token_count,
max_tokens = self.max_token_length,
selected = selected_messsage_count,
total = messages_vec.len(),
"token count exceeds max, truncating conversation"
);
if message.role == Role::User {
// If message that exceeds max token length is from user, we need to keep it
selected_messages_list_reversed.push(message);
}
break;
}
// If we are here, it means that the message is within the max token length
selected_messages_list_reversed.push(message);
}
if selected_messages_list_reversed.is_empty() {
debug!("no messages selected, using last message");
if let Some(last_message) = messages_vec.last() {
selected_messages_list_reversed.push(last_message);
}
}
// ensure that first and last selected message is from user
if let Some(first_message) = selected_messages_list_reversed.first() {
if first_message.role != Role::User {
warn!("last message is not from user, may lead to incorrect routing");
}
}
if let Some(last_message) = selected_messages_list_reversed.last() {
if last_message.role != Role::User {
warn!("first message is not from user, may lead to incorrect routing");
}
}
// Reverse the selected messages to maintain the conversation order
let selected_conversation_list = selected_messages_list_reversed
.iter()
.rev()
.map(|message| {
Message {
role: message.role.clone(),
// we can unwrap here because we have already filtered out messages without content
content: Some(MessageContent::Text(
message
.content
.as_ref()
.map_or(String::new(), |c| c.to_string()),
)),
name: None,
tool_calls: None,
tool_call_id: None,
}
})
.collect::<Vec<Message>>();
// Generate the router request message based on the usage preferences.
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
};
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(MessageContent::Text(router_message)),
role: Role::User,
name: None,
tool_calls: None,
tool_call_id: None,
}],
temperature: Some(0.01),
..Default::default()
}
}
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>> {
if content.is_empty() {
return Ok(None);
}
let router_resp_fixed = fix_json_response(content);
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
let selected_route = router_response.route.unwrap_or_default().to_string();
if selected_route.is_empty() || selected_route == "other" {
return Ok(None);
}
if let Some(usage_preferences) = usage_preferences {
// If usage preferences are defined, we need to find the model that matches the selected route
let model_name: Option<String> = usage_preferences
.iter()
.map(|pref| {
pref.routing_preferences
.iter()
.find(|routing_pref| routing_pref.name == selected_route)
.map(|_| pref.model.clone())
})
.find_map(|model| model);
if let Some(model_name) = model_name {
return Ok(Some((selected_route, model_name)));
} else {
warn!(
route = %selected_route,
preferences = ?usage_preferences,
"no matching model found for route"
);
return Ok(None);
}
}
// If no usage preferences are passed in request then use the default routing model preferences
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
return Ok(Some((selected_route, model)));
}
warn!(
route = %selected_route,
preferences = ?self.llm_route_to_model_map,
"no model found for route"
);
Ok(None)
}
fn get_model_name(&self) -> String {
self.routing_model.clone()
}
}
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", prefs)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
)
}
fn convert_to_router_preferences(
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
) -> Option<String> {
if let Some(usage_preferences) = prefs_from_request {
let routing_preferences = usage_preferences
.iter()
.flat_map(|pref| {
pref.routing_preferences
.iter()
.map(|routing_pref| RoutingPreference {
name: routing_pref.name.clone(),
description: routing_pref.description.clone(),
})
})
.collect::<Vec<RoutingPreference>>();
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
}
None
}
fn fix_json_response(body: &str) -> String {
let mut updated_body = body.to_string();
updated_body = updated_body.replace("'", "\"");
if updated_body.contains("\\n") {
updated_body = updated_body.replace("\\n", "");
}
if updated_body.starts_with("```json") {
updated_body = updated_body
.strip_prefix("```json")
.unwrap_or(&updated_body)
.to_string();
}
if updated_body.ends_with("```") {
updated_body = updated_body
.strip_suffix("```")
.unwrap_or(&updated_body)
.to_string();
}
updated_body
}
impl std::fmt::Debug for dyn RouterModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RouterModel")
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_system_prompt_format() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_system_prompt_format_usage_preferences() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
model: "claude/claude-3-7-sonnet".to_string(),
routing_preferences: vec![RoutingPreference {
name: "code-generation".to_string(),
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
}],
}]);
let req = router.generate_request(&conversation, &usage_preferences);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_conversation_exceed_token_count() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, 235);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_conversation_exceed_token_count_large_single_message() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, 200);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_conversation_trim_upto_user_message() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
[{"role":"user","content":"given the image In style of Andy Warhol"},{"role":"assistant","content":"ok here is the image"},{"role":"user","content":"pls give me another image about Bart and Lisa"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, 230);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol"
},
{
"role": "assistant",
"content": "ok here is the image"
},
{
"role": "user",
"content": "pls give me another image about Bart and Lisa"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_non_text_input() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "hi"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.png"
}
}
]
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_skip_tool_call() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "What's the weather like in Tokyo?"
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "toolcall-abc123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{ \"location\": \"Tokyo\" }"
}
}
]
},
{
"role": "tool",
"tool_call_id": "toolcall-abc123",
"content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }"
},
{
"role": "assistant",
"content": "The current weather in Tokyo is 22°C and sunny."
},
{
"role": "user",
"content": "What about in New York?"
}
]
"#;
// expects conversation to look like this
// [
// {
// "role": "user",
// "content": "What's the weather like in Tokyo?"
// },
// {
// "role": "assistant",
// "content": "The current weather in Tokyo is 22°C and sunny."
// },
// {
// "role": "user",
// "content": "What about in New York?"
// }
// ]
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
assert_eq!(expected_prompt, prompt);
}
#[test]
fn test_parse_response() {
let routes_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "Image generation"}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 3: Valid JSON with null route
let input = r#"{"route": null}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 4: JSON missing route field
let input = r#"{}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 4.1: empty string
let input = r#""#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 5: Malformed JSON
let input = r#"{"route": "route1""#; // missing closing }
let result = router.parse_response(input, &None);
assert!(result.is_err());
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'Image generation'}\\n";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 7: Code block marker
let input = "```json\n{\"route\": \"Image generation\"}\n```";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
}
}

View file

@ -0,0 +1,260 @@
#[cfg(test)]
mod tests {
use crate::router::orchestrator::OrchestratorService;
use crate::session_cache::memory::MemorySessionCache;
use common::configuration::{SelectionPolicy, SelectionPreference, TopLevelRoutingPreference};
use hermesllm::apis::openai::{Message, MessageContent, Role};
use std::sync::Arc;
fn make_messages(n: usize) -> Vec<Message> {
(0..n)
.map(|i| Message {
role: if i % 2 == 0 {
Role::User
} else {
Role::Assistant
},
content: Some(MessageContent::Text(format!(
"This is message number {i} with some padding text to make it realistic."
))),
name: None,
tool_calls: None,
tool_call_id: None,
})
.collect()
}
fn make_routing_prefs() -> Vec<TopLevelRoutingPreference> {
vec![
TopLevelRoutingPreference {
name: "code_generation".to_string(),
description: "Code generation and debugging tasks".to_string(),
models: vec![
"openai/gpt-4o".to_string(),
"openai/gpt-4o-mini".to_string(),
],
selection_policy: SelectionPolicy {
prefer: SelectionPreference::None,
},
},
TopLevelRoutingPreference {
name: "summarization".to_string(),
description: "Summarizing documents and text".to_string(),
models: vec![
"anthropic/claude-3-sonnet".to_string(),
"openai/gpt-4o-mini".to_string(),
],
selection_policy: SelectionPolicy {
prefer: SelectionPreference::None,
},
},
]
}
/// Stress test: exercise the full routing code path N times using a mock
/// HTTP server and measure jemalloc allocated bytes before/after.
///
/// This catches:
/// - Memory leaks in generate_request / parse_response
/// - Leaks in reqwest connection handling
/// - String accumulation in the orchestrator model
/// - Fragmentation (jemalloc allocated vs resident)
#[tokio::test]
async fn stress_test_routing_determine_route() {
let mut server = mockito::Server::new_async().await;
let router_url = format!("{}/v1/chat/completions", server.url());
let mock_response = serde_json::json!({
"id": "chatcmpl-mock",
"object": "chat.completion",
"created": 1234567890,
"model": "plano-orchestrator",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"route\": \"code_generation\"}"
},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
});
let _mock = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(mock_response.to_string())
.expect_at_least(1)
.create_async()
.await;
let prefs = make_routing_prefs();
let session_cache = Arc::new(MemorySessionCache::new(1000));
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
router_url,
"Plano-Orchestrator".to_string(),
"plano-orchestrator".to_string(),
Some(prefs.clone()),
None,
None,
session_cache,
None,
2048,
));
// Warm up: a few requests to stabilize allocator state
for _ in 0..10 {
let msgs = make_messages(5);
let _ = orchestrator_service
.determine_route(&msgs, None, "warmup")
.await;
}
// Snapshot memory after warmup
let baseline = get_allocated();
let num_iterations = 2000;
for i in 0..num_iterations {
let msgs = make_messages(5 + (i % 10));
let inline = if i % 3 == 0 {
Some(make_routing_prefs())
} else {
None
};
let _ = orchestrator_service
.determine_route(&msgs, inline, &format!("req-{i}"))
.await;
}
let after = get_allocated();
let growth = after.saturating_sub(baseline);
let growth_mb = growth as f64 / (1024.0 * 1024.0);
let per_request = growth.checked_div(num_iterations).unwrap_or(0);
eprintln!("=== Routing Stress Test Results ===");
eprintln!(" Iterations: {num_iterations}");
eprintln!(" Baseline alloc: {} bytes", baseline);
eprintln!(" Final alloc: {} bytes", after);
eprintln!(" Growth: {} bytes ({growth_mb:.2} MB)", growth);
eprintln!(" Per-request: {} bytes", per_request);
// Allow up to 256 bytes per request of retained growth (connection pool, etc.)
// A true leak would show thousands of bytes per request.
assert!(
per_request < 256,
"Possible memory leak: {per_request} bytes/request retained after {num_iterations} iterations"
);
}
/// Stress test with high concurrency: many parallel determine_route calls.
#[tokio::test]
async fn stress_test_routing_concurrent() {
let mut server = mockito::Server::new_async().await;
let router_url = format!("{}/v1/chat/completions", server.url());
let mock_response = serde_json::json!({
"id": "chatcmpl-mock",
"object": "chat.completion",
"created": 1234567890,
"model": "plano-orchestrator",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"route\": \"summarization\"}"
},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
});
let _mock = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(mock_response.to_string())
.expect_at_least(1)
.create_async()
.await;
let prefs = make_routing_prefs();
let session_cache = Arc::new(MemorySessionCache::new(1000));
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
router_url,
"Plano-Orchestrator".to_string(),
"plano-orchestrator".to_string(),
Some(prefs),
None,
None,
session_cache,
None,
2048,
));
// Warm up
for _ in 0..20 {
let msgs = make_messages(3);
let _ = orchestrator_service
.determine_route(&msgs, None, "warmup")
.await;
}
let baseline = get_allocated();
let concurrency = 50;
let requests_per_task = 100;
let total = concurrency * requests_per_task;
let mut handles = vec![];
for t in 0..concurrency {
let svc = Arc::clone(&orchestrator_service);
let handle = tokio::spawn(async move {
for r in 0..requests_per_task {
let msgs = make_messages(3 + (r % 8));
let _ = svc
.determine_route(&msgs, None, &format!("req-{t}-{r}"))
.await;
}
});
handles.push(handle);
}
for h in handles {
h.await.unwrap();
}
let after = get_allocated();
let growth = after.saturating_sub(baseline);
let per_request = growth / total;
eprintln!("=== Concurrent Routing Stress Test Results ===");
eprintln!(" Tasks: {concurrency} x {requests_per_task} = {total}");
eprintln!(" Baseline: {} bytes", baseline);
eprintln!(" Final: {} bytes", after);
eprintln!(
" Growth: {} bytes ({:.2} MB)",
growth,
growth as f64 / 1_048_576.0
);
eprintln!(" Per-request: {} bytes", per_request);
assert!(
per_request < 512,
"Possible memory leak under concurrency: {per_request} bytes/request retained after {total} requests"
);
}
#[cfg(feature = "jemalloc")]
fn get_allocated() -> usize {
tikv_jemalloc_ctl::epoch::advance().unwrap();
tikv_jemalloc_ctl::stats::allocated::read().unwrap_or(0)
}
#[cfg(not(feature = "jemalloc"))]
fn get_allocated() -> usize {
0
}
}

View file

@ -0,0 +1,82 @@
use std::{
num::NonZeroUsize,
sync::Arc,
time::{Duration, Instant},
};
use async_trait::async_trait;
use lru::LruCache;
use tokio::sync::Mutex;
use tracing::info;
use super::{CachedRoute, SessionCache};
type CacheStore = Mutex<LruCache<String, (CachedRoute, Instant, Duration)>>;
pub struct MemorySessionCache {
store: Arc<CacheStore>,
}
impl MemorySessionCache {
pub fn new(max_entries: usize) -> Self {
let capacity = NonZeroUsize::new(max_entries)
.unwrap_or_else(|| NonZeroUsize::new(10_000).expect("10_000 is non-zero"));
let store = Arc::new(Mutex::new(LruCache::new(capacity)));
// Spawn a background task to evict TTL-expired entries every 5 minutes.
let store_clone = Arc::clone(&store);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300));
loop {
interval.tick().await;
Self::evict_expired(&store_clone).await;
}
});
Self { store }
}
async fn evict_expired(store: &CacheStore) {
let mut cache = store.lock().await;
let expired: Vec<String> = cache
.iter()
.filter(|(_, (_, inserted_at, ttl))| inserted_at.elapsed() >= *ttl)
.map(|(k, _)| k.clone())
.collect();
let removed = expired.len();
for key in &expired {
cache.pop(key.as_str());
}
if removed > 0 {
info!(
removed = removed,
remaining = cache.len(),
"cleaned up expired session cache entries"
);
}
}
}
#[async_trait]
impl SessionCache for MemorySessionCache {
async fn get(&self, key: &str) -> Option<CachedRoute> {
let mut cache = self.store.lock().await;
if let Some((route, inserted_at, ttl)) = cache.get(key) {
if inserted_at.elapsed() < *ttl {
return Some(route.clone());
}
}
None
}
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
self.store
.lock()
.await
.put(key.to_string(), (route, Instant::now(), ttl));
}
async fn remove(&self, key: &str) {
self.store.lock().await.pop(key);
}
}

View file

@ -0,0 +1,70 @@
use std::sync::Arc;
use async_trait::async_trait;
use common::configuration::Configuration;
use std::time::Duration;
use tracing::{debug, info};
pub mod memory;
pub mod redis;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
}
#[async_trait]
pub trait SessionCache: Send + Sync {
/// Look up a cached routing decision by key.
async fn get(&self, key: &str) -> Option<CachedRoute>;
/// Store a routing decision in the session cache with the given TTL.
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration);
/// Remove a cached routing decision by key.
async fn remove(&self, key: &str);
}
/// Initialize the session cache backend from config.
/// Defaults to the in-memory backend when no `session_cache` block is configured.
pub async fn init_session_cache(
config: &Configuration,
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
use common::configuration::SessionCacheType;
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
let max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
let cache_config = config
.routing
.as_ref()
.and_then(|r| r.session_cache.as_ref());
let cache_type = cache_config
.map(|c| &c.cache_type)
.unwrap_or(&SessionCacheType::Memory);
match cache_type {
SessionCacheType::Memory => {
info!(storage_type = "memory", "initialized session cache");
Ok(Arc::new(memory::MemorySessionCache::new(max_entries)))
}
SessionCacheType::Redis => {
let url = cache_config
.and_then(|c| c.url.as_ref())
.ok_or("session_cache.url is required when type is redis")?;
debug!(storage_type = "redis", url = %url, "initializing session cache");
let cache = redis::RedisSessionCache::new(url)
.await
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
Ok(Arc::new(cache))
}
}
}

Some files were not shown because too many files have changed in this diff Show more