mirror of
https://github.com/katanemo/plano.git
synced 2026-06-14 15:15:15 +02:00
merge origin/main, add DigitalOcean alongside Vercel and OpenRouter
This commit is contained in:
commit
013f377ddf
138 changed files with 17041 additions and 3335 deletions
372
crates/Cargo.lock
generated
372
crates/Cargo.lock
generated
|
|
@ -23,6 +23,18 @@ version = "0.3.8"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.4"
|
||||
|
|
@ -257,6 +269,24 @@ dependencies = [
|
|||
"vsimd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.72.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash 2.1.2",
|
||||
"shlex",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
|
|
@ -316,6 +346,9 @@ dependencies = [
|
|||
"hyper 1.9.0",
|
||||
"hyper-util",
|
||||
"lru",
|
||||
"metrics 0.23.1",
|
||||
"metrics-exporter-prometheus",
|
||||
"metrics-process",
|
||||
"mockito",
|
||||
"opentelemetry",
|
||||
"opentelemetry-http",
|
||||
|
|
@ -325,6 +358,7 @@ dependencies = [
|
|||
"pretty_assertions",
|
||||
"rand 0.9.4",
|
||||
"redis",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -332,6 +366,8 @@ dependencies = [
|
|||
"serde_yaml",
|
||||
"strsim",
|
||||
"thiserror 2.0.18",
|
||||
"tikv-jemalloc-ctl",
|
||||
"tikv-jemallocator",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
|
|
@ -391,6 +427,15 @@ dependencies = [
|
|||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.4"
|
||||
|
|
@ -428,6 +473,17 @@ dependencies = [
|
|||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"libc",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmov"
|
||||
version = "0.5.3"
|
||||
|
|
@ -574,6 +630,21 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.7"
|
||||
|
|
@ -1070,6 +1141,12 @@ dependencies = [
|
|||
"wasip3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.6.3"
|
||||
|
|
@ -1128,7 +1205,7 @@ version = "0.8.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.3.8",
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
|
|
@ -1138,6 +1215,15 @@ version = "0.12.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
dependencies = [
|
||||
"ahash 0.8.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
|
|
@ -1189,6 +1275,12 @@ dependencies = [
|
|||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||
|
||||
[[package]]
|
||||
name = "hex"
|
||||
version = "0.4.3"
|
||||
|
|
@ -1665,6 +1757,27 @@ version = "0.2.185"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libproc"
|
||||
version = "0.14.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a54ad7278b8bc5301d5ffd2a94251c004feb971feba96c971ea4063645990757"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"errno",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.16"
|
||||
|
|
@ -1745,6 +1858,12 @@ version = "0.1.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "mach2"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dae608c151f68243f2b000364e1f7b186d9c29845f7d2d85bd31b9ad77ad552b"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.2.0"
|
||||
|
|
@ -1782,6 +1901,77 @@ version = "2.8.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3045b4193fbdc5b5681f32f11070da9be3609f189a79f3390706d42587f46bb5"
|
||||
dependencies = [
|
||||
"ahash 0.8.12",
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.24.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8"
|
||||
dependencies = [
|
||||
"ahash 0.8.12",
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-exporter-prometheus"
|
||||
version = "0.15.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"http-body-util",
|
||||
"hyper 1.9.0",
|
||||
"hyper-util",
|
||||
"indexmap 2.14.0",
|
||||
"ipnet",
|
||||
"metrics 0.23.1",
|
||||
"metrics-util",
|
||||
"quanta",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-process"
|
||||
version = "2.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4268d87f64a752f5a651314fc683f04da10be65701ea3e721ba4d74f79163cac"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"libproc",
|
||||
"mach2",
|
||||
"metrics 0.24.3",
|
||||
"once_cell",
|
||||
"procfs",
|
||||
"rlimit",
|
||||
"windows",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-util"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.5",
|
||||
"metrics 0.23.1",
|
||||
"num_cpus",
|
||||
"quanta",
|
||||
"sketches-ddsketch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
|
|
@ -1935,6 +2125,16 @@ dependencies = [
|
|||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2-core-foundation"
|
||||
version = "0.3.2"
|
||||
|
|
@ -2125,6 +2325,12 @@ dependencies = [
|
|||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.2"
|
||||
|
|
@ -2278,6 +2484,27 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "procfs"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25485360a54d6861439d60facef26de713b1e126bf015ec8f98239467a2b82f7"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"procfs-core",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "procfs-core"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6401bf7b6af22f78b563665d15a22e9aef27775b79b149a66ca022468a4e405"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"hex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt_gateway"
|
||||
version = "0.1.0"
|
||||
|
|
@ -2333,6 +2560,21 @@ dependencies = [
|
|||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi 0.11.1+wasi-snapshot-preview1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
|
|
@ -2485,6 +2727,15 @@ version = "0.10.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69"
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "0.27.6"
|
||||
|
|
@ -2646,6 +2897,15 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rlimit"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f35ee2729c56bb610f6dba436bf78135f728b7373bdffae2ec815b2d3eb98cc3"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
|
|
@ -3098,6 +3358,12 @@ version = "1.0.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e"
|
||||
|
||||
[[package]]
|
||||
name = "sketches-ddsketch"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.12"
|
||||
|
|
@ -3308,6 +3574,37 @@ dependencies = [
|
|||
"rustc-hash 1.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tikv-jemalloc-ctl"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "661f1f6a57b3a36dc9174a2c10f19513b4866816e13425d3e418b11cc37bc24c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"paste",
|
||||
"tikv-jemalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tikv-jemalloc-sys"
|
||||
version = "0.6.1+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd8aa5b2ab86a2cefa406d889139c162cbb230092f7d1d7cbc1716405d852a3b"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tikv-jemallocator"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0359b4327f954e0567e69fb191cf1436617748813819c94b8cd4a431422d053a"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"tikv-jemalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.47"
|
||||
|
|
@ -4003,6 +4300,49 @@ dependencies = [
|
|||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows"
|
||||
version = "0.62.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580"
|
||||
dependencies = [
|
||||
"windows-collections",
|
||||
"windows-core",
|
||||
"windows-future",
|
||||
"windows-numerics",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-collections"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610"
|
||||
dependencies = [
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.2"
|
||||
|
|
@ -4016,6 +4356,17 @@ dependencies = [
|
|||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-future"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb"
|
||||
dependencies = [
|
||||
"windows-core",
|
||||
"windows-link",
|
||||
"windows-threading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.2"
|
||||
|
|
@ -4044,6 +4395,16 @@ version = "0.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-numerics"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26"
|
||||
dependencies = [
|
||||
"windows-core",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-registry"
|
||||
version = "0.6.1"
|
||||
|
|
@ -4133,6 +4494,15 @@ dependencies = [
|
|||
"windows_x86_64_msvc 0.53.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-threading"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,18 @@ name = "brightstaff"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["jemalloc"]
|
||||
jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"]
|
||||
|
||||
[[bin]]
|
||||
name = "brightstaff"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "signals_replay"
|
||||
path = "src/bin/signals_replay.rs"
|
||||
|
||||
[dependencies]
|
||||
async-openai = "0.30.1"
|
||||
async-trait = "0.1"
|
||||
|
|
@ -26,7 +38,11 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
regex = "1.10"
|
||||
lru = "0.12"
|
||||
metrics = "0.23"
|
||||
metrics-exporter-prometheus = { version = "0.15", default-features = false, features = ["http-listener"] }
|
||||
metrics-process = "2.1"
|
||||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
|
|
@ -35,6 +51,8 @@ serde_with = "3.13.0"
|
|||
strsim = "0.11"
|
||||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tikv-jemallocator = { version = "0.6", optional = true }
|
||||
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true }
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
|
||||
tokio-stream = "0.1"
|
||||
|
|
|
|||
|
|
@ -24,4 +24,7 @@ pub struct AppState {
|
|||
/// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive).
|
||||
pub http_client: reqwest::Client,
|
||||
pub filter_pipeline: Arc<FilterPipeline>,
|
||||
/// When false, agentic signal analysis is skipped on LLM responses to save CPU.
|
||||
/// Controlled by `overrides.disable_signals` in plano config.
|
||||
pub signals_enabled: bool,
|
||||
}
|
||||
|
|
|
|||
175
crates/brightstaff/src/bin/signals_replay.rs
Normal file
175
crates/brightstaff/src/bin/signals_replay.rs
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
//! `signals-replay` — batch driver for the `brightstaff` signal analyzer.
|
||||
//!
|
||||
//! Reads JSONL conversations from stdin (one per line) and emits matching
|
||||
//! JSONL reports on stdout, one per input conversation, in the same order.
|
||||
//!
|
||||
//! Input shape (per line):
|
||||
//! ```json
|
||||
//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]}
|
||||
//! ```
|
||||
//!
|
||||
//! Output shape (per line, success):
|
||||
//! ```json
|
||||
//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }}
|
||||
//! ```
|
||||
//!
|
||||
//! On per-line failure (parse / analyzer error), emits:
|
||||
//! ```json
|
||||
//! {"id": "convo-42", "error": "..."}
|
||||
//! ```
|
||||
//!
|
||||
//! The output report dict is shaped to match the Python reference's
|
||||
//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a
|
||||
//! direct structural diff.
|
||||
|
||||
use std::io::{self, BufRead, BufWriter, Write};
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Map, Value};
|
||||
|
||||
use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct InputLine {
|
||||
id: Value,
|
||||
messages: Vec<MessageRow>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageRow {
|
||||
#[serde(default)]
|
||||
from: String,
|
||||
#[serde(default)]
|
||||
value: String,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let stdin = io::stdin();
|
||||
let stdout = io::stdout();
|
||||
let mut out = BufWriter::new(stdout.lock());
|
||||
let analyzer = SignalAnalyzer::default();
|
||||
|
||||
for line in stdin.lock().lines() {
|
||||
let line = match line {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
eprintln!("read error: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let result = process_line(&analyzer, trimmed);
|
||||
// Always emit one line per input line so id ordering stays aligned.
|
||||
if let Err(e) = writeln!(out, "{result}") {
|
||||
eprintln!("write error: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
// Flush periodically isn't strictly needed — BufWriter handles it,
|
||||
// and the parent process reads the whole stream when we're done.
|
||||
}
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value {
|
||||
let parsed: InputLine = match serde_json::from_str(line) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return json!({
|
||||
"id": Value::Null,
|
||||
"error": format!("input parse: {e}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let id = parsed.id.clone();
|
||||
|
||||
let view: Vec<brightstaff::signals::analyzer::ShareGptMessage<'_>> = parsed
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| brightstaff::signals::analyzer::ShareGptMessage {
|
||||
from: m.from.as_str(),
|
||||
value: m.value.as_str(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let report = analyzer.analyze_sharegpt(&view);
|
||||
let report_dict = report_to_python_dict(&report);
|
||||
json!({
|
||||
"id": id,
|
||||
"report": report_dict,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert a `SignalReport` into the Python reference's `to_dict()` shape.
|
||||
///
|
||||
/// Ordering of category keys in each layer dict follows the Python source
|
||||
/// exactly so even string-equality comparisons behave deterministically.
|
||||
fn report_to_python_dict(r: &SignalReport) -> Value {
|
||||
let mut interaction = Map::new();
|
||||
interaction.insert(
|
||||
"misalignment".to_string(),
|
||||
signal_group_to_python(&r.interaction.misalignment),
|
||||
);
|
||||
interaction.insert(
|
||||
"stagnation".to_string(),
|
||||
signal_group_to_python(&r.interaction.stagnation),
|
||||
);
|
||||
interaction.insert(
|
||||
"disengagement".to_string(),
|
||||
signal_group_to_python(&r.interaction.disengagement),
|
||||
);
|
||||
interaction.insert(
|
||||
"satisfaction".to_string(),
|
||||
signal_group_to_python(&r.interaction.satisfaction),
|
||||
);
|
||||
|
||||
let mut execution = Map::new();
|
||||
execution.insert(
|
||||
"failure".to_string(),
|
||||
signal_group_to_python(&r.execution.failure),
|
||||
);
|
||||
execution.insert(
|
||||
"loops".to_string(),
|
||||
signal_group_to_python(&r.execution.loops),
|
||||
);
|
||||
|
||||
let mut environment = Map::new();
|
||||
environment.insert(
|
||||
"exhaustion".to_string(),
|
||||
signal_group_to_python(&r.environment.exhaustion),
|
||||
);
|
||||
|
||||
json!({
|
||||
"interaction_signals": Value::Object(interaction),
|
||||
"execution_signals": Value::Object(execution),
|
||||
"environment_signals": Value::Object(environment),
|
||||
"overall_quality": r.overall_quality.as_str(),
|
||||
"summary": r.summary,
|
||||
})
|
||||
}
|
||||
|
||||
fn signal_group_to_python(g: &SignalGroup) -> Value {
|
||||
let signals: Vec<Value> = g
|
||||
.signals
|
||||
.iter()
|
||||
.map(|s| {
|
||||
json!({
|
||||
"signal_type": s.signal_type.as_str(),
|
||||
"message_index": s.message_index,
|
||||
"snippet": s.snippet,
|
||||
"confidence": s.confidence,
|
||||
"metadata": s.metadata,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"category": g.category,
|
||||
"count": g.count,
|
||||
"severity": g.severity,
|
||||
"signals": signals,
|
||||
})
|
||||
}
|
||||
53
crates/brightstaff/src/handlers/debug.rs
Normal file
53
crates/brightstaff/src/handlers/debug.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::{Response, StatusCode};
|
||||
|
||||
use super::full;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct MemStats {
|
||||
allocated_bytes: usize,
|
||||
resident_bytes: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
/// Returns jemalloc memory statistics as JSON.
|
||||
/// Falls back to a stub when the jemalloc feature is disabled.
|
||||
pub async fn memstats() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let stats = get_jemalloc_stats();
|
||||
let json = serde_json::to_string(&stats).unwrap();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full(json))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
fn get_jemalloc_stats() -> MemStats {
|
||||
use tikv_jemalloc_ctl::{epoch, stats};
|
||||
|
||||
if let Err(e) = epoch::advance() {
|
||||
return MemStats {
|
||||
allocated_bytes: 0,
|
||||
resident_bytes: 0,
|
||||
error: Some(format!("failed to advance jemalloc epoch: {e}")),
|
||||
};
|
||||
}
|
||||
|
||||
MemStats {
|
||||
allocated_bytes: stats::allocated::read().unwrap_or(0),
|
||||
resident_bytes: stats::resident::read().unwrap_or(0),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "jemalloc"))]
|
||||
fn get_jemalloc_stats() -> MemStats {
|
||||
MemStats {
|
||||
allocated_bytes: 0,
|
||||
resident_bytes: 0,
|
||||
error: Some("jemalloc feature not enabled".to_string()),
|
||||
}
|
||||
}
|
||||
|
|
@ -441,10 +441,8 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
// Handle str/string conversions
|
||||
"str" | "string" => {
|
||||
if !value.is_string() {
|
||||
return Ok(json!(value.to_string()));
|
||||
}
|
||||
"str" | "string" if !value.is_string() => {
|
||||
return Ok(json!(value.to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,16 +24,18 @@ use crate::app_state::AppState;
|
|||
use crate::handlers::agents::pipeline::PipelineProcessor;
|
||||
use crate::handlers::extract_request_id;
|
||||
use crate::handlers::full;
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::streaming::{
|
||||
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
|
||||
ObservableStreamProcessor, StreamProcessor,
|
||||
LlmMetricsCtx, ObservableStreamProcessor, StreamProcessor,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component,
|
||||
plano as tracing_plano, set_service_name,
|
||||
};
|
||||
use model_selection::router_chat_get_upstream_model;
|
||||
|
||||
|
|
@ -102,15 +104,36 @@ async fn llm_chat_inner(
|
|||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let pinned_model: Option<String> = if let Some(ref sid) = session_id {
|
||||
let cached_route = if let Some(ref sid) = session_id {
|
||||
state
|
||||
.orchestrator_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
.map(|c| c.model_name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (pinned_model, pinned_route_name): (Option<String>, Option<String>) = match cached_route {
|
||||
Some(c) => (Some(c.model_name), c.route_name),
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
// Record session id on the LLM span for the observability console.
|
||||
if let Some(ref sid) = session_id {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::SESSION_ID,
|
||||
sid.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
if let Some(ref route_name) = pinned_route_name {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::ROUTE_NAME,
|
||||
route_name.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);
|
||||
|
||||
|
|
@ -120,6 +143,7 @@ async fn llm_chat_inner(
|
|||
&request_path,
|
||||
&state.model_aliases,
|
||||
&state.llm_providers,
|
||||
state.signals_enabled,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -311,6 +335,18 @@ async fn llm_chat_inner(
|
|||
alias_resolved_model.clone()
|
||||
};
|
||||
|
||||
// Record route name on the LLM span (only when the orchestrator produced one).
|
||||
if let Some(ref rn) = route_name {
|
||||
if !rn.is_empty() && rn != "none" {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::ROUTE_NAME,
|
||||
rn.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sid) = session_id {
|
||||
state
|
||||
.orchestrator_service
|
||||
|
|
@ -373,6 +409,7 @@ async fn parse_and_validate_request(
|
|||
request_path: &str,
|
||||
model_aliases: &Option<HashMap<String, ModelAlias>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
signals_enabled: bool,
|
||||
) -> Result<PreparedRequest, Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
let raw_bytes = request
|
||||
.collect()
|
||||
|
|
@ -451,7 +488,11 @@ async fn parse_and_validate_request(
|
|||
let user_message_preview = client_request
|
||||
.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
let messages_for_signals = Some(client_request.get_messages());
|
||||
let messages_for_signals = if signals_enabled {
|
||||
Some(client_request.get_messages())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Set the upstream model name and strip routing metadata
|
||||
client_request.set_model(model_name_only.clone());
|
||||
|
|
@ -652,6 +693,13 @@ async fn send_upstream(
|
|||
|
||||
let request_start_time = std::time::Instant::now();
|
||||
|
||||
// Labels for LLM upstream metrics. We prefer `resolved_model` (post-routing)
|
||||
// and derive the provider from its `provider/model` prefix. This matches the
|
||||
// same model id the cost/latency router keys off.
|
||||
let (metric_provider_raw, metric_model_raw) = bs_metrics::split_provider_model(resolved_model);
|
||||
let metric_provider = metric_provider_raw.to_string();
|
||||
let metric_model = metric_model_raw.to_string();
|
||||
|
||||
let llm_response = match http_client
|
||||
.post(upstream_url)
|
||||
.headers(request_headers.clone())
|
||||
|
|
@ -661,6 +709,14 @@ async fn send_upstream(
|
|||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_class = bs_metrics::llm_error_class_from_reqwest(&err);
|
||||
bs_metrics::record_llm_upstream(
|
||||
&metric_provider,
|
||||
&metric_model,
|
||||
0,
|
||||
err_class,
|
||||
request_start_time.elapsed(),
|
||||
);
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
|
|
@ -671,6 +727,36 @@ async fn send_upstream(
|
|||
// Propagate upstream headers and status
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
|
||||
// Upstream routers (e.g. DigitalOcean Gradient) may return an
|
||||
// `x-model-router-selected-route` header indicating which task-level
|
||||
// route the request was classified into (e.g. "Code Generation"). Surface
|
||||
// it as `plano.route.name` so the obs console's Route hit % panel can
|
||||
// show the breakdown even when Plano's own orchestrator wasn't in the
|
||||
// routing path. Any value from Plano's orchestrator already set earlier
|
||||
// takes precedence — this only fires when the span doesn't already have
|
||||
// a route name.
|
||||
if let Some(upstream_route) = response_headers
|
||||
.get("x-model-router-selected-route")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if !upstream_route.is_empty() {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
crate::tracing::plano::ROUTE_NAME,
|
||||
upstream_route.to_string(),
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
// Record the upstream HTTP status on the span for the obs console.
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
crate::tracing::http::STATUS_CODE,
|
||||
upstream_status.as_u16() as i64,
|
||||
));
|
||||
});
|
||||
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (name, value) in response_headers.iter() {
|
||||
|
|
@ -686,7 +772,12 @@ async fn send_upstream(
|
|||
span_name,
|
||||
request_start_time,
|
||||
messages_for_signals,
|
||||
);
|
||||
)
|
||||
.with_llm_metrics(LlmMetricsCtx {
|
||||
provider: metric_provider.clone(),
|
||||
model: metric_model.clone(),
|
||||
upstream_status: upstream_status.as_u16(),
|
||||
});
|
||||
|
||||
let output_filter_request_headers = if filter_pipeline.has_output_filters() {
|
||||
Some(request_headers.clone())
|
||||
|
|
|
|||
|
|
@ -5,10 +5,24 @@ use hyper::StatusCode;
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::streaming::truncate_message;
|
||||
use crate::tracing::routing;
|
||||
|
||||
/// Classify a request path (already stripped of `/agents` or `/routing` by
|
||||
/// the caller) into the fixed `route` label used on routing metrics.
|
||||
fn route_label_for_path(request_path: &str) -> &'static str {
|
||||
if request_path.starts_with("/agents") {
|
||||
metric_labels::ROUTE_AGENT
|
||||
} else if request_path.starts_with("/routing") {
|
||||
metric_labels::ROUTE_ROUTING
|
||||
} else {
|
||||
metric_labels::ROUTE_LLM
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RoutingResult {
|
||||
/// Primary model to use (first in the ranked list).
|
||||
pub model_name: String,
|
||||
|
|
@ -106,15 +120,23 @@ pub async fn router_chat_get_upstream_model(
|
|||
)
|
||||
.await;
|
||||
|
||||
let determination_ms = routing_start_time.elapsed().as_millis() as i64;
|
||||
let determination_elapsed = routing_start_time.elapsed();
|
||||
let determination_ms = determination_elapsed.as_millis() as i64;
|
||||
let current_span = tracing::Span::current();
|
||||
current_span.record(routing::ROUTE_DETERMINATION_MS, determination_ms);
|
||||
let route_label = route_label_for_path(request_path);
|
||||
|
||||
match routing_result {
|
||||
Ok(route) => match route {
|
||||
Some((route_name, ranked_models)) => {
|
||||
let model_name = ranked_models.first().cloned().unwrap_or_default();
|
||||
current_span.record("route.selected_model", model_name.as_str());
|
||||
bs_metrics::record_router_decision(
|
||||
route_label,
|
||||
&model_name,
|
||||
false,
|
||||
determination_elapsed,
|
||||
);
|
||||
Ok(RoutingResult {
|
||||
model_name,
|
||||
models: ranked_models,
|
||||
|
|
@ -126,6 +148,12 @@ pub async fn router_chat_get_upstream_model(
|
|||
// This signals to llm.rs to use the original validated request model
|
||||
current_span.record("route.selected_model", "none");
|
||||
info!("no route determined, using default model");
|
||||
bs_metrics::record_router_decision(
|
||||
route_label,
|
||||
"none",
|
||||
true,
|
||||
determination_elapsed,
|
||||
);
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: "none".to_string(),
|
||||
|
|
@ -136,6 +164,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
},
|
||||
Err(err) => {
|
||||
current_span.record("route.selected_model", "unknown");
|
||||
bs_metrics::record_router_decision(route_label, "unknown", true, determination_elapsed);
|
||||
Err(RoutingError::internal_error(format!(
|
||||
"Failed to determine route: {}",
|
||||
err
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod agents;
|
||||
pub mod debug;
|
||||
pub mod function_calling;
|
||||
pub mod llm;
|
||||
pub mod models;
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ use tracing::{debug, info, info_span, warn, Instrument};
|
|||
|
||||
use super::extract_or_generate_traceparent;
|
||||
use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
|
|
@ -230,6 +232,17 @@ async fn routing_decision_inner(
|
|||
pinned: false,
|
||||
};
|
||||
|
||||
// Distinguish "decision served" (a concrete model picked) from
|
||||
// "no_candidates" (the sentinel "none" returned when nothing
|
||||
// matched). The handler still responds 200 in both cases, so RED
|
||||
// metrics alone can't tell them apart.
|
||||
let outcome = if response.models.first().map(|m| m == "none").unwrap_or(true) {
|
||||
metric_labels::ROUTING_SVC_NO_CANDIDATES
|
||||
} else {
|
||||
metric_labels::ROUTING_SVC_DECISION_SERVED
|
||||
};
|
||||
bs_metrics::record_routing_service_outcome(outcome);
|
||||
|
||||
info!(
|
||||
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
|
||||
total_models = response.models.len(),
|
||||
|
|
@ -249,6 +262,7 @@ async fn routing_decision_inner(
|
|||
.unwrap())
|
||||
}
|
||||
Err(err) => {
|
||||
bs_metrics::record_routing_service_outcome(metric_labels::ROUTING_SVC_POLICY_ERROR);
|
||||
warn!(error = %err.message, "routing decision failed");
|
||||
Ok(BrightStaffError::InternalServerError(err.message).into_response())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod app_state;
|
||||
pub mod handlers;
|
||||
pub mod metrics;
|
||||
pub mod router;
|
||||
pub mod session_cache;
|
||||
pub mod signals;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
#[cfg(feature = "jemalloc")]
|
||||
#[global_allocator]
|
||||
static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
use brightstaff::app_state::AppState;
|
||||
use brightstaff::handlers::agents::orchestrator::agent_chat;
|
||||
use brightstaff::handlers::debug;
|
||||
use brightstaff::handlers::empty;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::metrics as bs_metrics;
|
||||
use brightstaff::metrics::labels as metric_labels;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::session_cache::init_session_cache;
|
||||
|
|
@ -326,6 +333,8 @@ async fn init_app_state(
|
|||
.as_ref()
|
||||
.and_then(|tracing| tracing.span_attributes.clone());
|
||||
|
||||
let signals_enabled = !overrides.disable_signals.unwrap_or(false);
|
||||
|
||||
Ok(AppState {
|
||||
orchestrator_service,
|
||||
model_aliases: config.model_aliases.clone(),
|
||||
|
|
@ -337,6 +346,7 @@ async fn init_app_state(
|
|||
span_attributes,
|
||||
http_client: reqwest::Client::new(),
|
||||
filter_pipeline,
|
||||
signals_enabled,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -384,10 +394,79 @@ async fn init_state_storage(
|
|||
// Request routing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Normalized method label — limited set so we never emit a free-form string.
|
||||
fn method_label(method: &Method) -> &'static str {
|
||||
match *method {
|
||||
Method::GET => "GET",
|
||||
Method::POST => "POST",
|
||||
Method::PUT => "PUT",
|
||||
Method::DELETE => "DELETE",
|
||||
Method::PATCH => "PATCH",
|
||||
Method::HEAD => "HEAD",
|
||||
Method::OPTIONS => "OPTIONS",
|
||||
_ => "OTHER",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the fixed `handler` metric label from the request's path+method.
|
||||
/// Returning `None` for fall-through means `route()` will hand the request to
|
||||
/// the catch-all 404 branch.
|
||||
fn handler_label_for(method: &Method, path: &str) -> &'static str {
|
||||
if let Some(stripped) = path.strip_prefix("/agents") {
|
||||
if matches!(
|
||||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return metric_labels::HANDLER_AGENT_CHAT;
|
||||
}
|
||||
}
|
||||
if let Some(stripped) = path.strip_prefix("/routing") {
|
||||
if matches!(
|
||||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return metric_labels::HANDLER_ROUTING_DECISION;
|
||||
}
|
||||
}
|
||||
match (method, path) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
metric_labels::HANDLER_LLM_CHAT
|
||||
}
|
||||
(&Method::POST, "/function_calling") => metric_labels::HANDLER_FUNCTION_CALLING,
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => metric_labels::HANDLER_LIST_MODELS,
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
|
||||
metric_labels::HANDLER_CORS_PREFLIGHT
|
||||
}
|
||||
_ => metric_labels::HANDLER_NOT_FOUND,
|
||||
}
|
||||
}
|
||||
|
||||
/// Route an incoming HTTP request to the appropriate handler.
|
||||
async fn route(
|
||||
req: Request<Incoming>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let handler = handler_label_for(req.method(), req.uri().path());
|
||||
let method = method_label(req.method());
|
||||
let started = std::time::Instant::now();
|
||||
let _in_flight = bs_metrics::InFlightGuard::new(handler);
|
||||
|
||||
let result = dispatch(req, state).await;
|
||||
|
||||
let status = match &result {
|
||||
Ok(resp) => resp.status().as_u16(),
|
||||
// hyper::Error here means the body couldn't be produced; conventionally 500.
|
||||
Err(_) => 500,
|
||||
};
|
||||
bs_metrics::record_http(handler, method, status, started);
|
||||
result
|
||||
}
|
||||
|
||||
/// Inner dispatcher split out so `route()` can wrap it with metrics without
|
||||
/// duplicating the match tree.
|
||||
async fn dispatch(
|
||||
req: Request<Incoming>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers())));
|
||||
let path = req.uri().path().to_string();
|
||||
|
|
@ -439,6 +518,7 @@ async fn route(
|
|||
Ok(list_models(Arc::clone(&state.llm_providers)).await)
|
||||
}
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
|
||||
(&Method::GET, "/debug/memstats") => debug::memstats().await,
|
||||
_ => {
|
||||
debug!(method = %req.method(), path = %path, "no route found");
|
||||
let mut not_found = Response::new(empty());
|
||||
|
|
@ -503,6 +583,7 @@ async fn run_server(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Erro
|
|||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let config = load_config()?;
|
||||
let _tracer_provider = init_tracer(config.tracing.as_ref());
|
||||
bs_metrics::init();
|
||||
info!("loaded plano_config.yaml");
|
||||
let state = Arc::new(init_app_state(&config).await?);
|
||||
run_server(state).await
|
||||
|
|
|
|||
38
crates/brightstaff/src/metrics/labels.rs
Normal file
38
crates/brightstaff/src/metrics/labels.rs
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
//! Fixed label-value constants so callers never emit free-form strings
|
||||
//! (which would blow up cardinality).
|
||||
|
||||
// Handler enum — derived from the path+method match in `route()`.
|
||||
pub const HANDLER_AGENT_CHAT: &str = "agent_chat";
|
||||
pub const HANDLER_ROUTING_DECISION: &str = "routing_decision";
|
||||
pub const HANDLER_LLM_CHAT: &str = "llm_chat";
|
||||
pub const HANDLER_FUNCTION_CALLING: &str = "function_calling";
|
||||
pub const HANDLER_LIST_MODELS: &str = "list_models";
|
||||
pub const HANDLER_CORS_PREFLIGHT: &str = "cors_preflight";
|
||||
pub const HANDLER_NOT_FOUND: &str = "not_found";
|
||||
|
||||
// Router "route" class — which brightstaff endpoint prompted the decision.
|
||||
pub const ROUTE_AGENT: &str = "agent";
|
||||
pub const ROUTE_ROUTING: &str = "routing";
|
||||
pub const ROUTE_LLM: &str = "llm";
|
||||
|
||||
// Token kind for brightstaff_llm_tokens_total.
|
||||
pub const TOKEN_KIND_PROMPT: &str = "prompt";
|
||||
pub const TOKEN_KIND_COMPLETION: &str = "completion";
|
||||
|
||||
// LLM error_class values (match docstring in metrics/mod.rs).
|
||||
pub const LLM_ERR_NONE: &str = "none";
|
||||
pub const LLM_ERR_TIMEOUT: &str = "timeout";
|
||||
pub const LLM_ERR_CONNECT: &str = "connect";
|
||||
pub const LLM_ERR_PARSE: &str = "parse";
|
||||
pub const LLM_ERR_OTHER: &str = "other";
|
||||
pub const LLM_ERR_STREAM: &str = "stream";
|
||||
|
||||
// Routing service outcome values.
|
||||
pub const ROUTING_SVC_DECISION_SERVED: &str = "decision_served";
|
||||
pub const ROUTING_SVC_NO_CANDIDATES: &str = "no_candidates";
|
||||
pub const ROUTING_SVC_POLICY_ERROR: &str = "policy_error";
|
||||
|
||||
// Session cache outcome values.
|
||||
pub const SESSION_CACHE_HIT: &str = "hit";
|
||||
pub const SESSION_CACHE_MISS: &str = "miss";
|
||||
pub const SESSION_CACHE_STORE: &str = "store";
|
||||
377
crates/brightstaff/src/metrics/mod.rs
Normal file
377
crates/brightstaff/src/metrics/mod.rs
Normal file
|
|
@ -0,0 +1,377 @@
|
|||
//! Prometheus metrics for brightstaff.
|
||||
//!
|
||||
//! Installs the `metrics` global recorder backed by
|
||||
//! `metrics-exporter-prometheus` and exposes a `/metrics` HTTP endpoint on a
|
||||
//! dedicated admin port (default `0.0.0.0:9092`, overridable via
|
||||
//! `METRICS_BIND_ADDRESS`).
|
||||
//!
|
||||
//! Emitted metric families (see `describe_all` for full list):
|
||||
//! - HTTP RED: `brightstaff_http_requests_total`,
|
||||
//! `brightstaff_http_request_duration_seconds`,
|
||||
//! `brightstaff_http_in_flight_requests`.
|
||||
//! - LLM upstream: `brightstaff_llm_upstream_requests_total`,
|
||||
//! `brightstaff_llm_upstream_duration_seconds`,
|
||||
//! `brightstaff_llm_time_to_first_token_seconds`,
|
||||
//! `brightstaff_llm_tokens_total`,
|
||||
//! `brightstaff_llm_tokens_usage_missing_total`.
|
||||
//! - Routing: `brightstaff_router_decisions_total`,
|
||||
//! `brightstaff_router_decision_duration_seconds`,
|
||||
//! `brightstaff_routing_service_requests_total`,
|
||||
//! `brightstaff_session_cache_events_total`.
|
||||
//! - Process: via `metrics-process`.
|
||||
//! - Build: `brightstaff_build_info`.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub mod labels;
|
||||
|
||||
/// Guard flag so tests don't re-install the global recorder.
|
||||
static INIT: OnceLock<()> = OnceLock::new();
|
||||
|
||||
const DEFAULT_METRICS_BIND: &str = "0.0.0.0:9092";
|
||||
|
||||
/// HTTP request duration buckets (seconds). Capped at 60s.
|
||||
const HTTP_BUCKETS: &[f64] = &[
|
||||
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0,
|
||||
];
|
||||
|
||||
/// LLM upstream / TTFT buckets (seconds). Capped at 120s because provider
|
||||
/// completions routinely run that long.
|
||||
const LLM_BUCKETS: &[f64] = &[0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 120.0];
|
||||
|
||||
/// Router decision buckets (seconds). The orchestrator call itself is usually
|
||||
/// sub-second but bucketed generously in case of upstream slowness.
|
||||
const ROUTER_BUCKETS: &[f64] = &[
|
||||
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0,
|
||||
];
|
||||
|
||||
/// Install the global recorder and spawn the `/metrics` HTTP listener.
|
||||
///
|
||||
/// Safe to call more than once; subsequent calls are no-ops so tests that
|
||||
/// construct their own recorder still work.
|
||||
pub fn init() {
|
||||
if INIT.get().is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let bind: SocketAddr = std::env::var("METRICS_BIND_ADDRESS")
|
||||
.unwrap_or_else(|_| DEFAULT_METRICS_BIND.to_string())
|
||||
.parse()
|
||||
.unwrap_or_else(|err| {
|
||||
warn!(error = %err, default = DEFAULT_METRICS_BIND, "invalid METRICS_BIND_ADDRESS, falling back to default");
|
||||
DEFAULT_METRICS_BIND.parse().expect("default bind parses")
|
||||
});
|
||||
|
||||
let builder = PrometheusBuilder::new()
|
||||
.with_http_listener(bind)
|
||||
.set_buckets_for_metric(
|
||||
Matcher::Full("brightstaff_http_request_duration_seconds".to_string()),
|
||||
HTTP_BUCKETS,
|
||||
)
|
||||
.and_then(|b| {
|
||||
b.set_buckets_for_metric(Matcher::Prefix("brightstaff_llm_".to_string()), LLM_BUCKETS)
|
||||
})
|
||||
.and_then(|b| {
|
||||
b.set_buckets_for_metric(
|
||||
Matcher::Full("brightstaff_router_decision_duration_seconds".to_string()),
|
||||
ROUTER_BUCKETS,
|
||||
)
|
||||
});
|
||||
|
||||
let builder = match builder {
|
||||
Ok(b) => b,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to configure metrics buckets, using defaults");
|
||||
PrometheusBuilder::new().with_http_listener(bind)
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = builder.install() {
|
||||
warn!(error = %err, "failed to install Prometheus recorder; metrics disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = INIT.set(());
|
||||
|
||||
describe_all();
|
||||
emit_build_info();
|
||||
|
||||
// Register process-level collector (RSS, CPU, FDs).
|
||||
let collector = metrics_process::Collector::default();
|
||||
collector.describe();
|
||||
// Prime once at startup; subsequent scrapes refresh via the exporter's
|
||||
// per-scrape render, so we additionally refresh on a short interval to
|
||||
// keep gauges moving between scrapes without requiring client pull.
|
||||
collector.collect();
|
||||
tokio::spawn(async move {
|
||||
let mut tick = tokio::time::interval(Duration::from_secs(10));
|
||||
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tick.tick().await;
|
||||
collector.collect();
|
||||
}
|
||||
});
|
||||
|
||||
info!(address = %bind, "metrics listener started");
|
||||
}
|
||||
|
||||
fn describe_all() {
|
||||
describe_counter!(
|
||||
"brightstaff_http_requests_total",
|
||||
"Total HTTP requests served by brightstaff, by handler and status class."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_http_request_duration_seconds",
|
||||
"Wall-clock duration of HTTP requests served by brightstaff, by handler."
|
||||
);
|
||||
describe_gauge!(
|
||||
"brightstaff_http_in_flight_requests",
|
||||
"Number of HTTP requests currently being served by brightstaff, by handler."
|
||||
);
|
||||
|
||||
describe_counter!(
|
||||
"brightstaff_llm_upstream_requests_total",
|
||||
"LLM upstream request outcomes, by provider, model, status class and error class."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_llm_upstream_duration_seconds",
|
||||
"Wall-clock duration of LLM upstream calls (stream close for streaming), by provider and model."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_llm_time_to_first_token_seconds",
|
||||
"Time from request start to first streamed byte, by provider and model (streaming only)."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_llm_tokens_total",
|
||||
"Tokens reported in the provider `usage` field, by provider, model and kind (prompt/completion)."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_llm_tokens_usage_missing_total",
|
||||
"LLM responses that completed without a usable `usage` block (so token counts are unknown)."
|
||||
);
|
||||
|
||||
describe_counter!(
|
||||
"brightstaff_router_decisions_total",
|
||||
"Routing decisions made by the orchestrator, by route, selected model, and whether a fallback was used."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_router_decision_duration_seconds",
|
||||
"Time spent in the orchestrator deciding a route, by route."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_routing_service_requests_total",
|
||||
"Outcomes of /routing/* decision requests: decision_served, no_candidates, policy_error."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_session_cache_events_total",
|
||||
"Session affinity cache lookups and stores, by outcome."
|
||||
);
|
||||
|
||||
describe_gauge!(
|
||||
"brightstaff_build_info",
|
||||
"Build metadata. Always 1; labels carry version and git SHA."
|
||||
);
|
||||
}
|
||||
|
||||
fn emit_build_info() {
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
let git_sha = option_env!("GIT_SHA").unwrap_or("unknown");
|
||||
gauge!(
|
||||
"brightstaff_build_info",
|
||||
"version" => version.to_string(),
|
||||
"git_sha" => git_sha.to_string(),
|
||||
)
|
||||
.set(1.0);
|
||||
}
|
||||
|
||||
/// Split a provider-qualified model id like `"openai/gpt-4o"` into
|
||||
/// `(provider, model)`. Returns `("unknown", raw)` when there is no `/`.
|
||||
pub fn split_provider_model(full: &str) -> (&str, &str) {
|
||||
match full.split_once('/') {
|
||||
Some((p, m)) => (p, m),
|
||||
None => ("unknown", full),
|
||||
}
|
||||
}
|
||||
|
||||
/// Bucket an HTTP status code into `"2xx"` / `"4xx"` / `"5xx"` / `"1xx"` / `"3xx"`.
|
||||
pub fn status_class(status: u16) -> &'static str {
|
||||
match status {
|
||||
100..=199 => "1xx",
|
||||
200..=299 => "2xx",
|
||||
300..=399 => "3xx",
|
||||
400..=499 => "4xx",
|
||||
500..=599 => "5xx",
|
||||
_ => "other",
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTP RED helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// RAII guard that increments the in-flight gauge on construction and
|
||||
/// decrements on drop. Pair with [`HttpTimer`] in the `route()` wrapper so the
|
||||
/// gauge drops even on error paths.
|
||||
pub struct InFlightGuard {
|
||||
handler: &'static str,
|
||||
}
|
||||
|
||||
impl InFlightGuard {
|
||||
pub fn new(handler: &'static str) -> Self {
|
||||
gauge!(
|
||||
"brightstaff_http_in_flight_requests",
|
||||
"handler" => handler,
|
||||
)
|
||||
.increment(1.0);
|
||||
Self { handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for InFlightGuard {
|
||||
fn drop(&mut self) {
|
||||
gauge!(
|
||||
"brightstaff_http_in_flight_requests",
|
||||
"handler" => self.handler,
|
||||
)
|
||||
.decrement(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record the HTTP request counter + duration histogram.
|
||||
pub fn record_http(handler: &'static str, method: &'static str, status: u16, started: Instant) {
|
||||
let class = status_class(status);
|
||||
counter!(
|
||||
"brightstaff_http_requests_total",
|
||||
"handler" => handler,
|
||||
"method" => method,
|
||||
"status_class" => class,
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"brightstaff_http_request_duration_seconds",
|
||||
"handler" => handler,
|
||||
)
|
||||
.record(started.elapsed().as_secs_f64());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LLM upstream helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Classify an outcome of an LLM upstream call for the `error_class` label.
|
||||
pub fn llm_error_class_from_reqwest(err: &reqwest::Error) -> &'static str {
|
||||
if err.is_timeout() {
|
||||
"timeout"
|
||||
} else if err.is_connect() {
|
||||
"connect"
|
||||
} else if err.is_decode() {
|
||||
"parse"
|
||||
} else {
|
||||
"other"
|
||||
}
|
||||
}
|
||||
|
||||
/// Record the outcome of an LLM upstream call. `status` is the HTTP status
|
||||
/// the upstream returned (0 if the call never produced one, e.g. send failure).
|
||||
/// `error_class` is `"none"` on success, or a discriminated error label.
|
||||
pub fn record_llm_upstream(
|
||||
provider: &str,
|
||||
model: &str,
|
||||
status: u16,
|
||||
error_class: &str,
|
||||
duration: Duration,
|
||||
) {
|
||||
let class = if status == 0 {
|
||||
"error"
|
||||
} else {
|
||||
status_class(status)
|
||||
};
|
||||
counter!(
|
||||
"brightstaff_llm_upstream_requests_total",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
"status_class" => class,
|
||||
"error_class" => error_class.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"brightstaff_llm_upstream_duration_seconds",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
)
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_llm_ttft(provider: &str, model: &str, ttft: Duration) {
|
||||
histogram!(
|
||||
"brightstaff_llm_time_to_first_token_seconds",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
)
|
||||
.record(ttft.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_llm_tokens(provider: &str, model: &str, kind: &'static str, count: u64) {
|
||||
counter!(
|
||||
"brightstaff_llm_tokens_total",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
"kind" => kind,
|
||||
)
|
||||
.increment(count);
|
||||
}
|
||||
|
||||
pub fn record_llm_tokens_usage_missing(provider: &str, model: &str) {
|
||||
counter!(
|
||||
"brightstaff_llm_tokens_usage_missing_total",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Router helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub fn record_router_decision(
|
||||
route: &'static str,
|
||||
selected_model: &str,
|
||||
fallback: bool,
|
||||
duration: Duration,
|
||||
) {
|
||||
counter!(
|
||||
"brightstaff_router_decisions_total",
|
||||
"route" => route,
|
||||
"selected_model" => selected_model.to_string(),
|
||||
"fallback" => if fallback { "true" } else { "false" },
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"brightstaff_router_decision_duration_seconds",
|
||||
"route" => route,
|
||||
)
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_routing_service_outcome(outcome: &'static str) {
|
||||
counter!(
|
||||
"brightstaff_routing_service_requests_total",
|
||||
"outcome" => outcome,
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn record_session_cache_event(outcome: &'static str) {
|
||||
counter!(
|
||||
"brightstaff_session_cache_events_total",
|
||||
"outcome" => outcome,
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
|
@ -1,8 +1,14 @@
|
|||
use hermesllm::apis::openai::ChatCompletionsResponse;
|
||||
use hyper::header;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Max bytes of raw upstream body we include in a log message or error text
|
||||
/// when the body is not a recognizable error envelope. Keeps logs from being
|
||||
/// flooded by huge HTML error pages.
|
||||
const RAW_BODY_LOG_LIMIT: usize = 512;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HttpError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
|
|
@ -10,13 +16,64 @@ pub enum HttpError {
|
|||
|
||||
#[error("Failed to parse JSON response: {0}")]
|
||||
Json(serde_json::Error, String),
|
||||
|
||||
#[error("Upstream returned {status}: {message}")]
|
||||
Upstream { status: u16, message: String },
|
||||
}
|
||||
|
||||
/// Shape of an OpenAI-style error response body, e.g.
|
||||
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorEnvelope {
|
||||
error: UpstreamErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorBody {
|
||||
message: String,
|
||||
#[serde(default, rename = "type")]
|
||||
err_type: Option<String>,
|
||||
#[serde(default)]
|
||||
param: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a human-readable error message from an upstream response body.
|
||||
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
|
||||
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
|
||||
/// body (UTF-8 safe).
|
||||
fn extract_upstream_error_message(body: &str) -> String {
|
||||
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
|
||||
let mut msg = env.error.message;
|
||||
if let Some(param) = env.error.param {
|
||||
msg.push_str(&format!(" (param={param})"));
|
||||
}
|
||||
if let Some(err_type) = env.error.err_type {
|
||||
msg.push_str(&format!(" [type={err_type}]"));
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
truncate_for_log(body).to_string()
|
||||
}
|
||||
|
||||
fn truncate_for_log(s: &str) -> &str {
|
||||
if s.len() <= RAW_BODY_LOG_LIMIT {
|
||||
return s;
|
||||
}
|
||||
let mut end = RAW_BODY_LOG_LIMIT;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&s[..end]
|
||||
}
|
||||
|
||||
/// Sends a POST request to the given URL and extracts the text content
|
||||
/// from the first choice of the `ChatCompletionsResponse`.
|
||||
///
|
||||
/// Returns `Some((content, elapsed))` on success, or `None` if the response
|
||||
/// had no choices or the first choice had no content.
|
||||
/// Returns `Some((content, elapsed))` on success, `None` if the response
|
||||
/// had no choices or the first choice had no content. Returns
|
||||
/// `HttpError::Upstream` for any non-2xx status, carrying a message
|
||||
/// extracted from the OpenAI-style error envelope (or a truncated raw body
|
||||
/// if the body is not in that shape).
|
||||
pub async fn post_and_extract_content(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
|
|
@ -26,17 +83,36 @@ pub async fn post_and_extract_content(
|
|||
let start_time = std::time::Instant::now();
|
||||
|
||||
let res = client.post(url).headers(headers).body(body).send().await?;
|
||||
let status = res.status();
|
||||
|
||||
let body = res.text().await?;
|
||||
let elapsed = start_time.elapsed();
|
||||
|
||||
if !status.is_success() {
|
||||
let message = extract_upstream_error_message(&body);
|
||||
warn!(
|
||||
status = status.as_u16(),
|
||||
message = %message,
|
||||
body_size = body.len(),
|
||||
"upstream returned error response"
|
||||
);
|
||||
return Err(HttpError::Upstream {
|
||||
status: status.as_u16(),
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
||||
warn!(error = %err, body = %body, "failed to parse json response");
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %truncate_for_log(&body),
|
||||
"failed to parse json response",
|
||||
);
|
||||
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
||||
})?;
|
||||
|
||||
if response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in response");
|
||||
warn!(body = %truncate_for_log(&body), "no choices in response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
|
@ -46,3 +122,52 @@ pub async fn post_and_extract_content(
|
|||
.as_ref()
|
||||
.map(|c| (c.clone(), elapsed)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extracts_message_from_openai_style_error_envelope() {
|
||||
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert!(
|
||||
msg.starts_with("This model's maximum context length is 32768 tokens."),
|
||||
"unexpected message: {msg}"
|
||||
);
|
||||
assert!(msg.contains("(param=input_tokens)"));
|
||||
assert!(msg.contains("[type=BadRequestError]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_message_without_optional_fields() {
|
||||
let body = r#"{"error":{"message":"something broke"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, "something broke");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_raw_body_when_not_error_envelope() {
|
||||
let body = "<html><body>502 Bad Gateway</body></html>";
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_non_envelope_bodies_in_logs() {
|
||||
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
|
||||
let msg = extract_upstream_error_message(&body);
|
||||
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_for_log_respects_utf8_boundaries() {
|
||||
// 2-byte characters; picking a length that would split mid-char.
|
||||
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
|
||||
let out = truncate_for_log(&body);
|
||||
// Should be a valid &str (implicit — would panic if we returned
|
||||
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
|
||||
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
|
||||
assert!(out.chars().all(|c| c == 'é'));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@ pub mod model_metrics;
|
|||
pub mod orchestrator;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
#[cfg(test)]
|
||||
mod stress_tests;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ use super::http::{self, post_and_extract_content};
|
|||
use super::model_metrics::ModelMetricsService;
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::router::orchestrator_model_v1;
|
||||
use crate::session_cache::SessionCache;
|
||||
|
||||
|
|
@ -130,7 +132,13 @@ impl OrchestratorService {
|
|||
tenant_id: Option<&str>,
|
||||
) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.as_ref()?;
|
||||
cache.get(&Self::session_key(tenant_id, session_id)).await
|
||||
let result = cache.get(&Self::session_key(tenant_id, session_id)).await;
|
||||
bs_metrics::record_session_cache_event(if result.is_some() {
|
||||
metric_labels::SESSION_CACHE_HIT
|
||||
} else {
|
||||
metric_labels::SESSION_CACHE_MISS
|
||||
});
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn cache_route(
|
||||
|
|
@ -151,6 +159,7 @@ impl OrchestratorService {
|
|||
self.session_ttl,
|
||||
)
|
||||
.await;
|
||||
bs_metrics::record_session_cache_event(metric_labels::SESSION_CACHE_STORE);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,18 @@ use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
|||
|
||||
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
|
||||
|
||||
/// Hard cap on the number of recent messages considered when building the
|
||||
/// routing prompt. Bounds prompt growth for long-running conversations and
|
||||
/// acts as an outer guardrail before the token-budget loop runs. The most
|
||||
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
|
||||
/// dropped entirely.
|
||||
pub const MAX_ROUTING_TURNS: usize = 16;
|
||||
|
||||
/// Unicode ellipsis used to mark where content was trimmed out of a long
|
||||
/// message. Helps signal to the downstream router model that the message was
|
||||
/// truncated.
|
||||
const TRIM_MARKER: &str = "…";
|
||||
|
||||
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
||||
struct SpacedJsonFormatter;
|
||||
|
||||
|
|
@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
// Remove system/developer/tool messages and messages without extractable
|
||||
// text (tool calls have no text content we can classify against).
|
||||
let filtered: Vec<&Message> = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
|
|
@ -187,37 +198,72 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
.collect();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
// Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
|
||||
// messages when building the routing prompt. Keeps prompt growth
|
||||
// predictable for long conversations regardless of per-message size.
|
||||
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
|
||||
let messages_vec: &[&Message] = &filtered[start..];
|
||||
|
||||
// Ensure the conversation does not exceed the configured token budget.
|
||||
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
|
||||
// avoid running a real tokenizer on the hot path.
|
||||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
let mut selected_messages_list_reversed: Vec<Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
let message_text = message.content.extract_text();
|
||||
let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
|
||||
if token_count + message_token_count > self.max_token_length {
|
||||
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
attempted_total_tokens = token_count + message_token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
remaining_tokens,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
// If the overflow message is from the user we need to keep
|
||||
// some of it so the orchestrator still sees the latest user
|
||||
// intent. Use a middle-trim (head + ellipsis + tail): users
|
||||
// often frame the task at the start AND put the actual ask
|
||||
// at the end of a long pasted block, so preserving both is
|
||||
// better than a head-only cut. The ellipsis also signals to
|
||||
// the router model that content was dropped.
|
||||
if message.role == Role::User && remaining_tokens > 0 {
|
||||
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
|
||||
let truncated = trim_middle_utf8(&message_text, max_bytes);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(truncated)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
token_count += message_token_count;
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(message_text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: last_message.role.clone(),
|
||||
content: Some(MessageContent::Text(last_message.content.extract_text())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -237,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
let selected_conversation_list: Vec<Message> =
|
||||
selected_messages_list_reversed.into_iter().rev().collect();
|
||||
|
||||
// Generate the orchestrator request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them;
|
||||
|
|
@ -405,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
|
|||
body.replace("'", "\"").replace("\\n", "")
|
||||
}
|
||||
|
||||
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
|
||||
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
|
||||
/// separating the two. All splits respect UTF-8 character boundaries. When
|
||||
/// `max_bytes` is too small to fit the marker at all, falls back to a
|
||||
/// head-only truncation.
|
||||
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
if max_bytes <= TRIM_MARKER.len() {
|
||||
// Not enough room even for the marker — just keep the start.
|
||||
let mut end = max_bytes;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
return s[..end].to_string();
|
||||
}
|
||||
|
||||
let available = max_bytes - TRIM_MARKER.len();
|
||||
// Bias toward the start (60%) where task framing typically lives, while
|
||||
// still preserving ~40% of the tail where the user's actual ask often
|
||||
// appears after a long paste.
|
||||
let mut start_len = available * 3 / 5;
|
||||
while start_len > 0 && !s.is_char_boundary(start_len) {
|
||||
start_len -= 1;
|
||||
}
|
||||
let end_len = available - start_len;
|
||||
let mut end_start = s.len().saturating_sub(end_len);
|
||||
while end_start < s.len() && !s.is_char_boundary(end_start) {
|
||||
end_start += 1;
|
||||
}
|
||||
|
||||
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
|
||||
out.push_str(&s[..start_len]);
|
||||
out.push_str(TRIM_MARKER);
|
||||
out.push_str(&s[end_start..]);
|
||||
out
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn OrchestratorModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "OrchestratorModel")
|
||||
|
|
@ -777,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
// With max_token_length=230, the older user message "given the image
|
||||
// In style of Andy Warhol" overflows the remaining budget and gets
|
||||
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
|
||||
// are kept in full.
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant that selects the most suitable routes based on user intent.
|
||||
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
|
||||
|
|
@ -789,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
|
|||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
"content": "given…rhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
|
@ -862,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
|
|||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_huge_single_user_message_is_middle_trimmed() {
|
||||
// Regression test for the case where a single, extremely large user
|
||||
// message was being passed to the orchestrator verbatim and blowing
|
||||
// past the upstream model's context window. The trimmer must now
|
||||
// middle-trim (head + ellipsis + tail) the oversized message so the
|
||||
// resulting request stays within the configured budget, and the
|
||||
// trim marker must be present so the router model knows content
|
||||
// was dropped.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let max_token_length = 2048;
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
"test-model".to_string(),
|
||||
max_token_length,
|
||||
);
|
||||
|
||||
// ~500KB of content — same scale as the real payload that triggered
|
||||
// the production upstream 400.
|
||||
let head = "HEAD_MARKER_START ";
|
||||
let tail = " TAIL_MARKER_END";
|
||||
let filler = "A".repeat(500_000);
|
||||
let huge_user_content = format!("{head}{filler}{tail}");
|
||||
|
||||
let conversation = vec![Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(huge_user_content.clone())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
|
||||
// scaffolding + slack. Real result should be well under this.
|
||||
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
|
||||
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
|
||||
+ 1024;
|
||||
assert!(
|
||||
prompt.len() < byte_ceiling,
|
||||
"prompt length {} exceeded ceiling {} — truncation did not apply",
|
||||
prompt.len(),
|
||||
byte_ceiling,
|
||||
);
|
||||
|
||||
// Not all 500k filler chars survive.
|
||||
let a_count = prompt.chars().filter(|c| *c == 'A').count();
|
||||
assert!(
|
||||
a_count < filler.len(),
|
||||
"expected user message to be truncated; all {} 'A's survived",
|
||||
a_count
|
||||
);
|
||||
assert!(
|
||||
a_count > 0,
|
||||
"expected some of the user message to survive truncation"
|
||||
);
|
||||
|
||||
// Head and tail of the message must both be preserved (that's the
|
||||
// whole point of middle-trim over head-only).
|
||||
assert!(
|
||||
prompt.contains(head),
|
||||
"head marker missing — head was not preserved"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains(tail),
|
||||
"tail marker missing — tail was not preserved"
|
||||
);
|
||||
|
||||
// Trim marker must be present so the router model can see that
|
||||
// content was omitted.
|
||||
assert!(
|
||||
prompt.contains(TRIM_MARKER),
|
||||
"ellipsis trim marker missing from truncated prompt"
|
||||
);
|
||||
|
||||
// Routing prompt scaffolding remains intact.
|
||||
assert!(prompt.contains("<conversation>"));
|
||||
assert!(prompt.contains("<routes>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_turn_cap_limits_routing_history() {
|
||||
// The outer turn-cap guardrail should keep only the last
|
||||
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
|
||||
// conversation is. We build a conversation with alternating
|
||||
// user/assistant turns tagged with their index and verify that only
|
||||
// the tail of the conversation makes it into the prompt.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
|
||||
|
||||
let mut conversation: Vec<Message> = Vec::new();
|
||||
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
|
||||
for i in 0..total_turns {
|
||||
let role = if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
};
|
||||
conversation.push(Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
|
||||
// must all appear.
|
||||
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
prompt.contains(&tag),
|
||||
"expected recent turn tag {tag} to be present"
|
||||
);
|
||||
}
|
||||
|
||||
// And earlier turns (indexes 0..total-cap) must all be dropped.
|
||||
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
!prompt.contains(&tag),
|
||||
"old turn tag {tag} leaked past turn cap into the prompt"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_middle_utf8_helper() {
|
||||
// No-op when already small enough.
|
||||
assert_eq!(trim_middle_utf8("hello", 100), "hello");
|
||||
assert_eq!(trim_middle_utf8("hello", 5), "hello");
|
||||
|
||||
// 60/40 split with ellipsis when too long.
|
||||
let long = "a".repeat(20);
|
||||
let out = trim_middle_utf8(&long, 10);
|
||||
assert!(out.len() <= 10);
|
||||
assert!(out.contains(TRIM_MARKER));
|
||||
// Exactly one ellipsis, rest are 'a's.
|
||||
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
|
||||
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
|
||||
|
||||
// When max_bytes is smaller than the marker, falls back to
|
||||
// head-only truncation (no marker).
|
||||
let out = trim_middle_utf8("abcdefgh", 2);
|
||||
assert_eq!(out, "ab");
|
||||
|
||||
// UTF-8 boundary safety: 2-byte chars.
|
||||
let s = "é".repeat(50); // 100 bytes
|
||||
let out = trim_middle_utf8(&s, 25);
|
||||
assert!(out.len() <= 25);
|
||||
// Must still be valid UTF-8 that only contains 'é' and the marker.
|
||||
let ok = out.chars().all(|c| c == 'é' || c == '…');
|
||||
assert!(ok, "unexpected char in trimmed output: {out:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
|
|
|
|||
264
crates/brightstaff/src/router/stress_tests.rs
Normal file
264
crates/brightstaff/src/router/stress_tests.rs
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::session_cache::memory::MemorySessionCache;
|
||||
use common::configuration::{SelectionPolicy, SelectionPreference, TopLevelRoutingPreference};
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn make_messages(n: usize) -> Vec<Message> {
|
||||
(0..n)
|
||||
.map(|i| Message {
|
||||
role: if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
},
|
||||
content: Some(MessageContent::Text(format!(
|
||||
"This is message number {i} with some padding text to make it realistic."
|
||||
))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn make_routing_prefs() -> Vec<TopLevelRoutingPreference> {
|
||||
vec![
|
||||
TopLevelRoutingPreference {
|
||||
name: "code_generation".to_string(),
|
||||
description: "Code generation and debugging tasks".to_string(),
|
||||
models: vec![
|
||||
"openai/gpt-4o".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
],
|
||||
selection_policy: SelectionPolicy {
|
||||
prefer: SelectionPreference::None,
|
||||
},
|
||||
},
|
||||
TopLevelRoutingPreference {
|
||||
name: "summarization".to_string(),
|
||||
description: "Summarizing documents and text".to_string(),
|
||||
models: vec![
|
||||
"anthropic/claude-3-sonnet".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
],
|
||||
selection_policy: SelectionPolicy {
|
||||
prefer: SelectionPreference::None,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Stress test: exercise the full routing code path N times using a mock
|
||||
/// HTTP server and measure jemalloc allocated bytes before/after.
|
||||
///
|
||||
/// This catches:
|
||||
/// - Memory leaks in generate_request / parse_response
|
||||
/// - Leaks in reqwest connection handling
|
||||
/// - String accumulation in the orchestrator model
|
||||
/// - Fragmentation (jemalloc allocated vs resident)
|
||||
#[tokio::test]
|
||||
async fn stress_test_routing_determine_route() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let router_url = format!("{}/v1/chat/completions", server.url());
|
||||
|
||||
let mock_response = serde_json::json!({
|
||||
"id": "chatcmpl-mock",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "plano-orchestrator",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"route\": \"code_generation\"}"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
|
||||
});
|
||||
|
||||
let _mock = server
|
||||
.mock("POST", "/v1/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(mock_response.to_string())
|
||||
.expect_at_least(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let prefs = make_routing_prefs();
|
||||
let session_cache = Arc::new(MemorySessionCache::new(1000));
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
router_url,
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
Some(prefs.clone()),
|
||||
None,
|
||||
None,
|
||||
session_cache,
|
||||
None,
|
||||
2048,
|
||||
));
|
||||
|
||||
// Warm up: a few requests to stabilize allocator state
|
||||
for _ in 0..10 {
|
||||
let msgs = make_messages(5);
|
||||
let _ = orchestrator_service
|
||||
.determine_route(&msgs, None, "warmup")
|
||||
.await;
|
||||
}
|
||||
|
||||
// Snapshot memory after warmup
|
||||
let baseline = get_allocated();
|
||||
|
||||
let num_iterations = 2000;
|
||||
|
||||
for i in 0..num_iterations {
|
||||
let msgs = make_messages(5 + (i % 10));
|
||||
let inline = if i % 3 == 0 {
|
||||
Some(make_routing_prefs())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = orchestrator_service
|
||||
.determine_route(&msgs, inline, &format!("req-{i}"))
|
||||
.await;
|
||||
}
|
||||
|
||||
let after = get_allocated();
|
||||
|
||||
let growth = after.saturating_sub(baseline);
|
||||
let growth_mb = growth as f64 / (1024.0 * 1024.0);
|
||||
let per_request = if num_iterations > 0 {
|
||||
growth / num_iterations
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
eprintln!("=== Routing Stress Test Results ===");
|
||||
eprintln!(" Iterations: {num_iterations}");
|
||||
eprintln!(" Baseline alloc: {} bytes", baseline);
|
||||
eprintln!(" Final alloc: {} bytes", after);
|
||||
eprintln!(" Growth: {} bytes ({growth_mb:.2} MB)", growth);
|
||||
eprintln!(" Per-request: {} bytes", per_request);
|
||||
|
||||
// Allow up to 256 bytes per request of retained growth (connection pool, etc.)
|
||||
// A true leak would show thousands of bytes per request.
|
||||
assert!(
|
||||
per_request < 256,
|
||||
"Possible memory leak: {per_request} bytes/request retained after {num_iterations} iterations"
|
||||
);
|
||||
}
|
||||
|
||||
/// Stress test with high concurrency: many parallel determine_route calls.
|
||||
#[tokio::test]
|
||||
async fn stress_test_routing_concurrent() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let router_url = format!("{}/v1/chat/completions", server.url());
|
||||
|
||||
let mock_response = serde_json::json!({
|
||||
"id": "chatcmpl-mock",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "plano-orchestrator",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"route\": \"summarization\"}"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
|
||||
});
|
||||
|
||||
let _mock = server
|
||||
.mock("POST", "/v1/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(mock_response.to_string())
|
||||
.expect_at_least(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let prefs = make_routing_prefs();
|
||||
let session_cache = Arc::new(MemorySessionCache::new(1000));
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
router_url,
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
Some(prefs),
|
||||
None,
|
||||
None,
|
||||
session_cache,
|
||||
None,
|
||||
2048,
|
||||
));
|
||||
|
||||
// Warm up
|
||||
for _ in 0..20 {
|
||||
let msgs = make_messages(3);
|
||||
let _ = orchestrator_service
|
||||
.determine_route(&msgs, None, "warmup")
|
||||
.await;
|
||||
}
|
||||
|
||||
let baseline = get_allocated();
|
||||
|
||||
let concurrency = 50;
|
||||
let requests_per_task = 100;
|
||||
let total = concurrency * requests_per_task;
|
||||
|
||||
let mut handles = vec![];
|
||||
for t in 0..concurrency {
|
||||
let svc = Arc::clone(&orchestrator_service);
|
||||
let handle = tokio::spawn(async move {
|
||||
for r in 0..requests_per_task {
|
||||
let msgs = make_messages(3 + (r % 8));
|
||||
let _ = svc
|
||||
.determine_route(&msgs, None, &format!("req-{t}-{r}"))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.await.unwrap();
|
||||
}
|
||||
|
||||
let after = get_allocated();
|
||||
let growth = after.saturating_sub(baseline);
|
||||
let per_request = growth / total;
|
||||
|
||||
eprintln!("=== Concurrent Routing Stress Test Results ===");
|
||||
eprintln!(" Tasks: {concurrency} x {requests_per_task} = {total}");
|
||||
eprintln!(" Baseline: {} bytes", baseline);
|
||||
eprintln!(" Final: {} bytes", after);
|
||||
eprintln!(
|
||||
" Growth: {} bytes ({:.2} MB)",
|
||||
growth,
|
||||
growth as f64 / 1_048_576.0
|
||||
);
|
||||
eprintln!(" Per-request: {} bytes", per_request);
|
||||
|
||||
assert!(
|
||||
per_request < 512,
|
||||
"Possible memory leak under concurrency: {per_request} bytes/request retained after {total} requests"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
fn get_allocated() -> usize {
|
||||
tikv_jemalloc_ctl::epoch::advance().unwrap();
|
||||
tikv_jemalloc_ctl::stats::allocated::read().unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "jemalloc"))]
|
||||
fn get_allocated() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
347
crates/brightstaff/src/signals/environment/exhaustion.rs
Normal file
347
crates/brightstaff/src/signals/environment/exhaustion.rs
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
//! Environment exhaustion detector. Direct port of
|
||||
//! `signals/environment/exhaustion.py`.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use regex::Regex;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::signals::analyzer::ShareGptMessage;
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType};
|
||||
|
||||
pub const API_ERROR_PATTERNS: &[&str] = &[
|
||||
r"500\s*(internal\s+)?server\s+error",
|
||||
r"502\s*bad\s+gateway",
|
||||
r"503\s*service\s+unavailable",
|
||||
r"504\s*gateway\s+timeout",
|
||||
r"internal\s+server\s+error",
|
||||
r"service\s+unavailable",
|
||||
r"server\s+error",
|
||||
r"backend\s+error",
|
||||
r"upstream\s+error",
|
||||
r"service\s+temporarily\s+unavailable",
|
||||
r"maintenance\s+mode",
|
||||
r"under\s+maintenance",
|
||||
r"try\s+again\s+later",
|
||||
r"temporarily\s+unavailable",
|
||||
r"system\s+error",
|
||||
r"unexpected\s+error",
|
||||
r"unhandled\s+exception",
|
||||
];
|
||||
|
||||
pub const TIMEOUT_PATTERNS: &[&str] = &[
|
||||
r"timeout",
|
||||
r"timed?\s*out",
|
||||
r"etimedout",
|
||||
r"connection\s+timed?\s*out",
|
||||
r"read\s+timed?\s*out",
|
||||
r"request\s+timed?\s*out",
|
||||
r"gateway\s+timeout",
|
||||
r"deadline\s+exceeded",
|
||||
r"took\s+too\s+long",
|
||||
r"operation\s+timed?\s*out",
|
||||
r"socket\s+timeout",
|
||||
];
|
||||
|
||||
pub const RATE_LIMIT_PATTERNS: &[&str] = &[
|
||||
r"rate\s+limit",
|
||||
r"rate.limited",
|
||||
r"(status|error|http)\s*:?\s*429",
|
||||
r"429\s+(too\s+many|rate|limit)",
|
||||
r"too\s+many\s+requests?",
|
||||
r"quota\s+exceeded",
|
||||
r"quota\s+limit",
|
||||
r"throttl(ed|ing)",
|
||||
r"request\s+limit",
|
||||
r"api\s+limit",
|
||||
r"calls?\s+per\s+(second|minute|hour|day)",
|
||||
r"exceeded\s+.*\s+limit",
|
||||
r"slow\s+down",
|
||||
r"retry\s+after",
|
||||
r"requests?\s+exceeded",
|
||||
];
|
||||
|
||||
pub const NETWORK_PATTERNS: &[&str] = &[
|
||||
r"connection\s+refused",
|
||||
r"econnrefused",
|
||||
r"econnreset",
|
||||
r"connection\s+reset",
|
||||
r"enotfound",
|
||||
r"dns\s+(error|failure|lookup)",
|
||||
r"host\s+not\s+found",
|
||||
r"network\s+(error|failure|unreachable)",
|
||||
r"no\s+route\s+to\s+host",
|
||||
r"socket\s+error",
|
||||
r"connection\s+failed",
|
||||
r"unable\s+to\s+connect",
|
||||
r"cannot\s+connect",
|
||||
r"could\s+not\s+connect",
|
||||
r"connect\s+error",
|
||||
r"ssl\s+(error|handshake|certificate)",
|
||||
r"certificate\s+(error|invalid|expired)",
|
||||
];
|
||||
|
||||
pub const MALFORMED_PATTERNS: &[&str] = &[
|
||||
r"json\s+parse\s+error",
|
||||
r"invalid\s+json",
|
||||
r"unexpected\s+token",
|
||||
r"syntax\s+error.*json",
|
||||
r"malformed\s+(response|json|data)",
|
||||
r"unexpected\s+end\s+of",
|
||||
r"parse\s+error",
|
||||
r"parsing\s+failed",
|
||||
r"invalid\s+response",
|
||||
r"unexpected\s+response",
|
||||
r"response\s+format",
|
||||
r"missing\s+field.*response",
|
||||
r"unexpected\s+schema",
|
||||
r"schema\s+validation",
|
||||
r"deserialization\s+error",
|
||||
r"failed\s+to\s+decode",
|
||||
];
|
||||
|
||||
pub const CONTEXT_OVERFLOW_PATTERNS: &[&str] = &[
|
||||
r"context\s+(length|limit|overflow|exceeded)",
|
||||
r"token\s+(limit|overflow|exceeded)",
|
||||
r"max(imum)?\s+tokens?",
|
||||
r"input\s+too\s+(long|large)",
|
||||
r"exceeds?\s+(context|token|character|input)\s+limit",
|
||||
r"message\s+too\s+(long|large)",
|
||||
r"content\s+too\s+(long|large)",
|
||||
r"truncat(ed|ion)\s+(due\s+to|because|for)\s+(length|size|limit)",
|
||||
r"maximum\s+context",
|
||||
r"prompt\s+too\s+(long|large)",
|
||||
];
|
||||
|
||||
fn compile(patterns: &[&str]) -> Regex {
|
||||
let combined = patterns
|
||||
.iter()
|
||||
.map(|p| format!("({})", p))
|
||||
.collect::<Vec<_>>()
|
||||
.join("|");
|
||||
Regex::new(&format!("(?i){}", combined)).expect("exhaustion pattern regex must compile")
|
||||
}
|
||||
|
||||
fn api_error_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(API_ERROR_PATTERNS))
|
||||
}
|
||||
fn timeout_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(TIMEOUT_PATTERNS))
|
||||
}
|
||||
fn rate_limit_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(RATE_LIMIT_PATTERNS))
|
||||
}
|
||||
fn network_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(NETWORK_PATTERNS))
|
||||
}
|
||||
fn malformed_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(MALFORMED_PATTERNS))
|
||||
}
|
||||
fn context_overflow_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(CONTEXT_OVERFLOW_PATTERNS))
|
||||
}
|
||||
|
||||
fn snippet_around(text: &str, m: regex::Match<'_>, context: usize) -> String {
|
||||
let start = m.start().saturating_sub(context);
|
||||
let end = (m.end() + context).min(text.len());
|
||||
let start = align_char_boundary(text, start, false);
|
||||
let end = align_char_boundary(text, end, true);
|
||||
let mut snippet = String::new();
|
||||
if start > 0 {
|
||||
snippet.push_str("...");
|
||||
}
|
||||
snippet.push_str(&text[start..end]);
|
||||
if end < text.len() {
|
||||
snippet.push_str("...");
|
||||
}
|
||||
snippet
|
||||
}
|
||||
|
||||
fn align_char_boundary(s: &str, mut idx: usize, forward: bool) -> usize {
|
||||
if idx >= s.len() {
|
||||
return s.len();
|
||||
}
|
||||
while !s.is_char_boundary(idx) {
|
||||
if forward {
|
||||
idx += 1;
|
||||
} else if idx == 0 {
|
||||
break;
|
||||
} else {
|
||||
idx -= 1;
|
||||
}
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
pub fn analyze_exhaustion(messages: &[ShareGptMessage<'_>]) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("exhaustion");
|
||||
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
if msg.from != "observation" {
|
||||
continue;
|
||||
}
|
||||
let value = msg.value;
|
||||
let lower = value.to_lowercase();
|
||||
|
||||
if let Some(m) = rate_limit_re().find(&lower) {
|
||||
group.add_signal(emit(
|
||||
SignalType::EnvironmentExhaustionRateLimit,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
0.95,
|
||||
"rate_limit",
|
||||
m.as_str(),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = api_error_re().find(&lower) {
|
||||
group.add_signal(emit(
|
||||
SignalType::EnvironmentExhaustionApiError,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
0.9,
|
||||
"api_error",
|
||||
m.as_str(),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = timeout_re().find(&lower) {
|
||||
group.add_signal(emit(
|
||||
SignalType::EnvironmentExhaustionTimeout,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
0.9,
|
||||
"timeout",
|
||||
m.as_str(),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = network_re().find(&lower) {
|
||||
group.add_signal(emit(
|
||||
SignalType::EnvironmentExhaustionNetwork,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
0.9,
|
||||
"network",
|
||||
m.as_str(),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = malformed_re().find(&lower) {
|
||||
group.add_signal(emit(
|
||||
SignalType::EnvironmentExhaustionMalformed,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
0.85,
|
||||
"malformed_response",
|
||||
m.as_str(),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = context_overflow_re().find(&lower) {
|
||||
group.add_signal(emit(
|
||||
SignalType::EnvironmentExhaustionContextOverflow,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
0.9,
|
||||
"context_overflow",
|
||||
m.as_str(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
fn emit(
|
||||
t: SignalType,
|
||||
idx: usize,
|
||||
snippet: String,
|
||||
confidence: f32,
|
||||
kind: &str,
|
||||
matched: &str,
|
||||
) -> SignalInstance {
|
||||
SignalInstance::new(t, idx, snippet)
|
||||
.with_confidence(confidence)
|
||||
.with_metadata(json!({
|
||||
"exhaustion_type": kind,
|
||||
"matched": matched,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn obs(value: &str) -> ShareGptMessage<'_> {
|
||||
ShareGptMessage {
|
||||
from: "observation",
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_rate_limit() {
|
||||
let g = analyze_exhaustion(&[obs("HTTP 429: too many requests, retry after 30s")]);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionRateLimit)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_api_error() {
|
||||
let g = analyze_exhaustion(&[obs("503 service unavailable - try again later")]);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionApiError)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_timeout() {
|
||||
let g = analyze_exhaustion(&[obs("Connection timed out after 30 seconds")]);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionTimeout)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_network_failure() {
|
||||
let g = analyze_exhaustion(&[obs("ECONNREFUSED: connection refused by remote host")]);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionNetwork)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_malformed_response() {
|
||||
let g = analyze_exhaustion(&[obs("Invalid JSON: unexpected token at position 42")]);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionMalformed)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_context_overflow() {
|
||||
let g = analyze_exhaustion(&[obs("Maximum context length exceeded for this model")]);
|
||||
assert!(g.signals.iter().any(|s| matches!(
|
||||
s.signal_type,
|
||||
SignalType::EnvironmentExhaustionContextOverflow
|
||||
)));
|
||||
}
|
||||
}
|
||||
3
crates/brightstaff/src/signals/environment/mod.rs
Normal file
3
crates/brightstaff/src/signals/environment/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
//! Environment signals: exhaustion (external system failures and constraints).
|
||||
|
||||
pub mod exhaustion;
|
||||
388
crates/brightstaff/src/signals/execution/failure.rs
Normal file
388
crates/brightstaff/src/signals/execution/failure.rs
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
//! Execution failure detector. Direct port of `signals/execution/failure.py`.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use regex::Regex;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::signals::analyzer::ShareGptMessage;
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType};
|
||||
|
||||
pub const INVALID_ARGS_PATTERNS: &[&str] = &[
|
||||
r"invalid\s+argument",
|
||||
r"invalid\s+parameter",
|
||||
r"invalid\s+type",
|
||||
r"type\s*error",
|
||||
r"expected\s+\w+\s*,?\s*got\s+\w+",
|
||||
r"required\s+field",
|
||||
r"required\s+parameter",
|
||||
r"missing\s+required",
|
||||
r"missing\s+argument",
|
||||
r"validation\s+failed",
|
||||
r"validation\s+error",
|
||||
r"invalid\s+value",
|
||||
r"invalid\s+format",
|
||||
r"must\s+be\s+(a|an)\s+\w+",
|
||||
r"cannot\s+be\s+(null|empty|none)",
|
||||
r"is\s+not\s+valid",
|
||||
r"does\s+not\s+match",
|
||||
r"out\s+of\s+range",
|
||||
r"invalid\s+date",
|
||||
r"invalid\s+json",
|
||||
r"malformed\s+request",
|
||||
];
|
||||
|
||||
pub const BAD_QUERY_PATTERNS: &[&str] = &[
|
||||
r"invalid\s+query",
|
||||
r"query\s+syntax\s+error",
|
||||
r"malformed\s+query",
|
||||
r"unknown\s+field",
|
||||
r"invalid\s+field",
|
||||
r"invalid\s+filter",
|
||||
r"invalid\s+search",
|
||||
r"unknown\s+id",
|
||||
r"invalid\s+id",
|
||||
r"id\s+format\s+error",
|
||||
r"invalid\s+identifier",
|
||||
r"query\s+failed",
|
||||
r"search\s+error",
|
||||
r"invalid\s+operator",
|
||||
r"unsupported\s+query",
|
||||
];
|
||||
|
||||
pub const TOOL_NOT_FOUND_PATTERNS: &[&str] = &[
|
||||
r"unknown\s+function",
|
||||
r"unknown\s+tool",
|
||||
r"function\s+not\s+found",
|
||||
r"tool\s+not\s+found",
|
||||
r"no\s+such\s+function",
|
||||
r"no\s+such\s+tool",
|
||||
r"undefined\s+function",
|
||||
r"action\s+not\s+supported",
|
||||
r"invalid\s+tool",
|
||||
r"invalid\s+function",
|
||||
r"unrecognized\s+function",
|
||||
];
|
||||
|
||||
pub const AUTH_MISUSE_PATTERNS: &[&str] = &[
|
||||
r"\bunauthorized\b",
|
||||
r"(status|error|http|code)\s*:?\s*401",
|
||||
r"401\s+unauthorized",
|
||||
r"403\s+forbidden",
|
||||
r"permission\s+denied",
|
||||
r"access\s+denied",
|
||||
r"authentication\s+required",
|
||||
r"invalid\s+credentials",
|
||||
r"invalid\s+token",
|
||||
r"token\s+expired",
|
||||
r"missing\s+authorization",
|
||||
r"\bforbidden\b",
|
||||
r"not\s+authorized",
|
||||
r"insufficient\s+permissions?",
|
||||
];
|
||||
|
||||
pub const STATE_ERROR_PATTERNS: &[&str] = &[
|
||||
r"invalid\s+state",
|
||||
r"illegal\s+state",
|
||||
r"must\s+call\s+\w+\s+first",
|
||||
r"must\s+\w+\s+before",
|
||||
r"cannot\s+\w+\s+before",
|
||||
r"already\s+(exists?|created|started|finished)",
|
||||
r"not\s+initialized",
|
||||
r"not\s+started",
|
||||
r"already\s+in\s+progress",
|
||||
r"operation\s+in\s+progress",
|
||||
r"sequence\s+error",
|
||||
r"precondition\s+failed",
|
||||
r"(status|error|http)\s*:?\s*409",
|
||||
r"409\s+conflict",
|
||||
r"\bconflict\b",
|
||||
];
|
||||
|
||||
fn compile(patterns: &[&str]) -> Regex {
|
||||
// Use `(?i)` flag for case-insensitive matching, matching Python's `re.IGNORECASE`.
|
||||
let combined = patterns
|
||||
.iter()
|
||||
.map(|p| format!("({})", p))
|
||||
.collect::<Vec<_>>()
|
||||
.join("|");
|
||||
Regex::new(&format!("(?i){}", combined)).expect("failure pattern regex must compile")
|
||||
}
|
||||
|
||||
fn invalid_args_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(INVALID_ARGS_PATTERNS))
|
||||
}
|
||||
fn bad_query_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(BAD_QUERY_PATTERNS))
|
||||
}
|
||||
fn tool_not_found_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(TOOL_NOT_FOUND_PATTERNS))
|
||||
}
|
||||
fn auth_misuse_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(AUTH_MISUSE_PATTERNS))
|
||||
}
|
||||
fn state_error_re() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| compile(STATE_ERROR_PATTERNS))
|
||||
}
|
||||
|
||||
/// Pull tool name + args from a `function_call` message. Mirrors
|
||||
/// `_extract_tool_info` in the reference.
|
||||
pub(crate) fn extract_tool_info(value: &str) -> (String, String) {
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(value) {
|
||||
if let Some(obj) = parsed.as_object() {
|
||||
let name = obj
|
||||
.get("name")
|
||||
.or_else(|| obj.get("function"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let args = match obj.get("arguments").or_else(|| obj.get("args")) {
|
||||
Some(serde_json::Value::Object(o)) => {
|
||||
serde_json::to_string(&serde_json::Value::Object(o.clone())).unwrap_or_default()
|
||||
}
|
||||
Some(other) => other
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| serde_json::to_string(other).unwrap_or_default()),
|
||||
None => String::new(),
|
||||
};
|
||||
return (name, args);
|
||||
}
|
||||
}
|
||||
let mut snippet: String = value.chars().take(200).collect();
|
||||
snippet.shrink_to_fit();
|
||||
("unknown".to_string(), snippet)
|
||||
}
|
||||
|
||||
/// Build a context-window snippet around a regex match, with leading/trailing
|
||||
/// ellipses when truncated. Mirrors `_get_snippet`.
|
||||
fn snippet_around(text: &str, m: regex::Match<'_>, context: usize) -> String {
|
||||
let start = m.start().saturating_sub(context);
|
||||
let end = (m.end() + context).min(text.len());
|
||||
// Ensure we cut on UTF-8 boundaries.
|
||||
let start = align_char_boundary(text, start, false);
|
||||
let end = align_char_boundary(text, end, true);
|
||||
let mut snippet = String::new();
|
||||
if start > 0 {
|
||||
snippet.push_str("...");
|
||||
}
|
||||
snippet.push_str(&text[start..end]);
|
||||
if end < text.len() {
|
||||
snippet.push_str("...");
|
||||
}
|
||||
snippet
|
||||
}
|
||||
|
||||
fn align_char_boundary(s: &str, mut idx: usize, forward: bool) -> usize {
|
||||
if idx >= s.len() {
|
||||
return s.len();
|
||||
}
|
||||
while !s.is_char_boundary(idx) {
|
||||
if forward {
|
||||
idx += 1;
|
||||
} else if idx == 0 {
|
||||
break;
|
||||
} else {
|
||||
idx -= 1;
|
||||
}
|
||||
}
|
||||
idx
|
||||
}
|
||||
|
||||
pub fn analyze_failure(messages: &[ShareGptMessage<'_>]) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("failure");
|
||||
let mut last_call: Option<(usize, String, String)> = None;
|
||||
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
match msg.from {
|
||||
"function_call" => {
|
||||
let (name, args) = extract_tool_info(msg.value);
|
||||
last_call = Some((i, name, args));
|
||||
continue;
|
||||
}
|
||||
"observation" => {}
|
||||
_ => continue,
|
||||
}
|
||||
|
||||
let value = msg.value;
|
||||
let lower = value.to_lowercase();
|
||||
let (call_index, tool_name) = match &last_call {
|
||||
Some((idx, name, _)) => (*idx, name.clone()),
|
||||
None => (i.saturating_sub(1), "unknown".to_string()),
|
||||
};
|
||||
|
||||
if let Some(m) = invalid_args_re().find(&lower) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionFailureInvalidArgs,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
)
|
||||
.with_confidence(0.9)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"call_index": call_index,
|
||||
"error_type": "invalid_args",
|
||||
"matched": m.as_str(),
|
||||
})),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = tool_not_found_re().find(&lower) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionFailureToolNotFound,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
)
|
||||
.with_confidence(0.95)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"call_index": call_index,
|
||||
"error_type": "tool_not_found",
|
||||
"matched": m.as_str(),
|
||||
})),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = auth_misuse_re().find(&lower) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionFailureAuthMisuse,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
)
|
||||
.with_confidence(0.8)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"call_index": call_index,
|
||||
"error_type": "auth_misuse",
|
||||
"matched": m.as_str(),
|
||||
})),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = state_error_re().find(&lower) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionFailureStateError,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
)
|
||||
.with_confidence(0.85)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"call_index": call_index,
|
||||
"error_type": "state_error",
|
||||
"matched": m.as_str(),
|
||||
})),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(m) = bad_query_re().find(&lower) {
|
||||
let confidence = if ["error", "invalid", "failed"]
|
||||
.iter()
|
||||
.any(|w| lower.contains(w))
|
||||
{
|
||||
0.8
|
||||
} else {
|
||||
0.6
|
||||
};
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionFailureBadQuery,
|
||||
i,
|
||||
snippet_around(value, m, 50),
|
||||
)
|
||||
.with_confidence(confidence)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"call_index": call_index,
|
||||
"error_type": "bad_query",
|
||||
"matched": m.as_str(),
|
||||
})),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn fc(value: &str) -> ShareGptMessage<'_> {
|
||||
ShareGptMessage {
|
||||
from: "function_call",
|
||||
value,
|
||||
}
|
||||
}
|
||||
fn obs(value: &str) -> ShareGptMessage<'_> {
|
||||
ShareGptMessage {
|
||||
from: "observation",
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_invalid_args() {
|
||||
let msgs = vec![
|
||||
fc(r#"{"name":"create_user","arguments":{"age":"twelve"}}"#),
|
||||
obs("Error: validation failed - expected integer got string for field age"),
|
||||
];
|
||||
let g = analyze_failure(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionFailureInvalidArgs)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_tool_not_found() {
|
||||
let msgs = vec![
|
||||
fc(r#"{"name":"send_thought","arguments":{}}"#),
|
||||
obs("Error: unknown function 'send_thought'"),
|
||||
];
|
||||
let g = analyze_failure(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionFailureToolNotFound)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_auth_misuse() {
|
||||
let msgs = vec![
|
||||
fc(r#"{"name":"get_secret","arguments":{}}"#),
|
||||
obs("HTTP 401 Unauthorized"),
|
||||
];
|
||||
let g = analyze_failure(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionFailureAuthMisuse)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_state_error() {
|
||||
let msgs = vec![
|
||||
fc(r#"{"name":"commit_tx","arguments":{}}"#),
|
||||
obs("must call begin_tx first"),
|
||||
];
|
||||
let g = analyze_failure(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionFailureStateError)));
|
||||
}
|
||||
}
|
||||
433
crates/brightstaff/src/signals/execution/loops.rs
Normal file
433
crates/brightstaff/src/signals/execution/loops.rs
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
//! Execution loops detector. Direct port of `signals/execution/loops.py`.
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::signals::analyzer::ShareGptMessage;
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType};
|
||||
|
||||
pub const RETRY_THRESHOLD: usize = 3;
|
||||
pub const PARAMETER_DRIFT_THRESHOLD: usize = 3;
|
||||
pub const OSCILLATION_CYCLES_THRESHOLD: usize = 3;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCall {
|
||||
pub index: usize,
|
||||
pub name: String,
|
||||
/// Canonical JSON string of arguments (sorted keys when parseable).
|
||||
pub args: String,
|
||||
pub args_dict: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
impl ToolCall {
|
||||
pub fn args_equal(&self, other: &ToolCall) -> bool {
|
||||
match (&self.args_dict, &other.args_dict) {
|
||||
(Some(a), Some(b)) => a == b,
|
||||
_ => self.args == other.args,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_tool_call(index: usize, msg: &ShareGptMessage<'_>) -> Option<ToolCall> {
|
||||
if msg.from != "function_call" {
|
||||
return None;
|
||||
}
|
||||
let value = msg.value;
|
||||
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(value) {
|
||||
if let Some(obj) = parsed.as_object() {
|
||||
let name = obj
|
||||
.get("name")
|
||||
.or_else(|| obj.get("function"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let raw_args = obj.get("arguments").or_else(|| obj.get("args"));
|
||||
let (args_str, args_dict) = match raw_args {
|
||||
Some(serde_json::Value::Object(o)) => {
|
||||
let mut keys: Vec<&String> = o.keys().collect();
|
||||
keys.sort();
|
||||
let mut canon = serde_json::Map::new();
|
||||
for k in keys {
|
||||
canon.insert(k.clone(), o[k].clone());
|
||||
}
|
||||
(
|
||||
serde_json::to_string(&serde_json::Value::Object(canon.clone()))
|
||||
.unwrap_or_default(),
|
||||
Some(canon),
|
||||
)
|
||||
}
|
||||
Some(other) => (
|
||||
other
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| serde_json::to_string(other).unwrap_or_default()),
|
||||
None,
|
||||
),
|
||||
None => (String::new(), None),
|
||||
};
|
||||
return Some(ToolCall {
|
||||
index,
|
||||
name,
|
||||
args: args_str,
|
||||
args_dict,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(paren) = value.find('(') {
|
||||
if paren > 0 {
|
||||
let name = value[..paren].trim().to_string();
|
||||
let args_part = &value[paren..];
|
||||
if args_part.starts_with('(') && args_part.ends_with(')') {
|
||||
let inner = args_part[1..args_part.len() - 1].trim();
|
||||
if let Ok(serde_json::Value::Object(o)) =
|
||||
serde_json::from_str::<serde_json::Value>(inner)
|
||||
{
|
||||
let mut keys: Vec<&String> = o.keys().collect();
|
||||
keys.sort();
|
||||
let mut canon = serde_json::Map::new();
|
||||
for k in keys {
|
||||
canon.insert(k.clone(), o[k].clone());
|
||||
}
|
||||
return Some(ToolCall {
|
||||
index,
|
||||
name,
|
||||
args: serde_json::to_string(&serde_json::Value::Object(canon.clone()))
|
||||
.unwrap_or_default(),
|
||||
args_dict: Some(canon),
|
||||
});
|
||||
}
|
||||
return Some(ToolCall {
|
||||
index,
|
||||
name,
|
||||
args: inner.to_string(),
|
||||
args_dict: None,
|
||||
});
|
||||
}
|
||||
return Some(ToolCall {
|
||||
index,
|
||||
name,
|
||||
args: args_part.to_string(),
|
||||
args_dict: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Some(ToolCall {
|
||||
index,
|
||||
name: value.trim().to_string(),
|
||||
args: String::new(),
|
||||
args_dict: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_tool_calls(messages: &[ShareGptMessage<'_>]) -> Vec<ToolCall> {
|
||||
let mut out = Vec::new();
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
if let Some(c) = parse_tool_call(i, msg) {
|
||||
out.push(c);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn detect_retry(calls: &[ToolCall]) -> Vec<(usize, usize, String)> {
|
||||
if calls.len() < RETRY_THRESHOLD {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut patterns = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < calls.len() {
|
||||
let current = &calls[i];
|
||||
let mut j = i + 1;
|
||||
let mut run_length = 1;
|
||||
while j < calls.len() {
|
||||
if calls[j].name == current.name && calls[j].args_equal(current) {
|
||||
run_length += 1;
|
||||
j += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if run_length >= RETRY_THRESHOLD {
|
||||
patterns.push((calls[i].index, calls[j - 1].index, current.name.clone()));
|
||||
i = j;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
patterns
|
||||
}
|
||||
|
||||
fn detect_parameter_drift(calls: &[ToolCall]) -> Vec<(usize, usize, String, usize)> {
|
||||
if calls.len() < PARAMETER_DRIFT_THRESHOLD {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut patterns = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < calls.len() {
|
||||
let current_name = calls[i].name.clone();
|
||||
let mut seen_args: Vec<String> = vec![calls[i].args.clone()];
|
||||
let mut unique_args = 1;
|
||||
let mut j = i + 1;
|
||||
while j < calls.len() {
|
||||
if calls[j].name != current_name {
|
||||
break;
|
||||
}
|
||||
if !seen_args.iter().any(|a| a == &calls[j].args) {
|
||||
seen_args.push(calls[j].args.clone());
|
||||
unique_args += 1;
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
let run_length = j - i;
|
||||
if run_length >= PARAMETER_DRIFT_THRESHOLD && unique_args >= 2 {
|
||||
patterns.push((
|
||||
calls[i].index,
|
||||
calls[j - 1].index,
|
||||
current_name,
|
||||
unique_args,
|
||||
));
|
||||
i = j;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
patterns
|
||||
}
|
||||
|
||||
fn detect_oscillation(calls: &[ToolCall]) -> Vec<(usize, usize, Vec<String>, usize)> {
|
||||
let min_calls = 2 * OSCILLATION_CYCLES_THRESHOLD;
|
||||
if calls.len() < min_calls {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut patterns = Vec::new();
|
||||
let mut i: usize = 0;
|
||||
while i + min_calls <= calls.len() {
|
||||
let max_pat_len = (5usize).min(calls.len() - i);
|
||||
let mut found_for_i = false;
|
||||
for pat_len in 2..=max_pat_len {
|
||||
let pattern_names: Vec<String> =
|
||||
(0..pat_len).map(|k| calls[i + k].name.clone()).collect();
|
||||
let unique: std::collections::HashSet<&String> = pattern_names.iter().collect();
|
||||
if unique.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
let mut cycles = 1;
|
||||
let mut pos = i + pat_len;
|
||||
while pos + pat_len <= calls.len() {
|
||||
let mut all_match = true;
|
||||
for k in 0..pat_len {
|
||||
if calls[pos + k].name != pattern_names[k] {
|
||||
all_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if all_match {
|
||||
cycles += 1;
|
||||
pos += pat_len;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if cycles >= OSCILLATION_CYCLES_THRESHOLD {
|
||||
let end_idx_in_calls = i + (cycles * pat_len) - 1;
|
||||
patterns.push((
|
||||
calls[i].index,
|
||||
calls[end_idx_in_calls].index,
|
||||
pattern_names,
|
||||
cycles,
|
||||
));
|
||||
// Mirror Python: `i = end_idx + 1 - pattern_len`. We set `i` so that
|
||||
// the next outer iteration begins after we account for overlap.
|
||||
i = end_idx_in_calls + 1 - pat_len;
|
||||
found_for_i = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found_for_i {
|
||||
i += 1;
|
||||
} else {
|
||||
// Match Python's `i = end_idx + 1 - pattern_len; break` then loop.
|
||||
// We'll continue; the outer while re-checks i.
|
||||
}
|
||||
}
|
||||
if patterns.len() > 1 {
|
||||
patterns = deduplicate_patterns(patterns);
|
||||
}
|
||||
patterns
|
||||
}
|
||||
|
||||
fn deduplicate_patterns(
|
||||
mut patterns: Vec<(usize, usize, Vec<String>, usize)>,
|
||||
) -> Vec<(usize, usize, Vec<String>, usize)> {
|
||||
if patterns.is_empty() {
|
||||
return patterns;
|
||||
}
|
||||
patterns.sort_by(|a, b| {
|
||||
let ord = a.0.cmp(&b.0);
|
||||
if ord != std::cmp::Ordering::Equal {
|
||||
ord
|
||||
} else {
|
||||
(b.1 - b.0).cmp(&(a.1 - a.0))
|
||||
}
|
||||
});
|
||||
let mut result = Vec::new();
|
||||
let mut last_end: i64 = -1;
|
||||
for p in patterns {
|
||||
if (p.0 as i64) > last_end {
|
||||
last_end = p.1 as i64;
|
||||
result.push(p);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn analyze_loops(messages: &[ShareGptMessage<'_>]) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("loops");
|
||||
let calls = extract_tool_calls(messages);
|
||||
if calls.len() < RETRY_THRESHOLD {
|
||||
return group;
|
||||
}
|
||||
|
||||
let retries = detect_retry(&calls);
|
||||
for (start_idx, end_idx, tool_name) in &retries {
|
||||
let call_count = calls
|
||||
.iter()
|
||||
.filter(|c| *start_idx <= c.index && c.index <= *end_idx)
|
||||
.count();
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionLoopsRetry,
|
||||
*start_idx,
|
||||
format!(
|
||||
"Tool '{}' called {} times with identical arguments",
|
||||
tool_name, call_count
|
||||
),
|
||||
)
|
||||
.with_confidence(0.95)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"start_index": start_idx,
|
||||
"end_index": end_idx,
|
||||
"call_count": call_count,
|
||||
"loop_type": "retry",
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
let drifts = detect_parameter_drift(&calls);
|
||||
for (start_idx, end_idx, tool_name, variation_count) in &drifts {
|
||||
let overlaps_retry = retries
|
||||
.iter()
|
||||
.any(|r| !(*end_idx < r.0 || *start_idx > r.1));
|
||||
if overlaps_retry {
|
||||
continue;
|
||||
}
|
||||
let call_count = calls
|
||||
.iter()
|
||||
.filter(|c| *start_idx <= c.index && c.index <= *end_idx)
|
||||
.count();
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionLoopsParameterDrift,
|
||||
*start_idx,
|
||||
format!(
|
||||
"Tool '{}' called {} times with {} different argument variations",
|
||||
tool_name, call_count, variation_count
|
||||
),
|
||||
)
|
||||
.with_confidence(0.85)
|
||||
.with_metadata(json!({
|
||||
"tool_name": tool_name,
|
||||
"start_index": start_idx,
|
||||
"end_index": end_idx,
|
||||
"call_count": call_count,
|
||||
"variation_count": variation_count,
|
||||
"loop_type": "parameter_drift",
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
let oscillations = detect_oscillation(&calls);
|
||||
for (start_idx, end_idx, tool_names, cycle_count) in &oscillations {
|
||||
let pattern_str = tool_names.join(" \u{2192} ");
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::ExecutionLoopsOscillation,
|
||||
*start_idx,
|
||||
format!(
|
||||
"Oscillation pattern [{}] repeated {} times",
|
||||
pattern_str, cycle_count
|
||||
),
|
||||
)
|
||||
.with_confidence(0.9)
|
||||
.with_metadata(json!({
|
||||
"pattern": tool_names,
|
||||
"start_index": start_idx,
|
||||
"end_index": end_idx,
|
||||
"cycle_count": cycle_count,
|
||||
"loop_type": "oscillation",
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn fc(value: &str) -> ShareGptMessage<'_> {
|
||||
ShareGptMessage {
|
||||
from: "function_call",
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_retry_loop() {
|
||||
let arg = r#"{"name":"check_status","arguments":{"id":"abc"}}"#;
|
||||
let msgs = vec![fc(arg), fc(arg), fc(arg), fc(arg)];
|
||||
let g = analyze_loops(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionLoopsRetry)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_parameter_drift() {
|
||||
let msgs = vec![
|
||||
fc(r#"{"name":"search","arguments":{"q":"a"}}"#),
|
||||
fc(r#"{"name":"search","arguments":{"q":"ab"}}"#),
|
||||
fc(r#"{"name":"search","arguments":{"q":"abc"}}"#),
|
||||
fc(r#"{"name":"search","arguments":{"q":"abcd"}}"#),
|
||||
];
|
||||
let g = analyze_loops(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionLoopsParameterDrift)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_oscillation() {
|
||||
let a = r#"{"name":"toolA","arguments":{}}"#;
|
||||
let b = r#"{"name":"toolB","arguments":{}}"#;
|
||||
let msgs = vec![fc(a), fc(b), fc(a), fc(b), fc(a), fc(b)];
|
||||
let g = analyze_loops(&msgs);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::ExecutionLoopsOscillation)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_signals_when_few_calls() {
|
||||
let msgs = vec![fc(r#"{"name":"only_once","arguments":{}}"#)];
|
||||
let g = analyze_loops(&msgs);
|
||||
assert!(g.signals.is_empty());
|
||||
}
|
||||
}
|
||||
5
crates/brightstaff/src/signals/execution/mod.rs
Normal file
5
crates/brightstaff/src/signals/execution/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
//! Execution signals: failure (agent-caused tool errors) and loops
|
||||
//! (repetitive tool-call behavior).
|
||||
|
||||
pub mod failure;
|
||||
pub mod loops;
|
||||
193
crates/brightstaff/src/signals/interaction/constants.rs
Normal file
193
crates/brightstaff/src/signals/interaction/constants.rs
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
//! Shared constants for the interaction layer detectors.
|
||||
//!
|
||||
//! Direct port of `signals/interaction/constants.py`.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
pub const POSITIVE_PREFIXES: &[&str] = &[
|
||||
"yes",
|
||||
"yeah",
|
||||
"yep",
|
||||
"yup",
|
||||
"sure",
|
||||
"ok",
|
||||
"okay",
|
||||
"great",
|
||||
"awesome",
|
||||
"perfect",
|
||||
"thanks",
|
||||
"thank",
|
||||
"wonderful",
|
||||
"excellent",
|
||||
"amazing",
|
||||
"nice",
|
||||
"good",
|
||||
"cool",
|
||||
"absolutely",
|
||||
"definitely",
|
||||
"please",
|
||||
];
|
||||
|
||||
pub const CONFIRMATION_PREFIXES: &[&str] = &[
|
||||
"yes",
|
||||
"yeah",
|
||||
"yep",
|
||||
"yup",
|
||||
"correct",
|
||||
"right",
|
||||
"that's correct",
|
||||
"thats correct",
|
||||
"that's right",
|
||||
"thats right",
|
||||
"that is correct",
|
||||
"that is right",
|
||||
];
|
||||
|
||||
const STOPWORD_LIST: &[&str] = &[
|
||||
"a",
|
||||
"about",
|
||||
"above",
|
||||
"after",
|
||||
"again",
|
||||
"against",
|
||||
"all",
|
||||
"am",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"are",
|
||||
"as",
|
||||
"at",
|
||||
"be",
|
||||
"because",
|
||||
"been",
|
||||
"before",
|
||||
"being",
|
||||
"below",
|
||||
"between",
|
||||
"both",
|
||||
"but",
|
||||
"by",
|
||||
"can",
|
||||
"could",
|
||||
"did",
|
||||
"do",
|
||||
"does",
|
||||
"doing",
|
||||
"down",
|
||||
"during",
|
||||
"each",
|
||||
"few",
|
||||
"for",
|
||||
"from",
|
||||
"further",
|
||||
"had",
|
||||
"has",
|
||||
"have",
|
||||
"having",
|
||||
"he",
|
||||
"her",
|
||||
"here",
|
||||
"hers",
|
||||
"herself",
|
||||
"him",
|
||||
"himself",
|
||||
"his",
|
||||
"how",
|
||||
"i",
|
||||
"if",
|
||||
"in",
|
||||
"into",
|
||||
"is",
|
||||
"it",
|
||||
"its",
|
||||
"itself",
|
||||
"just",
|
||||
"me",
|
||||
"more",
|
||||
"most",
|
||||
"my",
|
||||
"myself",
|
||||
"no",
|
||||
"nor",
|
||||
"not",
|
||||
"now",
|
||||
"of",
|
||||
"off",
|
||||
"on",
|
||||
"once",
|
||||
"only",
|
||||
"or",
|
||||
"other",
|
||||
"our",
|
||||
"ours",
|
||||
"ourselves",
|
||||
"out",
|
||||
"over",
|
||||
"own",
|
||||
"same",
|
||||
"she",
|
||||
"should",
|
||||
"so",
|
||||
"some",
|
||||
"such",
|
||||
"than",
|
||||
"that",
|
||||
"the",
|
||||
"their",
|
||||
"theirs",
|
||||
"them",
|
||||
"themselves",
|
||||
"then",
|
||||
"there",
|
||||
"these",
|
||||
"they",
|
||||
"this",
|
||||
"those",
|
||||
"through",
|
||||
"to",
|
||||
"too",
|
||||
"under",
|
||||
"until",
|
||||
"up",
|
||||
"very",
|
||||
"was",
|
||||
"we",
|
||||
"were",
|
||||
"what",
|
||||
"when",
|
||||
"where",
|
||||
"which",
|
||||
"while",
|
||||
"who",
|
||||
"whom",
|
||||
"why",
|
||||
"with",
|
||||
"would",
|
||||
"you",
|
||||
"your",
|
||||
"yours",
|
||||
"yourself",
|
||||
"yourselves",
|
||||
];
|
||||
|
||||
pub fn stopwords() -> &'static HashSet<&'static str> {
|
||||
static SET: OnceLock<HashSet<&'static str>> = OnceLock::new();
|
||||
SET.get_or_init(|| STOPWORD_LIST.iter().copied().collect())
|
||||
}
|
||||
|
||||
/// Returns true if `text` (case-insensitive, trimmed) starts with any of the
|
||||
/// given prefixes treated as **whole tokens or token sequences**. This matches
|
||||
/// the Python's `text_lower.startswith(prefix)` plus the natural intent that
|
||||
/// `"please"` shouldn't fire on `"pleased"`.
|
||||
pub fn starts_with_prefix(text: &str, prefixes: &[&str]) -> bool {
|
||||
let lowered = text.to_lowercase();
|
||||
let trimmed = lowered.trim_start();
|
||||
for prefix in prefixes {
|
||||
if trimmed.starts_with(prefix) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
445
crates/brightstaff/src/signals/interaction/disengagement.rs
Normal file
445
crates/brightstaff/src/signals/interaction/disengagement.rs
Normal file
|
|
@ -0,0 +1,445 @@
|
|||
//! Disengagement signals: escalation, quit, negative stance.
|
||||
//!
|
||||
//! Direct port of `signals/interaction/disengagement.py`.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use regex::Regex;
|
||||
use serde_json::json;
|
||||
|
||||
use super::constants::{starts_with_prefix, POSITIVE_PREFIXES};
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType};
|
||||
use crate::signals::text_processing::{normalize_patterns, NormalizedMessage, NormalizedPattern};
|
||||
|
||||
const ESCALATION_PATTERN_TEXTS: &[&str] = &[
|
||||
// Human requests
|
||||
"speak to a human",
|
||||
"talk to a human",
|
||||
"connect me to a human",
|
||||
"connect me with a human",
|
||||
"transfer me to a human",
|
||||
"get me a human",
|
||||
"chat with a human",
|
||||
// Person requests
|
||||
"speak to a person",
|
||||
"talk to a person",
|
||||
"connect me to a person",
|
||||
"connect me with a person",
|
||||
"transfer me to a person",
|
||||
"get me a person",
|
||||
"chat with a person",
|
||||
// Real person requests
|
||||
"speak to a real person",
|
||||
"talk to a real person",
|
||||
"connect me to a real person",
|
||||
"connect me with a real person",
|
||||
"transfer me to a real person",
|
||||
"get me a real person",
|
||||
"chat with a real person",
|
||||
// Actual person requests
|
||||
"speak to an actual person",
|
||||
"talk to an actual person",
|
||||
"connect me to an actual person",
|
||||
"connect me with an actual person",
|
||||
"transfer me to an actual person",
|
||||
"get me an actual person",
|
||||
"chat with an actual person",
|
||||
// Supervisor requests
|
||||
"speak to a supervisor",
|
||||
"talk to a supervisor",
|
||||
"connect me to a supervisor",
|
||||
"connect me with a supervisor",
|
||||
"transfer me to a supervisor",
|
||||
"get me a supervisor",
|
||||
"chat with a supervisor",
|
||||
// Manager requests
|
||||
"speak to a manager",
|
||||
"talk to a manager",
|
||||
"connect me to a manager",
|
||||
"connect me with a manager",
|
||||
"transfer me to a manager",
|
||||
"get me a manager",
|
||||
"chat with a manager",
|
||||
// Customer service requests
|
||||
"speak to customer service",
|
||||
"talk to customer service",
|
||||
"connect me to customer service",
|
||||
"connect me with customer service",
|
||||
"transfer me to customer service",
|
||||
"get me customer service",
|
||||
"chat with customer service",
|
||||
// Customer support requests
|
||||
"speak to customer support",
|
||||
"talk to customer support",
|
||||
"connect me to customer support",
|
||||
"connect me with customer support",
|
||||
"transfer me to customer support",
|
||||
"get me customer support",
|
||||
"chat with customer support",
|
||||
// Support requests
|
||||
"speak to support",
|
||||
"talk to support",
|
||||
"connect me to support",
|
||||
"connect me with support",
|
||||
"transfer me to support",
|
||||
"get me support",
|
||||
"chat with support",
|
||||
// Tech support requests
|
||||
"speak to tech support",
|
||||
"talk to tech support",
|
||||
"connect me to tech support",
|
||||
"connect me with tech support",
|
||||
"transfer me to tech support",
|
||||
"get me tech support",
|
||||
"chat with tech support",
|
||||
// Help desk requests
|
||||
"speak to help desk",
|
||||
"talk to help desk",
|
||||
"connect me to help desk",
|
||||
"connect me with help desk",
|
||||
"transfer me to help desk",
|
||||
"get me help desk",
|
||||
"chat with help desk",
|
||||
// Explicit escalation
|
||||
"escalate this",
|
||||
];
|
||||
|
||||
const QUIT_PATTERN_TEXTS: &[&str] = &[
|
||||
"i give up",
|
||||
"i'm giving up",
|
||||
"im giving up",
|
||||
"i'm going to quit",
|
||||
"i quit",
|
||||
"forget it",
|
||||
"forget this",
|
||||
"screw it",
|
||||
"screw this",
|
||||
"don't bother trying",
|
||||
"don't bother with this",
|
||||
"don't bother with it",
|
||||
"don't even bother",
|
||||
"why bother",
|
||||
"not worth it",
|
||||
"this is hopeless",
|
||||
"going elsewhere",
|
||||
"try somewhere else",
|
||||
"look elsewhere",
|
||||
];
|
||||
|
||||
const NEGATIVE_STANCE_PATTERN_TEXTS: &[&str] = &[
|
||||
"this is useless",
|
||||
"not helpful",
|
||||
"doesn't help",
|
||||
"not helping",
|
||||
"you're not helping",
|
||||
"youre not helping",
|
||||
"this doesn't work",
|
||||
"this doesnt work",
|
||||
"this isn't working",
|
||||
"this isnt working",
|
||||
"still doesn't work",
|
||||
"still doesnt work",
|
||||
"still not working",
|
||||
"still isn't working",
|
||||
"still isnt working",
|
||||
"waste of time",
|
||||
"wasting my time",
|
||||
"this is ridiculous",
|
||||
"this is absurd",
|
||||
"this is insane",
|
||||
"this is stupid",
|
||||
"this is dumb",
|
||||
"this sucks",
|
||||
"this is frustrating",
|
||||
"not good enough",
|
||||
"why can't you",
|
||||
"why cant you",
|
||||
"same issue",
|
||||
"did that already",
|
||||
"done that already",
|
||||
"tried that already",
|
||||
"already tried that",
|
||||
"i've done that",
|
||||
"ive done that",
|
||||
"i've tried that",
|
||||
"ive tried that",
|
||||
"i'm disappointed",
|
||||
"im disappointed",
|
||||
"disappointed with you",
|
||||
"disappointed in you",
|
||||
"useless bot",
|
||||
"dumb bot",
|
||||
"stupid bot",
|
||||
];
|
||||
|
||||
const AGENT_DIRECTED_PROFANITY_PATTERN_TEXTS: &[&str] = &[
|
||||
"this is bullshit",
|
||||
"what bullshit",
|
||||
"such bullshit",
|
||||
"total bullshit",
|
||||
"complete bullshit",
|
||||
"this is crap",
|
||||
"what crap",
|
||||
"this is shit",
|
||||
"what the hell is wrong with you",
|
||||
"what the fuck is wrong with you",
|
||||
"you're fucking useless",
|
||||
"youre fucking useless",
|
||||
"you are fucking useless",
|
||||
"fucking useless",
|
||||
"this bot is shit",
|
||||
"this bot is crap",
|
||||
"damn bot",
|
||||
"fucking bot",
|
||||
"stupid fucking",
|
||||
"are you fucking kidding",
|
||||
"wtf is wrong with you",
|
||||
"wtf is this",
|
||||
"ffs just",
|
||||
"for fucks sake",
|
||||
"for fuck's sake",
|
||||
"what the f**k",
|
||||
"what the f*ck",
|
||||
"what the f***",
|
||||
"that's bullsh*t",
|
||||
"thats bullsh*t",
|
||||
"that's bull***t",
|
||||
"thats bull***t",
|
||||
"that's bs",
|
||||
"thats bs",
|
||||
"this is bullsh*t",
|
||||
"this is bull***t",
|
||||
"this is bs",
|
||||
];
|
||||
|
||||
fn escalation_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(ESCALATION_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn quit_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(QUIT_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn negative_stance_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(NEGATIVE_STANCE_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn profanity_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(AGENT_DIRECTED_PROFANITY_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn re_consecutive_q() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| Regex::new(r"\?{2,}").unwrap())
|
||||
}
|
||||
fn re_consecutive_e() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| Regex::new(r"!{2,}").unwrap())
|
||||
}
|
||||
fn re_mixed_punct() -> &'static Regex {
|
||||
static R: OnceLock<Regex> = OnceLock::new();
|
||||
R.get_or_init(|| Regex::new(r"[?!]{3,}").unwrap())
|
||||
}
|
||||
|
||||
pub fn analyze_disengagement(
|
||||
normalized_messages: &[(usize, &str, NormalizedMessage)],
|
||||
char_ngram_threshold: f32,
|
||||
token_cosine_threshold: f32,
|
||||
) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("disengagement");
|
||||
|
||||
for (idx, role, norm_msg) in normalized_messages {
|
||||
if *role != "human" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = &norm_msg.raw;
|
||||
|
||||
// All-caps shouting check.
|
||||
let alpha_chars: String = text.chars().filter(|c| c.is_alphabetic()).collect();
|
||||
if alpha_chars.chars().count() >= 10 {
|
||||
let upper_count = alpha_chars.chars().filter(|c| c.is_uppercase()).count();
|
||||
let upper_ratio = upper_count as f32 / alpha_chars.chars().count() as f32;
|
||||
if upper_ratio >= 0.8 {
|
||||
let snippet: String = text.chars().take(50).collect();
|
||||
group.add_signal(
|
||||
SignalInstance::new(SignalType::DisengagementNegativeStance, *idx, snippet)
|
||||
.with_metadata(json!({
|
||||
"indicator_type": "all_caps",
|
||||
"upper_ratio": upper_ratio,
|
||||
})),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Excessive consecutive punctuation.
|
||||
let starts_with_positive = starts_with_prefix(text, POSITIVE_PREFIXES);
|
||||
let cq = re_consecutive_q().find_iter(text).count();
|
||||
let ce = re_consecutive_e().find_iter(text).count();
|
||||
let mixed = re_mixed_punct().find_iter(text).count();
|
||||
if !starts_with_positive && (cq >= 1 || ce >= 1 || mixed >= 1) {
|
||||
let snippet: String = text.chars().take(50).collect();
|
||||
group.add_signal(
|
||||
SignalInstance::new(SignalType::DisengagementNegativeStance, *idx, snippet)
|
||||
.with_metadata(json!({
|
||||
"indicator_type": "excessive_punctuation",
|
||||
"consecutive_questions": cq,
|
||||
"consecutive_exclamations": ce,
|
||||
"mixed_punctuation": mixed,
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
// Escalation patterns.
|
||||
let mut found_escalation = false;
|
||||
for pattern in escalation_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::DisengagementEscalation,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({"pattern_type": "escalation"})),
|
||||
);
|
||||
found_escalation = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Quit patterns (independent of escalation).
|
||||
for pattern in quit_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(SignalType::DisengagementQuit, *idx, pattern.raw.clone())
|
||||
.with_metadata(json!({"pattern_type": "quit"})),
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Profanity (more specific) before generic negative stance.
|
||||
let mut found_profanity = false;
|
||||
for pattern in profanity_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::DisengagementNegativeStance,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({
|
||||
"indicator_type": "profanity",
|
||||
"pattern": pattern.raw,
|
||||
})),
|
||||
);
|
||||
found_profanity = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !found_escalation && !found_profanity {
|
||||
for pattern in negative_stance_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::DisengagementNegativeStance,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({
|
||||
"indicator_type": "complaint",
|
||||
"pattern": pattern.raw,
|
||||
})),
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn nm(s: &str) -> NormalizedMessage {
|
||||
NormalizedMessage::from_text(s, 2000)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_human_escalation_request() {
|
||||
let msgs = vec![(
|
||||
0usize,
|
||||
"human",
|
||||
nm("This is taking forever, get me a human"),
|
||||
)];
|
||||
let g = analyze_disengagement(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::DisengagementEscalation)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_quit_intent() {
|
||||
let msgs = vec![(0usize, "human", nm("Forget it, I give up"))];
|
||||
let g = analyze_disengagement(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::DisengagementQuit)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_negative_stance_complaint() {
|
||||
let msgs = vec![(0usize, "human", nm("This is useless"))];
|
||||
let g = analyze_disengagement(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::DisengagementNegativeStance)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_excessive_punctuation_as_negative_stance() {
|
||||
let msgs = vec![(0usize, "human", nm("WHY isn't this working???"))];
|
||||
let g = analyze_disengagement(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::DisengagementNegativeStance)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn positive_excitement_is_not_disengagement() {
|
||||
let msgs = vec![(0usize, "human", nm("Yes!! That's perfect!!!"))];
|
||||
let g = analyze_disengagement(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.all(|s| !matches!(s.signal_type, SignalType::DisengagementNegativeStance)));
|
||||
}
|
||||
}
|
||||
338
crates/brightstaff/src/signals/interaction/misalignment.rs
Normal file
338
crates/brightstaff/src/signals/interaction/misalignment.rs
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
//! Misalignment signals: corrections, rephrases, clarifications.
|
||||
//!
|
||||
//! Direct port of `signals/interaction/misalignment.py`.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use super::constants::{stopwords, CONFIRMATION_PREFIXES};
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType};
|
||||
use crate::signals::text_processing::{normalize_patterns, NormalizedMessage, NormalizedPattern};
|
||||
|
||||
const CORRECTION_PATTERN_TEXTS: &[&str] = &[
|
||||
"no, i meant",
|
||||
"no i meant",
|
||||
"no, i said",
|
||||
"no i said",
|
||||
"no, i asked",
|
||||
"no i asked",
|
||||
"nah, i meant",
|
||||
"nope, i meant",
|
||||
"not what i said",
|
||||
"not what i asked",
|
||||
"that's not what i said",
|
||||
"that's not what i asked",
|
||||
"that's not what i meant",
|
||||
"thats not what i said",
|
||||
"thats not what i asked",
|
||||
"thats not what i meant",
|
||||
"that's not what you",
|
||||
"no that's not what i",
|
||||
"no, that's not what i",
|
||||
"you're not quite right",
|
||||
"youre not quite right",
|
||||
"you're not exactly right",
|
||||
"youre not exactly right",
|
||||
"you're wrong about",
|
||||
"youre wrong about",
|
||||
"i just said",
|
||||
"i already said",
|
||||
"i already told you",
|
||||
];
|
||||
|
||||
const REPHRASE_PATTERN_TEXTS: &[&str] = &[
|
||||
"let me rephrase",
|
||||
"let me explain again",
|
||||
"what i'm trying to say",
|
||||
"what i'm saying is",
|
||||
"in other words",
|
||||
];
|
||||
|
||||
const CLARIFICATION_PATTERN_TEXTS: &[&str] = &[
|
||||
"i don't understand",
|
||||
"don't understand",
|
||||
"not understanding",
|
||||
"can't understand",
|
||||
"don't get it",
|
||||
"don't follow",
|
||||
"i'm confused",
|
||||
"so confused",
|
||||
"makes no sense",
|
||||
"doesn't make sense",
|
||||
"not making sense",
|
||||
"what do you mean",
|
||||
"what does that mean",
|
||||
"what are you saying",
|
||||
"i'm lost",
|
||||
"totally lost",
|
||||
"lost me",
|
||||
"no clue what you",
|
||||
"no idea what you",
|
||||
"no clue what that",
|
||||
"no idea what that",
|
||||
"come again",
|
||||
"say that again",
|
||||
"repeat that",
|
||||
"trouble following",
|
||||
"hard to follow",
|
||||
"can't follow",
|
||||
];
|
||||
|
||||
fn correction_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(CORRECTION_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn rephrase_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(REPHRASE_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn clarification_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(CLARIFICATION_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn is_confirmation_message(text: &str) -> bool {
|
||||
let lowered = text.to_lowercase();
|
||||
let trimmed = lowered.trim();
|
||||
CONFIRMATION_PREFIXES.iter().any(|p| trimmed.starts_with(p))
|
||||
}
|
||||
|
||||
/// Detect whether two user messages appear to be rephrases of each other.
|
||||
pub fn is_similar_rephrase(
|
||||
norm_msg1: &NormalizedMessage,
|
||||
norm_msg2: &NormalizedMessage,
|
||||
overlap_threshold: f32,
|
||||
min_meaningful_tokens: usize,
|
||||
max_new_content_ratio: f32,
|
||||
) -> bool {
|
||||
if norm_msg1.tokens.len() < 3 || norm_msg2.tokens.len() < 3 {
|
||||
return false;
|
||||
}
|
||||
if is_confirmation_message(&norm_msg1.raw) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let stops = stopwords();
|
||||
let tokens1: std::collections::HashSet<&str> = norm_msg1
|
||||
.tokens
|
||||
.iter()
|
||||
.filter(|t| !stops.contains(t.as_str()))
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
let tokens2: std::collections::HashSet<&str> = norm_msg2
|
||||
.tokens
|
||||
.iter()
|
||||
.filter(|t| !stops.contains(t.as_str()))
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
if tokens1.len() < min_meaningful_tokens || tokens2.len() < min_meaningful_tokens {
|
||||
return false;
|
||||
}
|
||||
|
||||
let new_tokens: std::collections::HashSet<&&str> = tokens1.difference(&tokens2).collect();
|
||||
let new_content_ratio = if tokens1.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
new_tokens.len() as f32 / tokens1.len() as f32
|
||||
};
|
||||
if new_content_ratio > max_new_content_ratio {
|
||||
return false;
|
||||
}
|
||||
|
||||
let intersection = tokens1.intersection(&tokens2).count();
|
||||
let min_size = tokens1.len().min(tokens2.len());
|
||||
if min_size == 0 {
|
||||
return false;
|
||||
}
|
||||
let overlap_ratio = intersection as f32 / min_size as f32;
|
||||
overlap_ratio >= overlap_threshold
|
||||
}
|
||||
|
||||
/// Analyze user messages for misalignment signals.
|
||||
pub fn analyze_misalignment(
|
||||
normalized_messages: &[(usize, &str, NormalizedMessage)],
|
||||
char_ngram_threshold: f32,
|
||||
token_cosine_threshold: f32,
|
||||
) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("misalignment");
|
||||
|
||||
let mut prev_user_idx: Option<usize> = None;
|
||||
let mut prev_user_msg: Option<&NormalizedMessage> = None;
|
||||
|
||||
for (idx, role, norm_msg) in normalized_messages {
|
||||
if *role != "human" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut found_in_turn = false;
|
||||
|
||||
for pattern in correction_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::MisalignmentCorrection,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({"pattern_type": "correction"})),
|
||||
);
|
||||
found_in_turn = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if found_in_turn {
|
||||
prev_user_idx = Some(*idx);
|
||||
prev_user_msg = Some(norm_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
for pattern in rephrase_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::MisalignmentRephrase,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({"pattern_type": "rephrase"})),
|
||||
);
|
||||
found_in_turn = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if found_in_turn {
|
||||
prev_user_idx = Some(*idx);
|
||||
prev_user_msg = Some(norm_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
for pattern in clarification_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::MisalignmentClarification,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({"pattern_type": "clarification"})),
|
||||
);
|
||||
found_in_turn = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if found_in_turn {
|
||||
prev_user_idx = Some(*idx);
|
||||
prev_user_msg = Some(norm_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Semantic rephrase vs the previous user message (recent only).
|
||||
if let (Some(prev_idx), Some(prev_msg)) = (prev_user_idx, prev_user_msg) {
|
||||
let turns_between = idx.saturating_sub(prev_idx);
|
||||
if turns_between <= 3 && is_similar_rephrase(norm_msg, prev_msg, 0.75, 4, 0.5) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::MisalignmentRephrase,
|
||||
*idx,
|
||||
"[similar rephrase detected]",
|
||||
)
|
||||
.with_confidence(0.8)
|
||||
.with_metadata(json!({
|
||||
"pattern_type": "semantic_rephrase",
|
||||
"compared_to": prev_idx,
|
||||
})),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
prev_user_idx = Some(*idx);
|
||||
prev_user_msg = Some(norm_msg);
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn nm(s: &str) -> NormalizedMessage {
|
||||
NormalizedMessage::from_text(s, 2000)
|
||||
}
|
||||
|
||||
fn make(items: &[(&'static str, &str)]) -> Vec<(usize, &'static str, NormalizedMessage)> {
|
||||
items
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (role, text))| (i, *role, nm(text)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_explicit_correction() {
|
||||
let msgs = make(&[
|
||||
("human", "Show me my orders"),
|
||||
("gpt", "Sure, here are your invoices"),
|
||||
("human", "No, I meant my recent orders"),
|
||||
]);
|
||||
let g = analyze_misalignment(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::MisalignmentCorrection)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_rephrase_marker() {
|
||||
let msgs = make(&[
|
||||
("human", "Show me X"),
|
||||
("gpt", "Sure"),
|
||||
("human", "Let me rephrase: I want X grouped by date"),
|
||||
]);
|
||||
let g = analyze_misalignment(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::MisalignmentRephrase)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_clarification_request() {
|
||||
let msgs = make(&[
|
||||
("human", "Run the report"),
|
||||
("gpt", "Foobar quux baz."),
|
||||
("human", "I don't understand what you mean"),
|
||||
]);
|
||||
let g = analyze_misalignment(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::MisalignmentClarification)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn confirmation_is_not_a_rephrase() {
|
||||
let m1 = nm("Yes, that's correct, please proceed with the order");
|
||||
let m2 = nm("please proceed with the order for the same product");
|
||||
assert!(!is_similar_rephrase(&m1, &m2, 0.75, 4, 0.5));
|
||||
}
|
||||
}
|
||||
10
crates/brightstaff/src/signals/interaction/mod.rs
Normal file
10
crates/brightstaff/src/signals/interaction/mod.rs
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
//! Interaction signals: misalignment, stagnation, disengagement, satisfaction.
|
||||
//!
|
||||
//! These signals capture how the dialogue itself unfolds (semantic alignment,
|
||||
//! progress, engagement, closure) independent of tool execution outcomes.
|
||||
|
||||
pub mod constants;
|
||||
pub mod disengagement;
|
||||
pub mod misalignment;
|
||||
pub mod satisfaction;
|
||||
pub mod stagnation;
|
||||
177
crates/brightstaff/src/signals/interaction/satisfaction.rs
Normal file
177
crates/brightstaff/src/signals/interaction/satisfaction.rs
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
//! Satisfaction signals: gratitude, confirmation, success.
|
||||
//!
|
||||
//! Direct port of `signals/interaction/satisfaction.py`.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType};
|
||||
use crate::signals::text_processing::{normalize_patterns, NormalizedMessage, NormalizedPattern};
|
||||
|
||||
const GRATITUDE_PATTERN_TEXTS: &[&str] = &[
|
||||
"that's helpful",
|
||||
"that helps",
|
||||
"this helps",
|
||||
"appreciate it",
|
||||
"appreciate that",
|
||||
"that's perfect",
|
||||
"exactly what i needed",
|
||||
"just what i needed",
|
||||
"you're the best",
|
||||
"you rock",
|
||||
"you're awesome",
|
||||
"you're amazing",
|
||||
"you're great",
|
||||
];
|
||||
|
||||
const CONFIRMATION_PATTERN_TEXTS: &[&str] = &[
|
||||
"that works",
|
||||
"this works",
|
||||
"that's great",
|
||||
"that's amazing",
|
||||
"this is great",
|
||||
"that's awesome",
|
||||
"love it",
|
||||
"love this",
|
||||
"love that",
|
||||
];
|
||||
|
||||
const SUCCESS_PATTERN_TEXTS: &[&str] = &[
|
||||
"it worked",
|
||||
"that worked",
|
||||
"this worked",
|
||||
"it's working",
|
||||
"that's working",
|
||||
"this is working",
|
||||
];
|
||||
|
||||
fn gratitude_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(GRATITUDE_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn confirmation_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(CONFIRMATION_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
fn success_patterns() -> &'static Vec<NormalizedPattern> {
|
||||
static PATS: OnceLock<Vec<NormalizedPattern>> = OnceLock::new();
|
||||
PATS.get_or_init(|| normalize_patterns(SUCCESS_PATTERN_TEXTS))
|
||||
}
|
||||
|
||||
pub fn analyze_satisfaction(
|
||||
normalized_messages: &[(usize, &str, NormalizedMessage)],
|
||||
char_ngram_threshold: f32,
|
||||
token_cosine_threshold: f32,
|
||||
) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("satisfaction");
|
||||
|
||||
for (idx, role, norm_msg) in normalized_messages {
|
||||
if *role != "human" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut found = false;
|
||||
|
||||
for pattern in gratitude_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::SatisfactionGratitude,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({"pattern_type": "gratitude"})),
|
||||
);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if found {
|
||||
continue;
|
||||
}
|
||||
|
||||
for pattern in confirmation_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::SatisfactionConfirmation,
|
||||
*idx,
|
||||
pattern.raw.clone(),
|
||||
)
|
||||
.with_metadata(json!({"pattern_type": "confirmation"})),
|
||||
);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if found {
|
||||
continue;
|
||||
}
|
||||
|
||||
for pattern in success_patterns() {
|
||||
if norm_msg.matches_normalized_pattern(
|
||||
pattern,
|
||||
char_ngram_threshold,
|
||||
token_cosine_threshold,
|
||||
) {
|
||||
group.add_signal(
|
||||
SignalInstance::new(SignalType::SatisfactionSuccess, *idx, pattern.raw.clone())
|
||||
.with_metadata(json!({"pattern_type": "success"})),
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn nm(s: &str) -> NormalizedMessage {
|
||||
NormalizedMessage::from_text(s, 2000)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_gratitude() {
|
||||
let msgs = vec![(0usize, "human", nm("That's perfect, appreciate it!"))];
|
||||
let g = analyze_satisfaction(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::SatisfactionGratitude)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_confirmation() {
|
||||
let msgs = vec![(0usize, "human", nm("That works for me, thanks"))];
|
||||
let g = analyze_satisfaction(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::SatisfactionConfirmation)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_success() {
|
||||
let msgs = vec![(0usize, "human", nm("Great, it worked!"))];
|
||||
let g = analyze_satisfaction(&msgs, 0.65, 0.6);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::SatisfactionSuccess)));
|
||||
}
|
||||
}
|
||||
241
crates/brightstaff/src/signals/interaction/stagnation.rs
Normal file
241
crates/brightstaff/src/signals/interaction/stagnation.rs
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
//! Stagnation signals: dragging (turn-count efficiency) and repetition.
|
||||
//!
|
||||
//! Direct port of `signals/interaction/stagnation.py`.
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use super::constants::{starts_with_prefix, POSITIVE_PREFIXES};
|
||||
use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType, TurnMetrics};
|
||||
use crate::signals::text_processing::NormalizedMessage;
|
||||
|
||||
/// Adapter row used by stagnation::dragging detector. Mirrors the ShareGPT
|
||||
/// `{"from": role, "value": text}` shape used in the Python reference.
|
||||
pub struct ShareGptMsg<'a> {
|
||||
pub from: &'a str,
|
||||
}
|
||||
|
||||
pub fn analyze_dragging(
|
||||
messages: &[ShareGptMsg<'_>],
|
||||
baseline_turns: usize,
|
||||
efficiency_threshold: f32,
|
||||
) -> (SignalGroup, TurnMetrics) {
|
||||
let mut group = SignalGroup::new("stagnation");
|
||||
|
||||
let mut user_turns: usize = 0;
|
||||
let mut assistant_turns: usize = 0;
|
||||
for m in messages {
|
||||
match m.from {
|
||||
"human" => user_turns += 1,
|
||||
"gpt" => assistant_turns += 1,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let total_turns = user_turns;
|
||||
let efficiency_score: f32 = if total_turns == 0 || total_turns <= baseline_turns {
|
||||
1.0
|
||||
} else {
|
||||
let excess = (total_turns - baseline_turns) as f32;
|
||||
1.0 / (1.0 + excess * 0.25)
|
||||
};
|
||||
|
||||
let is_dragging = efficiency_score < efficiency_threshold;
|
||||
let metrics = TurnMetrics {
|
||||
total_turns,
|
||||
user_turns,
|
||||
assistant_turns,
|
||||
is_dragging,
|
||||
efficiency_score,
|
||||
};
|
||||
|
||||
if is_dragging {
|
||||
let last_idx = messages.len().saturating_sub(1);
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::StagnationDragging,
|
||||
last_idx,
|
||||
format!(
|
||||
"Conversation dragging: {} turns (efficiency: {:.2})",
|
||||
total_turns, efficiency_score
|
||||
),
|
||||
)
|
||||
.with_confidence(1.0 - efficiency_score)
|
||||
.with_metadata(json!({
|
||||
"total_turns": total_turns,
|
||||
"efficiency_score": efficiency_score,
|
||||
"baseline_turns": baseline_turns,
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
(group, metrics)
|
||||
}
|
||||
|
||||
pub fn analyze_repetition(
|
||||
normalized_messages: &[(usize, &str, NormalizedMessage)],
|
||||
lookback: usize,
|
||||
exact_threshold: f32,
|
||||
near_duplicate_threshold: f32,
|
||||
) -> SignalGroup {
|
||||
let mut group = SignalGroup::new("stagnation");
|
||||
|
||||
// We keep references into `normalized_messages`. Since `normalized_messages`
|
||||
// is borrowed for the whole function, this avoids cloning.
|
||||
let mut prev_human: Vec<(usize, &NormalizedMessage)> = Vec::new();
|
||||
let mut prev_gpt: Vec<(usize, &NormalizedMessage)> = Vec::new();
|
||||
|
||||
for (idx, role, norm_msg) in normalized_messages {
|
||||
if *role != "human" && *role != "gpt" {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip human positive-prefix messages; they're naturally repetitive.
|
||||
if *role == "human" && starts_with_prefix(&norm_msg.raw, POSITIVE_PREFIXES) {
|
||||
prev_human.push((*idx, norm_msg));
|
||||
continue;
|
||||
}
|
||||
|
||||
if norm_msg.tokens.len() < 5 {
|
||||
if *role == "human" {
|
||||
prev_human.push((*idx, norm_msg));
|
||||
} else {
|
||||
prev_gpt.push((*idx, norm_msg));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let prev = if *role == "human" {
|
||||
&prev_human
|
||||
} else {
|
||||
&prev_gpt
|
||||
};
|
||||
let start = prev.len().saturating_sub(lookback);
|
||||
let mut matched = false;
|
||||
for (prev_idx, prev_msg) in &prev[start..] {
|
||||
if prev_msg.tokens.len() < 5 {
|
||||
continue;
|
||||
}
|
||||
let similarity = norm_msg.ngram_similarity_with_message(prev_msg);
|
||||
if similarity >= exact_threshold {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::StagnationRepetition,
|
||||
*idx,
|
||||
format!("Exact repetition with message {}", prev_idx),
|
||||
)
|
||||
.with_confidence(similarity)
|
||||
.with_metadata(json!({
|
||||
"repetition_type": "exact",
|
||||
"compared_to": prev_idx,
|
||||
"similarity": similarity,
|
||||
"role": role,
|
||||
})),
|
||||
);
|
||||
matched = true;
|
||||
break;
|
||||
} else if similarity >= near_duplicate_threshold {
|
||||
group.add_signal(
|
||||
SignalInstance::new(
|
||||
SignalType::StagnationRepetition,
|
||||
*idx,
|
||||
format!("Near-duplicate with message {}", prev_idx),
|
||||
)
|
||||
.with_confidence(similarity)
|
||||
.with_metadata(json!({
|
||||
"repetition_type": "near_duplicate",
|
||||
"compared_to": prev_idx,
|
||||
"similarity": similarity,
|
||||
"role": role,
|
||||
})),
|
||||
);
|
||||
matched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = matched;
|
||||
|
||||
if *role == "human" {
|
||||
prev_human.push((*idx, norm_msg));
|
||||
} else {
|
||||
prev_gpt.push((*idx, norm_msg));
|
||||
}
|
||||
}
|
||||
|
||||
group
|
||||
}
|
||||
|
||||
/// Combined stagnation analyzer: dragging + repetition.
|
||||
pub fn analyze_stagnation(
|
||||
messages: &[ShareGptMsg<'_>],
|
||||
normalized_messages: &[(usize, &str, NormalizedMessage)],
|
||||
baseline_turns: usize,
|
||||
) -> (SignalGroup, TurnMetrics) {
|
||||
let (dragging_group, metrics) = analyze_dragging(messages, baseline_turns, 0.5);
|
||||
let repetition_group = analyze_repetition(normalized_messages, 2, 0.95, 0.85);
|
||||
|
||||
let mut combined = SignalGroup::new("stagnation");
|
||||
for s in dragging_group.signals.iter().cloned() {
|
||||
combined.add_signal(s);
|
||||
}
|
||||
for s in repetition_group.signals.iter().cloned() {
|
||||
combined.add_signal(s);
|
||||
}
|
||||
(combined, metrics)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn nm(s: &str) -> NormalizedMessage {
|
||||
NormalizedMessage::from_text(s, 2000)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dragging_after_many_user_turns() {
|
||||
let msgs: Vec<_> = (0..15)
|
||||
.flat_map(|_| [ShareGptMsg { from: "human" }, ShareGptMsg { from: "gpt" }])
|
||||
.collect();
|
||||
let (g, m) = analyze_dragging(&msgs, 5, 0.5);
|
||||
assert!(m.is_dragging);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::StagnationDragging)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_dragging_below_baseline() {
|
||||
let msgs = vec![
|
||||
ShareGptMsg { from: "human" },
|
||||
ShareGptMsg { from: "gpt" },
|
||||
ShareGptMsg { from: "human" },
|
||||
ShareGptMsg { from: "gpt" },
|
||||
];
|
||||
let (g, m) = analyze_dragging(&msgs, 5, 0.5);
|
||||
assert!(!m.is_dragging);
|
||||
assert!(g.signals.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_exact_repetition_in_user_messages() {
|
||||
let n = vec![
|
||||
(
|
||||
0usize,
|
||||
"human",
|
||||
nm("This widget is broken and needs repair right now"),
|
||||
),
|
||||
(1, "gpt", nm("Sorry to hear that. Let me look into it.")),
|
||||
(
|
||||
2,
|
||||
"human",
|
||||
nm("This widget is broken and needs repair right now"),
|
||||
),
|
||||
];
|
||||
let g = analyze_repetition(&n, 2, 0.95, 0.85);
|
||||
assert!(g
|
||||
.signals
|
||||
.iter()
|
||||
.any(|s| matches!(s.signal_type, SignalType::StagnationRepetition)));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,26 @@
|
|||
mod analyzer;
|
||||
//! Plano signals: behavioral quality indicators for agent interactions.
|
||||
//!
|
||||
//! This is a Rust port of the paper-aligned Python reference implementation at
|
||||
//! `https://github.com/katanemo/signals` (or `/Users/shashmi/repos/signals`).
|
||||
//!
|
||||
//! Three layers of signals are detected from a conversation transcript:
|
||||
//!
|
||||
//! - **Interaction**: misalignment, stagnation, disengagement, satisfaction
|
||||
//! - **Execution**: failure, loops
|
||||
//! - **Environment**: exhaustion
|
||||
//!
|
||||
//! See `SignalType` for the full hierarchy.
|
||||
|
||||
pub use analyzer::*;
|
||||
pub mod analyzer;
|
||||
pub mod environment;
|
||||
pub mod execution;
|
||||
pub mod interaction;
|
||||
pub mod otel;
|
||||
pub mod schemas;
|
||||
pub mod text_processing;
|
||||
|
||||
pub use analyzer::{SignalAnalyzer, FLAG_MARKER};
|
||||
pub use schemas::{
|
||||
EnvironmentSignals, ExecutionSignals, InteractionQuality, InteractionSignals, SignalGroup,
|
||||
SignalInstance, SignalLayer, SignalReport, SignalType, TurnMetrics,
|
||||
};
|
||||
|
|
|
|||
241
crates/brightstaff/src/signals/otel.rs
Normal file
241
crates/brightstaff/src/signals/otel.rs
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
//! Helpers for emitting `SignalReport` data to OpenTelemetry spans.
|
||||
//!
|
||||
//! Two sets of attributes are emitted:
|
||||
//!
|
||||
//! - **Legacy** keys under `signals.*` (e.g. `signals.frustration.count`),
|
||||
//! computed from the new layered counts. Preserved for one release for
|
||||
//! backward compatibility with existing dashboards.
|
||||
//! - **New** layered keys (e.g. `signals.interaction.misalignment.count`),
|
||||
//! one set of `count`/`severity` attributes per category, plus per-instance
|
||||
//! span events named `signal.<dotted_signal_type>`.
|
||||
|
||||
use opentelemetry::trace::SpanRef;
|
||||
use opentelemetry::KeyValue;
|
||||
|
||||
use crate::signals::schemas::{SignalGroup, SignalReport, SignalType};
|
||||
|
||||
/// Emit both legacy and layered OTel attributes/events for a `SignalReport`.
|
||||
///
|
||||
/// Returns `true` if any "concerning" signal was found, mirroring the previous
|
||||
/// behavior used to flag the span operation name.
|
||||
pub fn emit_signals_to_span(span: &SpanRef<'_>, report: &SignalReport) -> bool {
|
||||
emit_overall(span, report);
|
||||
emit_layered_attributes(span, report);
|
||||
emit_legacy_attributes(span, report);
|
||||
emit_signal_events(span, report);
|
||||
|
||||
is_concerning(report)
|
||||
}
|
||||
|
||||
fn emit_overall(span: &SpanRef<'_>, report: &SignalReport) {
|
||||
span.set_attribute(KeyValue::new(
|
||||
"signals.quality",
|
||||
report.overall_quality.as_str().to_string(),
|
||||
));
|
||||
span.set_attribute(KeyValue::new(
|
||||
"signals.quality_score",
|
||||
report.quality_score as f64,
|
||||
));
|
||||
span.set_attribute(KeyValue::new(
|
||||
"signals.turn_count",
|
||||
report.turn_metrics.total_turns as i64,
|
||||
));
|
||||
span.set_attribute(KeyValue::new(
|
||||
"signals.efficiency_score",
|
||||
report.turn_metrics.efficiency_score as f64,
|
||||
));
|
||||
}
|
||||
|
||||
fn emit_group(span: &SpanRef<'_>, prefix: &str, group: &SignalGroup) {
|
||||
if group.count == 0 {
|
||||
return;
|
||||
}
|
||||
span.set_attribute(KeyValue::new(
|
||||
format!("{}.count", prefix),
|
||||
group.count as i64,
|
||||
));
|
||||
span.set_attribute(KeyValue::new(
|
||||
format!("{}.severity", prefix),
|
||||
group.severity as i64,
|
||||
));
|
||||
}
|
||||
|
||||
fn emit_layered_attributes(span: &SpanRef<'_>, report: &SignalReport) {
|
||||
emit_group(
|
||||
span,
|
||||
"signals.interaction.misalignment",
|
||||
&report.interaction.misalignment,
|
||||
);
|
||||
emit_group(
|
||||
span,
|
||||
"signals.interaction.stagnation",
|
||||
&report.interaction.stagnation,
|
||||
);
|
||||
emit_group(
|
||||
span,
|
||||
"signals.interaction.disengagement",
|
||||
&report.interaction.disengagement,
|
||||
);
|
||||
emit_group(
|
||||
span,
|
||||
"signals.interaction.satisfaction",
|
||||
&report.interaction.satisfaction,
|
||||
);
|
||||
emit_group(span, "signals.execution.failure", &report.execution.failure);
|
||||
emit_group(span, "signals.execution.loops", &report.execution.loops);
|
||||
emit_group(
|
||||
span,
|
||||
"signals.environment.exhaustion",
|
||||
&report.environment.exhaustion,
|
||||
);
|
||||
}
|
||||
|
||||
fn count_of(report: &SignalReport, t: SignalType) -> usize {
|
||||
report.iter_signals().filter(|s| s.signal_type == t).count()
|
||||
}
|
||||
|
||||
/// Emit the legacy attribute keys consumed by existing dashboards. These are
|
||||
/// derived from the new `SignalReport` so no detector contract is broken.
|
||||
fn emit_legacy_attributes(span: &SpanRef<'_>, report: &SignalReport) {
|
||||
use crate::tracing::signals as legacy;
|
||||
|
||||
// signals.follow_up.repair.{count,ratio} - misalignment proxies repairs.
|
||||
let repair_count = report.interaction.misalignment.count;
|
||||
let user_turns = report.turn_metrics.user_turns.max(1) as f32;
|
||||
if repair_count > 0 {
|
||||
span.set_attribute(KeyValue::new(legacy::REPAIR_COUNT, repair_count as i64));
|
||||
let ratio = repair_count as f32 / user_turns;
|
||||
span.set_attribute(KeyValue::new(legacy::REPAIR_RATIO, format!("{:.3}", ratio)));
|
||||
}
|
||||
|
||||
// signals.frustration.{count,severity} - disengagement.negative_stance is
|
||||
// the closest legacy analog of "frustration".
|
||||
let frustration_count = count_of(report, SignalType::DisengagementNegativeStance);
|
||||
if frustration_count > 0 {
|
||||
span.set_attribute(KeyValue::new(
|
||||
legacy::FRUSTRATION_COUNT,
|
||||
frustration_count as i64,
|
||||
));
|
||||
let severity = match frustration_count {
|
||||
0 => 0,
|
||||
1..=2 => 1,
|
||||
3..=4 => 2,
|
||||
_ => 3,
|
||||
};
|
||||
span.set_attribute(KeyValue::new(legacy::FRUSTRATION_SEVERITY, severity as i64));
|
||||
}
|
||||
|
||||
// signals.repetition.count - stagnation (repetition + dragging).
|
||||
if report.interaction.stagnation.count > 0 {
|
||||
span.set_attribute(KeyValue::new(
|
||||
legacy::REPETITION_COUNT,
|
||||
report.interaction.stagnation.count as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// signals.escalation.requested - any escalation/quit signal.
|
||||
let escalated = report.interaction.disengagement.signals.iter().any(|s| {
|
||||
matches!(
|
||||
s.signal_type,
|
||||
SignalType::DisengagementEscalation | SignalType::DisengagementQuit
|
||||
)
|
||||
});
|
||||
if escalated {
|
||||
span.set_attribute(KeyValue::new(legacy::ESCALATION_REQUESTED, true));
|
||||
}
|
||||
|
||||
// signals.positive_feedback.count - satisfaction signals.
|
||||
if report.interaction.satisfaction.count > 0 {
|
||||
span.set_attribute(KeyValue::new(
|
||||
legacy::POSITIVE_FEEDBACK_COUNT,
|
||||
report.interaction.satisfaction.count as i64,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_signal_events(span: &SpanRef<'_>, report: &SignalReport) {
|
||||
for sig in report.iter_signals() {
|
||||
let event_name = format!("signal.{}", sig.signal_type.as_str());
|
||||
let mut attrs: Vec<KeyValue> = vec![
|
||||
KeyValue::new("signal.type", sig.signal_type.as_str().to_string()),
|
||||
KeyValue::new("signal.message_index", sig.message_index as i64),
|
||||
KeyValue::new("signal.confidence", sig.confidence as f64),
|
||||
];
|
||||
if !sig.snippet.is_empty() {
|
||||
attrs.push(KeyValue::new("signal.snippet", sig.snippet.clone()));
|
||||
}
|
||||
if !sig.metadata.is_null() {
|
||||
attrs.push(KeyValue::new("signal.metadata", sig.metadata.to_string()));
|
||||
}
|
||||
span.add_event(event_name, attrs);
|
||||
}
|
||||
}
|
||||
|
||||
fn is_concerning(report: &SignalReport) -> bool {
|
||||
use crate::signals::schemas::InteractionQuality;
|
||||
if matches!(
|
||||
report.overall_quality,
|
||||
InteractionQuality::Poor | InteractionQuality::Severe
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
if report.interaction.disengagement.count > 0 {
|
||||
return true;
|
||||
}
|
||||
if report.interaction.stagnation.count > 2 {
|
||||
return true;
|
||||
}
|
||||
if report.execution.failure.count > 0 || report.execution.loops.count > 0 {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::signals::schemas::{
|
||||
EnvironmentSignals, ExecutionSignals, InteractionQuality, InteractionSignals, SignalGroup,
|
||||
SignalInstance, SignalReport, SignalType, TurnMetrics,
|
||||
};
|
||||
|
||||
fn report_with_escalation() -> SignalReport {
|
||||
let mut diseng = SignalGroup::new("disengagement");
|
||||
diseng.add_signal(SignalInstance::new(
|
||||
SignalType::DisengagementEscalation,
|
||||
3,
|
||||
"get me a human",
|
||||
));
|
||||
SignalReport {
|
||||
interaction: InteractionSignals {
|
||||
disengagement: diseng,
|
||||
..InteractionSignals::default()
|
||||
},
|
||||
execution: ExecutionSignals::default(),
|
||||
environment: EnvironmentSignals::default(),
|
||||
overall_quality: InteractionQuality::Severe,
|
||||
quality_score: 0.0,
|
||||
turn_metrics: TurnMetrics {
|
||||
total_turns: 3,
|
||||
user_turns: 2,
|
||||
assistant_turns: 1,
|
||||
is_dragging: false,
|
||||
efficiency_score: 1.0,
|
||||
},
|
||||
summary: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_concerning_flags_disengagement() {
|
||||
let r = report_with_escalation();
|
||||
assert!(is_concerning(&r));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn count_of_returns_per_type_count() {
|
||||
let r = report_with_escalation();
|
||||
assert_eq!(count_of(&r, SignalType::DisengagementEscalation), 1);
|
||||
assert_eq!(count_of(&r, SignalType::DisengagementNegativeStance), 0);
|
||||
}
|
||||
}
|
||||
431
crates/brightstaff/src/signals/schemas.rs
Normal file
431
crates/brightstaff/src/signals/schemas.rs
Normal file
|
|
@ -0,0 +1,431 @@
|
|||
//! Data shapes for the signal analyzer.
|
||||
//!
|
||||
//! Mirrors `signals/schemas.py` from the reference implementation. Where the
|
||||
//! Python library exposes a `Dict[str, SignalGroup]` partitioned by category,
|
||||
//! the Rust port uses strongly-typed sub-structs (`InteractionSignals`,
|
||||
//! `ExecutionSignals`, `EnvironmentSignals`) for the same partitioning.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Hierarchical signal type. The 20 leaf variants mirror the paper taxonomy
|
||||
/// and the Python reference's `SignalType` string enum.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum SignalType {
|
||||
// Interaction > Misalignment
|
||||
MisalignmentCorrection,
|
||||
MisalignmentRephrase,
|
||||
MisalignmentClarification,
|
||||
|
||||
// Interaction > Stagnation
|
||||
StagnationDragging,
|
||||
StagnationRepetition,
|
||||
|
||||
// Interaction > Disengagement
|
||||
DisengagementEscalation,
|
||||
DisengagementQuit,
|
||||
DisengagementNegativeStance,
|
||||
|
||||
// Interaction > Satisfaction
|
||||
SatisfactionGratitude,
|
||||
SatisfactionConfirmation,
|
||||
SatisfactionSuccess,
|
||||
|
||||
// Execution > Failure
|
||||
ExecutionFailureInvalidArgs,
|
||||
ExecutionFailureBadQuery,
|
||||
ExecutionFailureToolNotFound,
|
||||
ExecutionFailureAuthMisuse,
|
||||
ExecutionFailureStateError,
|
||||
|
||||
// Execution > Loops
|
||||
ExecutionLoopsRetry,
|
||||
ExecutionLoopsParameterDrift,
|
||||
ExecutionLoopsOscillation,
|
||||
|
||||
// Environment > Exhaustion
|
||||
EnvironmentExhaustionApiError,
|
||||
EnvironmentExhaustionTimeout,
|
||||
EnvironmentExhaustionRateLimit,
|
||||
EnvironmentExhaustionNetwork,
|
||||
EnvironmentExhaustionMalformed,
|
||||
EnvironmentExhaustionContextOverflow,
|
||||
}
|
||||
|
||||
impl SignalType {
|
||||
/// Dotted hierarchical string identifier, e.g.
|
||||
/// `"interaction.misalignment.correction"`. Matches the Python reference's
|
||||
/// `SignalType` enum *value* strings byte-for-byte.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
SignalType::MisalignmentCorrection => "interaction.misalignment.correction",
|
||||
SignalType::MisalignmentRephrase => "interaction.misalignment.rephrase",
|
||||
SignalType::MisalignmentClarification => "interaction.misalignment.clarification",
|
||||
SignalType::StagnationDragging => "interaction.stagnation.dragging",
|
||||
SignalType::StagnationRepetition => "interaction.stagnation.repetition",
|
||||
SignalType::DisengagementEscalation => "interaction.disengagement.escalation",
|
||||
SignalType::DisengagementQuit => "interaction.disengagement.quit",
|
||||
SignalType::DisengagementNegativeStance => "interaction.disengagement.negative_stance",
|
||||
SignalType::SatisfactionGratitude => "interaction.satisfaction.gratitude",
|
||||
SignalType::SatisfactionConfirmation => "interaction.satisfaction.confirmation",
|
||||
SignalType::SatisfactionSuccess => "interaction.satisfaction.success",
|
||||
SignalType::ExecutionFailureInvalidArgs => "execution.failure.invalid_args",
|
||||
SignalType::ExecutionFailureBadQuery => "execution.failure.bad_query",
|
||||
SignalType::ExecutionFailureToolNotFound => "execution.failure.tool_not_found",
|
||||
SignalType::ExecutionFailureAuthMisuse => "execution.failure.auth_misuse",
|
||||
SignalType::ExecutionFailureStateError => "execution.failure.state_error",
|
||||
SignalType::ExecutionLoopsRetry => "execution.loops.retry",
|
||||
SignalType::ExecutionLoopsParameterDrift => "execution.loops.parameter_drift",
|
||||
SignalType::ExecutionLoopsOscillation => "execution.loops.oscillation",
|
||||
SignalType::EnvironmentExhaustionApiError => "environment.exhaustion.api_error",
|
||||
SignalType::EnvironmentExhaustionTimeout => "environment.exhaustion.timeout",
|
||||
SignalType::EnvironmentExhaustionRateLimit => "environment.exhaustion.rate_limit",
|
||||
SignalType::EnvironmentExhaustionNetwork => "environment.exhaustion.network",
|
||||
SignalType::EnvironmentExhaustionMalformed => {
|
||||
"environment.exhaustion.malformed_response"
|
||||
}
|
||||
SignalType::EnvironmentExhaustionContextOverflow => {
|
||||
"environment.exhaustion.context_overflow"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer(&self) -> SignalLayer {
|
||||
match self {
|
||||
SignalType::MisalignmentCorrection
|
||||
| SignalType::MisalignmentRephrase
|
||||
| SignalType::MisalignmentClarification
|
||||
| SignalType::StagnationDragging
|
||||
| SignalType::StagnationRepetition
|
||||
| SignalType::DisengagementEscalation
|
||||
| SignalType::DisengagementQuit
|
||||
| SignalType::DisengagementNegativeStance
|
||||
| SignalType::SatisfactionGratitude
|
||||
| SignalType::SatisfactionConfirmation
|
||||
| SignalType::SatisfactionSuccess => SignalLayer::Interaction,
|
||||
SignalType::ExecutionFailureInvalidArgs
|
||||
| SignalType::ExecutionFailureBadQuery
|
||||
| SignalType::ExecutionFailureToolNotFound
|
||||
| SignalType::ExecutionFailureAuthMisuse
|
||||
| SignalType::ExecutionFailureStateError
|
||||
| SignalType::ExecutionLoopsRetry
|
||||
| SignalType::ExecutionLoopsParameterDrift
|
||||
| SignalType::ExecutionLoopsOscillation => SignalLayer::Execution,
|
||||
SignalType::EnvironmentExhaustionApiError
|
||||
| SignalType::EnvironmentExhaustionTimeout
|
||||
| SignalType::EnvironmentExhaustionRateLimit
|
||||
| SignalType::EnvironmentExhaustionNetwork
|
||||
| SignalType::EnvironmentExhaustionMalformed
|
||||
| SignalType::EnvironmentExhaustionContextOverflow => SignalLayer::Environment,
|
||||
}
|
||||
}
|
||||
|
||||
/// Category name within the layer (e.g. `"misalignment"`, `"failure"`).
|
||||
pub fn category(&self) -> &'static str {
|
||||
// Strip the layer prefix and take everything before the next dot.
|
||||
let s = self.as_str();
|
||||
let after_layer = s.split_once('.').map(|(_, rest)| rest).unwrap_or(s);
|
||||
after_layer
|
||||
.split_once('.')
|
||||
.map(|(c, _)| c)
|
||||
.unwrap_or(after_layer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum SignalLayer {
|
||||
Interaction,
|
||||
Execution,
|
||||
Environment,
|
||||
}
|
||||
|
||||
impl SignalLayer {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
SignalLayer::Interaction => "interaction",
|
||||
SignalLayer::Execution => "execution",
|
||||
SignalLayer::Environment => "environment",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Overall quality assessment for an agent interaction session.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum InteractionQuality {
|
||||
Excellent,
|
||||
Good,
|
||||
Neutral,
|
||||
Poor,
|
||||
Severe,
|
||||
}
|
||||
|
||||
impl InteractionQuality {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
InteractionQuality::Excellent => "excellent",
|
||||
InteractionQuality::Good => "good",
|
||||
InteractionQuality::Neutral => "neutral",
|
||||
InteractionQuality::Poor => "poor",
|
||||
InteractionQuality::Severe => "severe",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single detected signal instance.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SignalInstance {
|
||||
pub signal_type: SignalType,
|
||||
/// Absolute index into the original conversation `Vec<Message>`.
|
||||
pub message_index: usize,
|
||||
pub snippet: String,
|
||||
pub confidence: f32,
|
||||
/// Free-form metadata payload mirroring the Python `Dict[str, Any]`.
|
||||
/// Stored as a JSON object so we can faithfully reproduce the reference's
|
||||
/// flexible per-detector metadata.
|
||||
#[serde(default)]
|
||||
pub metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
impl SignalInstance {
|
||||
pub fn new(signal_type: SignalType, message_index: usize, snippet: impl Into<String>) -> Self {
|
||||
Self {
|
||||
signal_type,
|
||||
message_index,
|
||||
snippet: snippet.into(),
|
||||
confidence: 1.0,
|
||||
metadata: serde_json::Value::Object(serde_json::Map::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_confidence(mut self, c: f32) -> Self {
|
||||
self.confidence = c;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, m: serde_json::Value) -> Self {
|
||||
self.metadata = m;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregated signals for a specific category.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SignalGroup {
|
||||
pub category: String,
|
||||
pub count: usize,
|
||||
pub signals: Vec<SignalInstance>,
|
||||
/// Severity level (0-3: none, mild, moderate, severe).
|
||||
pub severity: u8,
|
||||
}
|
||||
|
||||
impl SignalGroup {
|
||||
pub fn new(category: impl Into<String>) -> Self {
|
||||
Self {
|
||||
category: category.into(),
|
||||
count: 0,
|
||||
signals: Vec::new(),
|
||||
severity: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_signal(&mut self, signal: SignalInstance) {
|
||||
self.signals.push(signal);
|
||||
self.count = self.signals.len();
|
||||
self.update_severity();
|
||||
}
|
||||
|
||||
fn update_severity(&mut self) {
|
||||
self.severity = match self.count {
|
||||
0 => 0,
|
||||
1..=2 => 1,
|
||||
3..=4 => 2,
|
||||
_ => 3,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Turn count and efficiency metrics, used by stagnation.dragging.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct TurnMetrics {
|
||||
pub total_turns: usize,
|
||||
pub user_turns: usize,
|
||||
pub assistant_turns: usize,
|
||||
pub is_dragging: bool,
|
||||
pub efficiency_score: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InteractionSignals {
|
||||
pub misalignment: SignalGroup,
|
||||
pub stagnation: SignalGroup,
|
||||
pub disengagement: SignalGroup,
|
||||
pub satisfaction: SignalGroup,
|
||||
}
|
||||
|
||||
impl Default for InteractionSignals {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
misalignment: SignalGroup::new("misalignment"),
|
||||
stagnation: SignalGroup::new("stagnation"),
|
||||
disengagement: SignalGroup::new("disengagement"),
|
||||
satisfaction: SignalGroup::new("satisfaction"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InteractionSignals {
|
||||
/// Ratio of misalignment instances to user turns. Used as a quality
|
||||
/// scoring input and as a threshold for the "high misalignment rate"
|
||||
/// summary callout. Mirrors `misalignment.count / max(user_turns, 1)`
|
||||
/// from the Python reference's `_assess_quality` and `_generate_summary`.
|
||||
pub fn misalignment_ratio(&self, user_turns: usize) -> f32 {
|
||||
let denom = user_turns.max(1) as f32;
|
||||
self.misalignment.count as f32 / denom
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionSignals {
|
||||
pub failure: SignalGroup,
|
||||
pub loops: SignalGroup,
|
||||
}
|
||||
|
||||
impl Default for ExecutionSignals {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
failure: SignalGroup::new("failure"),
|
||||
loops: SignalGroup::new("loops"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EnvironmentSignals {
|
||||
pub exhaustion: SignalGroup,
|
||||
}
|
||||
|
||||
impl Default for EnvironmentSignals {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
exhaustion: SignalGroup::new("exhaustion"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete signal analysis report for a conversation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SignalReport {
|
||||
pub interaction: InteractionSignals,
|
||||
pub execution: ExecutionSignals,
|
||||
pub environment: EnvironmentSignals,
|
||||
pub overall_quality: InteractionQuality,
|
||||
pub quality_score: f32,
|
||||
pub turn_metrics: TurnMetrics,
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
impl Default for SignalReport {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
interaction: InteractionSignals::default(),
|
||||
execution: ExecutionSignals::default(),
|
||||
environment: EnvironmentSignals::default(),
|
||||
overall_quality: InteractionQuality::Neutral,
|
||||
quality_score: 50.0,
|
||||
turn_metrics: TurnMetrics::default(),
|
||||
summary: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SignalReport {
|
||||
/// Iterate over every `SignalInstance` across all layers and groups.
|
||||
pub fn iter_signals(&self) -> impl Iterator<Item = &SignalInstance> {
|
||||
self.interaction
|
||||
.misalignment
|
||||
.signals
|
||||
.iter()
|
||||
.chain(self.interaction.stagnation.signals.iter())
|
||||
.chain(self.interaction.disengagement.signals.iter())
|
||||
.chain(self.interaction.satisfaction.signals.iter())
|
||||
.chain(self.execution.failure.signals.iter())
|
||||
.chain(self.execution.loops.signals.iter())
|
||||
.chain(self.environment.exhaustion.signals.iter())
|
||||
}
|
||||
|
||||
pub fn has_signal_type(&self, t: SignalType) -> bool {
|
||||
self.iter_signals().any(|s| s.signal_type == t)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn signal_type_strings_match_paper_taxonomy() {
|
||||
assert_eq!(
|
||||
SignalType::MisalignmentCorrection.as_str(),
|
||||
"interaction.misalignment.correction"
|
||||
);
|
||||
assert_eq!(
|
||||
SignalType::ExecutionFailureInvalidArgs.as_str(),
|
||||
"execution.failure.invalid_args"
|
||||
);
|
||||
assert_eq!(
|
||||
SignalType::EnvironmentExhaustionMalformed.as_str(),
|
||||
"environment.exhaustion.malformed_response"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_type_layer_and_category() {
|
||||
assert_eq!(
|
||||
SignalType::MisalignmentRephrase.layer(),
|
||||
SignalLayer::Interaction
|
||||
);
|
||||
assert_eq!(SignalType::MisalignmentRephrase.category(), "misalignment");
|
||||
assert_eq!(
|
||||
SignalType::ExecutionLoopsRetry.layer(),
|
||||
SignalLayer::Execution
|
||||
);
|
||||
assert_eq!(SignalType::ExecutionLoopsRetry.category(), "loops");
|
||||
assert_eq!(
|
||||
SignalType::EnvironmentExhaustionTimeout.layer(),
|
||||
SignalLayer::Environment
|
||||
);
|
||||
assert_eq!(
|
||||
SignalType::EnvironmentExhaustionTimeout.category(),
|
||||
"exhaustion"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_group_severity_buckets_match_python() {
|
||||
let mut g = SignalGroup::new("misalignment");
|
||||
assert_eq!(g.severity, 0);
|
||||
for n in 1..=2 {
|
||||
g.add_signal(SignalInstance::new(
|
||||
SignalType::MisalignmentCorrection,
|
||||
n,
|
||||
"x",
|
||||
));
|
||||
}
|
||||
assert_eq!(g.severity, 1);
|
||||
for n in 3..=4 {
|
||||
g.add_signal(SignalInstance::new(
|
||||
SignalType::MisalignmentCorrection,
|
||||
n,
|
||||
"x",
|
||||
));
|
||||
}
|
||||
assert_eq!(g.severity, 2);
|
||||
for n in 5..=6 {
|
||||
g.add_signal(SignalInstance::new(
|
||||
SignalType::MisalignmentCorrection,
|
||||
n,
|
||||
"x",
|
||||
));
|
||||
}
|
||||
assert_eq!(g.severity, 3);
|
||||
}
|
||||
}
|
||||
401
crates/brightstaff/src/signals/text_processing.rs
Normal file
401
crates/brightstaff/src/signals/text_processing.rs
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
//! Text normalization and similarity primitives.
|
||||
//!
|
||||
//! Direct Rust port of `signals/text_processing.py` from the reference. The
|
||||
//! shapes (`NormalizedMessage`, `NormalizedPattern`) and similarity formulas
|
||||
//! match the Python implementation exactly so that pattern matching produces
|
||||
//! the same results on the same inputs.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Size of character n-grams used for fuzzy similarity (3 = trigrams).
|
||||
pub const NGRAM_SIZE: usize = 3;
|
||||
|
||||
const PUNCT_TRIM: &[char] = &[
|
||||
'!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=',
|
||||
'>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~',
|
||||
];
|
||||
|
||||
/// Pre-processed message with normalized text and tokens for efficient matching.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct NormalizedMessage {
|
||||
pub raw: String,
|
||||
pub tokens: Vec<String>,
|
||||
pub token_set: HashSet<String>,
|
||||
pub bigram_set: HashSet<String>,
|
||||
pub char_ngram_set: HashSet<String>,
|
||||
pub token_frequency: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
impl NormalizedMessage {
|
||||
/// Create a normalized message from raw text. Mirrors
|
||||
/// `NormalizedMessage.from_text` in the reference, including the
|
||||
/// head-20%/tail-80% truncation strategy when text exceeds `max_length`.
|
||||
pub fn from_text(text: &str, max_length: usize) -> Self {
|
||||
let char_count = text.chars().count();
|
||||
|
||||
let raw: String = if char_count <= max_length {
|
||||
text.to_string()
|
||||
} else {
|
||||
let head_len = max_length / 5;
|
||||
// Reserve one char for the joining space.
|
||||
let tail_len = max_length.saturating_sub(head_len + 1);
|
||||
let head: String = text.chars().take(head_len).collect();
|
||||
let tail: String = text
|
||||
.chars()
|
||||
.skip(char_count.saturating_sub(tail_len))
|
||||
.collect();
|
||||
format!("{} {}", head, tail)
|
||||
};
|
||||
|
||||
// Normalize unicode punctuation to ASCII equivalents.
|
||||
let normalized_unicode = raw
|
||||
.replace(['\u{2019}', '\u{2018}'], "'")
|
||||
.replace(['\u{201c}', '\u{201d}'], "\"")
|
||||
.replace(['\u{2013}', '\u{2014}'], "-");
|
||||
|
||||
// Lowercase + collapse whitespace (matches Python's `" ".join(s.split())`).
|
||||
let normalized: String = normalized_unicode
|
||||
.to_lowercase()
|
||||
.split_whitespace()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
let mut tokens: Vec<String> = Vec::new();
|
||||
for word in normalized.split_whitespace() {
|
||||
let stripped: String = word.trim_matches(PUNCT_TRIM).to_string();
|
||||
if !stripped.is_empty() {
|
||||
tokens.push(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
let token_set: HashSet<String> = tokens.iter().cloned().collect();
|
||||
|
||||
let mut bigram_set: HashSet<String> = HashSet::new();
|
||||
for i in 0..tokens.len().saturating_sub(1) {
|
||||
bigram_set.insert(format!("{} {}", tokens[i], tokens[i + 1]));
|
||||
}
|
||||
|
||||
let tokens_text = tokens.join(" ");
|
||||
let char_ngram_set = char_ngrams(&tokens_text, NGRAM_SIZE);
|
||||
|
||||
let mut token_frequency: HashMap<String, usize> = HashMap::new();
|
||||
for t in &tokens {
|
||||
*token_frequency.entry(t.clone()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
Self {
|
||||
raw,
|
||||
tokens,
|
||||
token_set,
|
||||
bigram_set,
|
||||
char_ngram_set,
|
||||
token_frequency,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains_token(&self, token: &str) -> bool {
|
||||
self.token_set.contains(token)
|
||||
}
|
||||
|
||||
pub fn contains_phrase(&self, phrase: &str) -> bool {
|
||||
let phrase_tokens: Vec<&str> = phrase.split_whitespace().collect();
|
||||
if phrase_tokens.is_empty() {
|
||||
return false;
|
||||
}
|
||||
if phrase_tokens.len() == 1 {
|
||||
return self.contains_token(phrase_tokens[0]);
|
||||
}
|
||||
if phrase_tokens.len() > self.tokens.len() {
|
||||
return false;
|
||||
}
|
||||
let n = phrase_tokens.len();
|
||||
for i in 0..=self.tokens.len() - n {
|
||||
if self.tokens[i..i + n]
|
||||
.iter()
|
||||
.zip(phrase_tokens.iter())
|
||||
.all(|(a, b)| a == b)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Character n-gram (Jaccard) similarity vs another normalized message.
|
||||
pub fn ngram_similarity_with_message(&self, other: &NormalizedMessage) -> f32 {
|
||||
jaccard(&self.char_ngram_set, &other.char_ngram_set)
|
||||
}
|
||||
|
||||
/// Character n-gram (Jaccard) similarity vs a raw pattern string.
|
||||
pub fn ngram_similarity_with_pattern(&self, pattern: &str) -> f32 {
|
||||
let normalized = strip_non_word_chars(&pattern.to_lowercase());
|
||||
let pattern_ngrams = char_ngrams(&normalized, NGRAM_SIZE);
|
||||
jaccard(&self.char_ngram_set, &pattern_ngrams)
|
||||
}
|
||||
|
||||
/// Fraction of pattern's ngrams contained in this message's ngram set.
|
||||
pub fn char_ngram_containment(&self, pattern: &str) -> f32 {
|
||||
let normalized = strip_non_word_chars(&pattern.to_lowercase());
|
||||
let pattern_ngrams = char_ngrams(&normalized, NGRAM_SIZE);
|
||||
if pattern_ngrams.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let contained = pattern_ngrams
|
||||
.iter()
|
||||
.filter(|ng| self.char_ngram_set.contains(*ng))
|
||||
.count();
|
||||
contained as f32 / pattern_ngrams.len() as f32
|
||||
}
|
||||
|
||||
/// Token-frequency cosine similarity vs a raw pattern string.
|
||||
pub fn token_cosine_similarity(&self, pattern: &str) -> f32 {
|
||||
let mut pattern_freq: HashMap<String, usize> = HashMap::new();
|
||||
for word in pattern.to_lowercase().split_whitespace() {
|
||||
let stripped = word.trim_matches(PUNCT_TRIM);
|
||||
if !stripped.is_empty() {
|
||||
*pattern_freq.entry(stripped.to_string()).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
cosine_freq(&self.token_frequency, &pattern_freq)
|
||||
}
|
||||
|
||||
/// Layered match against a pre-normalized pattern. Mirrors
|
||||
/// `matches_normalized_pattern` from the reference: exact phrase ->
|
||||
/// char-ngram Jaccard -> token cosine.
|
||||
pub fn matches_normalized_pattern(
|
||||
&self,
|
||||
pattern: &NormalizedPattern,
|
||||
char_ngram_threshold: f32,
|
||||
token_cosine_threshold: f32,
|
||||
) -> bool {
|
||||
// Layer 0: exact phrase match using pre-tokenized message.
|
||||
let plen = pattern.tokens.len();
|
||||
let slen = self.tokens.len();
|
||||
if plen > 0 && plen <= slen {
|
||||
for i in 0..=slen - plen {
|
||||
if self.tokens[i..i + plen] == pattern.tokens[..] {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layer 1: character n-gram Jaccard similarity.
|
||||
if !self.char_ngram_set.is_empty() && !pattern.char_ngram_set.is_empty() {
|
||||
let inter = self
|
||||
.char_ngram_set
|
||||
.intersection(&pattern.char_ngram_set)
|
||||
.count();
|
||||
let union = self.char_ngram_set.union(&pattern.char_ngram_set).count();
|
||||
if union > 0 {
|
||||
let sim = inter as f32 / union as f32;
|
||||
if sim >= char_ngram_threshold {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layer 2: token frequency cosine similarity.
|
||||
if !self.token_frequency.is_empty() && !pattern.token_frequency.is_empty() {
|
||||
let sim = cosine_freq(&self.token_frequency, &pattern.token_frequency);
|
||||
if sim >= token_cosine_threshold {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-processed pattern with normalized text and pre-computed n-grams/tokens.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct NormalizedPattern {
|
||||
pub raw: String,
|
||||
pub tokens: Vec<String>,
|
||||
pub char_ngram_set: HashSet<String>,
|
||||
pub token_frequency: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
impl NormalizedPattern {
|
||||
pub fn from_text(pattern: &str) -> Self {
|
||||
let normalized = pattern
|
||||
.to_lowercase()
|
||||
.replace(['\u{2019}', '\u{2018}'], "'")
|
||||
.replace(['\u{201c}', '\u{201d}'], "\"")
|
||||
.replace(['\u{2013}', '\u{2014}'], "-");
|
||||
let normalized: String = normalized.split_whitespace().collect::<Vec<_>>().join(" ");
|
||||
|
||||
// Tokenize the same way as NormalizedMessage (trim boundary punctuation,
|
||||
// keep internal punctuation).
|
||||
let mut tokens: Vec<String> = Vec::new();
|
||||
for word in normalized.split_whitespace() {
|
||||
let stripped = word.trim_matches(PUNCT_TRIM);
|
||||
if !stripped.is_empty() {
|
||||
tokens.push(stripped.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// For ngrams + cosine, strip ALL punctuation (matches Python's
|
||||
// `re.sub(r"[^\w\s]", "", normalized)`).
|
||||
let normalized_for_ngrams = strip_non_word_chars(&normalized);
|
||||
let char_ngram_set = char_ngrams(&normalized_for_ngrams, NGRAM_SIZE);
|
||||
|
||||
let tokens_no_punct: Vec<&str> = normalized_for_ngrams.split_whitespace().collect();
|
||||
let mut token_frequency: HashMap<String, usize> = HashMap::new();
|
||||
for t in &tokens_no_punct {
|
||||
*token_frequency.entry((*t).to_string()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
Self {
|
||||
raw: pattern.to_string(),
|
||||
tokens,
|
||||
char_ngram_set,
|
||||
token_frequency,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience: normalize a list of raw pattern strings into `NormalizedPattern`s.
|
||||
pub fn normalize_patterns(patterns: &[&str]) -> Vec<NormalizedPattern> {
|
||||
patterns
|
||||
.iter()
|
||||
.map(|p| NormalizedPattern::from_text(p))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Similarity primitives
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn char_ngrams(s: &str, n: usize) -> HashSet<String> {
|
||||
// Python iterates by character index, not byte; mirror that with .chars().
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
let mut out: HashSet<String> = HashSet::new();
|
||||
if chars.len() < n {
|
||||
return out;
|
||||
}
|
||||
for i in 0..=chars.len() - n {
|
||||
out.insert(chars[i..i + n].iter().collect());
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
|
||||
if a.is_empty() && b.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
if a.is_empty() || b.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let inter = a.intersection(b).count();
|
||||
let union = a.union(b).count();
|
||||
if union == 0 {
|
||||
0.0
|
||||
} else {
|
||||
inter as f32 / union as f32
|
||||
}
|
||||
}
|
||||
|
||||
fn cosine_freq(a: &HashMap<String, usize>, b: &HashMap<String, usize>) -> f32 {
|
||||
if a.is_empty() && b.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
if a.is_empty() || b.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let mut dot: f64 = 0.0;
|
||||
let mut n1_sq: f64 = 0.0;
|
||||
let mut n2_sq: f64 = 0.0;
|
||||
for (token, &freq2) in b {
|
||||
let freq1 = *a.get(token).unwrap_or(&0);
|
||||
dot += (freq1 * freq2) as f64;
|
||||
n2_sq += (freq2 * freq2) as f64;
|
||||
}
|
||||
for &freq1 in a.values() {
|
||||
n1_sq += (freq1 * freq1) as f64;
|
||||
}
|
||||
let n1 = n1_sq.sqrt();
|
||||
let n2 = n2_sq.sqrt();
|
||||
if n1 == 0.0 || n2 == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
(dot / (n1 * n2)) as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Python equivalent: `re.sub(r"[^\w\s]", "", text)` followed by whitespace
|
||||
/// collapse. Python's `\w` is `[A-Za-z0-9_]` plus unicode word characters; we
|
||||
/// use Rust's `char::is_alphanumeric()` plus `_` for an equivalent definition.
|
||||
fn strip_non_word_chars(text: &str) -> String {
|
||||
let mut out = String::with_capacity(text.len());
|
||||
for c in text.chars() {
|
||||
if c.is_alphanumeric() || c == '_' || c.is_whitespace() {
|
||||
out.push(c);
|
||||
}
|
||||
}
|
||||
out.split_whitespace().collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn normalize_lowercases_and_strips_punctuation() {
|
||||
let m = NormalizedMessage::from_text("Hello, World!", 2000);
|
||||
assert_eq!(m.tokens, vec!["hello".to_string(), "world".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalizes_smart_quotes() {
|
||||
let m = NormalizedMessage::from_text("don\u{2019}t", 2000);
|
||||
assert!(m.tokens.contains(&"don't".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_long_text_with_head_tail() {
|
||||
let long = "a".repeat(3000);
|
||||
let m = NormalizedMessage::from_text(&long, 2000);
|
||||
// raw should be ~ 2000 chars (head + space + tail)
|
||||
assert!(m.raw.chars().count() <= 2001);
|
||||
assert!(m.raw.starts_with("aa"));
|
||||
assert!(m.raw.ends_with("aa"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn contains_phrase_matches_consecutive_tokens() {
|
||||
let m = NormalizedMessage::from_text("I think this is great work", 2000);
|
||||
assert!(m.contains_phrase("this is great"));
|
||||
assert!(!m.contains_phrase("great this"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_pattern_via_exact_phrase() {
|
||||
let m = NormalizedMessage::from_text("No, I meant the second one", 2000);
|
||||
let p = NormalizedPattern::from_text("no i meant");
|
||||
assert!(m.matches_normalized_pattern(&p, 0.65, 0.6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_pattern_via_char_ngram_fuzziness() {
|
||||
// Typo in "meant" -> "ment" so layer 0 (exact phrase) cannot match,
|
||||
// forcing the matcher to fall back to layer 1 (char n-gram Jaccard).
|
||||
let m = NormalizedMessage::from_text("No I ment", 2000);
|
||||
let p = NormalizedPattern::from_text("no i meant");
|
||||
assert!(m.matches_normalized_pattern(&p, 0.4, 0.6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jaccard_identical_sets_is_one() {
|
||||
let a: HashSet<String> = ["abc", "bcd"].iter().map(|s| s.to_string()).collect();
|
||||
assert!((jaccard(&a, &a) - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_freq_orthogonal_is_zero() {
|
||||
let mut a: HashMap<String, usize> = HashMap::new();
|
||||
a.insert("hello".to_string(), 1);
|
||||
let mut b: HashMap<String, usize> = HashMap::new();
|
||||
b.insert("world".to_string(), 1);
|
||||
assert_eq!(cosine_freq(&a, &b), 0.0);
|
||||
}
|
||||
}
|
||||
|
|
@ -16,10 +16,134 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
|
|||
use crate::handlers::agents::pipeline::{PipelineError, PipelineProcessor};
|
||||
|
||||
const STREAM_BUFFER_SIZE: usize = 16;
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
/// Cap on accumulated response bytes kept for usage extraction.
|
||||
/// Most chat responses are well under this; pathological ones are dropped without
|
||||
/// affecting pass-through streaming to the client.
|
||||
const USAGE_BUFFER_MAX: usize = 2 * 1024 * 1024;
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::signals::otel::emit_signals_to_span;
|
||||
use crate::signals::{SignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Parsed usage + resolved-model details from a provider response.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct ExtractedUsage {
|
||||
prompt_tokens: Option<i64>,
|
||||
completion_tokens: Option<i64>,
|
||||
total_tokens: Option<i64>,
|
||||
cached_input_tokens: Option<i64>,
|
||||
cache_creation_tokens: Option<i64>,
|
||||
reasoning_tokens: Option<i64>,
|
||||
/// The model the upstream actually used. For router aliases (e.g.
|
||||
/// `router:software-engineering`), this differs from the request model.
|
||||
resolved_model: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtractedUsage {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.prompt_tokens.is_none()
|
||||
&& self.completion_tokens.is_none()
|
||||
&& self.total_tokens.is_none()
|
||||
&& self.resolved_model.is_none()
|
||||
}
|
||||
|
||||
fn from_json(value: &serde_json::Value) -> Self {
|
||||
let mut out = Self::default();
|
||||
if let Some(model) = value.get("model").and_then(|v| v.as_str()) {
|
||||
if !model.is_empty() {
|
||||
out.resolved_model = Some(model.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(u) = value.get("usage") {
|
||||
// OpenAI-shape usage
|
||||
out.prompt_tokens = u.get("prompt_tokens").and_then(|v| v.as_i64());
|
||||
out.completion_tokens = u.get("completion_tokens").and_then(|v| v.as_i64());
|
||||
out.total_tokens = u.get("total_tokens").and_then(|v| v.as_i64());
|
||||
out.cached_input_tokens = u
|
||||
.get("prompt_tokens_details")
|
||||
.and_then(|d| d.get("cached_tokens"))
|
||||
.and_then(|v| v.as_i64());
|
||||
out.reasoning_tokens = u
|
||||
.get("completion_tokens_details")
|
||||
.and_then(|d| d.get("reasoning_tokens"))
|
||||
.and_then(|v| v.as_i64());
|
||||
|
||||
// Anthropic-shape fallbacks
|
||||
if out.prompt_tokens.is_none() {
|
||||
out.prompt_tokens = u.get("input_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.completion_tokens.is_none() {
|
||||
out.completion_tokens = u.get("output_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.total_tokens.is_none() {
|
||||
if let (Some(p), Some(c)) = (out.prompt_tokens, out.completion_tokens) {
|
||||
out.total_tokens = Some(p + c);
|
||||
}
|
||||
}
|
||||
if out.cached_input_tokens.is_none() {
|
||||
out.cached_input_tokens = u.get("cache_read_input_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.cached_input_tokens.is_none() {
|
||||
out.cached_input_tokens =
|
||||
u.get("cached_content_token_count").and_then(|v| v.as_i64());
|
||||
}
|
||||
out.cache_creation_tokens = u
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_i64());
|
||||
if out.reasoning_tokens.is_none() {
|
||||
out.reasoning_tokens = u.get("thoughts_token_count").and_then(|v| v.as_i64());
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to pull usage out of an accumulated response body.
|
||||
/// Handles both a single JSON object (non-streaming) and SSE streams where the
|
||||
/// final `data: {...}` event carries the `usage` field.
|
||||
fn extract_usage_from_bytes(buf: &[u8]) -> ExtractedUsage {
|
||||
if buf.is_empty() {
|
||||
return ExtractedUsage::default();
|
||||
}
|
||||
|
||||
// Fast path: full-body JSON (non-streaming).
|
||||
if let Ok(value) = serde_json::from_slice::<serde_json::Value>(buf) {
|
||||
let u = ExtractedUsage::from_json(&value);
|
||||
if !u.is_empty() {
|
||||
return u;
|
||||
}
|
||||
}
|
||||
|
||||
// SSE path: scan from the end for a `data:` line containing a usage object.
|
||||
let text = match std::str::from_utf8(buf) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return ExtractedUsage::default(),
|
||||
};
|
||||
for line in text.lines().rev() {
|
||||
let trimmed = line.trim_start();
|
||||
let payload = match trimmed.strip_prefix("data:") {
|
||||
Some(p) => p.trim_start(),
|
||||
None => continue,
|
||||
};
|
||||
if payload == "[DONE]" || payload.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !payload.contains("\"usage\"") {
|
||||
continue;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(payload) {
|
||||
let u = ExtractedUsage::from_json(&value);
|
||||
if !u.is_empty() {
|
||||
return u;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ExtractedUsage::default()
|
||||
}
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
pub trait StreamProcessor: Send + 'static {
|
||||
|
|
@ -51,6 +175,18 @@ impl StreamProcessor for Box<dyn StreamProcessor> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Optional Prometheus-metric context for an LLM upstream call. When present,
|
||||
/// [`ObservableStreamProcessor`] emits `brightstaff_llm_*` metrics at
|
||||
/// first-byte / complete / error callbacks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlmMetricsCtx {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
/// HTTP status of the upstream response. Used to pick `status_class` and
|
||||
/// `error_class` on `on_complete`.
|
||||
pub upstream_status: u16,
|
||||
}
|
||||
|
||||
/// A processor that tracks streaming metrics
|
||||
pub struct ObservableStreamProcessor {
|
||||
service_name: String,
|
||||
|
|
@ -60,6 +196,12 @@ pub struct ObservableStreamProcessor {
|
|||
start_time: Instant,
|
||||
time_to_first_token: Option<u128>,
|
||||
messages: Option<Vec<Message>>,
|
||||
/// Accumulated response bytes used only for best-effort usage extraction
|
||||
/// on `on_complete`. Capped at `USAGE_BUFFER_MAX`; excess chunks are dropped
|
||||
/// from the buffer (they still pass through to the client).
|
||||
response_buffer: Vec<u8>,
|
||||
llm_metrics: Option<LlmMetricsCtx>,
|
||||
metrics_recorded: bool,
|
||||
}
|
||||
|
||||
impl ObservableStreamProcessor {
|
||||
|
|
@ -93,21 +235,42 @@ impl ObservableStreamProcessor {
|
|||
start_time,
|
||||
time_to_first_token: None,
|
||||
messages,
|
||||
response_buffer: Vec::new(),
|
||||
llm_metrics: None,
|
||||
metrics_recorded: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach LLM upstream metric context so the processor emits
|
||||
/// `brightstaff_llm_*` metrics on first-byte / complete / error.
|
||||
pub fn with_llm_metrics(mut self, ctx: LlmMetricsCtx) -> Self {
|
||||
self.llm_metrics = Some(ctx);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamProcessor for ObservableStreamProcessor {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
self.total_bytes += chunk.len();
|
||||
self.chunk_count += 1;
|
||||
// Accumulate for best-effort usage extraction; drop further chunks once
|
||||
// the cap is reached so we don't retain huge response bodies in memory.
|
||||
if self.response_buffer.len() < USAGE_BUFFER_MAX {
|
||||
let remaining = USAGE_BUFFER_MAX - self.response_buffer.len();
|
||||
let take = chunk.len().min(remaining);
|
||||
self.response_buffer.extend_from_slice(&chunk[..take]);
|
||||
}
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
fn on_first_bytes(&mut self) {
|
||||
// Record time to first token (only for streaming)
|
||||
if self.time_to_first_token.is_none() {
|
||||
self.time_to_first_token = Some(self.start_time.elapsed().as_millis());
|
||||
let elapsed = self.start_time.elapsed();
|
||||
self.time_to_first_token = Some(elapsed.as_millis());
|
||||
if let Some(ref ctx) = self.llm_metrics {
|
||||
bs_metrics::record_llm_ttft(&ctx.provider, &ctx.model, elapsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -124,77 +287,98 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
);
|
||||
}
|
||||
|
||||
// Analyze signals if messages are available and record as span attributes
|
||||
if let Some(ref messages) = self.messages {
|
||||
let analyzer: Box<dyn SignalAnalyzer> = Box::new(TextBasedSignalAnalyzer::new());
|
||||
let report = analyzer.analyze(messages);
|
||||
// Record total duration on the span for the observability console.
|
||||
let duration_ms = self.start_time.elapsed().as_millis() as i64;
|
||||
{
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
otel_span.set_attribute(KeyValue::new(llm::DURATION_MS, duration_ms));
|
||||
otel_span.set_attribute(KeyValue::new(llm::RESPONSE_BYTES, self.total_bytes as i64));
|
||||
}
|
||||
|
||||
// Best-effort usage extraction + emission (works for both streaming
|
||||
// SSE and non-streaming JSON responses that include a `usage` object).
|
||||
let usage = extract_usage_from_bytes(&self.response_buffer);
|
||||
if !usage.is_empty() {
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
if let Some(v) = usage.prompt_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::PROMPT_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.completion_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::COMPLETION_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.total_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::TOTAL_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.cached_input_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::CACHED_INPUT_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.cache_creation_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::CACHE_CREATION_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.reasoning_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::REASONING_TOKENS, v));
|
||||
}
|
||||
// Override `llm.model` with the model the upstream actually ran
|
||||
// (e.g. `openai-gpt-5.4` resolved from `router:software-engineering`).
|
||||
// Cost lookup keys off the real model, not the alias.
|
||||
if let Some(resolved) = usage.resolved_model.clone() {
|
||||
otel_span.set_attribute(KeyValue::new(llm::MODEL_NAME, resolved));
|
||||
}
|
||||
}
|
||||
|
||||
// Emit LLM upstream prometheus metrics (duration + tokens) if wired.
|
||||
// The upstream responded (we have a status), so status_class alone
|
||||
// carries the non-2xx signal — error_class stays "none".
|
||||
if let Some(ref ctx) = self.llm_metrics {
|
||||
bs_metrics::record_llm_upstream(
|
||||
&ctx.provider,
|
||||
&ctx.model,
|
||||
ctx.upstream_status,
|
||||
metric_labels::LLM_ERR_NONE,
|
||||
self.start_time.elapsed(),
|
||||
);
|
||||
if let Some(v) = usage.prompt_tokens {
|
||||
bs_metrics::record_llm_tokens(
|
||||
&ctx.provider,
|
||||
&ctx.model,
|
||||
metric_labels::TOKEN_KIND_PROMPT,
|
||||
v.max(0) as u64,
|
||||
);
|
||||
}
|
||||
if let Some(v) = usage.completion_tokens {
|
||||
bs_metrics::record_llm_tokens(
|
||||
&ctx.provider,
|
||||
&ctx.model,
|
||||
metric_labels::TOKEN_KIND_COMPLETION,
|
||||
v.max(0) as u64,
|
||||
);
|
||||
}
|
||||
if usage.prompt_tokens.is_none() && usage.completion_tokens.is_none() {
|
||||
bs_metrics::record_llm_tokens_usage_missing(&ctx.provider, &ctx.model);
|
||||
}
|
||||
self.metrics_recorded = true;
|
||||
}
|
||||
// Release the buffered bytes early; nothing downstream needs them.
|
||||
self.response_buffer.clear();
|
||||
self.response_buffer.shrink_to_fit();
|
||||
|
||||
// Analyze signals if messages are available and record as span
|
||||
// attributes + per-signal events. We dual-emit legacy aggregate keys
|
||||
// and the new layered taxonomy so existing dashboards keep working
|
||||
// while new consumers can opt into the richer hierarchy.
|
||||
if let Some(ref messages) = self.messages {
|
||||
let analyzer = SignalAnalyzer::default();
|
||||
let report = analyzer.analyze_openai(messages);
|
||||
|
||||
// Get the current OTel span to set signal attributes
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
|
||||
// Add overall quality
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::QUALITY,
|
||||
format!("{:?}", report.overall_quality),
|
||||
));
|
||||
|
||||
// Add repair/follow-up metrics if concerning
|
||||
if report.follow_up.is_concerning || report.follow_up.repair_count > 0 {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::REPAIR_COUNT,
|
||||
report.follow_up.repair_count as i64,
|
||||
));
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::REPAIR_RATIO,
|
||||
format!("{:.3}", report.follow_up.repair_ratio),
|
||||
));
|
||||
}
|
||||
|
||||
// Add frustration metrics
|
||||
if report.frustration.has_frustration {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::FRUSTRATION_COUNT,
|
||||
report.frustration.frustration_count as i64,
|
||||
));
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::FRUSTRATION_SEVERITY,
|
||||
report.frustration.severity as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// Add repetition metrics
|
||||
if report.repetition.has_looping {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::REPETITION_COUNT,
|
||||
report.repetition.repetition_count as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// Add escalation metrics
|
||||
if report.escalation.escalation_requested {
|
||||
otel_span
|
||||
.set_attribute(KeyValue::new(signal_constants::ESCALATION_REQUESTED, true));
|
||||
}
|
||||
|
||||
// Add positive feedback metrics
|
||||
if report.positive_feedback.has_positive_feedback {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::POSITIVE_FEEDBACK_COUNT,
|
||||
report.positive_feedback.positive_count as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// Flag the span name if any concerning signal is detected
|
||||
let should_flag = report.frustration.has_frustration
|
||||
|| report.repetition.has_looping
|
||||
|| report.escalation.escalation_requested
|
||||
|| matches!(
|
||||
report.overall_quality,
|
||||
InteractionQuality::Poor | InteractionQuality::Severe
|
||||
);
|
||||
|
||||
let should_flag = emit_signals_to_span(&otel_span, &report);
|
||||
if should_flag {
|
||||
otel_span.update_name(format!("{} {}", self.operation_name, FLAG_MARKER));
|
||||
}
|
||||
|
|
@ -217,6 +401,18 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
duration_ms = self.start_time.elapsed().as_millis(),
|
||||
"stream error"
|
||||
);
|
||||
if let Some(ref ctx) = self.llm_metrics {
|
||||
if !self.metrics_recorded {
|
||||
bs_metrics::record_llm_upstream(
|
||||
&ctx.provider,
|
||||
&ctx.model,
|
||||
ctx.upstream_status,
|
||||
metric_labels::LLM_ERR_STREAM,
|
||||
self.start_time.elapsed(),
|
||||
);
|
||||
self.metrics_recorded = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -404,3 +600,55 @@ pub fn truncate_message(message: &str, max_length: usize) -> String {
|
|||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod usage_extraction_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn non_streaming_openai_with_cached() {
|
||||
let body = br#"{"id":"x","model":"gpt-4o","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":34,"total_tokens":46,"prompt_tokens_details":{"cached_tokens":5}}}"#;
|
||||
let u = extract_usage_from_bytes(body);
|
||||
assert_eq!(u.prompt_tokens, Some(12));
|
||||
assert_eq!(u.completion_tokens, Some(34));
|
||||
assert_eq!(u.total_tokens, Some(46));
|
||||
assert_eq!(u.cached_input_tokens, Some(5));
|
||||
assert_eq!(u.reasoning_tokens, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_streaming_anthropic_with_cache_creation() {
|
||||
let body = br#"{"id":"x","model":"claude","usage":{"input_tokens":100,"output_tokens":50,"cache_creation_input_tokens":20,"cache_read_input_tokens":30}}"#;
|
||||
let u = extract_usage_from_bytes(body);
|
||||
assert_eq!(u.prompt_tokens, Some(100));
|
||||
assert_eq!(u.completion_tokens, Some(50));
|
||||
assert_eq!(u.total_tokens, Some(150));
|
||||
assert_eq!(u.cached_input_tokens, Some(30));
|
||||
assert_eq!(u.cache_creation_tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_openai_final_chunk_has_usage() {
|
||||
let sse = b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}
|
||||
|
||||
data: {\"choices\":[{\"delta\":{}, \"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3,\"total_tokens\":10}}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
";
|
||||
let u = extract_usage_from_bytes(sse);
|
||||
assert_eq!(u.prompt_tokens, Some(7));
|
||||
assert_eq!(u.completion_tokens, Some(3));
|
||||
assert_eq!(u.total_tokens, Some(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_returns_default() {
|
||||
assert!(extract_usage_from_bytes(b"").is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_usage_in_body_returns_default() {
|
||||
assert!(extract_usage_from_bytes(br#"{"ok":true}"#).is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,6 +80,18 @@ pub mod llm {
|
|||
/// Total tokens used (prompt + completion)
|
||||
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
|
||||
|
||||
/// Tokens served from a prompt cache read
|
||||
/// (OpenAI `prompt_tokens_details.cached_tokens`, Anthropic `cache_read_input_tokens`,
|
||||
/// Google `cached_content_token_count`)
|
||||
pub const CACHED_INPUT_TOKENS: &str = "llm.usage.cached_input_tokens";
|
||||
|
||||
/// Tokens used to write a prompt cache entry (Anthropic `cache_creation_input_tokens`)
|
||||
pub const CACHE_CREATION_TOKENS: &str = "llm.usage.cache_creation_tokens";
|
||||
|
||||
/// Reasoning tokens for reasoning models
|
||||
/// (OpenAI `completion_tokens_details.reasoning_tokens`, Google `thoughts_token_count`)
|
||||
pub const REASONING_TOKENS: &str = "llm.usage.reasoning_tokens";
|
||||
|
||||
/// Temperature parameter used
|
||||
pub const TEMPERATURE: &str = "llm.temperature";
|
||||
|
||||
|
|
@ -119,6 +131,22 @@ pub mod routing {
|
|||
pub const SELECTION_REASON: &str = "routing.selection_reason";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Plano-specific
|
||||
// =============================================================================
|
||||
|
||||
/// Attributes specific to Plano (session affinity, routing decisions).
|
||||
pub mod plano {
|
||||
/// Session identifier propagated via the `x-model-affinity` header.
|
||||
/// Absent when the client did not send the header.
|
||||
pub const SESSION_ID: &str = "plano.session_id";
|
||||
|
||||
/// Matched route name from routing (e.g. "code", "summarization",
|
||||
/// "software-engineering"). Absent when the client routed directly
|
||||
/// to a concrete model.
|
||||
pub const ROUTE_NAME: &str = "plano.route.name";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Error Handling
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ mod init;
|
|||
mod service_name_exporter;
|
||||
|
||||
pub use constants::{
|
||||
error, http, llm, operation_component, routing, signals, OperationNameBuilder,
|
||||
error, http, llm, operation_component, plano, routing, signals, OperationNameBuilder,
|
||||
};
|
||||
pub use custom_attributes::collect_custom_trace_attributes;
|
||||
pub use init::init_tracer;
|
||||
|
|
|
|||
|
|
@ -234,6 +234,7 @@ pub struct Overrides {
|
|||
pub llm_routing_model: Option<String>,
|
||||
pub agent_orchestration_model: Option<String>,
|
||||
pub orchestrator_model_context_length: Option<usize>,
|
||||
pub disable_signals: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -395,6 +396,8 @@ pub enum LlmProviderType {
|
|||
Vercel,
|
||||
#[serde(rename = "openrouter")]
|
||||
OpenRouter,
|
||||
#[serde(rename = "digitalocean")]
|
||||
DigitalOcean,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
|
|
@ -418,6 +421,7 @@ impl Display for LlmProviderType {
|
|||
LlmProviderType::Plano => write!(f, "plano"),
|
||||
LlmProviderType::Vercel => write!(f, "vercel"),
|
||||
LlmProviderType::OpenRouter => write!(f, "openrouter"),
|
||||
LlmProviderType::DigitalOcean => write!(f, "digitalocean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -753,4 +757,29 @@ mod test {
|
|||
assert!(model_ids.contains(&"openai-gpt4".to_string()));
|
||||
assert!(!model_ids.contains(&"plano-orchestrator".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overrides_disable_signals_default_none() {
|
||||
let overrides = super::Overrides::default();
|
||||
assert_eq!(overrides.disable_signals, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overrides_disable_signals_deserialize() {
|
||||
let yaml = r#"
|
||||
disable_signals: true
|
||||
"#;
|
||||
let overrides: super::Overrides = serde_yaml::from_str(yaml).unwrap();
|
||||
assert_eq!(overrides.disable_signals, Some(true));
|
||||
|
||||
let yaml_false = r#"
|
||||
disable_signals: false
|
||||
"#;
|
||||
let overrides: super::Overrides = serde_yaml::from_str(yaml_false).unwrap();
|
||||
assert_eq!(overrides.disable_signals, Some(false));
|
||||
|
||||
let yaml_missing = "{}";
|
||||
let overrides: super::Overrides = serde_yaml::from_str(yaml_missing).unwrap();
|
||||
assert_eq!(overrides.disable_signals, None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -435,6 +435,12 @@ impl TokenUsage for MessagesResponse {
|
|||
fn total_tokens(&self) -> usize {
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize
|
||||
}
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
self.usage.cache_read_input_tokens.map(|t| t as usize)
|
||||
}
|
||||
fn cache_creation_tokens(&self) -> Option<usize> {
|
||||
self.usage.cache_creation_input_tokens.map(|t| t as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for MessagesResponse {
|
||||
|
|
|
|||
|
|
@ -596,6 +596,18 @@ impl TokenUsage for Usage {
|
|||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
self.prompt_tokens_details
|
||||
.as_ref()
|
||||
.and_then(|d| d.cached_tokens.map(|t| t as usize))
|
||||
}
|
||||
|
||||
fn reasoning_tokens(&self) -> Option<usize> {
|
||||
self.completion_tokens_details
|
||||
.as_ref()
|
||||
.and_then(|d| d.reasoning_tokens.map(|t| t as usize))
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderRequest for ChatCompletionsRequest
|
||||
|
|
|
|||
|
|
@ -710,6 +710,18 @@ impl crate::providers::response::TokenUsage for ResponseUsage {
|
|||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
self.input_tokens_details
|
||||
.as_ref()
|
||||
.map(|d| d.cached_tokens.max(0) as usize)
|
||||
}
|
||||
|
||||
fn reasoning_tokens(&self) -> Option<usize> {
|
||||
self.output_tokens_details
|
||||
.as_ref()
|
||||
.map(|d| d.reasoning_tokens.max(0) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Token details
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessageDelta, MessagesStopReason, MessagesStreamEvent, MessagesUsage,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use log::warn;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// SSE Stream Buffer for Anthropic Messages API streaming.
|
||||
|
|
@ -11,13 +14,24 @@ use std::collections::HashSet;
|
|||
///
|
||||
/// When converting from OpenAI to Anthropic format, this buffer injects the required
|
||||
/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol.
|
||||
///
|
||||
/// Guarantees (Anthropic Messages API contract):
|
||||
/// 1. `message_stop` is never emitted unless a matching `message_start` was emitted first.
|
||||
/// 2. `message_stop` is emitted at most once per stream (no double-close).
|
||||
/// 3. If upstream terminates with no content (empty/filtered/errored response), a
|
||||
/// minimal but well-formed envelope is synthesized so the client's state machine
|
||||
/// stays consistent.
|
||||
pub struct AnthropicMessagesStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
|
||||
/// Track if we've seen a message_start event
|
||||
/// Track if we've emitted a message_start event
|
||||
message_started: bool,
|
||||
|
||||
/// Track if we've emitted a terminal message_stop event (for idempotency /
|
||||
/// double-close protection).
|
||||
message_stopped: bool,
|
||||
|
||||
/// Track content block indices that have received ContentBlockStart events
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
|
||||
|
|
@ -42,6 +56,7 @@ impl AnthropicMessagesStreamBuffer {
|
|||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
message_started: false,
|
||||
message_stopped: false,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
needs_content_block_stop: false,
|
||||
seen_message_delta: false,
|
||||
|
|
@ -49,6 +64,66 @@ impl AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
/// Inject a `message_start` event into the buffer if one hasn't been emitted yet.
|
||||
/// This is the single source of truth for opening a message — every handler
|
||||
/// that can legitimately be the first event on the wire must call this before
|
||||
/// pushing its own event.
|
||||
fn ensure_message_started(&mut self) {
|
||||
if self.message_started {
|
||||
return;
|
||||
}
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
/// Inject a synthetic `message_delta` with `end_turn` / zero usage.
|
||||
/// Used when we must close a message but upstream never produced a terminal
|
||||
/// event (e.g. `[DONE]` arrives with no prior `finish_reason`).
|
||||
fn push_synthetic_message_delta(&mut self) {
|
||||
let event = MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
};
|
||||
let sse_string: String = event.clone().into();
|
||||
self.buffered_events.push(SseEvent {
|
||||
data: None,
|
||||
event: Some("message_delta".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: Some(ProviderStreamResponseType::MessagesStreamEvent(event)),
|
||||
});
|
||||
self.seen_message_delta = true;
|
||||
}
|
||||
|
||||
/// Inject a `message_stop` event into the buffer, marking the stream as closed.
|
||||
/// Idempotent — subsequent calls are no-ops.
|
||||
fn push_message_stop(&mut self) {
|
||||
if self.message_stopped {
|
||||
return;
|
||||
}
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
self.buffered_events.push(SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
});
|
||||
self.message_stopped = true;
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
|
|
@ -149,6 +224,27 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// We match on a reference first to determine the type, then move the event
|
||||
match &event.provider_stream_response {
|
||||
Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => {
|
||||
// If the message has already been closed, drop any trailing events
|
||||
// to avoid emitting data after `message_stop` (protocol violation).
|
||||
// This typically indicates a duplicate `[DONE]` from upstream or a
|
||||
// replay of previously-buffered bytes — worth surfacing so we can
|
||||
// spot misbehaving providers.
|
||||
if self.message_stopped {
|
||||
warn!(
|
||||
"anthropic stream buffer: dropping event after message_stop (variant={})",
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => {
|
||||
// Add the message_start event
|
||||
|
|
@ -157,14 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
MessagesStreamEvent::ContentBlockStart { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
self.ensure_message_started();
|
||||
|
||||
// Add the content_block_start event (from tool calls or other sources)
|
||||
self.buffered_events.push(event);
|
||||
|
|
@ -173,14 +262,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
self.ensure_message_started();
|
||||
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
|
|
@ -196,6 +278,11 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// `message_delta` is only meaningful inside an open message.
|
||||
// Upstream can send it with no prior content (empty completion,
|
||||
// content filter, etc.), so we must open a message first.
|
||||
self.ensure_message_started();
|
||||
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop =
|
||||
|
|
@ -230,15 +317,52 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
// ContentBlockStop received from upstream (e.g., Bedrock)
|
||||
self.ensure_message_started();
|
||||
// Clear the flag so we don't inject another one
|
||||
self.needs_content_block_stop = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE])
|
||||
// Clear the flag so we don't inject another one
|
||||
self.seen_message_delta = false;
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE]).
|
||||
//
|
||||
// The Anthropic protocol requires the full envelope
|
||||
// message_start → [content blocks] → message_delta → message_stop
|
||||
// so we must not emit a bare `message_stop`. Synthesize whatever
|
||||
// is missing to keep the client's state machine consistent.
|
||||
self.ensure_message_started();
|
||||
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
||||
// If no message_delta has been emitted yet (empty/filtered upstream
|
||||
// response), synthesize a minimal one carrying `end_turn`.
|
||||
if !self.seen_message_delta {
|
||||
// If we also never opened a content block, open and close one
|
||||
// so clients that expect at least one block are happy.
|
||||
if self.content_block_start_indices.is_empty() {
|
||||
let content_block_start =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_start_event(
|
||||
);
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(0);
|
||||
let content_block_stop =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_stop_event(
|
||||
);
|
||||
self.buffered_events.push(content_block_stop);
|
||||
}
|
||||
self.push_synthetic_message_delta();
|
||||
}
|
||||
|
||||
// Push the upstream-provided message_stop and mark closed.
|
||||
// `push_message_stop` is idempotent but we want to reuse the
|
||||
// original SseEvent so raw passthrough semantics are preserved.
|
||||
self.buffered_events.push(event);
|
||||
self.message_stopped = true;
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
_ => {
|
||||
// Other Anthropic event types (Ping, etc.), just accumulate
|
||||
|
|
@ -254,24 +378,23 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// Convert all accumulated events to bytes and clear buffer.
|
||||
//
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
||||
// Inject MessageStop after MessageDelta if we've seen one
|
||||
// This completes the Anthropic Messages API event sequence
|
||||
if self.seen_message_delta {
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
let message_stop_event = SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
};
|
||||
self.buffered_events.push(message_stop_event);
|
||||
self.seen_message_delta = false;
|
||||
//
|
||||
// Inject a synthetic `message_stop` only when:
|
||||
// 1. A `message_delta` has been seen (otherwise we'd violate the Anthropic
|
||||
// protocol by emitting `message_stop` without a preceding `message_delta`), AND
|
||||
// 2. We haven't already emitted `message_stop` (either synthetic from a
|
||||
// previous flush, or real from an upstream `[DONE]`).
|
||||
//
|
||||
// Without the `!message_stopped` guard, a stream whose `finish_reason` chunk
|
||||
// and `[DONE]` marker land in separate HTTP body chunks would receive two
|
||||
// `message_stop` events, triggering Claude Code's "Received message_stop
|
||||
// without a current message" error.
|
||||
if self.seen_message_delta && !self.message_stopped {
|
||||
self.push_message_stop();
|
||||
}
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
|
|
@ -615,4 +738,133 @@ data: [DONE]"#;
|
|||
println!("✓ Stop reason: tool_use");
|
||||
println!("✓ Proper Anthropic tool_use protocol\n");
|
||||
}
|
||||
|
||||
/// Regression test for:
|
||||
/// Claude Code CLI error: "Received message_stop without a current message"
|
||||
///
|
||||
/// Reproduces the *double-close* scenario: OpenAI's final `finish_reason`
|
||||
/// chunk and the `[DONE]` marker arrive in **separate** HTTP body chunks, so
|
||||
/// `to_bytes()` is called between them. Before the fix, this produced two
|
||||
/// `message_stop` events on the wire (one synthetic, one from `[DONE]`).
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_emits_single_message_stop_across_chunk_boundary() {
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
// --- HTTP chunk 1: content + finish_reason (no [DONE] yet) -----------
|
||||
let chunk_1 = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hi"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}"#;
|
||||
|
||||
for raw in SseStreamIter::try_from(chunk_1.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let out_1 = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
|
||||
// --- HTTP chunk 2: just the [DONE] marker ----------------------------
|
||||
let chunk_2 = "data: [DONE]";
|
||||
for raw in SseStreamIter::try_from(chunk_2.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let out_2 = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
|
||||
let combined = format!("{}{}", out_1, out_2);
|
||||
let start_count = combined.matches("event: message_start").count();
|
||||
let stop_count = combined.matches("event: message_stop").count();
|
||||
|
||||
assert_eq!(
|
||||
start_count, 1,
|
||||
"Must emit exactly one message_start across chunks, got {start_count}. Output:\n{combined}"
|
||||
);
|
||||
assert_eq!(
|
||||
stop_count, 1,
|
||||
"Must emit exactly one message_stop across chunks (no double-close), got {stop_count}. Output:\n{combined}"
|
||||
);
|
||||
// Every message_stop must be preceded by a message_start earlier in the stream.
|
||||
let start_pos = combined.find("event: message_start").unwrap();
|
||||
let stop_pos = combined.find("event: message_stop").unwrap();
|
||||
assert!(
|
||||
start_pos < stop_pos,
|
||||
"message_start must come before message_stop. Output:\n{combined}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test for:
|
||||
/// "Received message_stop without a current message" on empty upstream responses.
|
||||
///
|
||||
/// OpenAI returns only `[DONE]` with no content deltas and no `finish_reason`
|
||||
/// (this happens with content filters, truncated upstream streams, and some
|
||||
/// 5xx recoveries). Before the fix, the buffer emitted a bare `message_stop`
|
||||
/// with no preceding `message_start`. After the fix, it synthesizes a
|
||||
/// minimal but well-formed envelope.
|
||||
#[test]
|
||||
fn test_openai_done_only_stream_synthesizes_valid_envelope() {
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
let raw_input = "data: [DONE]";
|
||||
for raw in SseStreamIter::try_from(raw_input.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let out = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
|
||||
assert!(
|
||||
out.contains("event: message_start"),
|
||||
"Empty upstream must still produce message_start. Output:\n{out}"
|
||||
);
|
||||
assert!(
|
||||
out.contains("event: message_delta"),
|
||||
"Empty upstream must produce a synthesized message_delta. Output:\n{out}"
|
||||
);
|
||||
assert_eq!(
|
||||
out.matches("event: message_stop").count(),
|
||||
1,
|
||||
"Empty upstream must produce exactly one message_stop. Output:\n{out}"
|
||||
);
|
||||
|
||||
// Protocol ordering: start < delta < stop.
|
||||
let p_start = out.find("event: message_start").unwrap();
|
||||
let p_delta = out.find("event: message_delta").unwrap();
|
||||
let p_stop = out.find("event: message_stop").unwrap();
|
||||
assert!(
|
||||
p_start < p_delta && p_delta < p_stop,
|
||||
"Bad ordering. Output:\n{out}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test: events arriving after `message_stop` (e.g. a stray `[DONE]`
|
||||
/// echo, or late-arriving deltas from a racing upstream) must be dropped
|
||||
/// rather than written after the terminal frame.
|
||||
#[test]
|
||||
fn test_events_after_message_stop_are_dropped() {
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
let first = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
for raw in SseStreamIter::try_from(first.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let _ = buffer.to_bytes();
|
||||
|
||||
// Simulate a duplicate / late `[DONE]` after the stream was already closed.
|
||||
let late = "data: [DONE]";
|
||||
for raw in SseStreamIter::try_from(late.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let tail = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
assert!(
|
||||
tail.is_empty(),
|
||||
"No bytes should be emitted after message_stop, got: {tail:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ providers:
|
|||
anthropic:
|
||||
- anthropic/claude-sonnet-4-6
|
||||
- anthropic/claude-opus-4-6
|
||||
- anthropic/claude-opus-4-7
|
||||
- anthropic/claude-opus-4-5-20251101
|
||||
- anthropic/claude-opus-4-5
|
||||
- anthropic/claude-haiku-4-5-20251001
|
||||
|
|
@ -328,7 +329,53 @@ providers:
|
|||
- xiaomi/mimo-v2-flash
|
||||
- xiaomi/mimo-v2-omni
|
||||
- xiaomi/mimo-v2-pro
|
||||
digitalocean:
|
||||
- digitalocean/openai-gpt-4.1
|
||||
- digitalocean/openai-gpt-4o
|
||||
- digitalocean/openai-gpt-4o-mini
|
||||
- digitalocean/openai-gpt-5
|
||||
- digitalocean/openai-gpt-5-mini
|
||||
- digitalocean/openai-gpt-5-nano
|
||||
- digitalocean/openai-gpt-5.1-codex-max
|
||||
- digitalocean/openai-gpt-5.2
|
||||
- digitalocean/openai-gpt-5.2-pro
|
||||
- digitalocean/openai-gpt-5.3-codex
|
||||
- digitalocean/openai-gpt-5.4
|
||||
- digitalocean/openai-gpt-5.4-mini
|
||||
- digitalocean/openai-gpt-5.4-nano
|
||||
- digitalocean/openai-gpt-5.4-pro
|
||||
- digitalocean/openai-gpt-oss-120b
|
||||
- digitalocean/openai-gpt-oss-20b
|
||||
- digitalocean/openai-o1
|
||||
- digitalocean/openai-o3
|
||||
- digitalocean/openai-o3-mini
|
||||
- digitalocean/anthropic-claude-4.1-opus
|
||||
- digitalocean/anthropic-claude-4.5-sonnet
|
||||
- digitalocean/anthropic-claude-4.6-sonnet
|
||||
- digitalocean/anthropic-claude-haiku-4.5
|
||||
- digitalocean/anthropic-claude-opus-4
|
||||
- digitalocean/anthropic-claude-opus-4.5
|
||||
- digitalocean/anthropic-claude-opus-4.6
|
||||
- digitalocean/anthropic-claude-opus-4.7
|
||||
- digitalocean/anthropic-claude-sonnet-4
|
||||
- digitalocean/alibaba-qwen3-32b
|
||||
- digitalocean/arcee-trinity-large-thinking
|
||||
- digitalocean/deepseek-3.2
|
||||
- digitalocean/deepseek-r1-distill-llama-70b
|
||||
- digitalocean/gemma-4-31B-it
|
||||
- digitalocean/glm-5
|
||||
- digitalocean/kimi-k2.5
|
||||
- digitalocean/llama3.3-70b-instruct
|
||||
- digitalocean/minimax-m2.5
|
||||
- digitalocean/nvidia-nemotron-3-super-120b
|
||||
- digitalocean/qwen3-coder-flash
|
||||
- digitalocean/qwen3.5-397b-a17b
|
||||
- digitalocean/all-mini-lm-l6-v2
|
||||
- digitalocean/gte-large-en-v1.5
|
||||
- digitalocean/multi-qa-mpnet-base-dot-v1
|
||||
- digitalocean/qwen3-embedding-0.6b
|
||||
- digitalocean/router:software-engineering
|
||||
metadata:
|
||||
total_providers: 11
|
||||
total_models: 316
|
||||
last_updated: 2026-04-03T23:14:46.956158+00:00
|
||||
total_providers: 12
|
||||
total_models: 361
|
||||
last_updated: 2026-04-16T00:00:00.000000+00:00
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ pub enum ProviderId {
|
|||
AmazonBedrock,
|
||||
Vercel,
|
||||
OpenRouter,
|
||||
DigitalOcean,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ProviderId {
|
||||
|
|
@ -75,6 +76,9 @@ impl TryFrom<&str> for ProviderId {
|
|||
"amazon" => Ok(ProviderId::AmazonBedrock), // alias
|
||||
"vercel" => Ok(ProviderId::Vercel),
|
||||
"openrouter" => Ok(ProviderId::OpenRouter),
|
||||
"digitalocean" => Ok(ProviderId::DigitalOcean),
|
||||
"do" => Ok(ProviderId::DigitalOcean), // alias
|
||||
"do_ai" => Ok(ProviderId::DigitalOcean), // alias
|
||||
_ => Err(format!("Unknown provider: {}", value)),
|
||||
}
|
||||
}
|
||||
|
|
@ -99,6 +103,7 @@ impl ProviderId {
|
|||
ProviderId::Moonshotai => "moonshotai",
|
||||
ProviderId::Zhipu => "z-ai",
|
||||
ProviderId::Qwen => "qwen",
|
||||
ProviderId::DigitalOcean => "digitalocean",
|
||||
// Vercel and OpenRouter are open-ended gateways; model lists are unbounded.
|
||||
// Users configure these with wildcards (e.g. vercel/*); no static expansion needed.
|
||||
ProviderId::Vercel | ProviderId::OpenRouter => return Vec::new(),
|
||||
|
|
@ -157,7 +162,8 @@ impl ProviderId {
|
|||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen
|
||||
| ProviderId::Vercel
|
||||
| ProviderId::OpenRouter,
|
||||
| ProviderId::OpenRouter
|
||||
| ProviderId::DigitalOcean,
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
|
|
@ -178,7 +184,8 @@ impl ProviderId {
|
|||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen
|
||||
| ProviderId::Vercel
|
||||
| ProviderId::OpenRouter,
|
||||
| ProviderId::OpenRouter
|
||||
| ProviderId::DigitalOcean,
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
|
|
@ -247,6 +254,7 @@ impl Display for ProviderId {
|
|||
ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
ProviderId::Vercel => write!(f, "vercel"),
|
||||
ProviderId::OpenRouter => write!(f, "openrouter"),
|
||||
ProviderId::DigitalOcean => write!(f, "digitalocean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,31 @@ pub trait TokenUsage {
|
|||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
/// Tokens served from a prompt cache read (OpenAI `prompt_tokens_details.cached_tokens`,
|
||||
/// Anthropic `cache_read_input_tokens`, Google `cached_content_token_count`).
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Tokens used to write a cache entry (Anthropic `cache_creation_input_tokens`).
|
||||
fn cache_creation_tokens(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Reasoning tokens for reasoning models (OpenAI `completion_tokens_details.reasoning_tokens`,
|
||||
/// Google `thoughts_token_count`).
|
||||
fn reasoning_tokens(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Rich usage breakdown extracted from a provider response.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct UsageDetails {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
pub cached_input_tokens: Option<usize>,
|
||||
pub cache_creation_tokens: Option<usize>,
|
||||
pub reasoning_tokens: Option<usize>,
|
||||
}
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
|
|
@ -34,6 +59,18 @@ pub trait ProviderResponse: Send + Sync {
|
|||
self.usage()
|
||||
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
|
||||
/// Extract a rich usage breakdown including cached/cache-creation/reasoning tokens.
|
||||
fn extract_usage_details(&self) -> Option<UsageDetails> {
|
||||
self.usage().map(|u| UsageDetails {
|
||||
prompt_tokens: u.prompt_tokens(),
|
||||
completion_tokens: u.completion_tokens(),
|
||||
total_tokens: u.total_tokens(),
|
||||
cached_input_tokens: u.cached_input_tokens(),
|
||||
cache_creation_tokens: u.cache_creation_tokens(),
|
||||
reasoning_tokens: u.reasoning_tokens(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ProviderResponseType {
|
||||
|
|
|
|||
|
|
@ -346,12 +346,10 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
|
|||
(
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
) if transformed_event.is_event_only() && transformed_event.event.is_some() => {
|
||||
// OpenAI clients don't expect separate event: lines
|
||||
// Suppress upstream Anthropic event-only lines
|
||||
if transformed_event.is_event_only() && transformed_event.event.is_some() {
|
||||
transformed_event.sse_transformed_lines = "\n".to_string();
|
||||
}
|
||||
transformed_event.sse_transformed_lines = "\n".to_string();
|
||||
}
|
||||
_ => {
|
||||
// Other cross-API combinations can be handled here as needed
|
||||
|
|
@ -371,12 +369,10 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
|
|||
| (
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
if transformed_event.is_event_only() && transformed_event.event.is_some() {
|
||||
// Mark as should-skip by clearing sse_transformed_lines
|
||||
// The event line is already included when the data line is transformed
|
||||
transformed_event.sse_transformed_lines = String::new();
|
||||
}
|
||||
) if transformed_event.is_event_only() && transformed_event.event.is_some() => {
|
||||
// Mark as should-skip by clearing sse_transformed_lines
|
||||
// The event line is already included when the data line is transformed
|
||||
transformed_event.sse_transformed_lines = String::new();
|
||||
}
|
||||
_ => {
|
||||
// Other passthrough combinations (OpenAI ChatCompletions, etc.) don't have this issue
|
||||
|
|
|
|||
|
|
@ -188,14 +188,13 @@ pub fn convert_openai_message_to_anthropic_content(
|
|||
|
||||
// Handle regular content
|
||||
match &message.content {
|
||||
Some(MessageContent::Text(text)) => {
|
||||
if !text.is_empty() {
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
Some(MessageContent::Text(text)) if !text.is_empty() => {
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
Some(MessageContent::Text(_)) => {}
|
||||
Some(MessageContent::Parts(parts)) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
|
|
|
|||
|
|
@ -354,10 +354,10 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
MessagesMessageContent::Blocks(blocks) => {
|
||||
for block in blocks {
|
||||
match block {
|
||||
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::Text { text, .. }
|
||||
if !text.is_empty() =>
|
||||
{
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
|
|
|
|||
|
|
@ -317,11 +317,10 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
Role::User => {
|
||||
// Convert user message content to content blocks
|
||||
match message.content {
|
||||
Some(MessageContent::Text(text)) => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
Some(MessageContent::Text(text)) if !text.is_empty() => {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
Some(MessageContent::Text(_)) => {}
|
||||
Some(MessageContent::Parts(parts)) => {
|
||||
// Convert OpenAI content parts to Bedrock ContentBlocks
|
||||
for part in parts {
|
||||
|
|
|
|||
|
|
@ -177,24 +177,33 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
if self.llm_provider().passthrough_auth == Some(true) {
|
||||
// Check if client provided an Authorization header
|
||||
if self.get_http_request_header("Authorization").is_none() {
|
||||
warn!(
|
||||
"request_id={}: passthrough_auth enabled but no authorization header present in client request",
|
||||
self.request_identifier()
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
"request_id={}: preserving client authorization header for provider '{}'",
|
||||
self.request_identifier(),
|
||||
self.llm_provider().name
|
||||
);
|
||||
// Determine the credential to forward upstream. Either the client
|
||||
// supplied one (passthrough_auth) or it's configured on the provider.
|
||||
let credential: String = if self.llm_provider().passthrough_auth == Some(true) {
|
||||
// Client auth may arrive in either Anthropic-style (`x-api-key`)
|
||||
// or OpenAI-style (`Authorization: Bearer ...`). Accept both so
|
||||
// clients using Anthropic SDKs (which default to `x-api-key`)
|
||||
// work when the upstream is OpenAI-compatible, and vice versa.
|
||||
let authorization = self.get_http_request_header("Authorization");
|
||||
let x_api_key = self.get_http_request_header("x-api-key");
|
||||
match extract_client_credential(authorization.as_deref(), x_api_key.as_deref()) {
|
||||
Some(key) => {
|
||||
debug!(
|
||||
"request_id={}: forwarding client credential to provider '{}'",
|
||||
self.request_identifier(),
|
||||
self.llm_provider().name
|
||||
);
|
||||
key
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"request_id={}: passthrough_auth enabled but no Authorization / x-api-key header present in client request",
|
||||
self.request_identifier()
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let llm_provider_api_key_value =
|
||||
} else {
|
||||
self.llm_provider()
|
||||
.access_key
|
||||
.as_ref()
|
||||
|
|
@ -203,15 +212,19 @@ impl StreamContext {
|
|||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
self.llm_provider()
|
||||
),
|
||||
})?;
|
||||
})?
|
||||
.clone()
|
||||
};
|
||||
|
||||
// Set API-specific headers based on the resolved upstream API
|
||||
// Normalize the credential into whichever header the upstream expects.
|
||||
// This lets an Anthropic-SDK client reach an OpenAI-compatible upstream
|
||||
// (and vice versa) without the caller needing to know what format the
|
||||
// upstream uses.
|
||||
match self.resolved_api.as_ref() {
|
||||
Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Anthropic API requires x-api-key and anthropic-version headers
|
||||
// Remove any existing Authorization header since Anthropic doesn't use it
|
||||
// Anthropic expects `x-api-key` + `anthropic-version`.
|
||||
self.remove_http_request_header("Authorization");
|
||||
self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value));
|
||||
self.set_http_request_header("x-api-key", Some(&credential));
|
||||
self.set_http_request_header("anthropic-version", Some("2023-06-01"));
|
||||
}
|
||||
Some(
|
||||
|
|
@ -221,10 +234,9 @@ impl StreamContext {
|
|||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
)
|
||||
| None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
// Remove any existing x-api-key header since OpenAI doesn't use it
|
||||
// OpenAI (and default): `Authorization: Bearer ...`.
|
||||
self.remove_http_request_header("x-api-key");
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
let authorization_header_value = format!("Bearer {}", credential);
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
}
|
||||
}
|
||||
|
|
@ -1235,3 +1247,86 @@ fn current_time_ns() -> u128 {
|
|||
}
|
||||
|
||||
impl Context for StreamContext {}
|
||||
|
||||
/// Extract the credential a client sent in either an OpenAI-style
|
||||
/// `Authorization` header or an Anthropic-style `x-api-key` header.
|
||||
///
|
||||
/// Returns `None` when neither header is present or both are empty/whitespace.
|
||||
/// The `Bearer ` prefix on the `Authorization` value is stripped if present;
|
||||
/// otherwise the value is taken verbatim (some clients send a raw token).
|
||||
fn extract_client_credential(
|
||||
authorization: Option<&str>,
|
||||
x_api_key: Option<&str>,
|
||||
) -> Option<String> {
|
||||
// Strip the optional "Bearer " / "Bearer" prefix (case-sensitive, matches
|
||||
// OpenAI SDK behavior) and trim surrounding whitespace before validating
|
||||
// non-empty.
|
||||
let from_authorization = authorization
|
||||
.map(|v| {
|
||||
v.strip_prefix("Bearer ")
|
||||
.or_else(|| v.strip_prefix("Bearer"))
|
||||
.unwrap_or(v)
|
||||
.trim()
|
||||
.to_string()
|
||||
})
|
||||
.filter(|s| !s.is_empty());
|
||||
if from_authorization.is_some() {
|
||||
return from_authorization;
|
||||
}
|
||||
x_api_key
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::extract_client_credential;
|
||||
|
||||
#[test]
|
||||
fn authorization_bearer_strips_prefix() {
|
||||
assert_eq!(
|
||||
extract_client_credential(Some("Bearer sk-abc"), None),
|
||||
Some("sk-abc".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authorization_raw_token_preserved() {
|
||||
// Some clients send the raw token without "Bearer " — accept it.
|
||||
assert_eq!(
|
||||
extract_client_credential(Some("sk-abc"), None),
|
||||
Some("sk-abc".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn x_api_key_used_when_authorization_absent() {
|
||||
assert_eq!(
|
||||
extract_client_credential(None, Some("sk-ant-api-key")),
|
||||
Some("sk-ant-api-key".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authorization_wins_when_both_present() {
|
||||
// If a client is particularly exotic and sends both, prefer the
|
||||
// OpenAI-style Authorization header.
|
||||
assert_eq!(
|
||||
extract_client_credential(Some("Bearer openai-key"), Some("anthropic-key")),
|
||||
Some("openai-key".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_when_neither_present() {
|
||||
assert!(extract_client_credential(None, None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_and_whitespace_headers_are_ignored() {
|
||||
assert!(extract_client_credential(Some(""), None).is_none());
|
||||
assert!(extract_client_credential(Some("Bearer "), None).is_none());
|
||||
assert!(extract_client_credential(Some(" "), Some(" ")).is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue