diff --git a/cli/planoai/chatgpt_auth.py b/cli/planoai/chatgpt_auth.py new file mode 100644 index 00000000..dbbde3ac --- /dev/null +++ b/cli/planoai/chatgpt_auth.py @@ -0,0 +1,290 @@ +""" +ChatGPT subscription OAuth device-flow authentication. + +Implements the device code flow used by OpenAI Codex CLI to authenticate +with a ChatGPT Plus/Pro subscription. Tokens are stored locally in +~/.plano/chatgpt/auth.json and auto-refreshed when expired. +""" + +import base64 +import json +import os +import time +from typing import Any, Dict, Optional, Tuple + +import requests + +from planoai.consts import PLANO_HOME + +# OAuth + API constants (derived from openai/codex) +CHATGPT_AUTH_BASE = "https://auth.openai.com" +CHATGPT_DEVICE_CODE_URL = f"{CHATGPT_AUTH_BASE}/api/accounts/deviceauth/usercode" +CHATGPT_DEVICE_TOKEN_URL = f"{CHATGPT_AUTH_BASE}/api/accounts/deviceauth/token" +CHATGPT_OAUTH_TOKEN_URL = f"{CHATGPT_AUTH_BASE}/oauth/token" +CHATGPT_DEVICE_VERIFY_URL = f"{CHATGPT_AUTH_BASE}/codex/device" +CHATGPT_API_BASE = "https://chatgpt.com/backend-api/codex" +CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" + +# Local storage +CHATGPT_AUTH_DIR = os.path.join(PLANO_HOME, "chatgpt") +CHATGPT_AUTH_FILE = os.path.join(CHATGPT_AUTH_DIR, "auth.json") + +# Timeouts +TOKEN_EXPIRY_SKEW_SECONDS = 60 +DEVICE_CODE_TIMEOUT_SECONDS = 15 * 60 +DEVICE_CODE_POLL_SECONDS = 5 + + +def _ensure_auth_dir(): + os.makedirs(CHATGPT_AUTH_DIR, exist_ok=True) + + +def load_auth() -> Optional[Dict[str, Any]]: + """Load auth data from disk.""" + try: + with open(CHATGPT_AUTH_FILE, "r") as f: + return json.load(f) + except (IOError, json.JSONDecodeError): + return None + + +def save_auth(data: Dict[str, Any]): + """Save auth data to disk.""" + _ensure_auth_dir() + fd = os.open(CHATGPT_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, "w") as f: + json.dump(data, f, indent=2) + + +def delete_auth(): + """Remove stored credentials.""" + try: + os.remove(CHATGPT_AUTH_FILE) + except FileNotFoundError: + pass + + +def _decode_jwt_claims(token: str) -> Dict[str, Any]: + """Decode JWT payload without verification.""" + try: + parts = token.split(".") + if len(parts) < 2: + return {} + payload_b64 = parts[1] + payload_b64 += "=" * (-len(payload_b64) % 4) + return json.loads(base64.urlsafe_b64decode(payload_b64).decode("utf-8")) + except Exception: + return {} + + +def _get_expires_at(token: str) -> Optional[int]: + """Extract expiration time from JWT.""" + claims = _decode_jwt_claims(token) + exp = claims.get("exp") + return int(exp) if isinstance(exp, (int, float)) else None + + +def _extract_account_id(token: Optional[str]) -> Optional[str]: + """Extract ChatGPT account ID from JWT claims.""" + if not token: + return None + claims = _decode_jwt_claims(token) + auth_claims = claims.get("https://api.openai.com/auth") + if isinstance(auth_claims, dict): + account_id = auth_claims.get("chatgpt_account_id") + if isinstance(account_id, str) and account_id: + return account_id + return None + + +def _is_token_expired(auth_data: Dict[str, Any]) -> bool: + """Check if the access token is expired.""" + expires_at = auth_data.get("expires_at") + if expires_at is None: + access_token = auth_data.get("access_token") + if access_token: + expires_at = _get_expires_at(access_token) + if expires_at: + auth_data["expires_at"] = expires_at + save_auth(auth_data) + if expires_at is None: + return True + return time.time() >= float(expires_at) - TOKEN_EXPIRY_SKEW_SECONDS + + +def _refresh_tokens(refresh_token: str) -> Dict[str, str]: + """Refresh the access token using the refresh token.""" + resp = requests.post( + CHATGPT_OAUTH_TOKEN_URL, + json={ + "client_id": CHATGPT_CLIENT_ID, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "scope": "openid profile email", + }, + ) + resp.raise_for_status() + data = resp.json() + + access_token = data.get("access_token") + id_token = data.get("id_token") + if not access_token or not id_token: + raise RuntimeError(f"Refresh response missing fields: {data}") + + return { + "access_token": access_token, + "refresh_token": data.get("refresh_token", refresh_token), + "id_token": id_token, + } + + +def _build_auth_record(tokens: Dict[str, str]) -> Dict[str, Any]: + """Build the auth record to persist.""" + access_token = tokens.get("access_token") + id_token = tokens.get("id_token") + expires_at = _get_expires_at(access_token) if access_token else None + account_id = _extract_account_id(id_token or access_token) + return { + "access_token": access_token, + "refresh_token": tokens.get("refresh_token"), + "id_token": id_token, + "expires_at": expires_at, + "account_id": account_id, + } + + +def request_device_code() -> Dict[str, str]: + """Request a device code from OpenAI's device auth endpoint.""" + resp = requests.post( + CHATGPT_DEVICE_CODE_URL, + json={"client_id": CHATGPT_CLIENT_ID}, + ) + resp.raise_for_status() + data = resp.json() + + device_auth_id = data.get("device_auth_id") + user_code = data.get("user_code") or data.get("usercode") + interval = data.get("interval") + if not device_auth_id or not user_code: + raise RuntimeError(f"Device code response missing fields: {data}") + + return { + "device_auth_id": device_auth_id, + "user_code": user_code, + "interval": str(interval or "5"), + } + + +def poll_for_authorization(device_code: Dict[str, str]) -> Dict[str, str]: + """Poll until the user completes authorization. Returns code_data.""" + interval = int(device_code.get("interval", "5")) + start_time = time.time() + + while time.time() - start_time < DEVICE_CODE_TIMEOUT_SECONDS: + try: + resp = requests.post( + CHATGPT_DEVICE_TOKEN_URL, + json={ + "device_auth_id": device_code["device_auth_id"], + "user_code": device_code["user_code"], + }, + ) + if resp.status_code == 200: + data = resp.json() + if all( + key in data + for key in ("authorization_code", "code_challenge", "code_verifier") + ): + return data + if resp.status_code in (403, 404): + time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS)) + continue + resp.raise_for_status() + except requests.HTTPError as exc: + if exc.response is not None and exc.response.status_code in (403, 404): + time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS)) + continue + raise RuntimeError(f"Polling failed: {exc}") from exc + + time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS)) + + raise RuntimeError("Timed out waiting for device authorization") + + +def exchange_code_for_tokens(code_data: Dict[str, str]) -> Dict[str, str]: + """Exchange the authorization code for access/refresh/id tokens.""" + redirect_uri = f"{CHATGPT_AUTH_BASE}/deviceauth/callback" + body = ( + "grant_type=authorization_code" + f"&code={code_data['authorization_code']}" + f"&redirect_uri={redirect_uri}" + f"&client_id={CHATGPT_CLIENT_ID}" + f"&code_verifier={code_data['code_verifier']}" + ) + resp = requests.post( + CHATGPT_OAUTH_TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data=body, + ) + resp.raise_for_status() + data = resp.json() + + if not all(key in data for key in ("access_token", "refresh_token", "id_token")): + raise RuntimeError(f"Token exchange response missing fields: {data}") + + return { + "access_token": data["access_token"], + "refresh_token": data["refresh_token"], + "id_token": data["id_token"], + } + + +def login() -> Dict[str, Any]: + """Run the full device code login flow. Returns the auth record.""" + device_code = request_device_code() + auth_record = _build_auth_record({}) + auth_record["device_code_requested_at"] = time.time() + save_auth(auth_record) + + print( + "\nSign in with your ChatGPT account:\n" + f" 1) Visit: {CHATGPT_DEVICE_VERIFY_URL}\n" + f" 2) Enter code: {device_code['user_code']}\n\n" + "Device codes are a common phishing target. Never share this code.\n", + flush=True, + ) + + code_data = poll_for_authorization(device_code) + tokens = exchange_code_for_tokens(code_data) + auth_record = _build_auth_record(tokens) + save_auth(auth_record) + return auth_record + + +def get_access_token() -> Tuple[str, Optional[str]]: + """ + Get a valid access token and account ID. + Refreshes automatically if expired. Raises if no auth data exists. + Returns (access_token, account_id). + """ + auth_data = load_auth() + if not auth_data: + raise RuntimeError( + "No ChatGPT credentials found. Run 'planoai chatgpt login' first." + ) + + access_token = auth_data.get("access_token") + if access_token and not _is_token_expired(auth_data): + return access_token, auth_data.get("account_id") + + # Try refresh + refresh_token = auth_data.get("refresh_token") + if refresh_token: + tokens = _refresh_tokens(refresh_token) + auth_record = _build_auth_record(tokens) + save_auth(auth_record) + return auth_record["access_token"], auth_record.get("account_id") + + raise RuntimeError( + "ChatGPT token expired and refresh failed. Run 'planoai chatgpt login' again." + ) diff --git a/cli/planoai/chatgpt_cmd.py b/cli/planoai/chatgpt_cmd.py new file mode 100644 index 00000000..b61068c4 --- /dev/null +++ b/cli/planoai/chatgpt_cmd.py @@ -0,0 +1,86 @@ +""" +CLI commands for ChatGPT subscription management. + +Usage: + planoai chatgpt login - Authenticate with ChatGPT via device code flow + planoai chatgpt status - Check authentication status + planoai chatgpt logout - Remove stored credentials +""" + +import datetime + +import click +from rich.console import Console + +from planoai import chatgpt_auth + +console = Console() + + +@click.group() +def chatgpt(): + """ChatGPT subscription management.""" + pass + + +@chatgpt.command() +def login(): + """Authenticate with your ChatGPT subscription using device code flow.""" + try: + auth_record = chatgpt_auth.login() + account_id = auth_record.get("account_id", "unknown") + console.print( + f"\n[green]Successfully authenticated with ChatGPT![/green]" + f"\nAccount ID: {account_id}" + f"\nCredentials saved to: {chatgpt_auth.CHATGPT_AUTH_FILE}" + ) + except Exception as e: + console.print(f"\n[red]Authentication failed:[/red] {e}") + raise SystemExit(1) + + +@chatgpt.command() +def status(): + """Check ChatGPT authentication status.""" + auth_data = chatgpt_auth.load_auth() + if not auth_data or not auth_data.get("access_token"): + console.print( + "[yellow]Not authenticated.[/yellow] Run 'planoai chatgpt login'." + ) + return + + account_id = auth_data.get("account_id", "unknown") + expires_at = auth_data.get("expires_at") + + if expires_at: + expiry_time = datetime.datetime.fromtimestamp( + expires_at, tz=datetime.timezone.utc + ) + now = datetime.datetime.now(tz=datetime.timezone.utc) + if expiry_time > now: + remaining = expiry_time - now + console.print( + f"[green]Authenticated[/green]" + f"\n Account ID: {account_id}" + f"\n Token expires: {expiry_time.strftime('%Y-%m-%d %H:%M:%S UTC')}" + f" ({remaining.seconds // 60}m remaining)" + ) + else: + console.print( + f"[yellow]Token expired[/yellow]" + f"\n Account ID: {account_id}" + f"\n Expired at: {expiry_time.strftime('%Y-%m-%d %H:%M:%S UTC')}" + f"\n Will auto-refresh on next use, or run 'planoai chatgpt login'." + ) + else: + console.print( + f"[green]Authenticated[/green] (no expiry info)" + f"\n Account ID: {account_id}" + ) + + +@chatgpt.command() +def logout(): + """Remove stored ChatGPT credentials.""" + chatgpt_auth.delete_auth() + console.print("[green]ChatGPT credentials removed.[/green]") diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index d9d76d79..5eaae3c6 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -1,5 +1,6 @@ import json import os +import uuid from planoai.utils import convert_legacy_listeners from jinja2 import Environment, FileSystemLoader import yaml @@ -28,9 +29,14 @@ SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [ "xai", "moonshotai", "zhipu", + "chatgpt", "digitalocean", ] +CHATGPT_API_BASE = "https://chatgpt.com/backend-api/codex" +CHATGPT_DEFAULT_ORIGINATOR = "codex_cli_rs" +CHATGPT_DEFAULT_USER_AGENT = "codex_cli_rs/0.0.0 (Unknown 0; unknown) unknown" + SUPPORTED_PROVIDERS = ( SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL ) @@ -332,6 +338,25 @@ def validate_and_render_schema(): provider = model_provider["provider"] model_provider["provider_interface"] = provider del model_provider["provider"] + + # Auto-wire ChatGPT provider: inject base_url, passthrough_auth, and extra headers + if provider == "chatgpt": + if not model_provider.get("base_url"): + model_provider["base_url"] = CHATGPT_API_BASE + if not model_provider.get("access_key") and not model_provider.get( + "passthrough_auth" + ): + model_provider["passthrough_auth"] = True + headers = model_provider.get("headers", {}) + headers.setdefault( + "ChatGPT-Account-Id", + os.environ.get("CHATGPT_ACCOUNT_ID", ""), + ) + headers.setdefault("originator", CHATGPT_DEFAULT_ORIGINATOR) + headers.setdefault("user-agent", CHATGPT_DEFAULT_USER_AGENT) + headers.setdefault("session_id", str(uuid.uuid4())) + model_provider["headers"] = headers + updated_model_providers.append(model_provider) if model_provider.get("base_url", None): diff --git a/cli/planoai/main.py b/cli/planoai/main.py index 5686b0ff..8e766cf8 100644 --- a/cli/planoai/main.py +++ b/cli/planoai/main.py @@ -37,6 +37,7 @@ from planoai.core import ( ) from planoai.init_cmd import init as init_cmd from planoai.trace_cmd import trace as trace_cmd, start_trace_listener_background +from planoai.chatgpt_cmd import chatgpt as chatgpt_cmd from planoai.obs_cmd import obs as obs_cmd from planoai.consts import ( DEFAULT_OTEL_TRACING_GRPC_ENDPOINT, @@ -125,6 +126,28 @@ def _temporary_cli_log_level(level: str | None): set_log_level(current_level) +def _inject_chatgpt_tokens_if_needed(config, env, console): + """If config uses chatgpt providers, resolve tokens from ~/.plano/chatgpt/auth.json.""" + providers = config.get("model_providers") or config.get("llm_providers") or [] + has_chatgpt = any(str(p.get("model", "")).startswith("chatgpt/") for p in providers) + if not has_chatgpt: + return + + try: + from planoai.chatgpt_auth import get_access_token + + access_token, account_id = get_access_token() + env["CHATGPT_ACCESS_TOKEN"] = access_token + if account_id: + env["CHATGPT_ACCOUNT_ID"] = account_id + except Exception as e: + console.print( + f"\n[red]ChatGPT auth error:[/red] {e}\n" + f"[dim]Run 'planoai chatgpt login' to authenticate.[/dim]\n" + ) + sys.exit(1) + + def _print_missing_keys(console, missing_keys: list[str]) -> None: console.print(f"\n[red]✗[/red] [red]Missing API keys![/red]\n") for key in missing_keys: @@ -418,6 +441,14 @@ def up( env = os.environ.copy() env.pop("PATH", None) + import yaml + + with open(plano_config_file, "r") as f: + plano_config = yaml.safe_load(f) + + # Inject ChatGPT tokens from ~/.plano/chatgpt/auth.json if any provider needs them + _inject_chatgpt_tokens_if_needed(plano_config, env, console) + # Check access keys access_keys = get_llm_provider_access_keys(plano_config_file=plano_config_file) access_keys = set(access_keys) @@ -715,6 +746,7 @@ main.add_command(cli_agent) main.add_command(generate_prompt_targets) main.add_command(init_cmd, name="init") main.add_command(trace_cmd, name="trace") +main.add_command(chatgpt_cmd, name="chatgpt") main.add_command(obs_cmd, name="obs") if __name__ == "__main__": diff --git a/cli/planoai/native_runner.py b/cli/planoai/native_runner.py index bbbbfd3e..1b58b36d 100644 --- a/cli/planoai/native_runner.py +++ b/cli/planoai/native_runner.py @@ -253,6 +253,7 @@ def start_native( log.info("Plano is running (native mode)") for port in gateway_ports: log.info(f" http://localhost:{port}") + break # Check if processes are still alive @@ -367,8 +368,11 @@ def _kill_pid(pid): pass -def stop_native(): - """Stop natively-running Envoy and brightstaff processes. +def stop_native(skip_pids: set | None = None): + """Stop natively-running Envoy, brightstaff, and watchdog processes. + + Args: + skip_pids: Set of PIDs to skip (used by the watchdog to avoid self-termination). Returns: bool: True if at least one process was running and received a stop signal, @@ -385,7 +389,12 @@ def stop_native(): brightstaff_pid = pids.get("brightstaff_pid") had_running_process = False - for name, pid in [("envoy", envoy_pid), ("brightstaff", brightstaff_pid)]: + for name, pid in [ + ("envoy", envoy_pid), + ("brightstaff", brightstaff_pid), + ]: + if skip_pids and pid in skip_pids: + continue if pid is None: continue try: diff --git a/config/grafana/brightstaff_dashboard.json b/config/grafana/brightstaff_dashboard.json new file mode 100644 index 00000000..4b54721f --- /dev/null +++ b/config/grafana/brightstaff_dashboard.json @@ -0,0 +1,541 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "RED, LLM upstream, routing service, and process metrics for brightstaff. Pair with Envoy admin metrics from cluster=bright_staff.", + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 1, + "id": null, + "links": [], + "liveNow": false, + "panels": [ + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 }, + "id": 100, + "panels": [], + "title": "HTTP RED", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisLabel": "req/s", + "drawStyle": "line", + "fillOpacity": 10, + "lineWidth": 1, + "showPoints": "never" + }, + "unit": "reqps" + } + }, + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 1 }, + "id": 1, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (handler) (rate(brightstaff_http_requests_total[1m]))", + "legendFormat": "{{handler}}", + "refId": "A" + } + ], + "title": "Rate — brightstaff RPS by handler", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "5xx fraction over 5m. Page-worthy when sustained above ~1%.", + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "thresholds": { + "mode": "absolute", + "steps": [ + { "color": "green", "value": null }, + { "color": "yellow", "value": 0.01 }, + { "color": "red", "value": 0.05 } + ] + }, + "unit": "percentunit" + } + }, + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 1 }, + "id": 2, + "options": { + "colorMode": "background", + "graphMode": "area", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum(rate(brightstaff_http_requests_total{status_class=\"5xx\"}[5m])) / clamp_min(sum(rate(brightstaff_http_requests_total[5m])), 1)", + "legendFormat": "5xx rate", + "refId": "A" + } + ], + "title": "Errors — brightstaff 5xx rate", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "p50/p95/p99 by handler, computed from histogram buckets over 5m.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" }, + "unit": "s" + } + }, + "gridPos": { "h": 9, "w": 24, "x": 0, "y": 9 }, + "id": 3, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "histogram_quantile(0.50, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))", + "legendFormat": "p50 {{handler}}", + "refId": "A" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "histogram_quantile(0.95, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))", + "legendFormat": "p95 {{handler}}", + "refId": "B" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "histogram_quantile(0.99, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))", + "legendFormat": "p99 {{handler}}", + "refId": "C" + } + ], + "title": "Duration — p50 / p95 / p99 by handler", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "In-flight requests by handler. Climbs before latency does when brightstaff is saturated.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" }, + "unit": "short" + } + }, + "gridPos": { "h": 8, "w": 24, "x": 0, "y": 18 }, + "id": 4, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (handler) (brightstaff_http_in_flight_requests)", + "legendFormat": "{{handler}}", + "refId": "A" + } + ], + "title": "In-flight requests by handler", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 26 }, + "id": 200, + "panels": [], + "title": "LLM upstream", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" }, + "unit": "s" + } + }, + "gridPos": { "h": 9, "w": 12, "x": 0, "y": 27 }, + "id": 5, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "histogram_quantile(0.95, sum by (le, provider, model) (rate(brightstaff_llm_upstream_duration_seconds_bucket[5m])))", + "legendFormat": "p95 {{provider}}/{{model}}", + "refId": "A" + } + ], + "title": "LLM upstream p95 by provider/model", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "All non-success error classes. timeout/connect = network, 5xx/429 = provider, parse = body shape mismatch, stream = mid-stream disconnect.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } }, + "unit": "reqps" + } + }, + "gridPos": { "h": 9, "w": 12, "x": 12, "y": 27 }, + "id": 6, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (provider, error_class) (rate(brightstaff_llm_upstream_requests_total{error_class!=\"none\"}[5m]))", + "legendFormat": "{{provider}} / {{error_class}}", + "refId": "A" + } + ], + "title": "LLM upstream errors by provider / class", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "Streaming only. Empty if the route never streams.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" }, + "unit": "s" + } + }, + "gridPos": { "h": 9, "w": 12, "x": 0, "y": 36 }, + "id": 7, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "histogram_quantile(0.95, sum by (le, provider, model) (rate(brightstaff_llm_time_to_first_token_seconds_bucket[5m])))", + "legendFormat": "p95 {{provider}}/{{model}}", + "refId": "A" + } + ], + "title": "Time-to-first-token p95 (streaming)", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "Tokens/sec by provider/model/kind — proxy for cost. Stacked.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } }, + "unit": "tokens/s" + } + }, + "gridPos": { "h": 9, "w": 12, "x": 12, "y": 36 }, + "id": 8, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (provider, model, kind) (rate(brightstaff_llm_tokens_total[5m]))", + "legendFormat": "{{provider}}/{{model}} {{kind}}", + "refId": "A" + } + ], + "title": "Token throughput by provider / model / kind", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 45 }, + "id": 300, + "panels": [], + "title": "Routing service", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "Which models the orchestrator picked over the last 15 minutes.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "unit": "short" + } + }, + "gridPos": { "h": 9, "w": 12, "x": 0, "y": 46 }, + "id": 9, + "options": { + "displayMode": "gradient", + "orientation": "horizontal", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (selected_model) (increase(brightstaff_router_decisions_total[15m]))", + "legendFormat": "{{selected_model}}", + "refId": "A" + } + ], + "title": "Model selection distribution (last 15m)", + "type": "bargauge" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "Fraction of decisions that fell back (orchestrator returned `none` or errored). High = router can't classify intent or no candidates configured.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" }, + "unit": "percentunit" + } + }, + "gridPos": { "h": 9, "w": 12, "x": 12, "y": 46 }, + "id": 10, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (route) (rate(brightstaff_router_decisions_total{fallback=\"true\"}[5m])) / clamp_min(sum by (route) (rate(brightstaff_router_decisions_total[5m])), 1)", + "legendFormat": "{{route}}", + "refId": "A" + } + ], + "title": "Fallback rate by route", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" }, + "unit": "s" + } + }, + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 55 }, + "id": 11, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "histogram_quantile(0.95, sum by (le, route) (rate(brightstaff_router_decision_duration_seconds_bucket[5m])))", + "legendFormat": "p95 {{route}}", + "refId": "A" + } + ], + "title": "Router decision p95 latency", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "Hit / (hit + miss). Low ratio = sessions aren't being reused or TTL too short.", + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "thresholds": { + "mode": "absolute", + "steps": [ + { "color": "red", "value": null }, + { "color": "yellow", "value": 0.5 }, + { "color": "green", "value": 0.8 } + ] + }, + "unit": "percentunit", + "min": 0, + "max": 1 + } + }, + "gridPos": { "h": 8, "w": 6, "x": 12, "y": 55 }, + "id": 12, + "options": { + "colorMode": "background", + "graphMode": "area", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum(rate(brightstaff_session_cache_events_total{outcome=\"hit\"}[5m])) / clamp_min(sum(rate(brightstaff_session_cache_events_total{outcome=~\"hit|miss\"}[5m])), 1)", + "legendFormat": "hit rate", + "refId": "A" + } + ], + "title": "Session cache hit rate", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "decision_served = a real model picked. no_candidates = sentinel `none` returned. policy_error = orchestrator failed.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } }, + "unit": "reqps" + } + }, + "gridPos": { "h": 8, "w": 6, "x": 18, "y": 55 }, + "id": 13, + "options": { + "legend": { "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum by (outcome) (rate(brightstaff_routing_service_requests_total[5m]))", + "legendFormat": "{{outcome}}", + "refId": "A" + } + ], + "title": "/routing/* outcomes", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 63 }, + "id": 400, + "panels": [], + "title": "Process & Envoy link", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "description": "Compare to brightstaff RPS (panel 1) — sustained gap = network or Envoy queueing.", + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" }, + "unit": "reqps" + } + }, + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 64 }, + "id": 14, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum(rate(envoy_cluster_upstream_rq_total{envoy_cluster_name=\"bright_staff\"}[1m]))", + "legendFormat": "envoy → bright_staff", + "refId": "A" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "sum(rate(brightstaff_http_requests_total[1m]))", + "legendFormat": "brightstaff served", + "refId": "B" + } + ], + "title": "Envoy → brightstaff link health", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" } + }, + "overrides": [ + { + "matcher": { "id": "byName", "options": "RSS" }, + "properties": [{ "id": "unit", "value": "bytes" }] + }, + { + "matcher": { "id": "byName", "options": "CPU" }, + "properties": [{ "id": "unit", "value": "percentunit" }] + } + ] + }, + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 64 }, + "id": 15, + "options": { + "legend": { "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi" } + }, + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "process_resident_memory_bytes{job=\"brightstaff\"}", + "legendFormat": "RSS", + "refId": "A" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "expr": "rate(process_cpu_seconds_total{job=\"brightstaff\"}[1m])", + "legendFormat": "CPU", + "refId": "B" + } + ], + "title": "Brightstaff process RSS / CPU", + "type": "timeseries" + } + ], + "refresh": "30s", + "schemaVersion": 39, + "tags": ["plano", "brightstaff", "llm"], + "templating": { + "list": [ + { + "name": "DS_PROMETHEUS", + "label": "Prometheus", + "type": "datasource", + "query": "prometheus", + "current": { "selected": false, "text": "Prometheus", "value": "DS_PROMETHEUS" }, + "hide": 0, + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "includeAll": false, + "multi": false + } + ] + }, + "time": { "from": "now-1h", "to": "now" }, + "timepicker": {}, + "timezone": "browser", + "title": "Brightstaff (Plano dataplane)", + "uid": "brightstaff", + "version": 1, + "weekStart": "" +} diff --git a/config/grafana/docker-compose.yaml b/config/grafana/docker-compose.yaml new file mode 100644 index 00000000..33238073 --- /dev/null +++ b/config/grafana/docker-compose.yaml @@ -0,0 +1,43 @@ +# One-command Prometheus + Grafana stack for observing a locally-running +# Plano (Envoy admin :9901 + brightstaff :9092 on the host). +# +# cd config/grafana +# docker compose up -d +# open http://localhost:3000 (admin / admin) +# +# Grafana is preloaded with: +# - Prometheus datasource (uid=DS_PROMETHEUS) → http://prometheus:9090 +# - Brightstaff dashboard (auto-imported from brightstaff_dashboard.json) +# +# Prometheus scrapes the host's :9092 and :9901 via host.docker.internal. +# On Linux this works because of the `extra_hosts: host-gateway` mapping +# below. On Mac/Win it works natively. + +services: + prometheus: + image: prom/prometheus:latest + container_name: plano-prometheus + ports: + - "9090:9090" + volumes: + - ./prometheus_scrape.yaml:/etc/prometheus/prometheus.yml:ro + extra_hosts: + - "host.docker.internal:host-gateway" + restart: unless-stopped + + grafana: + image: grafana/grafana:latest + container_name: plano-grafana + ports: + - "3000:3000" + environment: + GF_SECURITY_ADMIN_USER: admin + GF_SECURITY_ADMIN_PASSWORD: admin + GF_AUTH_ANONYMOUS_ENABLED: "true" + GF_AUTH_ANONYMOUS_ORG_ROLE: Viewer + volumes: + - ./provisioning:/etc/grafana/provisioning:ro + - ./brightstaff_dashboard.json:/var/lib/grafana/dashboards/brightstaff_dashboard.json:ro + depends_on: + - prometheus + restart: unless-stopped diff --git a/config/grafana/prometheus_scrape.yaml b/config/grafana/prometheus_scrape.yaml new file mode 100644 index 00000000..b4041287 --- /dev/null +++ b/config/grafana/prometheus_scrape.yaml @@ -0,0 +1,44 @@ +# Prometheus config that scrapes Plano (Envoy admin + brightstaff). This is +# a complete Prometheus config — mount it directly at +# /etc/prometheus/prometheus.yml. The included docker-compose.yaml does this +# for you. +# +# Targets: +# - envoy:9901 Envoy admin → envoy_cluster_*, envoy_http_*, envoy_server_*. +# - brightstaff:9092 Native dataplane → brightstaff_http_*, brightstaff_llm_*, +# brightstaff_router_*, process_*. +# +# Hostname `host.docker.internal` works on Docker Desktop (Mac/Win) and on +# Linux when the container is started with `--add-host=host.docker.internal: +# host-gateway` (the included compose does this). If Plano runs *inside* +# Docker on the same network as Prometheus, replace it with the container +# name (e.g. `plano:9092`). +# +# This file is unrelated to demos/llm_routing/model_routing_service/prometheus.yaml, +# which scrapes a fake metrics service to feed the routing engine. + +global: + scrape_interval: 15s + scrape_timeout: 10s + evaluation_interval: 15s + +scrape_configs: + - job_name: envoy + honor_timestamps: true + metrics_path: /stats + params: + format: ["prometheus"] + static_configs: + - targets: + - host.docker.internal:9901 + labels: + service: plano + + - job_name: brightstaff + honor_timestamps: true + metrics_path: /metrics + static_configs: + - targets: + - host.docker.internal:9092 + labels: + service: plano diff --git a/config/grafana/provisioning/dashboards/brightstaff.yaml b/config/grafana/provisioning/dashboards/brightstaff.yaml new file mode 100644 index 00000000..271e4a9b --- /dev/null +++ b/config/grafana/provisioning/dashboards/brightstaff.yaml @@ -0,0 +1,15 @@ +# Auto-load the brightstaff dashboard JSON on Grafana startup. + +apiVersion: 1 + +providers: + - name: brightstaff + orgId: 1 + folder: Plano + type: file + disableDeletion: false + updateIntervalSeconds: 30 + allowUiUpdates: true + options: + path: /var/lib/grafana/dashboards + foldersFromFilesStructure: false diff --git a/config/grafana/provisioning/datasources/prometheus.yaml b/config/grafana/provisioning/datasources/prometheus.yaml new file mode 100644 index 00000000..2e3170ec --- /dev/null +++ b/config/grafana/provisioning/datasources/prometheus.yaml @@ -0,0 +1,14 @@ +# Auto-provision the Prometheus datasource so the bundled dashboard wires up +# without any clicks. The `uid: DS_PROMETHEUS` matches the templated input in +# brightstaff_dashboard.json. + +apiVersion: 1 + +datasources: + - name: Prometheus + uid: DS_PROMETHEUS + type: prometheus + access: proxy + url: http://prometheus:9090 + isDefault: true + editable: true diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index bdde05d4..2f9eea63 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -190,9 +190,15 @@ properties: - openai - xiaomi - gemini + - chatgpt - digitalocean - vercel - openrouter + headers: + type: object + additionalProperties: + type: string + description: "Additional headers to send with upstream requests (e.g., ChatGPT-Account-Id, originator)." routing_preferences: type: array items: @@ -241,9 +247,15 @@ properties: - openai - xiaomi - gemini + - chatgpt - digitalocean - vercel - openrouter + headers: + type: object + additionalProperties: + type: string + description: "Additional headers to send with upstream requests (e.g., ChatGPT-Account-Id, originator)." routing_preferences: type: array items: @@ -282,6 +294,9 @@ properties: type: boolean use_agent_orchestrator: type: boolean + disable_signals: + type: boolean + description: "Disable agentic signal analysis (frustration, repetition, escalation, etc.) on LLM responses to save CPU. Default false." upstream_connect_timeout: type: string description: "Connect timeout for upstream provider clusters (e.g., '5s', '10s'). Default is '5s'." diff --git a/crates/Cargo.lock b/crates/Cargo.lock index e07b47ee..39261d67 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -23,6 +23,18 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217" +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -257,6 +269,24 @@ dependencies = [ "vsimd", ] +[[package]] +name = "bindgen" +version = "0.72.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "proc-macro2", + "quote", + "regex", + "rustc-hash 2.1.2", + "shlex", + "syn 2.0.117", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -316,6 +346,9 @@ dependencies = [ "hyper 1.9.0", "hyper-util", "lru", + "metrics 0.23.1", + "metrics-exporter-prometheus", + "metrics-process", "mockito", "opentelemetry", "opentelemetry-http", @@ -325,6 +358,7 @@ dependencies = [ "pretty_assertions", "rand 0.9.4", "redis", + "regex", "reqwest", "serde", "serde_json", @@ -332,6 +366,8 @@ dependencies = [ "serde_yaml", "strsim", "thiserror 2.0.18", + "tikv-jemalloc-ctl", + "tikv-jemallocator", "time", "tokio", "tokio-postgres", @@ -391,6 +427,15 @@ dependencies = [ "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -428,6 +473,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "cmov" version = "0.5.3" @@ -574,6 +630,21 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.7" @@ -1070,6 +1141,12 @@ dependencies = [ "wasip3", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "governor" version = "0.6.3" @@ -1128,7 +1205,7 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25" dependencies = [ - "ahash", + "ahash 0.3.8", "autocfg", ] @@ -1138,6 +1215,15 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash 0.8.12", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -1189,6 +1275,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -1665,6 +1757,27 @@ version = "0.2.185" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libproc" +version = "0.14.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a54ad7278b8bc5301d5ffd2a94251c004feb971feba96c971ea4063645990757" +dependencies = [ + "bindgen", + "errno", + "libc", +] + [[package]] name = "libredox" version = "0.1.16" @@ -1745,6 +1858,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "mach2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dae608c151f68243f2b000364e1f7b186d9c29845f7d2d85bd31b9ad77ad552b" + [[package]] name = "matchers" version = "0.2.0" @@ -1782,6 +1901,77 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "metrics" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3045b4193fbdc5b5681f32f11070da9be3609f189a79f3390706d42587f46bb5" +dependencies = [ + "ahash 0.8.12", + "portable-atomic", +] + +[[package]] +name = "metrics" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8" +dependencies = [ + "ahash 0.8.12", + "portable-atomic", +] + +[[package]] +name = "metrics-exporter-prometheus" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" +dependencies = [ + "base64 0.22.1", + "http-body-util", + "hyper 1.9.0", + "hyper-util", + "indexmap 2.14.0", + "ipnet", + "metrics 0.23.1", + "metrics-util", + "quanta", + "thiserror 1.0.69", + "tokio", + "tracing", +] + +[[package]] +name = "metrics-process" +version = "2.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4268d87f64a752f5a651314fc683f04da10be65701ea3e721ba4d74f79163cac" +dependencies = [ + "libc", + "libproc", + "mach2", + "metrics 0.24.3", + "once_cell", + "procfs", + "rlimit", + "windows", +] + +[[package]] +name = "metrics-util" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", + "hashbrown 0.14.5", + "metrics 0.23.1", + "num_cpus", + "quanta", + "sketches-ddsketch", +] + [[package]] name = "mime" version = "0.3.17" @@ -1935,6 +2125,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "objc2-core-foundation" version = "0.3.2" @@ -2125,6 +2325,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2278,6 +2484,27 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "procfs" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25485360a54d6861439d60facef26de713b1e126bf015ec8f98239467a2b82f7" +dependencies = [ + "bitflags", + "procfs-core", + "rustix", +] + +[[package]] +name = "procfs-core" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6401bf7b6af22f78b563665d15a22e9aef27775b79b149a66ca022468a4e405" +dependencies = [ + "bitflags", + "hex", +] + [[package]] name = "prompt_gateway" version = "0.1.0" @@ -2333,6 +2560,21 @@ dependencies = [ "log", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.1+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2485,6 +2727,15 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "redis" version = "0.27.6" @@ -2646,6 +2897,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rlimit" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f35ee2729c56bb610f6dba436bf78135f728b7373bdffae2ec815b2d3eb98cc3" +dependencies = [ + "libc", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -3098,6 +3358,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +[[package]] +name = "sketches-ddsketch" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" + [[package]] name = "slab" version = "0.4.12" @@ -3308,6 +3574,37 @@ dependencies = [ "rustc-hash 1.1.0", ] +[[package]] +name = "tikv-jemalloc-ctl" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661f1f6a57b3a36dc9174a2c10f19513b4866816e13425d3e418b11cc37bc24c" +dependencies = [ + "libc", + "paste", + "tikv-jemalloc-sys", +] + +[[package]] +name = "tikv-jemalloc-sys" +version = "0.6.1+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8aa5b2ab86a2cefa406d889139c162cbb230092f7d1d7cbc1716405d852a3b" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tikv-jemallocator" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0359b4327f954e0567e69fb191cf1436617748813819c94b8cd4a431422d053a" +dependencies = [ + "libc", + "tikv-jemalloc-sys", +] + [[package]] name = "time" version = "0.3.47" @@ -4003,6 +4300,49 @@ dependencies = [ "web-sys", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" +dependencies = [ + "windows-core", +] + [[package]] name = "windows-core" version = "0.62.2" @@ -4016,6 +4356,17 @@ dependencies = [ "windows-strings", ] +[[package]] +name = "windows-future" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.2" @@ -4044,6 +4395,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-numerics" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" +dependencies = [ + "windows-core", + "windows-link", +] + [[package]] name = "windows-registry" version = "0.6.1" @@ -4133,6 +4494,15 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows-threading" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index f88ed918..d2635963 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -3,6 +3,18 @@ name = "brightstaff" version = "0.1.0" edition = "2021" +[features] +default = ["jemalloc"] +jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"] + +[[bin]] +name = "brightstaff" +path = "src/main.rs" + +[[bin]] +name = "signals_replay" +path = "src/bin/signals_replay.rs" + [dependencies] async-openai = "0.30.1" async-trait = "0.1" @@ -26,7 +38,11 @@ opentelemetry-stdout = "0.31" opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] } pretty_assertions = "1.4.1" rand = "0.9.2" +regex = "1.10" lru = "0.12" +metrics = "0.23" +metrics-exporter-prometheus = { version = "0.15", default-features = false, features = ["http-listener"] } +metrics-process = "2.1" redis = { version = "0.27", features = ["tokio-comp"] } reqwest = { version = "0.12.15", features = ["stream"] } serde = { version = "1.0.219", features = ["derive"] } @@ -35,6 +51,8 @@ serde_with = "3.13.0" strsim = "0.11" serde_yaml = "0.9.34" thiserror = "2.0.12" +tikv-jemallocator = { version = "0.6", optional = true } +tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } tokio = { version = "1.44.2", features = ["full"] } tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] } tokio-stream = "0.1" diff --git a/crates/brightstaff/src/app_state.rs b/crates/brightstaff/src/app_state.rs index e585d2db..1d534e89 100644 --- a/crates/brightstaff/src/app_state.rs +++ b/crates/brightstaff/src/app_state.rs @@ -24,4 +24,7 @@ pub struct AppState { /// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive). pub http_client: reqwest::Client, pub filter_pipeline: Arc, + /// When false, agentic signal analysis is skipped on LLM responses to save CPU. + /// Controlled by `overrides.disable_signals` in plano config. + pub signals_enabled: bool, } diff --git a/crates/brightstaff/src/bin/signals_replay.rs b/crates/brightstaff/src/bin/signals_replay.rs new file mode 100644 index 00000000..41879ac1 --- /dev/null +++ b/crates/brightstaff/src/bin/signals_replay.rs @@ -0,0 +1,175 @@ +//! `signals-replay` — batch driver for the `brightstaff` signal analyzer. +//! +//! Reads JSONL conversations from stdin (one per line) and emits matching +//! JSONL reports on stdout, one per input conversation, in the same order. +//! +//! Input shape (per line): +//! ```json +//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]} +//! ``` +//! +//! Output shape (per line, success): +//! ```json +//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }} +//! ``` +//! +//! On per-line failure (parse / analyzer error), emits: +//! ```json +//! {"id": "convo-42", "error": "..."} +//! ``` +//! +//! The output report dict is shaped to match the Python reference's +//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a +//! direct structural diff. + +use std::io::{self, BufRead, BufWriter, Write}; + +use serde::Deserialize; +use serde_json::{json, Map, Value}; + +use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport}; + +#[derive(Debug, Deserialize)] +struct InputLine { + id: Value, + messages: Vec, +} + +#[derive(Debug, Deserialize)] +struct MessageRow { + #[serde(default)] + from: String, + #[serde(default)] + value: String, +} + +fn main() { + let stdin = io::stdin(); + let stdout = io::stdout(); + let mut out = BufWriter::new(stdout.lock()); + let analyzer = SignalAnalyzer::default(); + + for line in stdin.lock().lines() { + let line = match line { + Ok(l) => l, + Err(e) => { + eprintln!("read error: {e}"); + std::process::exit(1); + } + }; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let result = process_line(&analyzer, trimmed); + // Always emit one line per input line so id ordering stays aligned. + if let Err(e) = writeln!(out, "{result}") { + eprintln!("write error: {e}"); + std::process::exit(1); + } + // Flush periodically isn't strictly needed — BufWriter handles it, + // and the parent process reads the whole stream when we're done. + } + let _ = out.flush(); +} + +fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value { + let parsed: InputLine = match serde_json::from_str(line) { + Ok(p) => p, + Err(e) => { + return json!({ + "id": Value::Null, + "error": format!("input parse: {e}"), + }); + } + }; + + let id = parsed.id.clone(); + + let view: Vec> = parsed + .messages + .iter() + .map(|m| brightstaff::signals::analyzer::ShareGptMessage { + from: m.from.as_str(), + value: m.value.as_str(), + }) + .collect(); + + let report = analyzer.analyze_sharegpt(&view); + let report_dict = report_to_python_dict(&report); + json!({ + "id": id, + "report": report_dict, + }) +} + +/// Convert a `SignalReport` into the Python reference's `to_dict()` shape. +/// +/// Ordering of category keys in each layer dict follows the Python source +/// exactly so even string-equality comparisons behave deterministically. +fn report_to_python_dict(r: &SignalReport) -> Value { + let mut interaction = Map::new(); + interaction.insert( + "misalignment".to_string(), + signal_group_to_python(&r.interaction.misalignment), + ); + interaction.insert( + "stagnation".to_string(), + signal_group_to_python(&r.interaction.stagnation), + ); + interaction.insert( + "disengagement".to_string(), + signal_group_to_python(&r.interaction.disengagement), + ); + interaction.insert( + "satisfaction".to_string(), + signal_group_to_python(&r.interaction.satisfaction), + ); + + let mut execution = Map::new(); + execution.insert( + "failure".to_string(), + signal_group_to_python(&r.execution.failure), + ); + execution.insert( + "loops".to_string(), + signal_group_to_python(&r.execution.loops), + ); + + let mut environment = Map::new(); + environment.insert( + "exhaustion".to_string(), + signal_group_to_python(&r.environment.exhaustion), + ); + + json!({ + "interaction_signals": Value::Object(interaction), + "execution_signals": Value::Object(execution), + "environment_signals": Value::Object(environment), + "overall_quality": r.overall_quality.as_str(), + "summary": r.summary, + }) +} + +fn signal_group_to_python(g: &SignalGroup) -> Value { + let signals: Vec = g + .signals + .iter() + .map(|s| { + json!({ + "signal_type": s.signal_type.as_str(), + "message_index": s.message_index, + "snippet": s.snippet, + "confidence": s.confidence, + "metadata": s.metadata, + }) + }) + .collect(); + + json!({ + "category": g.category, + "count": g.count, + "severity": g.severity, + "signals": signals, + }) +} diff --git a/crates/brightstaff/src/handlers/debug.rs b/crates/brightstaff/src/handlers/debug.rs new file mode 100644 index 00000000..58fbecd2 --- /dev/null +++ b/crates/brightstaff/src/handlers/debug.rs @@ -0,0 +1,53 @@ +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use hyper::{Response, StatusCode}; + +use super::full; + +#[derive(serde::Serialize)] +struct MemStats { + allocated_bytes: usize, + resident_bytes: usize, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +/// Returns jemalloc memory statistics as JSON. +/// Falls back to a stub when the jemalloc feature is disabled. +pub async fn memstats() -> Result>, hyper::Error> { + let stats = get_jemalloc_stats(); + let json = serde_json::to_string(&stats).unwrap(); + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(full(json)) + .unwrap()) +} + +#[cfg(feature = "jemalloc")] +fn get_jemalloc_stats() -> MemStats { + use tikv_jemalloc_ctl::{epoch, stats}; + + if let Err(e) = epoch::advance() { + return MemStats { + allocated_bytes: 0, + resident_bytes: 0, + error: Some(format!("failed to advance jemalloc epoch: {e}")), + }; + } + + MemStats { + allocated_bytes: stats::allocated::read().unwrap_or(0), + resident_bytes: stats::resident::read().unwrap_or(0), + error: None, + } +} + +#[cfg(not(feature = "jemalloc"))] +fn get_jemalloc_stats() -> MemStats { + MemStats { + allocated_bytes: 0, + resident_bytes: 0, + error: Some("jemalloc feature not enabled".to_string()), + } +} diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index ca4def32..3e2543bc 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -441,10 +441,8 @@ impl ArchFunctionHandler { } } // Handle str/string conversions - "str" | "string" => { - if !value.is_string() { - return Ok(json!(value.to_string())); - } + "str" | "string" if !value.is_string() => { + return Ok(json!(value.to_string())); } _ => {} } diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index 719c048d..3336209f 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -24,13 +24,14 @@ use crate::app_state::AppState; use crate::handlers::agents::pipeline::PipelineProcessor; use crate::handlers::extract_request_id; use crate::handlers::full; +use crate::metrics as bs_metrics; use crate::state::response_state_processor::ResponsesStateProcessor; use crate::state::{ extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError, }; use crate::streaming::{ create_streaming_response, create_streaming_response_with_output_filter, truncate_message, - ObservableStreamProcessor, StreamProcessor, + LlmMetricsCtx, ObservableStreamProcessor, StreamProcessor, }; use crate::tracing::{ collect_custom_trace_attributes, llm as tracing_llm, operation_component, @@ -142,6 +143,7 @@ async fn llm_chat_inner( &request_path, &state.model_aliases, &state.llm_providers, + state.signals_enabled, ) .await { @@ -253,7 +255,15 @@ async fn llm_chat_inner( if let Some(ref client_api_kind) = client_api { let upstream_api = provider_id.compatible_api_for_client(client_api_kind, is_streaming_request); - client_request.normalize_for_upstream(provider_id, &upstream_api); + if let Err(e) = client_request.normalize_for_upstream(provider_id, &upstream_api) { + warn!( + "request_id={}: normalize_for_upstream failed: {}", + request_id, e + ); + let mut bad_request = Response::new(full(e.message)); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } } // --- Phase 2: Resolve conversation state (v1/responses API) --- @@ -407,6 +417,7 @@ async fn parse_and_validate_request( request_path: &str, model_aliases: &Option>, llm_providers: &Arc>, + signals_enabled: bool, ) -> Result>> { let raw_bytes = request .collect() @@ -485,7 +496,11 @@ async fn parse_and_validate_request( let user_message_preview = client_request .get_recent_user_message() .map(|msg| truncate_message(&msg, 50)); - let messages_for_signals = Some(client_request.get_messages()); + let messages_for_signals = if signals_enabled { + Some(client_request.get_messages()) + } else { + None + }; // Set the upstream model name and strip routing metadata client_request.set_model(model_name_only.clone()); @@ -686,6 +701,13 @@ async fn send_upstream( let request_start_time = std::time::Instant::now(); + // Labels for LLM upstream metrics. We prefer `resolved_model` (post-routing) + // and derive the provider from its `provider/model` prefix. This matches the + // same model id the cost/latency router keys off. + let (metric_provider_raw, metric_model_raw) = bs_metrics::split_provider_model(resolved_model); + let metric_provider = metric_provider_raw.to_string(); + let metric_model = metric_model_raw.to_string(); + let llm_response = match http_client .post(upstream_url) .headers(request_headers.clone()) @@ -695,6 +717,14 @@ async fn send_upstream( { Ok(res) => res, Err(err) => { + let err_class = bs_metrics::llm_error_class_from_reqwest(&err); + bs_metrics::record_llm_upstream( + &metric_provider, + &metric_model, + 0, + err_class, + request_start_time.elapsed(), + ); let err_msg = format!("Failed to send request: {}", err); let mut internal_error = Response::new(full(err_msg)); *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; @@ -750,7 +780,12 @@ async fn send_upstream( span_name, request_start_time, messages_for_signals, - ); + ) + .with_llm_metrics(LlmMetricsCtx { + provider: metric_provider.clone(), + model: metric_model.clone(), + upstream_status: upstream_status.as_u16(), + }); let output_filter_request_headers = if filter_pipeline.has_output_filters() { Some(request_headers.clone()) diff --git a/crates/brightstaff/src/handlers/llm/model_selection.rs b/crates/brightstaff/src/handlers/llm/model_selection.rs index 1b4315e7..a1378d86 100644 --- a/crates/brightstaff/src/handlers/llm/model_selection.rs +++ b/crates/brightstaff/src/handlers/llm/model_selection.rs @@ -5,10 +5,24 @@ use hyper::StatusCode; use std::sync::Arc; use tracing::{debug, info, warn}; +use crate::metrics as bs_metrics; +use crate::metrics::labels as metric_labels; use crate::router::orchestrator::OrchestratorService; use crate::streaming::truncate_message; use crate::tracing::routing; +/// Classify a request path (already stripped of `/agents` or `/routing` by +/// the caller) into the fixed `route` label used on routing metrics. +fn route_label_for_path(request_path: &str) -> &'static str { + if request_path.starts_with("/agents") { + metric_labels::ROUTE_AGENT + } else if request_path.starts_with("/routing") { + metric_labels::ROUTE_ROUTING + } else { + metric_labels::ROUTE_LLM + } +} + pub struct RoutingResult { /// Primary model to use (first in the ranked list). pub model_name: String, @@ -106,15 +120,23 @@ pub async fn router_chat_get_upstream_model( ) .await; - let determination_ms = routing_start_time.elapsed().as_millis() as i64; + let determination_elapsed = routing_start_time.elapsed(); + let determination_ms = determination_elapsed.as_millis() as i64; let current_span = tracing::Span::current(); current_span.record(routing::ROUTE_DETERMINATION_MS, determination_ms); + let route_label = route_label_for_path(request_path); match routing_result { Ok(route) => match route { Some((route_name, ranked_models)) => { let model_name = ranked_models.first().cloned().unwrap_or_default(); current_span.record("route.selected_model", model_name.as_str()); + bs_metrics::record_router_decision( + route_label, + &model_name, + false, + determination_elapsed, + ); Ok(RoutingResult { model_name, models: ranked_models, @@ -126,6 +148,12 @@ pub async fn router_chat_get_upstream_model( // This signals to llm.rs to use the original validated request model current_span.record("route.selected_model", "none"); info!("no route determined, using default model"); + bs_metrics::record_router_decision( + route_label, + "none", + true, + determination_elapsed, + ); Ok(RoutingResult { model_name: "none".to_string(), @@ -136,6 +164,7 @@ pub async fn router_chat_get_upstream_model( }, Err(err) => { current_span.record("route.selected_model", "unknown"); + bs_metrics::record_router_decision(route_label, "unknown", true, determination_elapsed); Err(RoutingError::internal_error(format!( "Failed to determine route: {}", err diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 485a0438..4e851264 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,4 +1,5 @@ pub mod agents; +pub mod debug; pub mod function_calling; pub mod llm; pub mod models; diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index 5fc0d3b9..b93b1422 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -12,6 +12,8 @@ use tracing::{debug, info, info_span, warn, Instrument}; use super::extract_or_generate_traceparent; use crate::handlers::llm::model_selection::router_chat_get_upstream_model; +use crate::metrics as bs_metrics; +use crate::metrics::labels as metric_labels; use crate::router::orchestrator::OrchestratorService; use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name}; @@ -230,6 +232,17 @@ async fn routing_decision_inner( pinned: false, }; + // Distinguish "decision served" (a concrete model picked) from + // "no_candidates" (the sentinel "none" returned when nothing + // matched). The handler still responds 200 in both cases, so RED + // metrics alone can't tell them apart. + let outcome = if response.models.first().map(|m| m == "none").unwrap_or(true) { + metric_labels::ROUTING_SVC_NO_CANDIDATES + } else { + metric_labels::ROUTING_SVC_DECISION_SERVED + }; + bs_metrics::record_routing_service_outcome(outcome); + info!( primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"), total_models = response.models.len(), @@ -249,6 +262,7 @@ async fn routing_decision_inner( .unwrap()) } Err(err) => { + bs_metrics::record_routing_service_outcome(metric_labels::ROUTING_SVC_POLICY_ERROR); warn!(error = %err.message, "routing decision failed"); Ok(BrightStaffError::InternalServerError(err.message).into_response()) } diff --git a/crates/brightstaff/src/lib.rs b/crates/brightstaff/src/lib.rs index a0ba5f43..66c6eadf 100644 --- a/crates/brightstaff/src/lib.rs +++ b/crates/brightstaff/src/lib.rs @@ -1,5 +1,6 @@ pub mod app_state; pub mod handlers; +pub mod metrics; pub mod router; pub mod session_cache; pub mod signals; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 40ac429d..b1e17e42 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,10 +1,17 @@ +#[cfg(feature = "jemalloc")] +#[global_allocator] +static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + use brightstaff::app_state::AppState; use brightstaff::handlers::agents::orchestrator::agent_chat; +use brightstaff::handlers::debug; use brightstaff::handlers::empty; use brightstaff::handlers::function_calling::function_calling_chat_handler; use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::handlers::routing_service::routing_decision; +use brightstaff::metrics as bs_metrics; +use brightstaff::metrics::labels as metric_labels; use brightstaff::router::model_metrics::ModelMetricsService; use brightstaff::router::orchestrator::OrchestratorService; use brightstaff::session_cache::init_session_cache; @@ -326,6 +333,8 @@ async fn init_app_state( .as_ref() .and_then(|tracing| tracing.span_attributes.clone()); + let signals_enabled = !overrides.disable_signals.unwrap_or(false); + Ok(AppState { orchestrator_service, model_aliases: config.model_aliases.clone(), @@ -337,6 +346,7 @@ async fn init_app_state( span_attributes, http_client: reqwest::Client::new(), filter_pipeline, + signals_enabled, }) } @@ -384,10 +394,79 @@ async fn init_state_storage( // Request routing // --------------------------------------------------------------------------- +/// Normalized method label — limited set so we never emit a free-form string. +fn method_label(method: &Method) -> &'static str { + match *method { + Method::GET => "GET", + Method::POST => "POST", + Method::PUT => "PUT", + Method::DELETE => "DELETE", + Method::PATCH => "PATCH", + Method::HEAD => "HEAD", + Method::OPTIONS => "OPTIONS", + _ => "OTHER", + } +} + +/// Compute the fixed `handler` metric label from the request's path+method. +/// Returning `None` for fall-through means `route()` will hand the request to +/// the catch-all 404 branch. +fn handler_label_for(method: &Method, path: &str) -> &'static str { + if let Some(stripped) = path.strip_prefix("/agents") { + if matches!( + stripped, + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH + ) { + return metric_labels::HANDLER_AGENT_CHAT; + } + } + if let Some(stripped) = path.strip_prefix("/routing") { + if matches!( + stripped, + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH + ) { + return metric_labels::HANDLER_ROUTING_DECISION; + } + } + match (method, path) { + (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { + metric_labels::HANDLER_LLM_CHAT + } + (&Method::POST, "/function_calling") => metric_labels::HANDLER_FUNCTION_CALLING, + (&Method::GET, "/v1/models" | "/agents/v1/models") => metric_labels::HANDLER_LIST_MODELS, + (&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => { + metric_labels::HANDLER_CORS_PREFLIGHT + } + _ => metric_labels::HANDLER_NOT_FOUND, + } +} + /// Route an incoming HTTP request to the appropriate handler. async fn route( req: Request, state: Arc, +) -> Result>, hyper::Error> { + let handler = handler_label_for(req.method(), req.uri().path()); + let method = method_label(req.method()); + let started = std::time::Instant::now(); + let _in_flight = bs_metrics::InFlightGuard::new(handler); + + let result = dispatch(req, state).await; + + let status = match &result { + Ok(resp) => resp.status().as_u16(), + // hyper::Error here means the body couldn't be produced; conventionally 500. + Err(_) => 500, + }; + bs_metrics::record_http(handler, method, status, started); + result +} + +/// Inner dispatcher split out so `route()` can wrap it with metrics without +/// duplicating the match tree. +async fn dispatch( + req: Request, + state: Arc, ) -> Result>, hyper::Error> { let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers()))); let path = req.uri().path().to_string(); @@ -439,6 +518,7 @@ async fn route( Ok(list_models(Arc::clone(&state.llm_providers)).await) } (&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(), + (&Method::GET, "/debug/memstats") => debug::memstats().await, _ => { debug!(method = %req.method(), path = %path, "no route found"); let mut not_found = Response::new(empty()); @@ -503,6 +583,7 @@ async fn run_server(state: Arc) -> Result<(), Box Result<(), Box> { let config = load_config()?; let _tracer_provider = init_tracer(config.tracing.as_ref()); + bs_metrics::init(); info!("loaded plano_config.yaml"); let state = Arc::new(init_app_state(&config).await?); run_server(state).await diff --git a/crates/brightstaff/src/metrics/labels.rs b/crates/brightstaff/src/metrics/labels.rs new file mode 100644 index 00000000..4eaf3e59 --- /dev/null +++ b/crates/brightstaff/src/metrics/labels.rs @@ -0,0 +1,38 @@ +//! Fixed label-value constants so callers never emit free-form strings +//! (which would blow up cardinality). + +// Handler enum — derived from the path+method match in `route()`. +pub const HANDLER_AGENT_CHAT: &str = "agent_chat"; +pub const HANDLER_ROUTING_DECISION: &str = "routing_decision"; +pub const HANDLER_LLM_CHAT: &str = "llm_chat"; +pub const HANDLER_FUNCTION_CALLING: &str = "function_calling"; +pub const HANDLER_LIST_MODELS: &str = "list_models"; +pub const HANDLER_CORS_PREFLIGHT: &str = "cors_preflight"; +pub const HANDLER_NOT_FOUND: &str = "not_found"; + +// Router "route" class — which brightstaff endpoint prompted the decision. +pub const ROUTE_AGENT: &str = "agent"; +pub const ROUTE_ROUTING: &str = "routing"; +pub const ROUTE_LLM: &str = "llm"; + +// Token kind for brightstaff_llm_tokens_total. +pub const TOKEN_KIND_PROMPT: &str = "prompt"; +pub const TOKEN_KIND_COMPLETION: &str = "completion"; + +// LLM error_class values (match docstring in metrics/mod.rs). +pub const LLM_ERR_NONE: &str = "none"; +pub const LLM_ERR_TIMEOUT: &str = "timeout"; +pub const LLM_ERR_CONNECT: &str = "connect"; +pub const LLM_ERR_PARSE: &str = "parse"; +pub const LLM_ERR_OTHER: &str = "other"; +pub const LLM_ERR_STREAM: &str = "stream"; + +// Routing service outcome values. +pub const ROUTING_SVC_DECISION_SERVED: &str = "decision_served"; +pub const ROUTING_SVC_NO_CANDIDATES: &str = "no_candidates"; +pub const ROUTING_SVC_POLICY_ERROR: &str = "policy_error"; + +// Session cache outcome values. +pub const SESSION_CACHE_HIT: &str = "hit"; +pub const SESSION_CACHE_MISS: &str = "miss"; +pub const SESSION_CACHE_STORE: &str = "store"; diff --git a/crates/brightstaff/src/metrics/mod.rs b/crates/brightstaff/src/metrics/mod.rs new file mode 100644 index 00000000..34679cca --- /dev/null +++ b/crates/brightstaff/src/metrics/mod.rs @@ -0,0 +1,377 @@ +//! Prometheus metrics for brightstaff. +//! +//! Installs the `metrics` global recorder backed by +//! `metrics-exporter-prometheus` and exposes a `/metrics` HTTP endpoint on a +//! dedicated admin port (default `0.0.0.0:9092`, overridable via +//! `METRICS_BIND_ADDRESS`). +//! +//! Emitted metric families (see `describe_all` for full list): +//! - HTTP RED: `brightstaff_http_requests_total`, +//! `brightstaff_http_request_duration_seconds`, +//! `brightstaff_http_in_flight_requests`. +//! - LLM upstream: `brightstaff_llm_upstream_requests_total`, +//! `brightstaff_llm_upstream_duration_seconds`, +//! `brightstaff_llm_time_to_first_token_seconds`, +//! `brightstaff_llm_tokens_total`, +//! `brightstaff_llm_tokens_usage_missing_total`. +//! - Routing: `brightstaff_router_decisions_total`, +//! `brightstaff_router_decision_duration_seconds`, +//! `brightstaff_routing_service_requests_total`, +//! `brightstaff_session_cache_events_total`. +//! - Process: via `metrics-process`. +//! - Build: `brightstaff_build_info`. + +use std::net::SocketAddr; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; + +use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder}; +use tracing::{info, warn}; + +pub mod labels; + +/// Guard flag so tests don't re-install the global recorder. +static INIT: OnceLock<()> = OnceLock::new(); + +const DEFAULT_METRICS_BIND: &str = "0.0.0.0:9092"; + +/// HTTP request duration buckets (seconds). Capped at 60s. +const HTTP_BUCKETS: &[f64] = &[ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, +]; + +/// LLM upstream / TTFT buckets (seconds). Capped at 120s because provider +/// completions routinely run that long. +const LLM_BUCKETS: &[f64] = &[0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 120.0]; + +/// Router decision buckets (seconds). The orchestrator call itself is usually +/// sub-second but bucketed generously in case of upstream slowness. +const ROUTER_BUCKETS: &[f64] = &[ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, +]; + +/// Install the global recorder and spawn the `/metrics` HTTP listener. +/// +/// Safe to call more than once; subsequent calls are no-ops so tests that +/// construct their own recorder still work. +pub fn init() { + if INIT.get().is_some() { + return; + } + + let bind: SocketAddr = std::env::var("METRICS_BIND_ADDRESS") + .unwrap_or_else(|_| DEFAULT_METRICS_BIND.to_string()) + .parse() + .unwrap_or_else(|err| { + warn!(error = %err, default = DEFAULT_METRICS_BIND, "invalid METRICS_BIND_ADDRESS, falling back to default"); + DEFAULT_METRICS_BIND.parse().expect("default bind parses") + }); + + let builder = PrometheusBuilder::new() + .with_http_listener(bind) + .set_buckets_for_metric( + Matcher::Full("brightstaff_http_request_duration_seconds".to_string()), + HTTP_BUCKETS, + ) + .and_then(|b| { + b.set_buckets_for_metric(Matcher::Prefix("brightstaff_llm_".to_string()), LLM_BUCKETS) + }) + .and_then(|b| { + b.set_buckets_for_metric( + Matcher::Full("brightstaff_router_decision_duration_seconds".to_string()), + ROUTER_BUCKETS, + ) + }); + + let builder = match builder { + Ok(b) => b, + Err(err) => { + warn!(error = %err, "failed to configure metrics buckets, using defaults"); + PrometheusBuilder::new().with_http_listener(bind) + } + }; + + if let Err(err) = builder.install() { + warn!(error = %err, "failed to install Prometheus recorder; metrics disabled"); + return; + } + + let _ = INIT.set(()); + + describe_all(); + emit_build_info(); + + // Register process-level collector (RSS, CPU, FDs). + let collector = metrics_process::Collector::default(); + collector.describe(); + // Prime once at startup; subsequent scrapes refresh via the exporter's + // per-scrape render, so we additionally refresh on a short interval to + // keep gauges moving between scrapes without requiring client pull. + collector.collect(); + tokio::spawn(async move { + let mut tick = tokio::time::interval(Duration::from_secs(10)); + tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + tick.tick().await; + collector.collect(); + } + }); + + info!(address = %bind, "metrics listener started"); +} + +fn describe_all() { + describe_counter!( + "brightstaff_http_requests_total", + "Total HTTP requests served by brightstaff, by handler and status class." + ); + describe_histogram!( + "brightstaff_http_request_duration_seconds", + "Wall-clock duration of HTTP requests served by brightstaff, by handler." + ); + describe_gauge!( + "brightstaff_http_in_flight_requests", + "Number of HTTP requests currently being served by brightstaff, by handler." + ); + + describe_counter!( + "brightstaff_llm_upstream_requests_total", + "LLM upstream request outcomes, by provider, model, status class and error class." + ); + describe_histogram!( + "brightstaff_llm_upstream_duration_seconds", + "Wall-clock duration of LLM upstream calls (stream close for streaming), by provider and model." + ); + describe_histogram!( + "brightstaff_llm_time_to_first_token_seconds", + "Time from request start to first streamed byte, by provider and model (streaming only)." + ); + describe_counter!( + "brightstaff_llm_tokens_total", + "Tokens reported in the provider `usage` field, by provider, model and kind (prompt/completion)." + ); + describe_counter!( + "brightstaff_llm_tokens_usage_missing_total", + "LLM responses that completed without a usable `usage` block (so token counts are unknown)." + ); + + describe_counter!( + "brightstaff_router_decisions_total", + "Routing decisions made by the orchestrator, by route, selected model, and whether a fallback was used." + ); + describe_histogram!( + "brightstaff_router_decision_duration_seconds", + "Time spent in the orchestrator deciding a route, by route." + ); + describe_counter!( + "brightstaff_routing_service_requests_total", + "Outcomes of /routing/* decision requests: decision_served, no_candidates, policy_error." + ); + describe_counter!( + "brightstaff_session_cache_events_total", + "Session affinity cache lookups and stores, by outcome." + ); + + describe_gauge!( + "brightstaff_build_info", + "Build metadata. Always 1; labels carry version and git SHA." + ); +} + +fn emit_build_info() { + let version = env!("CARGO_PKG_VERSION"); + let git_sha = option_env!("GIT_SHA").unwrap_or("unknown"); + gauge!( + "brightstaff_build_info", + "version" => version.to_string(), + "git_sha" => git_sha.to_string(), + ) + .set(1.0); +} + +/// Split a provider-qualified model id like `"openai/gpt-4o"` into +/// `(provider, model)`. Returns `("unknown", raw)` when there is no `/`. +pub fn split_provider_model(full: &str) -> (&str, &str) { + match full.split_once('/') { + Some((p, m)) => (p, m), + None => ("unknown", full), + } +} + +/// Bucket an HTTP status code into `"2xx"` / `"4xx"` / `"5xx"` / `"1xx"` / `"3xx"`. +pub fn status_class(status: u16) -> &'static str { + match status { + 100..=199 => "1xx", + 200..=299 => "2xx", + 300..=399 => "3xx", + 400..=499 => "4xx", + 500..=599 => "5xx", + _ => "other", + } +} + +// --------------------------------------------------------------------------- +// HTTP RED helpers +// --------------------------------------------------------------------------- + +/// RAII guard that increments the in-flight gauge on construction and +/// decrements on drop. Pair with [`HttpTimer`] in the `route()` wrapper so the +/// gauge drops even on error paths. +pub struct InFlightGuard { + handler: &'static str, +} + +impl InFlightGuard { + pub fn new(handler: &'static str) -> Self { + gauge!( + "brightstaff_http_in_flight_requests", + "handler" => handler, + ) + .increment(1.0); + Self { handler } + } +} + +impl Drop for InFlightGuard { + fn drop(&mut self) { + gauge!( + "brightstaff_http_in_flight_requests", + "handler" => self.handler, + ) + .decrement(1.0); + } +} + +/// Record the HTTP request counter + duration histogram. +pub fn record_http(handler: &'static str, method: &'static str, status: u16, started: Instant) { + let class = status_class(status); + counter!( + "brightstaff_http_requests_total", + "handler" => handler, + "method" => method, + "status_class" => class, + ) + .increment(1); + histogram!( + "brightstaff_http_request_duration_seconds", + "handler" => handler, + ) + .record(started.elapsed().as_secs_f64()); +} + +// --------------------------------------------------------------------------- +// LLM upstream helpers +// --------------------------------------------------------------------------- + +/// Classify an outcome of an LLM upstream call for the `error_class` label. +pub fn llm_error_class_from_reqwest(err: &reqwest::Error) -> &'static str { + if err.is_timeout() { + "timeout" + } else if err.is_connect() { + "connect" + } else if err.is_decode() { + "parse" + } else { + "other" + } +} + +/// Record the outcome of an LLM upstream call. `status` is the HTTP status +/// the upstream returned (0 if the call never produced one, e.g. send failure). +/// `error_class` is `"none"` on success, or a discriminated error label. +pub fn record_llm_upstream( + provider: &str, + model: &str, + status: u16, + error_class: &str, + duration: Duration, +) { + let class = if status == 0 { + "error" + } else { + status_class(status) + }; + counter!( + "brightstaff_llm_upstream_requests_total", + "provider" => provider.to_string(), + "model" => model.to_string(), + "status_class" => class, + "error_class" => error_class.to_string(), + ) + .increment(1); + histogram!( + "brightstaff_llm_upstream_duration_seconds", + "provider" => provider.to_string(), + "model" => model.to_string(), + ) + .record(duration.as_secs_f64()); +} + +pub fn record_llm_ttft(provider: &str, model: &str, ttft: Duration) { + histogram!( + "brightstaff_llm_time_to_first_token_seconds", + "provider" => provider.to_string(), + "model" => model.to_string(), + ) + .record(ttft.as_secs_f64()); +} + +pub fn record_llm_tokens(provider: &str, model: &str, kind: &'static str, count: u64) { + counter!( + "brightstaff_llm_tokens_total", + "provider" => provider.to_string(), + "model" => model.to_string(), + "kind" => kind, + ) + .increment(count); +} + +pub fn record_llm_tokens_usage_missing(provider: &str, model: &str) { + counter!( + "brightstaff_llm_tokens_usage_missing_total", + "provider" => provider.to_string(), + "model" => model.to_string(), + ) + .increment(1); +} + +// --------------------------------------------------------------------------- +// Router helpers +// --------------------------------------------------------------------------- + +pub fn record_router_decision( + route: &'static str, + selected_model: &str, + fallback: bool, + duration: Duration, +) { + counter!( + "brightstaff_router_decisions_total", + "route" => route, + "selected_model" => selected_model.to_string(), + "fallback" => if fallback { "true" } else { "false" }, + ) + .increment(1); + histogram!( + "brightstaff_router_decision_duration_seconds", + "route" => route, + ) + .record(duration.as_secs_f64()); +} + +pub fn record_routing_service_outcome(outcome: &'static str) { + counter!( + "brightstaff_routing_service_requests_total", + "outcome" => outcome, + ) + .increment(1); +} + +pub fn record_session_cache_event(outcome: &'static str) { + counter!( + "brightstaff_session_cache_events_total", + "outcome" => outcome, + ) + .increment(1); +} diff --git a/crates/brightstaff/src/router/mod.rs b/crates/brightstaff/src/router/mod.rs index 2ef0d11a..0f48c090 100644 --- a/crates/brightstaff/src/router/mod.rs +++ b/crates/brightstaff/src/router/mod.rs @@ -3,3 +3,5 @@ pub mod model_metrics; pub mod orchestrator; pub mod orchestrator_model; pub mod orchestrator_model_v1; +#[cfg(test)] +mod stress_tests; diff --git a/crates/brightstaff/src/router/orchestrator.rs b/crates/brightstaff/src/router/orchestrator.rs index 7aaf70a2..2d7b25de 100644 --- a/crates/brightstaff/src/router/orchestrator.rs +++ b/crates/brightstaff/src/router/orchestrator.rs @@ -15,6 +15,8 @@ use super::http::{self, post_and_extract_content}; use super::model_metrics::ModelMetricsService; use super::orchestrator_model::OrchestratorModel; +use crate::metrics as bs_metrics; +use crate::metrics::labels as metric_labels; use crate::router::orchestrator_model_v1; use crate::session_cache::SessionCache; @@ -130,7 +132,13 @@ impl OrchestratorService { tenant_id: Option<&str>, ) -> Option { let cache = self.session_cache.as_ref()?; - cache.get(&Self::session_key(tenant_id, session_id)).await + let result = cache.get(&Self::session_key(tenant_id, session_id)).await; + bs_metrics::record_session_cache_event(if result.is_some() { + metric_labels::SESSION_CACHE_HIT + } else { + metric_labels::SESSION_CACHE_MISS + }); + result } pub async fn cache_route( @@ -151,6 +159,7 @@ impl OrchestratorService { self.session_ttl, ) .await; + bs_metrics::record_session_cache_event(metric_labels::SESSION_CACHE_STORE); } } diff --git a/crates/brightstaff/src/router/stress_tests.rs b/crates/brightstaff/src/router/stress_tests.rs new file mode 100644 index 00000000..63c4112f --- /dev/null +++ b/crates/brightstaff/src/router/stress_tests.rs @@ -0,0 +1,264 @@ +#[cfg(test)] +mod tests { + use crate::router::orchestrator::OrchestratorService; + use crate::session_cache::memory::MemorySessionCache; + use common::configuration::{SelectionPolicy, SelectionPreference, TopLevelRoutingPreference}; + use hermesllm::apis::openai::{Message, MessageContent, Role}; + use std::sync::Arc; + + fn make_messages(n: usize) -> Vec { + (0..n) + .map(|i| Message { + role: if i % 2 == 0 { + Role::User + } else { + Role::Assistant + }, + content: Some(MessageContent::Text(format!( + "This is message number {i} with some padding text to make it realistic." + ))), + name: None, + tool_calls: None, + tool_call_id: None, + }) + .collect() + } + + fn make_routing_prefs() -> Vec { + vec![ + TopLevelRoutingPreference { + name: "code_generation".to_string(), + description: "Code generation and debugging tasks".to_string(), + models: vec![ + "openai/gpt-4o".to_string(), + "openai/gpt-4o-mini".to_string(), + ], + selection_policy: SelectionPolicy { + prefer: SelectionPreference::None, + }, + }, + TopLevelRoutingPreference { + name: "summarization".to_string(), + description: "Summarizing documents and text".to_string(), + models: vec![ + "anthropic/claude-3-sonnet".to_string(), + "openai/gpt-4o-mini".to_string(), + ], + selection_policy: SelectionPolicy { + prefer: SelectionPreference::None, + }, + }, + ] + } + + /// Stress test: exercise the full routing code path N times using a mock + /// HTTP server and measure jemalloc allocated bytes before/after. + /// + /// This catches: + /// - Memory leaks in generate_request / parse_response + /// - Leaks in reqwest connection handling + /// - String accumulation in the orchestrator model + /// - Fragmentation (jemalloc allocated vs resident) + #[tokio::test] + async fn stress_test_routing_determine_route() { + let mut server = mockito::Server::new_async().await; + let router_url = format!("{}/v1/chat/completions", server.url()); + + let mock_response = serde_json::json!({ + "id": "chatcmpl-mock", + "object": "chat.completion", + "created": 1234567890, + "model": "plano-orchestrator", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"route\": \"code_generation\"}" + }, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110} + }); + + let _mock = server + .mock("POST", "/v1/chat/completions") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(mock_response.to_string()) + .expect_at_least(1) + .create_async() + .await; + + let prefs = make_routing_prefs(); + let session_cache = Arc::new(MemorySessionCache::new(1000)); + let orchestrator_service = Arc::new(OrchestratorService::with_routing( + router_url, + "Plano-Orchestrator".to_string(), + "plano-orchestrator".to_string(), + Some(prefs.clone()), + None, + None, + session_cache, + None, + 2048, + )); + + // Warm up: a few requests to stabilize allocator state + for _ in 0..10 { + let msgs = make_messages(5); + let _ = orchestrator_service + .determine_route(&msgs, None, "warmup") + .await; + } + + // Snapshot memory after warmup + let baseline = get_allocated(); + + let num_iterations = 2000; + + for i in 0..num_iterations { + let msgs = make_messages(5 + (i % 10)); + let inline = if i % 3 == 0 { + Some(make_routing_prefs()) + } else { + None + }; + let _ = orchestrator_service + .determine_route(&msgs, inline, &format!("req-{i}")) + .await; + } + + let after = get_allocated(); + + let growth = after.saturating_sub(baseline); + let growth_mb = growth as f64 / (1024.0 * 1024.0); + let per_request = if num_iterations > 0 { + growth / num_iterations + } else { + 0 + }; + + eprintln!("=== Routing Stress Test Results ==="); + eprintln!(" Iterations: {num_iterations}"); + eprintln!(" Baseline alloc: {} bytes", baseline); + eprintln!(" Final alloc: {} bytes", after); + eprintln!(" Growth: {} bytes ({growth_mb:.2} MB)", growth); + eprintln!(" Per-request: {} bytes", per_request); + + // Allow up to 256 bytes per request of retained growth (connection pool, etc.) + // A true leak would show thousands of bytes per request. + assert!( + per_request < 256, + "Possible memory leak: {per_request} bytes/request retained after {num_iterations} iterations" + ); + } + + /// Stress test with high concurrency: many parallel determine_route calls. + #[tokio::test] + async fn stress_test_routing_concurrent() { + let mut server = mockito::Server::new_async().await; + let router_url = format!("{}/v1/chat/completions", server.url()); + + let mock_response = serde_json::json!({ + "id": "chatcmpl-mock", + "object": "chat.completion", + "created": 1234567890, + "model": "plano-orchestrator", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"route\": \"summarization\"}" + }, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110} + }); + + let _mock = server + .mock("POST", "/v1/chat/completions") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(mock_response.to_string()) + .expect_at_least(1) + .create_async() + .await; + + let prefs = make_routing_prefs(); + let session_cache = Arc::new(MemorySessionCache::new(1000)); + let orchestrator_service = Arc::new(OrchestratorService::with_routing( + router_url, + "Plano-Orchestrator".to_string(), + "plano-orchestrator".to_string(), + Some(prefs), + None, + None, + session_cache, + None, + 2048, + )); + + // Warm up + for _ in 0..20 { + let msgs = make_messages(3); + let _ = orchestrator_service + .determine_route(&msgs, None, "warmup") + .await; + } + + let baseline = get_allocated(); + + let concurrency = 50; + let requests_per_task = 100; + let total = concurrency * requests_per_task; + + let mut handles = vec![]; + for t in 0..concurrency { + let svc = Arc::clone(&orchestrator_service); + let handle = tokio::spawn(async move { + for r in 0..requests_per_task { + let msgs = make_messages(3 + (r % 8)); + let _ = svc + .determine_route(&msgs, None, &format!("req-{t}-{r}")) + .await; + } + }); + handles.push(handle); + } + + for h in handles { + h.await.unwrap(); + } + + let after = get_allocated(); + let growth = after.saturating_sub(baseline); + let per_request = growth / total; + + eprintln!("=== Concurrent Routing Stress Test Results ==="); + eprintln!(" Tasks: {concurrency} x {requests_per_task} = {total}"); + eprintln!(" Baseline: {} bytes", baseline); + eprintln!(" Final: {} bytes", after); + eprintln!( + " Growth: {} bytes ({:.2} MB)", + growth, + growth as f64 / 1_048_576.0 + ); + eprintln!(" Per-request: {} bytes", per_request); + + assert!( + per_request < 512, + "Possible memory leak under concurrency: {per_request} bytes/request retained after {total} requests" + ); + } + + #[cfg(feature = "jemalloc")] + fn get_allocated() -> usize { + tikv_jemalloc_ctl::epoch::advance().unwrap(); + tikv_jemalloc_ctl::stats::allocated::read().unwrap_or(0) + } + + #[cfg(not(feature = "jemalloc"))] + fn get_allocated() -> usize { + 0 + } +} diff --git a/crates/brightstaff/src/signals/analyzer.rs b/crates/brightstaff/src/signals/analyzer.rs index 8dffdd96..433bfe04 100644 --- a/crates/brightstaff/src/signals/analyzer.rs +++ b/crates/brightstaff/src/signals/analyzer.rs @@ -1,3255 +1,571 @@ -//! Agentic Signals - Behavioral quality indicators for agent interactions +//! Top-level signal analyzer. //! -//! This module implements various signals that serve as early warning indicators -//! of brilliant successes or failures in agentic interactions. These signals are -//! derived from conversation patterns and can be computed algorithmically from -//! message arrays. - -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; -use std::sync::LazyLock; +//! Direct port of `signals/analyzer.py`. Orchestrates all detectors across +//! the three layers (interaction / execution / environment) and produces a +//! `SignalReport`. use hermesllm::apis::openai::{Message, Role}; +use hermesllm::transforms::ExtractText; -// ============================================================================ -// Constants -// ============================================================================ +use super::environment::exhaustion::analyze_exhaustion; +use super::execution::failure::analyze_failure; +use super::execution::loops::analyze_loops; +use super::interaction::disengagement::analyze_disengagement; +use super::interaction::misalignment::analyze_misalignment; +use super::interaction::satisfaction::analyze_satisfaction; +use super::interaction::stagnation::{analyze_stagnation, ShareGptMsg}; +use super::schemas::{ + EnvironmentSignals, ExecutionSignals, InteractionQuality, InteractionSignals, SignalReport, + SignalType, TurnMetrics, +}; +use super::text_processing::NormalizedMessage; -/// Flag emoji for marking spans/operations worth investigating -pub const FLAG_MARKER: &str = "\u{1F6A9}"; +/// Marker appended to the span operation name when concerning signals are +/// detected. Kept in sync with the previous implementation for backward +/// compatibility with downstream consumers. +pub const FLAG_MARKER: &str = "[!]"; -/// Size of character n-grams for similarity matching (3 = trigrams) -const NGRAM_SIZE: usize = 3; +/// ShareGPT-shaped row used as the canonical input to the analyzer's +/// detectors. `from` is one of `"human"`, `"gpt"`, `"function_call"`, +/// `"observation"`. `value` is the raw message body. +#[derive(Debug, Clone, Copy)] +pub struct ShareGptMessage<'a> { + pub from: &'a str, + pub value: &'a str, +} -// ============================================================================ -// Normalized Message Processing -// ============================================================================ - -/// Pre-processed message with normalized text and tokens for efficient matching +/// Configuration knobs for the analyzer. Defaults match +/// `signals/analyzer.py:SignalAnalyzer.__init__`. #[derive(Debug, Clone)] -struct NormalizedMessage { - /// Original raw text - raw: String, - /// Tokens (words) extracted from the message - tokens: Vec, - /// Token set for fast lookup - token_set: HashSet, - /// Bigram set for fast similarity computation - bigram_set: HashSet, - /// Character ngram set for robust similarity matching - char_ngram_set: HashSet, - /// Token frequency map for multiset cosine similarity - token_frequency: HashMap, +pub struct SignalAnalyzerConfig { + pub baseline_turns: usize, + pub char_ngram_threshold: f32, + pub token_cosine_threshold: f32, + pub max_message_length: usize, + pub max_messages: usize, } -impl NormalizedMessage { - #[allow(dead_code)] // Used in tests for algorithm validation - fn from_text(text: &str) -> Self { - Self::from_text_with_limit(text, usize::MAX) - } - - fn from_text_with_limit(text: &str, max_length: usize) -> Self { - // Truncate to max_length characters to prevent unbounded computation - // Keep head (20%) + tail (80%) to preserve both context and intent - - let char_count = text.chars().count(); - - let raw = if char_count <= max_length { - text.to_string() - } else { - // Split: 20% head, 79% tail, 1 char space delimiter - let head_len = max_length / 5; - let tail_len = max_length - head_len - 1; - - let head: String = text.chars().take(head_len).collect(); - let tail: String = text.chars().skip(char_count - tail_len).collect(); - - format!("{} {}", head, tail) - }; - - // Normalize unicode punctuation to ASCII equivalents - let normalized_unicode = raw - .replace(['\u{2019}', '\u{2018}'], "'") // U+2019/U+2018 SINGLE QUOTATION MARKs - .replace(['\u{201C}', '\u{201D}'], "\"") // U+201C/U+201D DOUBLE QUOTATION MARKs - .replace(['\u{2013}', '\u{2014}'], "-"); // U+2013/U+2014 EN/EM DASHes - - // Normalize: lowercase, collapse whitespace - let normalized = normalized_unicode - .to_lowercase() - .split_whitespace() - .collect::>() - .join(" "); - - // Tokenize: split on whitespace and strip punctuation from boundaries - let tokens: Vec = normalized - .split_whitespace() - .map(|word| { - // Strip leading/trailing punctuation but keep internal punctuation - word.trim_matches(|c: char| c.is_ascii_punctuation()) - .to_string() - }) - .filter(|w| !w.is_empty()) - .collect(); - - let token_set: HashSet = tokens.iter().cloned().collect(); - - // Generate bigram set directly for similarity matching - let bigram_set: HashSet = tokens - .windows(2) - .map(|w| format!("{} {}", w[0], w[1])) - .collect(); - - // Generate character ngram set for robust similarity matching - // Uses tokens (with punctuation stripped) for consistency with pattern matching - let tokens_text = tokens.join(" "); - let char_ngram_set: HashSet = tokens_text - .chars() - .collect::>() - .windows(NGRAM_SIZE) - .map(|w| w.iter().collect::()) - .collect(); - - // Compute token frequency map for cosine similarity - let mut token_frequency: HashMap = HashMap::new(); - for token in &tokens { - *token_frequency.entry(token.clone()).or_insert(0) += 1; - } - - Self { - raw, - tokens, - token_set, - bigram_set, - char_ngram_set, - token_frequency, - } - } - - /// Check if a single token exists in the message (word boundary aware) - fn contains_token(&self, token: &str) -> bool { - self.token_set.contains(token) - } - - /// Check if a phrase (sequence of tokens) exists in the message - fn contains_phrase(&self, phrase: &str) -> bool { - let phrase_tokens: Vec<&str> = phrase.split_whitespace().collect(); - if phrase_tokens.is_empty() { - return false; - } - - if phrase_tokens.len() == 1 { - return self.contains_token(phrase_tokens[0]); - } - - // Multi-word phrase: check for sequence in tokens - self.tokens.windows(phrase_tokens.len()).any(|window| { - window - .iter() - .zip(phrase_tokens.iter()) - .all(|(token, phrase_token)| token == phrase_token) - }) - } - - /// Calculate character ngram similarity between this message and a pattern - /// Returns a similarity score between 0.0 and 1.0 - /// This is robust to typos, small edits, and word insertions - #[allow(dead_code)] // Used in tests for algorithm validation - fn char_ngram_similarity(&self, pattern: &str) -> f64 { - // Normalize the pattern: lowercase and remove ALL punctuation - // This makes "doesn't" → "doesnt" for robust typo matching - let normalized_pattern = pattern - .to_lowercase() - .chars() - .filter(|c| c.is_alphanumeric() || c.is_whitespace()) - .collect::() - .split_whitespace() - .collect::>() - .join(" "); - - // Generate ngrams for the pattern - let pattern_ngrams: HashSet = normalized_pattern - .chars() - .collect::>() - .windows(NGRAM_SIZE) - .map(|w| w.iter().collect::()) - .collect(); - - if self.char_ngram_set.is_empty() && pattern_ngrams.is_empty() { - return 1.0; // Both empty = identical - } - - if self.char_ngram_set.is_empty() || pattern_ngrams.is_empty() { - return 0.0; - } - - // Compute Jaccard similarity (intersection / union) - let intersection = self.char_ngram_set.intersection(&pattern_ngrams).count(); - let union = self.char_ngram_set.union(&pattern_ngrams).count(); - - if union == 0 { - return 0.0; - } - - intersection as f64 / union as f64 - } - - /// Calculate token-based cosine similarity using term frequencies - /// Returns a similarity score between 0.0 and 1.0 - /// This handles word frequency and is stable for longer messages - #[allow(dead_code)] // Used in tests for algorithm validation - fn token_cosine_similarity(&self, pattern: &str) -> f64 { - // Tokenize and compute frequencies for the pattern - let pattern_tokens: Vec = pattern - .to_lowercase() - .split_whitespace() - .map(|word| { - word.trim_matches(|c: char| c.is_ascii_punctuation()) - .to_string() - }) - .filter(|w| !w.is_empty()) - .collect(); - - let mut pattern_frequency: HashMap = HashMap::new(); - for token in &pattern_tokens { - *pattern_frequency.entry(token.clone()).or_insert(0) += 1; - } - - if self.token_frequency.is_empty() && pattern_frequency.is_empty() { - return 1.0; - } - - if self.token_frequency.is_empty() || pattern_frequency.is_empty() { - return 0.0; - } - - // Compute cosine similarity - // cosine_sim = dot_product / (norm1 * norm2) - - let mut dot_product = 0.0; - let mut norm1_squared = 0.0; - let mut norm2_squared = 0.0; - - // Collect all unique tokens from both sets - let all_tokens: HashSet = self - .token_frequency - .keys() - .chain(pattern_frequency.keys()) - .cloned() - .collect(); - - for token in all_tokens { - let freq1 = *self.token_frequency.get(&token).unwrap_or(&0) as f64; - let freq2 = *pattern_frequency.get(&token).unwrap_or(&0) as f64; - - dot_product += freq1 * freq2; - norm1_squared += freq1 * freq1; - norm2_squared += freq2 * freq2; - } - - let norm1 = norm1_squared.sqrt(); - let norm2 = norm2_squared.sqrt(); - - if norm1 == 0.0 || norm2 == 0.0 { - return 0.0; - } - - dot_product / (norm1 * norm2) - } - - /// Layered phrase matching: exact → character ngram → token cosine - /// Returns true if the pattern matches using any layer - #[allow(dead_code)] // Kept for reference; production uses matches_normalized_pattern - fn layered_contains_phrase( - &self, - pattern: &str, - char_ngram_threshold: f64, - token_cosine_threshold: f64, - ) -> bool { - // Layer 0: Exact phrase match (fastest) - if self.contains_phrase(pattern) { - return true; - } - - // Layer 1: Character ngram similarity (typo/edit robustness) - // Check whole message first (for short messages) - if self.char_ngram_similarity(pattern) >= char_ngram_threshold { - return true; - } - - // ngram containment check for patterns buried in longer messages - // If ALL of the pattern's ngrams exist in the message, the pattern must be - // present (possibly with minor variations like missing apostrophes). - // This is O(pattern_ngrams) lookups vs expensive window sliding. - if self.char_ngram_containment(pattern) >= 1.0 { - return true; - } - - // Layer 2: Token cosine similarity (semantic stability for long messages) - if self.token_cosine_similarity(pattern) >= token_cosine_threshold { - return true; - } - - false - } - - fn char_ngram_containment(&self, pattern: &str) -> f64 { - // Normalize the pattern the same way as char_ngram_similarity - let normalized_pattern = pattern - .to_lowercase() - .chars() - .filter(|c| c.is_alphanumeric() || c.is_whitespace()) - .collect::() - .split_whitespace() - .collect::>() - .join(" "); - - // Generate ngrams for the pattern - let pattern_ngrams: HashSet = normalized_pattern - .chars() - .collect::>() - .windows(NGRAM_SIZE) - .map(|w| w.iter().collect::()) - .collect(); - - if pattern_ngrams.is_empty() { - return 0.0; - } - - // Count how many pattern ngrams exist in the message - let contained = pattern_ngrams - .iter() - .filter(|t| self.char_ngram_set.contains(*t)) - .count(); - - contained as f64 / pattern_ngrams.len() as f64 - } - - /// Fast matching against a pre-normalized pattern - /// This avoids re-normalizing and re-computing ngrams for each pattern - fn matches_normalized_pattern( - &self, - pattern: &NormalizedPattern, - char_ngram_threshold: f64, - token_cosine_threshold: f64, - ) -> bool { - // Layer 0: Exact phrase match (fastest) - if self.contains_phrase(&pattern.raw) { - return true; - } - - // Layer 1: Character ngram similarity using pre-computed ngrams - if !self.char_ngram_set.is_empty() && !pattern.char_ngram_set.is_empty() { - let intersection = self - .char_ngram_set - .intersection(&pattern.char_ngram_set) - .count(); - let union = self.char_ngram_set.union(&pattern.char_ngram_set).count(); - if union > 0 { - let similarity = intersection as f64 / union as f64; - if similarity >= char_ngram_threshold { - return true; - } - } - } - - // Ngram containment check using pre-computed ngrams - if !pattern.char_ngram_set.is_empty() { - let contained = pattern - .char_ngram_set - .iter() - .filter(|t| self.char_ngram_set.contains(*t)) - .count(); - let containment = contained as f64 / pattern.char_ngram_set.len() as f64; - if containment >= 1.0 { - return true; - } - } - - // Layer 2: Token cosine similarity using pre-computed frequencies - if !self.token_frequency.is_empty() && !pattern.token_frequency.is_empty() { - let mut dot_product = 0.0; - let mut norm1_squared = 0.0; - let mut norm2_squared = 0.0; - - // Iterate over pattern tokens (usually smaller set) - for (token, &freq2) in &pattern.token_frequency { - let freq1 = *self.token_frequency.get(token).unwrap_or(&0) as f64; - let freq2 = freq2 as f64; - dot_product += freq1 * freq2; - norm2_squared += freq2 * freq2; - } - - // Add self tokens not in pattern for norm1 - for &freq1 in self.token_frequency.values() { - norm1_squared += (freq1 as f64) * (freq1 as f64); - } - - let norm1 = norm1_squared.sqrt(); - let norm2 = norm2_squared.sqrt(); - - if norm1 > 0.0 && norm2 > 0.0 { - let similarity = dot_product / (norm1 * norm2); - if similarity >= token_cosine_threshold { - return true; - } - } - } - - false - } -} - -// ============================================================================ -// Normalized Pattern (pre-computed for performance) -// ============================================================================ - -/// Pre-processed pattern with normalized text and pre-computed ngrams/tokens -/// This avoids redundant computation when matching against many messages -#[derive(Debug, Clone)] -struct NormalizedPattern { - /// Original raw pattern text - raw: String, - /// Character ngram set for similarity matching - char_ngram_set: HashSet, - /// Token frequency map for cosine similarity - token_frequency: HashMap, -} - -impl NormalizedPattern { - fn new(pattern: &str) -> Self { - // Normalize: lowercase and remove ALL punctuation - let normalized = pattern - .to_lowercase() - .chars() - .filter(|c| c.is_alphanumeric() || c.is_whitespace()) - .collect::() - .split_whitespace() - .collect::>() - .join(" "); - - // Generate ngrams - let char_ngram_set: HashSet = normalized - .chars() - .collect::>() - .windows(NGRAM_SIZE) - .map(|w| w.iter().collect::()) - .collect(); - - // Compute token frequency map - let tokens: Vec = normalized - .split_whitespace() - .map(|s| s.to_string()) - .collect(); - let mut token_frequency: HashMap = HashMap::new(); - for token in tokens { - *token_frequency.entry(token).or_insert(0) += 1; - } - - Self { - raw: pattern.to_string(), - char_ngram_set, - token_frequency, - } - } -} - -/// Helper to create a static slice of normalized patterns -fn normalize_patterns(patterns: &[&str]) -> Vec { - patterns.iter().map(|p| NormalizedPattern::new(p)).collect() -} - -// ============================================================================ -// Pre-computed Pattern Caches (initialized once at startup) -// ============================================================================ - -static REPAIR_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Explicit corrections - "i meant", - "i mean", - "sorry, i meant", - "what i meant was", - "what i actually meant", - "i was trying to say", - "let me correct that", - "correction", - "i misspoke", - // Negations and disagreements - "no, i", - "no i", - "nah i", - "nope i", - "not what i", - "that's not", - "that's not what", - "that isn't what", - "not quite", - "not exactly", - // Rephrasing indicators - "let me rephrase", - "let me try again", - "let me clarify", - "to clarify", - "to be clear", - "let me explain", - "what i'm trying to", - "what i'm saying", - "in other words", - // Actual/really emphasis - "actually i", - "actually no", - "what i actually", - "i actually", - "i really meant", - // Mistake acknowledgment - "i was wrong", - "my mistake", - "my bad", - "i should have said", - "i should clarify", - // Wait/hold indicators - "wait, i", - "wait no", - "hold on", - "hang on", - ]) -}); - -static COMPLAINT_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Useless/unhelpful (multi-word only) - "this is useless", - "not helpful", - "doesn't help", - "not helping", - "you're not helping", - "no help", - "unhelpful", - // Not working - "this doesn't work", - "doesn't work", - "not working", - "isn't working", - "won't work", - "still doesn't work", - "still not working", - // Not fixing/solving - "doesn't fix", - "not fixing", - "doesn't solve", - "doesn't seem to work", - "doesn't seem to fix", - "not resolving", - // Waste/pointless - "waste of time", - "wasting my time", - // Ridiculous/absurd - "this is ridiculous", - "ridiculous", - "this is absurd", - "absurd", - "this is insane", - "insane", - // Stupid/dumb (as adjectives, not as standalone tokens) - "this is stupid", - "this is dumb", - // Quality complaints (multi-word) - "this sucks", - "not good enough", - // Capability questions - "why can't you", - "can't you", - // Frustration - "this is frustrating", - "frustrated", - "incomplete", - "overwhelm", - "overwhelmed", - "overwhelming", - "exhausted", - "struggled", - // same issue - "same issue", - // polite dissatisfaction - "i'm disappointed", - "thanks, but", - "appreciate it, but", - "good, but", - // Fed up/done - "i give up", - "give up", - "fed up", - "had enough", - "can't take", - // Bot-specific complaints - "useless bot", - "dumb bot", - "stupid bot", - ]) -}); - -static CONFUSION_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Don't understand - "i don't understand", - "don't understand", - "not understanding", - "can't understand", - "don't get it", - "don't follow", - // Confused state - "i'm confused", - "so confused", - // Makes no sense - "makes no sense", - "doesn't make sense", - "not making sense", - // What do you mean (keep multi-word) - "what do you mean", - "what does that mean", - "what are you saying", - // Lost/unclear - "i'm lost", - "totally lost", - "lost me", - // No clue - "no clue", - "no idea", - // Come again - "come again", - "say that again", - "repeat that", - ]) -}); - -static GRATITUDE_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Standard gratitude - "thank you", - "thanks", - "thank u", - "thankyou", - "thx", - "ty", - "tyvm", - "tysm", - "thnx", - "thnks", - // Strong gratitude - "thanks so much", - "thank you so much", - "thanks a lot", - "thanks a bunch", - "much appreciated", - "really appreciate", - "greatly appreciate", - "appreciate it", - "appreciate that", - "i appreciate", - "grateful", - "so grateful", - // Helpfulness acknowledgment - "that's helpful", - "very helpful", - "super helpful", - "really helpful", - "that helps", - "this helps", - "helpful", - // Perfection expressions - "perfect", - "that's perfect", - "just perfect", - "exactly what i needed", - "exactly right", - "just what i needed", - "that's exactly", - // Informal positive - "you're the best", - "you rock", - "you're awesome", - "awesome sauce", - "legend", - ]) -}); - -static SATISFACTION_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Works/functions - "that works", - "this works", - "works great", - "works perfectly", - "works for me", - // Great variations - "that's great", - "that's amazing", - "this is great", - "sounds great", - "looks great", - "great job", - // Excellent/perfect - "excellent", - "outstanding", - "superb", - "spectacular", - // Awesome/amazing - "awesome", - "that's awesome", - "amazing", - "incredible", - // Love expressions - "love it", - "love this", - "i love", - "loving it", - "love that", - // Brilliant/wonderful - "brilliant", - "wonderful", - "fantastic", - "fabulous", - "marvelous", - ]) -}); - -static SUCCESS_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Understanding confirmation - "got it", - "i got it", - "understand", - "understood", - "i understand", - "makes sense", - "clear now", - "i see", - // Success/completion - "success", - "successful", - "it worked", - "that worked", - "this worked", - "worked", - // Problem resolution - "solved", - "resolved", - "fixed", - "fixed it", - "issue resolved", - "problem solved", - // Working state - "working now", - "it's working", - "works now", - "working fine", - "working great", - // Completion - "all set", - "all good", - "we're good", - "i'm good", - "all done", - "done", - "complete", - "finished", - // Perfect fit - "spot on", - "nailed it", - "bingo", - "exactly", - "just right", - ]) -}); - -static HUMAN_AGENT_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Speak to human - "speak to a human", - "speak to human", - "speak with a human", - "speak with human", - "talk to a human", - "talk to human", - "talk to a person", - "talk to person", - "talk to someone", - // Human/real agent - "human agent", - "real agent", - "actual agent", - "live agent", - "human support", - // Real/actual person - "real person", - "actual person", - "real human", - "actual human", - "someone real", - // Need/want human - "need a human", - "need human", - "want a human", - "want human", - "get me a human", - "get me human", - "get me someone", - // Transfer/connect - "transfer me", - "connect me", - "escalate this", - // Representative (removed standalone "rep" - too many false positives) - "representative", - "customer service rep", - "customer service representative", - // Not a bot - "not a bot", - "not talking to a bot", - "tired of bots", - ]) -}); - -static SUPPORT_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Contact support - "contact support", - "call support", - "reach support", - "get support", - // Customer support - "customer support", - "customer service", - "tech support", - "technical support", - // Help desk - "help desk", - "helpdesk", - "support desk", - // Talk to support - "talk to support", - "speak to support", - "speak with support", - "chat with support", - // Need help - "need real help", - "need actual help", - "help me now", - ]) -}); - -static QUIT_PATTERNS: LazyLock> = LazyLock::new(|| { - normalize_patterns(&[ - // Give up - "i give up", - "give up", - "giving up", - // Quit/leaving - "i'm going to quit", - "i quit", - "quitting", - "i'm leaving", - "i'm done", - "i'm out", - // Forget it - "forget it", - "forget this", - "screw it", - "screw this", - // Never mind - "never mind", - "nevermind", - "don't bother", - "not worth it", - // Hopeless - "this is hopeless", - // Going elsewhere - "going elsewhere", - "try somewhere else", - "look elsewhere", - "find another", - ]) -}); - -// ============================================================================ -// Core Signal Types -// ============================================================================ - -/// Overall quality assessment for an agent interaction session -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum InteractionQuality { - /// Excellent interaction with strong positive signals - Excellent, - /// Good interaction with mostly positive signals - Good, - /// Neutral interaction with mixed signals - Neutral, - /// Poor interaction with concerning signals - Poor, - /// Critical interaction with severe negative signals - Severe, -} - -/// Container for all computed signals for a conversation -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SignalReport { - /// Turn count and efficiency metrics - pub turn_count: TurnCountSignal, - /// Follow-up and repair frequency - pub follow_up: FollowUpSignal, - /// User frustration indicators - pub frustration: FrustrationSignal, - /// Repetition and looping behavior - pub repetition: RepetitionSignal, - /// Positive feedback indicators - pub positive_feedback: PositiveFeedbackSignal, - /// User escalation requests - pub escalation: EscalationSignal, - /// Overall quality assessment - pub overall_quality: InteractionQuality, - /// Human-readable summary - pub summary: String, -} - -// ============================================================================ -// Individual Signal Types -// ============================================================================ - -/// Turn count and efficiency metrics -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TurnCountSignal { - /// Total number of turns (user-agent exchanges) - pub total_turns: usize, - /// Number of user messages - pub user_turns: usize, - /// Number of assistant messages - pub assistant_turns: usize, - /// Whether the turn count is concerning (> 7) - pub is_concerning: bool, - /// Whether the turn count is excessive (> 12) - pub is_excessive: bool, - /// Efficiency score (0.0-1.0, lower turns = higher score) - pub efficiency_score: f64, -} - -/// Follow-up and repair frequency signal -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FollowUpSignal { - /// Number of detected repair attempts - pub repair_count: usize, - /// Ratio of repairs to total user turns - pub repair_ratio: f64, - /// Whether repair ratio is concerning (> 0.3) - pub is_concerning: bool, - /// List of detected repair phrases - pub repair_phrases: Vec, -} - -/// User frustration indicators -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FrustrationSignal { - /// Number of frustration indicators detected - pub frustration_count: usize, - /// Whether frustration is detected - pub has_frustration: bool, - /// Severity level (0-3: none, mild, moderate, severe) - pub severity: u8, - /// List of detected frustration indicators - pub indicators: Vec, -} - -/// Individual frustration indicator -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FrustrationIndicator { - /// Type of frustration detected - pub indicator_type: FrustrationType, - /// Message index where detected - pub message_index: usize, - /// Relevant text snippet - pub snippet: String, -} - -/// Types of frustration indicators -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum FrustrationType { - /// Negative sentiment detected - NegativeSentiment, - /// All caps typing - AllCaps, - /// Excessive punctuation - ExcessivePunctuation, - /// Profanity detected - Profanity, - /// Direct complaint - DirectComplaint, - /// Expression of confusion - Confusion, -} - -/// Repetition and looping behavior signal -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RepetitionSignal { - /// Number of repetitions detected - pub repetition_count: usize, - /// Whether significant looping detected (> 2 repetitions) - pub has_looping: bool, - /// Severity level (0-3: none, mild, moderate, severe) - pub severity: u8, - /// List of detected repetitions - pub repetitions: Vec, -} - -/// Individual repetition instance -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RepetitionInstance { - /// Message indices involved in repetition - pub message_indices: Vec, - /// Similarity score (0.0-1.0) - pub similarity: f64, - /// Type of repetition - pub repetition_type: RepetitionType, -} - -/// Types of repetition -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum RepetitionType { - /// Exact repetition - Exact, - /// Near-duplicate (high similarity) - NearDuplicate, - /// Semantic repetition (similar meaning) - Semantic, -} - -/// Positive feedback indicators -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PositiveFeedbackSignal { - /// Number of positive indicators detected - pub positive_count: usize, - /// Whether positive feedback is present - pub has_positive_feedback: bool, - /// Confidence score (0.0-1.0) - pub confidence: f64, - /// List of detected positive indicators - pub indicators: Vec, -} - -/// Individual positive indicator -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PositiveIndicator { - /// Type of positive feedback - pub indicator_type: PositiveType, - /// Message index where detected - pub message_index: usize, - /// Relevant text snippet - pub snippet: String, -} - -/// Types of positive indicators -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum PositiveType { - /// Expression of gratitude - Gratitude, - /// Explicit satisfaction - Satisfaction, - /// Confirmation of success - Success, - /// Positive sentiment - PositiveSentiment, - /// Natural topic transition - TopicTransition, -} - -/// User escalation signal -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EscalationSignal { - /// Whether escalation was requested - pub escalation_requested: bool, - /// Number of escalation requests - pub escalation_count: usize, - /// List of detected escalation requests - pub requests: Vec, -} - -/// Individual escalation request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EscalationRequest { - /// Message index where detected - pub message_index: usize, - /// Relevant text snippet - pub snippet: String, - /// Type of escalation - pub escalation_type: EscalationType, -} - -/// Types of escalation -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum EscalationType { - /// Request for human agent - HumanAgent, - /// Request for support - Support, - /// Threat to quit/leave - ThreatToQuit, - /// General help request - HelpRequest, -} - -// ============================================================================ -// Signal Analyzer -// ============================================================================ - -/// Trait for analyzing conversation signals -pub trait SignalAnalyzer { - /// Analyze a conversation and generate a complete signal report - fn analyze(&self, messages: &[Message]) -> SignalReport; -} - -/// Text-based implementation of signal analyzer that computes all signals from a message array -pub struct TextBasedSignalAnalyzer { - /// Baseline expected turns for normal interactions - baseline_turns: usize, - /// Threshold for character ngram similarity (0.0-1.0) - char_ngram_threshold: f64, - /// Threshold for token cosine similarity (0.0-1.0) - token_cosine_threshold: f64, - /// Maximum message length in characters (prevents unbounded computation) - max_message_length: usize, - /// Maximum number of messages to process (prevents unbounded computation) - max_messages: usize, - /// Maximum window size for repetition detection (prevents O(n²) explosion) - max_repetition_window: usize, -} - -impl TextBasedSignalAnalyzer { - /// Extract text content from MessageContent, skipping non-text content - fn extract_text(content: &Option) -> Option { - match content { - Some(hermesllm::apis::openai::MessageContent::Text(text)) => Some(text.clone()), - // Tool calls and other structured content are skipped - _ => None, - } - } - - /// Create a new signal analyzer with default settings - pub fn new() -> Self { +impl Default for SignalAnalyzerConfig { + fn default() -> Self { Self { baseline_turns: 5, - char_ngram_threshold: 0.50, // Lowered to handle typos and small edits realistically - token_cosine_threshold: 0.60, // Lowered for better semantic match in varied contexts - max_message_length: 2000, // Prevent unbounded ngram generation - max_messages: 100, // Prevent unbounded message processing - max_repetition_window: 20, // Prevent O(n²) explosion in repetition detection - } - } - - /// Create a new signal analyzer with custom baseline - pub fn with_baseline(baseline_turns: usize) -> Self { - Self { - baseline_turns, - char_ngram_threshold: 0.50, + char_ngram_threshold: 0.65, token_cosine_threshold: 0.60, max_message_length: 2000, max_messages: 100, - max_repetition_window: 20, } } - - /// Create a new signal analyzer with custom settings - /// - /// # Arguments - /// * `baseline_turns` - Expected baseline turns for normal interactions - /// * `char_ngram_threshold` - Threshold for character ngram similarity (0.0-1.0) - /// * `token_cosine_threshold` - Threshold for token cosine similarity (0.0-1.0) - pub fn with_settings( - baseline_turns: usize, - char_ngram_threshold: f64, - token_cosine_threshold: f64, - ) -> Self { - Self { - baseline_turns, - char_ngram_threshold, - token_cosine_threshold, - max_message_length: 2000, - max_messages: 100, - max_repetition_window: 20, - } - } - - /// Create a new signal analyzer with full custom settings including computation limits - /// - /// # Arguments - /// * `baseline_turns` - Expected baseline turns for normal interactions - /// * `char_ngram_threshold` - Threshold for character ngram similarity (0.0-1.0) - /// * `token_cosine_threshold` - Threshold for token cosine similarity (0.0-1.0) - /// * `max_message_length` - Maximum characters per message to process - /// * `max_messages` - Maximum number of messages to process - /// * `max_repetition_window` - Maximum messages to compare for repetition detection - pub fn with_full_settings( - baseline_turns: usize, - char_ngram_threshold: f64, - token_cosine_threshold: f64, - max_message_length: usize, - max_messages: usize, - max_repetition_window: usize, - ) -> Self { - Self { - baseline_turns, - char_ngram_threshold, - token_cosine_threshold, - max_message_length, - max_messages, - max_repetition_window, - } - } - - // ======================================================================== - // Individual Signal Analyzers - // ======================================================================== - - /// Analyze turn count and efficiency - fn analyze_turn_count(&self, messages: &[Message]) -> TurnCountSignal { - let mut user_turns = 0; - let mut assistant_turns = 0; - - for message in messages { - match message.role { - Role::User => user_turns += 1, - Role::Assistant => assistant_turns += 1, - _ => {} - } - } - - let total_turns = user_turns + assistant_turns; - let is_concerning = total_turns > 7; - let is_excessive = total_turns > 12; - - // Calculate efficiency score (exponential decay after baseline) - let efficiency_score = if total_turns == 0 || total_turns <= self.baseline_turns { - 1.0 - } else { - let excess = total_turns - self.baseline_turns; - 1.0 / (1.0 + (excess as f64 * 0.3)) - }; - - TurnCountSignal { - total_turns, - user_turns, - assistant_turns, - is_concerning, - is_excessive, - efficiency_score, - } - } - - /// Analyze follow-up and repair frequency - fn analyze_follow_up( - &self, - normalized_messages: &[(usize, Role, NormalizedMessage)], - ) -> FollowUpSignal { - let mut repair_count = 0; - let mut repair_phrases = Vec::new(); - let mut user_turn_count = 0; - - for (pos, (i, role, norm_msg)) in normalized_messages.iter().enumerate() { - if *role != Role::User { - continue; - } - - user_turn_count += 1; - - // Use per-turn boolean to prevent double-counting - let mut found_in_turn = false; - - // Use pre-computed patterns for fast matching - for pattern in REPAIR_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - repair_count += 1; - repair_phrases.push(format!("Turn {}: '{}'", i + 1, pattern.raw)); - found_in_turn = true; - break; - } - } - - // Only check for semantic similarity if no pattern matched. Walk - // backwards through the *normalized* list (not the original - // conversation indices, which may be non-contiguous because - // messages without extractable text are filtered out) to find the - // most recent prior user message. - if !found_in_turn && pos >= 1 { - for j in (0..pos).rev() { - let (_, prev_role, prev_norm_msg) = &normalized_messages[j]; - if *prev_role == Role::User { - if self.is_similar_rephrase(norm_msg, prev_norm_msg) { - repair_count += 1; - repair_phrases - .push(format!("Turn {}: Similar rephrase detected", i + 1)); - } - break; - } - } - } - } - - let repair_ratio = if user_turn_count == 0 { - 0.0 - } else { - repair_count as f64 / user_turn_count as f64 - }; - - let is_concerning = repair_ratio > 0.3; - - FollowUpSignal { - repair_count, - repair_ratio, - is_concerning, - repair_phrases, - } - } - - /// Analyze user frustration indicators - fn analyze_frustration( - &self, - normalized_messages: &[(usize, Role, NormalizedMessage)], - ) -> FrustrationSignal { - let mut indicators = Vec::new(); - - // Profanity list - only as standalone tokens, not substrings - let profanity_tokens = [ - "damn", "damnit", "crap", "wtf", "ffs", "bullshit", "shit", "fuck", "fucking", - ]; - - for (i, role, norm_msg) in normalized_messages { - if *role != Role::User { - continue; - } - - let text = &norm_msg.raw; - - // Check for all caps (at least 10 chars and 80% uppercase) - let alpha_chars: String = text.chars().filter(|c| c.is_alphabetic()).collect(); - if alpha_chars.len() >= 10 { - let upper_count = alpha_chars.chars().filter(|c| c.is_uppercase()).count(); - let upper_ratio = upper_count as f64 / alpha_chars.len() as f64; - if upper_ratio >= 0.8 { - indicators.push(FrustrationIndicator { - indicator_type: FrustrationType::AllCaps, - message_index: *i, - snippet: text.chars().take(50).collect(), - }); - } - } - - // Check for excessive punctuation - let question_marks = text.matches('?').count(); - let exclamation_marks = text.matches('!').count(); - if question_marks >= 3 || exclamation_marks >= 3 { - indicators.push(FrustrationIndicator { - indicator_type: FrustrationType::ExcessivePunctuation, - message_index: *i, - snippet: text.chars().take(50).collect(), - }); - } - - // Check for complaint patterns using pre-computed patterns - for pattern in COMPLAINT_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - indicators.push(FrustrationIndicator { - indicator_type: FrustrationType::DirectComplaint, - message_index: *i, - snippet: pattern.raw.clone(), - }); - break; - } - } - - // Check for confusion patterns using pre-computed patterns - for pattern in CONFUSION_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - indicators.push(FrustrationIndicator { - indicator_type: FrustrationType::Confusion, - message_index: *i, - snippet: pattern.raw.clone(), - }); - break; - } - } - - // Check for profanity (token-based, not substring) - for token in &profanity_tokens { - if norm_msg.contains_token(token) { - indicators.push(FrustrationIndicator { - indicator_type: FrustrationType::Profanity, - message_index: *i, - snippet: token.to_string(), - }); - break; - } - } - } - - let frustration_count = indicators.len(); - let has_frustration = frustration_count > 0; - - // Calculate severity - let severity = if frustration_count == 0 { - 0 - } else if frustration_count <= 2 { - 1 - } else if frustration_count <= 4 { - 2 - } else { - 3 - }; - - FrustrationSignal { - frustration_count, - has_frustration, - severity, - indicators, - } - } - - /// Analyze repetition and looping behavior - fn analyze_repetition( - &self, - normalized_messages: &[(usize, Role, NormalizedMessage)], - ) -> RepetitionSignal { - let mut repetitions = Vec::new(); - - // Collect assistant messages with normalized content - let assistant_messages: Vec<(usize, &NormalizedMessage)> = normalized_messages - .iter() - .filter(|(_, role, _)| *role == Role::Assistant) - .map(|(i, _, norm_msg)| (*i, norm_msg)) - .collect(); - - // Limit the window size to prevent O(n²) explosion - // Only compare messages within the max_repetition_window - let window_size = self.max_repetition_window.min(assistant_messages.len()); - - // Check for exact or near-duplicate responses using bigram similarity - // Only compare within the sliding window - for i in 0..assistant_messages.len() { - let window_start = i + 1; - let window_end = (i + 1 + window_size).min(assistant_messages.len()); - - for j in window_start..window_end { - let (idx_i, norm_msg_i) = &assistant_messages[i]; - let (idx_j, norm_msg_j) = &assistant_messages[j]; - - // Skip if messages are too short - if norm_msg_i.tokens.len() < 5 || norm_msg_j.tokens.len() < 5 { - continue; - } - - // Calculate bigram-based similarity (more accurate for near-duplicates) - let similarity = self.calculate_bigram_similarity(norm_msg_i, norm_msg_j); - - // Exact match - lowered from 0.95 to 0.85 for bigram similarity - if similarity >= 0.85 { - repetitions.push(RepetitionInstance { - message_indices: vec![*idx_i, *idx_j], - similarity, - repetition_type: RepetitionType::Exact, - }); - } - // Near duplicate - lowered from 0.75 to 0.50 to catch subtle repetitions - else if similarity >= 0.50 { - repetitions.push(RepetitionInstance { - message_indices: vec![*idx_i, *idx_j], - similarity, - repetition_type: RepetitionType::NearDuplicate, - }); - } - } - } - - let repetition_count = repetitions.len(); - let has_looping = repetition_count > 2; - - let severity = if repetition_count == 0 { - 0 - } else if repetition_count <= 2 { - 1 - } else if repetition_count <= 4 { - 2 - } else { - 3 - }; - - RepetitionSignal { - repetition_count, - has_looping, - severity, - repetitions, - } - } - - /// Calculate bigram similarity using cached bigram sets - fn calculate_bigram_similarity( - &self, - norm_msg1: &NormalizedMessage, - norm_msg2: &NormalizedMessage, - ) -> f64 { - // Use pre-cached bigram sets for O(1) lookups - let set1 = &norm_msg1.bigram_set; - let set2 = &norm_msg2.bigram_set; - - if set1.is_empty() && set2.is_empty() { - return 1.0; // Both empty = identical - } - - if set1.is_empty() || set2.is_empty() { - return 0.0; - } - - let intersection = set1.intersection(set2).count(); - let union = set1.union(set2).count(); - - if union == 0 { - return 0.0; - } - - intersection as f64 / union as f64 - } - - /// Analyze positive feedback indicators - fn analyze_positive_feedback( - &self, - normalized_messages: &[(usize, Role, NormalizedMessage)], - ) -> PositiveFeedbackSignal { - let mut indicators = Vec::new(); - - for (i, role, norm_msg) in normalized_messages { - if *role != Role::User { - continue; - } - - // Use per-turn boolean to prevent double-counting - let mut found_in_turn = false; - - // Check gratitude using pre-computed patterns - for pattern in GRATITUDE_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - indicators.push(PositiveIndicator { - indicator_type: PositiveType::Gratitude, - message_index: *i, - snippet: pattern.raw.clone(), - }); - found_in_turn = true; - break; - } - } - - if found_in_turn { - continue; - } - - // Check satisfaction using pre-computed patterns - for pattern in SATISFACTION_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - indicators.push(PositiveIndicator { - indicator_type: PositiveType::Satisfaction, - message_index: *i, - snippet: pattern.raw.clone(), - }); - found_in_turn = true; - break; - } - } - - if found_in_turn { - continue; - } - - // Check success confirmation using pre-computed patterns - for pattern in SUCCESS_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - indicators.push(PositiveIndicator { - indicator_type: PositiveType::Success, - message_index: *i, - snippet: pattern.raw.clone(), - }); - break; - } - } - } - - let positive_count = indicators.len(); - let has_positive_feedback = positive_count > 0; - - // Calculate confidence based on number and diversity of indicators - let confidence = if positive_count == 0 { - 0.0 - } else if positive_count == 1 { - 0.6 - } else if positive_count == 2 { - 0.8 - } else { - 0.95 - }; - - PositiveFeedbackSignal { - positive_count, - has_positive_feedback, - confidence, - indicators, - } - } - - /// Analyze user escalation requests - fn analyze_escalation( - &self, - normalized_messages: &[(usize, Role, NormalizedMessage)], - ) -> EscalationSignal { - let mut requests = Vec::new(); - - for (i, role, norm_msg) in normalized_messages { - if *role != Role::User { - continue; - } - - let mut found_human_agent = false; - - // Check for human agent request using pre-computed patterns - for pattern in HUMAN_AGENT_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - requests.push(EscalationRequest { - message_index: *i, - snippet: pattern.raw.clone(), - escalation_type: EscalationType::HumanAgent, - }); - found_human_agent = true; - break; - } - } - - // Check for support request (only if no human agent request found) - // HumanAgent and Support are too similar and often match the same phrase - if !found_human_agent { - for pattern in SUPPORT_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - requests.push(EscalationRequest { - message_index: *i, - snippet: pattern.raw.clone(), - escalation_type: EscalationType::Support, - }); - break; - } - } - } - - // Check for quit threats (independent of HumanAgent/Support) - // A message can contain both "give up" (quit) and "speak to human" (escalation) - for pattern in QUIT_PATTERNS.iter() { - if norm_msg.matches_normalized_pattern( - pattern, - self.char_ngram_threshold, - self.token_cosine_threshold, - ) { - requests.push(EscalationRequest { - message_index: *i, - snippet: pattern.raw.clone(), - escalation_type: EscalationType::ThreatToQuit, - }); - break; - } - } - } - - let escalation_count = requests.len(); - let escalation_requested = escalation_count > 0; - - EscalationSignal { - escalation_requested, - escalation_count, - requests, - } - } - - // ======================================================================== - // Helper Methods - // ======================================================================== - - /// Check if two messages are similar rephrases - fn is_similar_rephrase( - &self, - norm_msg1: &NormalizedMessage, - norm_msg2: &NormalizedMessage, - ) -> bool { - // Skip if too short - if norm_msg1.tokens.len() < 3 || norm_msg2.tokens.len() < 3 { - return false; - } - - // Common stopwords to downweight - let stopwords: HashSet<&str> = [ - "i", "me", "my", "you", "the", "a", "an", "is", "are", "was", "were", "to", "with", - "for", "of", "at", "by", "in", "on", "it", "this", "that", "can", "could", "do", - "does", "did", "will", "would", "should", "be", - ] - .iter() - .cloned() - .collect(); - - // Filter out stopwords for meaningful overlap - let tokens1: HashSet<_> = norm_msg1 - .tokens - .iter() - .filter(|t| !stopwords.contains(t.as_str())) - .collect(); - let tokens2: HashSet<_> = norm_msg2 - .tokens - .iter() - .filter(|t| !stopwords.contains(t.as_str())) - .collect(); - - // Need at least 2 non-stopword tokens - if tokens1.len() < 2 || tokens2.len() < 2 { - return false; - } - - let intersection = tokens1.intersection(&tokens2).count(); - let min_size = tokens1.len().min(tokens2.len()); - - // High overlap suggests rephrase - let overlap_ratio = intersection as f64 / min_size as f64; - overlap_ratio >= 0.6 - } - - /// Assess overall interaction quality based on all signals - fn assess_overall_quality( - &self, - turn_count: &TurnCountSignal, - follow_up: &FollowUpSignal, - frustration: &FrustrationSignal, - repetition: &RepetitionSignal, - positive: &PositiveFeedbackSignal, - escalation: &EscalationSignal, - ) -> InteractionQuality { - // Critical conditions - immediate fail - if escalation.escalation_requested - || frustration.severity >= 3 - || repetition.severity >= 3 - || turn_count.is_excessive - { - return InteractionQuality::Severe; - } - - // Calculate quality score - let mut score = 50.0; // Start at neutral - - // Positive factors - if positive.has_positive_feedback { - score += 20.0 * positive.confidence; - } - score += turn_count.efficiency_score * 10.0; - - // Negative factors - if frustration.has_frustration { - score -= frustration.severity as f64 * 10.00; - } - if follow_up.is_concerning { - score -= 15.0; - } - if repetition.has_looping { - score -= repetition.severity as f64 * 8.0; - } - if turn_count.is_concerning { - score -= 10.0; - } - - // Map score to quality level - if score >= 75.0 { - InteractionQuality::Excellent - } else if score >= 60.0 { - InteractionQuality::Good - } else if score >= 40.0 { - InteractionQuality::Neutral - } else if score >= 25.0 { - InteractionQuality::Poor - } else { - InteractionQuality::Severe - } - } - - /// Generate human-readable summary - #[allow(clippy::too_many_arguments)] - fn generate_summary( - &self, - turn_count: &TurnCountSignal, - follow_up: &FollowUpSignal, - frustration: &FrustrationSignal, - repetition: &RepetitionSignal, - positive: &PositiveFeedbackSignal, - escalation: &EscalationSignal, - quality: &InteractionQuality, - ) -> String { - let mut summary_parts = Vec::new(); - - summary_parts.push(format!("Overall Quality: {:?}", quality)); - - summary_parts.push(format!( - "Turn Count: {} turns (efficiency: {:.1}%)", - turn_count.total_turns, - turn_count.efficiency_score * 100.0 - )); - - if follow_up.is_concerning { - summary_parts.push(format!( - "⚠️ High repair rate: {:.1}% of user turns", - follow_up.repair_ratio * 100.0 - )); - } - - if frustration.has_frustration { - summary_parts.push(format!( - "⚠️ Frustration detected: {} indicators (severity: {})", - frustration.frustration_count, frustration.severity - )); - } - - if repetition.has_looping { - summary_parts.push(format!( - "⚠️ Looping detected: {} repetitions", - repetition.repetition_count - )); - } - - if positive.has_positive_feedback { - summary_parts.push(format!( - "✓ Positive feedback: {} indicators", - positive.positive_count - )); - } - - if escalation.escalation_requested { - summary_parts.push(format!( - "⚠️ Escalation requested: {} requests", - escalation.escalation_count - )); - } - - summary_parts.join(" | ") - } } -impl SignalAnalyzer for TextBasedSignalAnalyzer { - fn analyze(&self, messages: &[Message]) -> SignalReport { - // Limit the number of messages to process (take most recent messages) - let messages_to_process = if messages.len() > self.max_messages { - &messages[messages.len() - self.max_messages..] +/// Top-level analyzer. +pub struct SignalAnalyzer { + cfg: SignalAnalyzerConfig, +} + +impl Default for SignalAnalyzer { + fn default() -> Self { + Self::new(SignalAnalyzerConfig::default()) + } +} + +impl SignalAnalyzer { + pub fn new(cfg: SignalAnalyzerConfig) -> Self { + Self { cfg } + } + + /// Run the full multi-layer analysis on a ShareGPT-shaped conversation. + pub fn analyze_sharegpt(&self, messages: &[ShareGptMessage<'_>]) -> SignalReport { + // Truncate to the last `max_messages` (last-N is what the Python does). + let slice: &[ShareGptMessage<'_>] = if messages.len() > self.cfg.max_messages { + &messages[messages.len() - self.cfg.max_messages..] } else { messages }; + let offset = messages.len().saturating_sub(slice.len()); - // Preprocess all messages once, filtering out non-text content (tool calls, etc.) - // and truncating long messages - let normalized_messages: Vec<(usize, Role, NormalizedMessage)> = messages_to_process + // Preprocess to absolute-indexed normalized human/gpt messages. + let normalized_owned: Vec<(usize, &str, NormalizedMessage)> = slice .iter() .enumerate() - .filter_map(|(i, msg)| { - Self::extract_text(&msg.content).map(|text| { - ( - i, - msg.role.clone(), - NormalizedMessage::from_text_with_limit(&text, self.max_message_length), - ) - }) + .filter_map(|(i, m)| { + if (m.from == "human" || m.from == "gpt") && !m.value.is_empty() { + Some(( + offset + i, + m.from, + NormalizedMessage::from_text(m.value, self.cfg.max_message_length), + )) + } else { + None + } }) .collect(); - let turn_count = self.analyze_turn_count(messages_to_process); - let follow_up = self.analyze_follow_up(&normalized_messages); - let frustration = self.analyze_frustration(&normalized_messages); - let repetition = self.analyze_repetition(&normalized_messages); - let positive_feedback = self.analyze_positive_feedback(&normalized_messages); - let escalation = self.analyze_escalation(&normalized_messages); - - let overall_quality = self.assess_overall_quality( - &turn_count, - &follow_up, - &frustration, - &repetition, - &positive_feedback, - &escalation, + let misalignment = analyze_misalignment( + &normalized_owned, + self.cfg.char_ngram_threshold, + self.cfg.token_cosine_threshold, ); - let summary = self.generate_summary( - &turn_count, - &follow_up, - &frustration, - &repetition, - &positive_feedback, - &escalation, - &overall_quality, + let stagnation_input: Vec> = + slice.iter().map(|m| ShareGptMsg { from: m.from }).collect(); + let (mut stagnation, turn_metrics) = analyze_stagnation( + &stagnation_input, + &normalized_owned, + self.cfg.baseline_turns, + ); + + let disengagement = analyze_disengagement( + &normalized_owned, + self.cfg.char_ngram_threshold, + self.cfg.token_cosine_threshold, + ); + + let satisfaction = analyze_satisfaction( + &normalized_owned, + self.cfg.char_ngram_threshold, + self.cfg.token_cosine_threshold, + ); + + let failure = analyze_failure(slice); + let loops = analyze_loops(slice); + let exhaustion = analyze_exhaustion(slice); + + // Bias the dragging signal's message_index back into absolute coords. + for s in &mut stagnation.signals { + s.message_index = offset + s.message_index.min(slice.len().saturating_sub(1)); + } + + let interaction = InteractionSignals { + misalignment, + stagnation, + disengagement, + satisfaction, + }; + let execution = ExecutionSignals { failure, loops }; + let environment = EnvironmentSignals { exhaustion }; + + let (overall_quality, score) = assess_quality( + &interaction, + &execution, + &environment, + turn_metrics.user_turns, + ); + let summary = generate_summary( + &turn_metrics, + &interaction, + &execution, + &environment, + overall_quality, ); SignalReport { - turn_count, - follow_up, - frustration, - repetition, - positive_feedback, - escalation, + interaction, + execution, + environment, overall_quality, + quality_score: score, + turn_metrics, summary, } } -} -impl Default for TextBasedSignalAnalyzer { - fn default() -> Self { - Self::new() + /// Convenience entry point: convert OpenAI-shaped chat `Message`s into the + /// ShareGPT format the detectors operate on, then run analysis. + pub fn analyze_openai(&self, messages: &[Message]) -> SignalReport { + let owned = messages_to_sharegpt(messages); + let view: Vec> = owned + .iter() + .map(|(role, value)| ShareGptMessage { + from: role.as_str(), + value: value.as_str(), + }) + .collect(); + self.analyze_sharegpt(&view) } } -// ============================================================================ -// Tests -// ============================================================================ +/// Convert OpenAI-shaped messages to a sequence of ShareGPT +/// `(role, value)` pairs. +/// +/// Mapping (preserves original message order; tool calls are emitted as a +/// separate `function_call` row immediately after the assistant text): +/// +/// - `User` -> `("human", text)` +/// - `Assistant` -> `("gpt", text)`, then one `("function_call", json)` per tool call +/// - `Tool` -> `("observation", text)` +/// - `System` / `Developer` -> dropped (not analyzed) +pub fn messages_to_sharegpt(messages: &[Message]) -> Vec<(String, String)> { + let mut out: Vec<(String, String)> = Vec::with_capacity(messages.len()); + for m in messages { + match m.role { + Role::User => { + let text = m.content.extract_text(); + out.push(("human".to_string(), text)); + } + Role::Assistant => { + let text = m.content.extract_text(); + if !text.is_empty() { + out.push(("gpt".to_string(), text)); + } + if let Some(calls) = &m.tool_calls { + for call in calls { + let payload = serde_json::json!({ + "name": call.function.name, + "arguments": call.function.arguments, + }); + out.push(("function_call".to_string(), payload.to_string())); + } + } + } + Role::Tool => { + let text = m.content.extract_text(); + out.push(("observation".to_string(), text)); + } + Role::System | Role::Developer => {} + } + } + out +} + +// --------------------------------------------------------------------------- +// Quality scoring (mirrors `_assess_quality` in the reference) +// --------------------------------------------------------------------------- + +fn assess_quality( + interaction: &InteractionSignals, + execution: &ExecutionSignals, + environment: &EnvironmentSignals, + user_turns: usize, +) -> (InteractionQuality, f32) { + // Critical: explicit escalation/quit OR severe disengagement OR severe stagnation. + let has_escalation_or_quit = interaction.disengagement.signals.iter().any(|s| { + matches!( + s.signal_type, + SignalType::DisengagementEscalation | SignalType::DisengagementQuit + ) + }); + if (interaction.disengagement.count > 0 && has_escalation_or_quit) + || interaction.disengagement.severity >= 3 + || interaction.stagnation.severity >= 3 + { + return (InteractionQuality::Severe, 0.0); + } + + let mut score: f32 = 50.0; + + if interaction.satisfaction.count > 0 { + let confidence = match interaction.satisfaction.count { + 1 => 0.6, + 2 => 0.8, + _ => 0.95, + }; + score += 20.0 * confidence; + } + + if interaction.disengagement.count > 0 { + score -= interaction.disengagement.severity as f32 * 10.0; + } + if interaction.misalignment.severity > 0 && interaction.misalignment_ratio(user_turns) > 0.3 { + score -= 15.0; + } + if interaction.stagnation.count > 2 { + score -= interaction.stagnation.severity as f32 * 8.0; + } + + if execution.failure.count > 0 { + score -= execution.failure.count as f32 * 8.0; + } + if execution.loops.count > 0 { + score -= execution.loops.count as f32 * 5.0; + } + if environment.exhaustion.count > 0 { + score -= environment.exhaustion.count as f32 * 3.0; + } + + score = score.clamp(0.0, 100.0); + + let quality = if score >= 75.0 { + InteractionQuality::Excellent + } else if score >= 60.0 { + InteractionQuality::Good + } else if score >= 40.0 { + InteractionQuality::Neutral + } else if score >= 25.0 { + InteractionQuality::Poor + } else { + InteractionQuality::Severe + }; + (quality, score) +} + +/// Render the per-conversation summary string. +/// +/// Output is structurally grouped by the paper taxonomy so a reader can see +/// at a glance which layer fired: +/// +/// ```text +/// Overall Quality: severe | Turns: 7 (efficiency: 71.4%) +/// | Interaction — misalignment: 2 (sev 1), stagnation: 0, disengagement: 2 (sev 1), satisfaction: 0 +/// | Execution — failure: 0, loops: 0 +/// | Environment — exhaustion: 0 +/// | High misalignment rate: 50.0% of user turns +/// | Escalation requested: 1 +/// ``` +/// +/// Layer headers are always present (even when their counts are all zero) so +/// the taxonomy is visible by inspection. Quality-driving callouts — +/// "high misalignment rate", "looping detected", "escalation requested" — +/// are appended after the layer summary as a separate "alerts" tail. +fn generate_summary( + turn_metrics: &TurnMetrics, + interaction: &InteractionSignals, + execution: &ExecutionSignals, + environment: &EnvironmentSignals, + quality: InteractionQuality, +) -> String { + let mut parts: Vec = Vec::new(); + parts.push(format!("Overall Quality: {}", quality.as_str())); + parts.push(format!( + "Turns: {} (efficiency: {:.1}%)", + turn_metrics.total_turns, + turn_metrics.efficiency_score * 100.0 + )); + + parts.push(format!( + "Interaction \u{2014} {}, {}, {}, {}", + fmt_group("misalignment", &interaction.misalignment), + fmt_group("stagnation", &interaction.stagnation), + fmt_group("disengagement", &interaction.disengagement), + fmt_group("satisfaction", &interaction.satisfaction), + )); + parts.push(format!( + "Execution \u{2014} {}, {}", + fmt_group("failure", &execution.failure), + fmt_group("loops", &execution.loops), + )); + parts.push(format!( + "Environment \u{2014} {}", + fmt_group("exhaustion", &environment.exhaustion), + )); + + if interaction.misalignment.count > 0 { + let misalignment_ratio = interaction.misalignment_ratio(turn_metrics.user_turns); + if misalignment_ratio > 0.3 { + parts.push(format!( + "High misalignment rate: {:.1}% of user turns", + misalignment_ratio * 100.0 + )); + } + } + if interaction.stagnation.count > 2 { + parts.push(format!( + "Looping detected: {} repetitions", + interaction.stagnation.count + )); + } + let escalation_count = interaction + .disengagement + .signals + .iter() + .filter(|s| matches!(s.signal_type, SignalType::DisengagementEscalation)) + .count(); + if escalation_count > 0 { + parts.push(format!("Escalation requested: {}", escalation_count)); + } + + parts.join(" | ") +} + +/// Render `": (sev )"`, dropping the severity suffix +/// when the count is zero (keeps the summary readable for clean conversations). +fn fmt_group(name: &str, group: &super::schemas::SignalGroup) -> String { + if group.count == 0 { + format!("{}: 0", name) + } else { + format!("{}: {} (sev {})", name, group.count, group.severity) + } +} #[cfg(test)] mod tests { use super::*; - use hermesllm::apis::openai::MessageContent; - use hermesllm::transforms::lib::ExtractText; - use std::time::Instant; + use hermesllm::apis::openai::{Message, MessageContent, Role}; + #[allow(unused_imports)] + use hermesllm::transforms::ExtractText; - fn create_message(role: Role, content: &str) -> Message { + fn user(t: &str) -> Message { Message { - role, - content: Some(MessageContent::Text(content.to_string())), + role: Role::User, + content: Some(MessageContent::Text(t.to_string())), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + fn assistant(t: &str) -> Message { + Message { + role: Role::Assistant, + content: Some(MessageContent::Text(t.to_string())), name: None, tool_calls: None, tool_call_id: None, } } - // ======================================================================== - // Tests for New Similarity Methods - // ======================================================================== - #[test] - fn test_char_ngram_similarity_exact_match() { - let msg = NormalizedMessage::from_text("thank you very much"); - let similarity = msg.char_ngram_similarity("thank you very much"); - assert!( - similarity > 0.95, - "Exact match should have very high similarity" - ); + fn report_quality_neutral_for_short_clean_chat() { + let msgs = vec![ + user("Hello, can you help me with a question?"), + assistant("Of course, what's your question?"), + user("How does X work?"), + assistant("X works by ..."), + ]; + let r = SignalAnalyzer::default().analyze_openai(&msgs); + assert!(matches!( + r.overall_quality, + InteractionQuality::Neutral | InteractionQuality::Good | InteractionQuality::Excellent + )); + assert!(r.summary.starts_with("Overall Quality:")); } #[test] - fn test_char_ngram_similarity_typo() { - let msg = NormalizedMessage::from_text("thank you very much"); - // Common typo: "thnks" instead of "thanks" - let similarity = msg.char_ngram_similarity("thnks you very much"); - assert!( - similarity > 0.50, - "Should handle single-character typo with decent similarity: {}", - similarity - ); - } - - #[test] - fn test_char_ngram_similarity_small_edit() { - let msg = NormalizedMessage::from_text("this doesn't work"); - let similarity = msg.char_ngram_similarity("this doesnt work"); - assert!( - similarity > 0.70, - "Should handle punctuation removal gracefully: {}", - similarity - ); - } - - #[test] - fn test_char_ngram_similarity_word_insertion() { - let msg = NormalizedMessage::from_text("i don't understand"); - let similarity = msg.char_ngram_similarity("i really don't understand"); - assert!( - similarity > 0.40, - "Should be robust to word insertions: {}", - similarity - ); - } - - #[test] - fn test_token_cosine_similarity_exact_match() { - let msg = NormalizedMessage::from_text("this is not helpful"); - let similarity = msg.token_cosine_similarity("this is not helpful"); - assert!( - (similarity - 1.0).abs() < 0.01, - "Exact match should have cosine similarity of 1.0" - ); - } - - #[test] - fn test_token_cosine_similarity_word_order() { - let msg = NormalizedMessage::from_text("not helpful at all"); - let similarity = msg.token_cosine_similarity("helpful not at all"); - assert!( - similarity > 0.95, - "Should be robust to word order changes: {}", - similarity - ); - } - - #[test] - fn test_token_cosine_similarity_frequency() { - let msg = NormalizedMessage::from_text("help help help please"); - let similarity = msg.token_cosine_similarity("help please"); - assert!( - similarity > 0.7 && similarity < 1.0, - "Should account for frequency differences: {}", - similarity - ); - } - - #[test] - fn test_token_cosine_similarity_long_message_with_context() { - let msg = NormalizedMessage::from_text( - "I've been trying to set up my account for the past hour \ - and the verification email never arrived. I checked my spam folder \ - and still nothing. This is really frustrating and not helpful at all.", - ); - let similarity = msg.token_cosine_similarity("not helpful"); - assert!( - similarity > 0.15 && similarity < 0.7, - "Should detect pattern in long message with lower but non-zero similarity: {}", - similarity - ); - } - - #[test] - fn test_layered_matching_exact_hit() { - let msg = NormalizedMessage::from_text("thank you so much"); - assert!( - msg.layered_contains_phrase("thank you", 0.50, 0.60), - "Should match exact phrase in Layer 0" - ); - } - - #[test] - fn test_layered_matching_typo_hit() { - // Test that shows layered matching is more robust than exact matching alone - let msg = NormalizedMessage::from_text("it doesnt work for me"); - - // "doesnt work" should match "doesn't work" via character ngrams (high overlap) - assert!( - msg.layered_contains_phrase("doesn't work", 0.50, 0.60), - "Should match 'doesnt work' to 'doesn't work' via character ngrams" - ); - } - - #[test] - fn test_layered_matching_word_order_hit() { - let msg = NormalizedMessage::from_text("helpful not very"); - assert!( - msg.layered_contains_phrase("not helpful", 0.50, 0.60), - "Should match reordered words via token cosine in Layer 2" - ); - } - - #[test] - fn test_layered_matching_long_message_with_pattern() { - let msg = NormalizedMessage::from_text( - "I've tried everything and followed all the instructions \ - but this is not helpful at all and I'm getting frustrated", - ); - assert!( - msg.layered_contains_phrase("not helpful", 0.50, 0.60), - "Should detect pattern buried in long message" - ); - } - - #[test] - fn test_layered_matching_no_match() { - let msg = NormalizedMessage::from_text("everything is working perfectly"); - assert!( - !msg.layered_contains_phrase("not helpful", 0.50, 0.60), - "Should not match completely different content" - ); - } - - #[test] - fn test_char_ngram_vs_token_cosine_tradeoffs() { - // Character ngrams handle character-level changes well - let msg1 = NormalizedMessage::from_text("this doesnt work"); - let char_sim1 = msg1.char_ngram_similarity("this doesn't work"); - assert!( - char_sim1 > 0.70, - "Character ngrams should handle punctuation: {}", - char_sim1 - ); - - // Token cosine is better for word order and long messages with semantic overlap - let msg2 = - NormalizedMessage::from_text("I really appreciate all your help with this issue today"); - let token_sim2 = msg2.token_cosine_similarity("thank you for help"); - assert!( - token_sim2 > 0.15, - "Token cosine should detect semantic overlap: {}", - token_sim2 - ); - } - - // ======================================================================== - // Existing Tests - // ======================================================================== - - fn preprocess_messages(messages: &[Message]) -> Vec<(usize, Role, NormalizedMessage)> { - messages + fn report_severe_when_user_escalates() { + let msgs = vec![ + user("This isn't helpful at all"), + assistant("I'm sorry, can you tell me more?"), + user("Get me a human, this is useless"), + ]; + let r = SignalAnalyzer::default().analyze_openai(&msgs); + assert_eq!(r.overall_quality, InteractionQuality::Severe); + assert!(r + .interaction + .disengagement + .signals .iter() - .enumerate() - .map(|(i, msg)| { - let text = msg.content.extract_text(); - (i, msg.role.clone(), NormalizedMessage::from_text(&text)) - }) - .collect() + .any(|s| matches!(s.signal_type, SignalType::DisengagementEscalation))); } #[test] - fn test_turn_count_efficient() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Hello"), - create_message(Role::Assistant, "Hi! How can I help?"), - create_message(Role::User, "Thanks!"), + fn report_excellent_when_user_satisfied() { + let msgs = vec![ + user("Can you summarize this report?"), + assistant("Here's a summary: ..."), + user("That's perfect, exactly what I needed, you're awesome!"), ]; - - let signal = analyzer.analyze_turn_count(&messages); - assert_eq!(signal.total_turns, 3); - assert_eq!(signal.user_turns, 2); - assert_eq!(signal.assistant_turns, 1); - assert!(!signal.is_concerning); - assert!(!signal.is_excessive); - assert!(signal.efficiency_score > 0.9); - println!("test_turn_count_efficient took: {:?}", start.elapsed()); + let r = SignalAnalyzer::default().analyze_openai(&msgs); + assert!(r.interaction.satisfaction.count > 0); + assert!(matches!( + r.overall_quality, + InteractionQuality::Good | InteractionQuality::Excellent + )); } #[test] - fn test_turn_count_excessive() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let mut messages = Vec::new(); - for i in 0..15 { - messages.push(create_message( - if i % 2 == 0 { - Role::User - } else { - Role::Assistant - }, - &format!("Message {}", i), - )); - } - - let signal = analyzer.analyze_turn_count(&messages); - assert_eq!(signal.total_turns, 15); - assert!(signal.is_concerning); - assert!(signal.is_excessive); - assert!(signal.efficiency_score < 0.5); - println!("test_turn_count_excessive took: {:?}", start.elapsed()); - } - - #[test] - fn test_follow_up_detection() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Show me restaurants"), - create_message(Role::Assistant, "Here are some options"), - create_message(Role::User, "No, I meant Italian restaurants"), - create_message(Role::Assistant, "Here are Italian restaurants"), + fn repro_gratitude_does_not_trigger_misalignment() { + let msgs = vec![ + user("What is the weather in Istanbul?"), + assistant("Istanbul is 14C and partly cloudy."), + user("That worked, exactly what I needed. Thanks, that is perfect!"), ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_follow_up(&normalized_messages); - assert_eq!(signal.repair_count, 1); - assert!(signal.repair_ratio > 0.0); - println!("test_follow_up_detection took: {:?}", start.elapsed()); - } - - #[test] - fn test_follow_up_does_not_panic_with_filtered_messages() { - // Regression test: the preprocessing pipeline filters out messages - // without extractable text (tool calls, tool results, empty content). - // The stored tuple index `i` is the ORIGINAL-conversation index, so - // once anything is filtered out, `i` no longer matches the position - // inside `normalized_messages`. The old code used `*i` to index into - // `normalized_messages`, which panicked with "index out of bounds" - // when a user message appeared after filtered entries. - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - Message { - role: Role::User, - content: Some(hermesllm::apis::openai::MessageContent::Text( - "first question".to_string(), - )), - name: None, - tool_calls: None, - tool_call_id: None, - }, - // Assistant message with no text content (e.g. tool call) — filtered out. - Message { - role: Role::Assistant, - content: None, - name: None, - tool_calls: None, - tool_call_id: None, - }, - // Tool-role message with no extractable text — filtered out. - Message { - role: Role::Tool, - content: None, - name: None, - tool_calls: None, - tool_call_id: None, - }, - Message { - role: Role::Assistant, - content: Some(hermesllm::apis::openai::MessageContent::Text( - "some answer".to_string(), - )), - name: None, - tool_calls: None, - tool_call_id: None, - }, - // Rephrased user turn — original index 4, but after filtering - // only 3 messages remain in `normalized_messages` before it. - Message { - role: Role::User, - content: Some(hermesllm::apis::openai::MessageContent::Text( - "first question please".to_string(), - )), - name: None, - tool_calls: None, - tool_call_id: None, - }, - ]; - - // Must not panic — exercises the full analyze pipeline. - let _report = analyzer.analyze(&messages); - } - - #[test] - fn test_frustration_detection() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "THIS IS RIDICULOUS!!!"), - create_message(Role::Assistant, "I apologize for the frustration"), - create_message(Role::User, "This doesn't work at all"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - assert!(signal.has_frustration); - assert!(signal.frustration_count >= 2); - assert!(signal.severity > 0); - println!("test_frustration_detection took: {:?}", start.elapsed()); - } - - #[test] - fn test_positive_feedback_detection() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Can you help me?"), - create_message(Role::Assistant, "Sure!"), - create_message(Role::User, "Thank you! That's exactly what I needed."), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_positive_feedback(&normalized_messages); - assert!(signal.has_positive_feedback); - assert!(signal.positive_count >= 1); - assert!(signal.confidence > 0.5); - println!( - "test_positive_feedback_detection took: {:?}", - start.elapsed() - ); - } - - #[test] - fn test_escalation_detection() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "This isn't working"), - create_message(Role::Assistant, "Let me help"), - create_message(Role::User, "I need to speak to a human agent"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_escalation(&normalized_messages); - assert!(signal.escalation_requested); - assert_eq!(signal.escalation_count, 1); - println!("test_escalation_detection took: {:?}", start.elapsed()); - } - - #[test] - fn test_repetition_detection() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "What's the weather?"), - create_message( - Role::Assistant, - "I can help you with the weather information", - ), - create_message(Role::User, "Show me the forecast"), - create_message(Role::Assistant, "Sure, I can help you with the forecast"), - create_message(Role::User, "Stop repeating yourself"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_repetition(&normalized_messages); - - for rep in &signal.repetitions { - println!( - " - Messages {:?}, similarity: {:.3}, type: {:?}", - rep.message_indices, rep.similarity, rep.repetition_type + let r = SignalAnalyzer::default().analyze_openai(&msgs); + for s in &r.interaction.misalignment.signals { + eprintln!( + "misalignment fired: type={:?} idx={} snippet={:?} meta={:?}", + s.signal_type, s.message_index, s.snippet, s.metadata ); } - - assert!(signal.repetition_count > 0, - "Should detect the subtle repetition between 'I can help you with the weather information' \ - and 'Sure, I can help you with the forecast'"); - println!("test_repetition_detection took: {:?}", start.elapsed()); - } - - #[test] - fn test_full_analysis_excellent() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "I need to book a flight"), - create_message(Role::Assistant, "Sure! Where would you like to go?"), - create_message(Role::User, "New York"), - create_message(Role::Assistant, "Great! I found several options."), - create_message(Role::User, "Perfect!"), - ]; - - let report = analyzer.analyze(&messages); - assert!(matches!( - report.overall_quality, - InteractionQuality::Excellent | InteractionQuality::Good - )); - assert!(report.positive_feedback.has_positive_feedback); - assert!(!report.frustration.has_frustration); - println!("test_full_analysis_excellent took: {:?}", start.elapsed()); - } - - #[test] - fn test_full_analysis_poor() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Help me"), - create_message(Role::Assistant, "How can I assist?"), - create_message(Role::User, "No, I meant something else"), - create_message(Role::Assistant, "What do you need?"), - create_message(Role::User, "THIS DOESN'T WORK!!!"), - create_message(Role::Assistant, "I apologize"), - create_message(Role::User, "Let me speak to a human"), - ]; - - let report = analyzer.analyze(&messages); - assert!(matches!( - report.overall_quality, - InteractionQuality::Poor | InteractionQuality::Severe - )); - assert!(report.frustration.has_frustration); - assert!(report.escalation.escalation_requested); - println!("test_full_analysis_poor took: {:?}", start.elapsed()); - } - - #[test] - fn test_fuzzy_matching_gratitude() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Can you help me?"), - create_message(Role::Assistant, "Sure!"), - create_message(Role::User, "thnaks! that's exactly what i needed."), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_positive_feedback(&normalized_messages); - assert!(signal.has_positive_feedback); - assert!(signal.positive_count >= 1); - println!("test_fuzzy_matching_gratitude took: {:?}", start.elapsed()); - } - - #[test] - fn test_fuzzy_matching_escalation() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "This isn't working"), - create_message(Role::Assistant, "Let me help"), - create_message(Role::User, "i need to speek to a human agnet"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_escalation(&normalized_messages); - assert!(signal.escalation_requested); - assert_eq!(signal.escalation_count, 1); - println!("test_fuzzy_matching_escalation took: {:?}", start.elapsed()); - } - - #[test] - fn test_fuzzy_matching_repair() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Show me restaurants"), - create_message(Role::Assistant, "Here are some options"), - create_message(Role::User, "no i ment Italian restaurants"), - create_message(Role::Assistant, "Here are Italian restaurants"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_follow_up(&normalized_messages); - assert!(signal.repair_count >= 1); - println!("test_fuzzy_matching_repair took: {:?}", start.elapsed()); - } - - #[test] - fn test_fuzzy_matching_complaint() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - // Use a complaint that should match - "doesnt work" is close enough to "doesn't work" - let messages = vec![ - create_message(Role::User, "this doesnt work at all"), // Common typo: missing apostrophe - create_message(Role::Assistant, "I apologize"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - - // The layered matching should catch this via character ngrams or token cosine - // "doesnt work" has high character-level similarity to "doesn't work" - assert!( - signal.has_frustration, - "Should detect frustration from complaint pattern" - ); - assert!(signal.frustration_count >= 1); - println!("test_fuzzy_matching_complaint took: {:?}", start.elapsed()); - } - - #[test] - fn test_exact_match_priority() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "thank you so much")]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_positive_feedback(&normalized_messages); - assert!(signal.has_positive_feedback); - // Should detect exact match, not fuzzy - assert!(signal.indicators[0].snippet.contains("thank you")); - assert!(!signal.indicators[0].snippet.contains("fuzzy")); - println!("test_exact_match_priority took: {:?}", start.elapsed()); - } - - // ======================================================================== - // Anti-Tests: Verify fixes stay fixed - // ======================================================================== - - #[test] - fn test_hello_not_profanity() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "hello there")]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - assert!( - !signal.has_frustration, - "\"hello\" should not trigger profanity detection" - ); - } - - #[test] - fn test_prepare_not_escalation() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "Can you help me prepare for the meeting?", - )]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_escalation(&normalized_messages); - assert!( - !signal.escalation_requested, - "\"prepare\" should not trigger escalation (rep pattern removed)" - ); - } - - #[test] - fn test_unicode_apostrophe_confusion() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "I'm confused"), // Unicode apostrophe - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - assert!( - signal.has_frustration, - "Unicode apostrophe 'I'm confused' should trigger confusion" - ); - } - - #[test] - fn test_unicode_quotes_work() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "\u{201C}doesn\u{2019}t work\u{201D} with unicode quotes", - )]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - assert!( - signal.has_frustration, - "Unicode quotes should be normalized and match patterns" - ); - } - - #[test] - fn test_absolute_not_profanity() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "That's absolute nonsense")]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - // Should match on "nonsense" logic, not on "bs" substring - let has_bs_match = signal - .indicators - .iter() - .any(|ind| ind.snippet.contains("bs")); - assert!( - !has_bs_match, - "\"absolute\" should not trigger 'bs' profanity match" - ); - } - - #[test] - fn test_stopwords_not_rephrase() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Help me with X"), - create_message(Role::Assistant, "Sure"), - create_message(Role::User, "Help me with Y"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_follow_up(&normalized_messages); - // Should not detect as rephrase since only stopwords overlap assert_eq!( - signal.repair_count, 0, - "Messages with only stopword overlap should not be rephrases" + r.interaction.misalignment.count, 0, + "a pure gratitude message should not trigger repair/misalignment" ); + assert!(r.interaction.satisfaction.count > 0); } #[test] - fn test_frustrated_user_with_legitimate_repair() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - - use hermesllm::apis::openai::{FunctionCall, ToolCall}; - - // Helper to create a message with tool calls - let create_assistant_with_tools = - |content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message { - Message { - role: Role::Assistant, - content: Some(MessageContent::Text(content.to_string())), - name: None, - tool_calls: Some(vec![ToolCall { - id: tool_id.to_string(), - call_type: "function".to_string(), - function: FunctionCall { - name: tool_name.to_string(), - arguments: args.to_string(), - }, - }]), - tool_call_id: None, - } - }; - - // Helper to create a tool response message - let create_tool_message = |tool_call_id: &str, content: &str| -> Message { - Message { - role: Role::Tool, - content: Some(MessageContent::Text(content.to_string())), - name: None, - tool_calls: None, - tool_call_id: Some(tool_call_id.to_string()), - } - }; - - // Scenario: User DOES mention New York in first message, making "I already told you" legitimate - let messages = vec![ - create_message( - Role::User, - "I need to book a flight from New York to Paris for December 20th", - ), - create_assistant_with_tools( - "I'll help you search for flights to Paris.", - "call_123", - "search_flights", - r#"{"origin": "NYC", "destination": "Paris", "date": "2025-12-20"}"#, - ), - create_tool_message("call_123", r#"{"flights": []}"#), - create_message( - Role::Assistant, - "I couldn't find any flights. Could you provide your departure city?", - ), - create_message(Role::User, "I already told you, from New York!"), - create_assistant_with_tools( - "Let me try again.", - "call_456", - "search_flights", - r#"{"origin": "New York", "destination": "Paris", "date": "2025-12-20"}"#, - ), - create_tool_message("call_456", r#"{"flights": []}"#), - create_message( - Role::Assistant, - "I'm still not finding results. Let me check the system.", - ), - create_message( - Role::User, - "THIS IS RIDICULOUS!!! The tool doesn't work at all. Why do you keep calling it?", - ), - create_message( - Role::Assistant, - "I sincerely apologize for the frustration with the search tool.", - ), - create_message( - Role::User, - "Forget it. I need to speak to a human agent. This is a waste of time.", - ), + fn summary_groups_signals_by_taxonomy() { + // Even on a clean conversation the summary should expose the three + // layer headers so the taxonomy is visible. + let msgs = vec![ + user("Hello"), + assistant("Hi! How can I help?"), + user("What's 2 + 2?"), + assistant("4"), ]; - - let report = analyzer.analyze(&messages); - - // Tool messages should be filtered out, so we should only analyze text messages - // That's 4 user messages + 5 assistant text messages = 9 turns - assert_eq!( - report.turn_count.total_turns, 9, - "Should count 9 text messages (tool messages filtered out)" + let r = SignalAnalyzer::default().analyze_openai(&msgs); + assert!( + r.summary.contains("Interaction \u{2014}"), + "missing Interaction header in: {}", + r.summary ); assert!( - report.turn_count.is_concerning, - "Should flag concerning turn count" - ); - - // Should detect frustration (all caps, complaints) - assert!( - report.frustration.has_frustration, - "Should detect frustration" + r.summary.contains("Execution \u{2014}"), + "missing Execution header in: {}", + r.summary ); assert!( - report.frustration.frustration_count >= 2, - "Should detect multiple frustration indicators" - ); - assert!( - report.frustration.severity >= 2, - "Should have moderate or higher frustration severity" - ); - - // Should detect escalation request - assert!( - report.escalation.escalation_requested, - "Should detect escalation to human agent" - ); - assert!( - report.escalation.escalation_count >= 1, - "Should detect at least one escalation" - ); - - // Overall quality should be Poor or Severe - assert!( - matches!( - report.overall_quality, - InteractionQuality::Poor | InteractionQuality::Severe - ), - "Quality should be Poor or Severe, got {:?}", - report.overall_quality - ); - - println!( - "test_frustrated_user_with_legitimate_repair took: {:?}", - start.elapsed() + r.summary.contains("Environment \u{2014}"), + "missing Environment header in: {}", + r.summary ); + assert!(r.summary.contains("misalignment: 0")); + assert!(r.summary.contains("loops: 0")); + assert!(r.summary.contains("exhaustion: 0")); } #[test] - fn test_frustrated_user_false_claim() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - - use hermesllm::apis::openai::{FunctionCall, ToolCall}; - - // Helper to create a message with tool calls - let create_assistant_with_tools = - |content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message { - Message { - role: Role::Assistant, - content: Some(MessageContent::Text(content.to_string())), - name: None, - tool_calls: Some(vec![ToolCall { - id: tool_id.to_string(), - call_type: "function".to_string(), - function: FunctionCall { - name: tool_name.to_string(), - arguments: args.to_string(), - }, - }]), - tool_call_id: None, - } - }; - - // Helper to create a tool response message - let create_tool_message = |tool_call_id: &str, content: &str| -> Message { - Message { - role: Role::Tool, - content: Some(MessageContent::Text(content.to_string())), - name: None, - tool_calls: None, - tool_call_id: Some(tool_call_id.to_string()), - } - }; - - // Scenario: User NEVER mentions New York in first message but claims "I already told you" - // This represents realistic frustrated user behavior - exaggeration/misremembering - let messages = vec![ - create_message( - Role::User, - "I need to book a flight to Paris for December 20th", - ), - create_assistant_with_tools( - "I'll help you search for flights to Paris.", - "call_123", - "search_flights", - r#"{"destination": "Paris", "date": "2025-12-20"}"#, - ), - create_tool_message("call_123", r#"{"error": "origin required"}"#), - create_message( - Role::Assistant, - "I couldn't find any flights. Could you provide your departure city?", - ), - create_message(Role::User, "I already told you, from New York!"), // False claim - never mentioned it - create_assistant_with_tools( - "Let me try again.", - "call_456", - "search_flights", - r#"{"origin": "New York", "destination": "Paris", "date": "2025-12-20"}"#, - ), - create_tool_message("call_456", r#"{"flights": []}"#), - create_message( - Role::Assistant, - "I'm still not finding results. Let me check the system.", - ), - create_message( - Role::User, - "THIS IS RIDICULOUS!!! The tool doesn't work at all. Why do you keep calling it?", - ), - create_message( - Role::Assistant, - "I sincerely apologize for the frustration with the search tool.", - ), - create_message( - Role::User, - "Forget it. I need to speak to a human agent. This is a waste of time.", - ), + fn summary_includes_severity_when_signals_fire() { + let msgs = vec![ + user("This isn't helpful at all"), + assistant("I'm sorry, can you tell me more?"), + user("Get me a human, this is useless"), ]; - - let report = analyzer.analyze(&messages); - - // Tool messages should be filtered out, so we should only analyze text messages - // That's 4 user messages + 5 assistant text messages = 9 turns - assert_eq!( - report.turn_count.total_turns, 9, - "Should count 9 text messages (tool messages filtered out)" + let r = SignalAnalyzer::default().analyze_openai(&msgs); + // Disengagement fires; should render with `(sev N)` and the + // escalation-requested alert tail. + assert!( + r.summary.contains("disengagement:") && r.summary.contains("(sev "), + "expected severity rendered for disengagement: {}", + r.summary ); assert!( - report.turn_count.is_concerning, - "Should flag concerning turn count" - ); - - // Should detect frustration (all caps, complaints, false claims) - assert!( - report.frustration.has_frustration, - "Should detect frustration" - ); - assert!( - report.frustration.frustration_count >= 2, - "Should detect multiple frustration indicators" - ); - assert!( - report.frustration.severity >= 2, - "Should have moderate or higher frustration severity" - ); - - // Should detect escalation request - assert!( - report.escalation.escalation_requested, - "Should detect escalation to human agent" - ); - assert!( - report.escalation.escalation_count >= 1, - "Should detect at least one escalation" - ); - - // Note: May detect false positive "positive feedback" due to fuzzy matching - // e.g., "I already told YOU" matches "you rock", "THIS is RIDICULOUS" matches "this helps" - // However, the overall quality should still be Poor/Severe due to frustration+escalation - - // Overall quality should be Poor or Severe (frustration + escalation indicates poor interaction) - assert!( - matches!( - report.overall_quality, - InteractionQuality::Poor | InteractionQuality::Severe - ), - "Quality should be Poor or Severe for frustrated user with false claims, got {:?}", - report.overall_quality - ); - - println!( - "test_frustrated_user_false_claim took: {:?}", - start.elapsed() + r.summary.contains("Escalation requested:"), + "expected escalation alert in: {}", + r.summary ); } - // false negative tests #[test] - fn test_dissatisfaction_polite_not_working_for_me() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Thanks, but this still isn't working for me."), // Polite dissatisfaction, e.g., I appreciate it, but this isn't what I was looking for. - create_message(Role::Assistant, "Sorry—what error do you see?"), + fn execution_failures_lower_quality() { + let msgs = vec![ShareGptMessage { + from: "human", + value: "do the thing", + }]; + let _ = msgs; + // Build a synthetic ShareGPT input with multiple tool failures. + let convo = vec![ + ShareGptMessage { + from: "human", + value: "create a user", + }, + ShareGptMessage { + from: "function_call", + value: r#"{"name":"create_user","arguments":{"age":"twelve"}}"#, + }, + ShareGptMessage { + from: "observation", + value: "Error: validation failed - expected integer got string", + }, + ShareGptMessage { + from: "function_call", + value: r#"{"name":"create_user","arguments":{}}"#, + }, + ShareGptMessage { + from: "observation", + value: "missing required field: name", + }, ]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "Polite dissatisfaction should be detected" - ); - } - - #[test] - fn test_dissatisfaction_giving_up_without_escalation() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "Never mind, I'll figure it out myself.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_escalation(&normalized); - assert!( - signal.escalation_requested, - "Giving up should count as escalation/quit intent" - ); - } - - #[test] - fn test_dissatisfaction_same_problem_again() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "I'm running into the same issue again.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "'same issue again' should be detected" - ); - } - - #[test] - fn test_unsatisfied_incomplete() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "This feels incomplete.")]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "Should detect 'incomplete' dissatisfaction" - ); - } - - #[test] - fn test_low_mood_overwhelming() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "This is overwhelming and I'm not sure what to do.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!(signal.has_frustration, "Should detect overwhelmed language"); - } - - #[test] - fn test_low_mood_exhausted_trying() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "I'm exhausted trying to get this working.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "Should detect exhaustion/struggle language" - ); - } - - #[test] - fn test_common_polite_unresolved_dissatisfaction() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "I'm trying to set up SSH keys for GitHub."), - create_message( - Role::Assistant, - "Sure. First generate a key using ssh-keygen.", - ), - create_message(Role::User, "I did that already."), - create_message( - Role::Assistant, - "Then add the key to your GitHub account settings.", - ), - create_message(Role::User, "I've done that too."), - create_message( - Role::Assistant, - "After that, make sure your SSH agent is running.", - ), - create_message( - Role::User, - "Okay, but this still doesn't seem to fix the issue.", - ), - create_message(Role::Assistant, "What error message are you seeing?"), - create_message(Role::User, "It's just not connecting the way I expected."), - ]; - - let report = analyzer.analyze(&messages); - - // This is a common false negative if you only look for caps/profanity. - // Desired: detect dissatisfaction/frustration (or at least not rate as Excellent). - assert!( - report.frustration.has_frustration || report.follow_up.repair_count >= 1, - "Should detect polite unresolved dissatisfaction via frustration or follow-up indicators" - ); - - assert!( - !matches!(report.overall_quality, InteractionQuality::Excellent), - "Should not classify unresolved dissatisfaction as Excellent" - ); - } - - #[test] - fn test_common_resigned_giving_up_quietly() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message( - Role::User, - "Can you explain how to deploy this with Docker?", - ), - create_message( - Role::Assistant, - "You need to write a Dockerfile and build an image.", - ), - create_message(Role::User, "I tried that."), - create_message(Role::Assistant, "Then you can run docker-compose up."), - create_message(Role::User, "I did, but it didn’t really help."), - create_message(Role::Assistant, "What error are you getting?"), - create_message( - Role::User, - "Honestly, never mind. I’ll just try something else.", - ), - ]; - - let report = analyzer.analyze(&messages); - - // Many systems miss "never mind / I'll try something else" if they only look for "human agent". - assert!( - report.escalation.escalation_requested || report.frustration.has_frustration, - "Resigned quitting language should trigger escalation or frustration" - ); - - assert!( - matches!( - report.overall_quality, - InteractionQuality::Poor | InteractionQuality::Severe - ) || report.escalation.escalation_requested - || report.frustration.has_frustration, - "Giving up should not be classified as a high-quality interaction" - ); - } - - #[test] - fn test_common_discouraged_overwhelmed_low_mood() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "I'm trying to understand backpropagation."), - create_message( - Role::Assistant, - "It's a way to compute gradients efficiently.", - ), - create_message(Role::User, "I’ve read that explanation already."), - create_message(Role::Assistant, "Would you like a mathematical derivation?"), - create_message(Role::User, "Maybe, but I’m still having trouble following."), - create_message(Role::Assistant, "I can walk through a simple example."), - create_message( - Role::User, - "That might help, but honestly this is pretty overwhelming.", - ), - create_message(Role::Assistant, "Let’s slow it down step by step."), - create_message( - Role::User, - "Yeah… I’m just feeling kind of discouraged right now.", - ), - ]; - - let report = analyzer.analyze(&messages); - - // This is negative affect without caps/profanity. Should still count as frustration/negative signal. - assert!( - report.frustration.has_frustration, - "Overwhelmed/discouraged language should be detected as negative sentiment/frustration" - ); - - assert!( - !matches!(report.overall_quality, InteractionQuality::Excellent), - "Low-mood discouragement should not be classified as Excellent" - ); - } - - #[test] - fn test_common_misalignment_not_what_i_asked() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "How do I optimize this SQL query?"), - create_message( - Role::Assistant, - "You can add indexes to improve performance.", - ), - create_message(Role::User, "I already have indexes."), - create_message(Role::Assistant, "Then you could consider query caching."), - create_message(Role::User, "That’s not really what I was asking about."), - create_message( - Role::Assistant, - "What specifically are you trying to optimize?", - ), - create_message( - Role::User, - "The execution plan — this answer doesn’t address that.", - ), - ]; - - let report = analyzer.analyze(&messages); - - // Misalignment often shows as follow-up repair or frustration. - assert!( - report.follow_up.repair_count >= 1 || report.frustration.has_frustration, - "Misalignment ('not what I asked') should trigger repair or frustration signals" - ); - - assert!( - !matches!(report.overall_quality, InteractionQuality::Excellent), - "Misalignment should not be rated as Excellent" - ); - } - - #[test] - fn test_common_false_negative_polite_disappointment_complexity() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Can you help me write a regex for this?"), - create_message(Role::Assistant, "Sure, try this pattern: ^[a-z]+$"), - create_message(Role::User, "I tested it."), - create_message(Role::Assistant, "Did it work?"), - create_message(Role::User, "Not quite — it matches more than it should."), - create_message(Role::Assistant, "You can refine it with a lookahead."), - create_message( - Role::User, - "I see… this is more complicated than I expected.", - ), - ]; - - let report = analyzer.analyze(&messages); - - // Polite disappointment often becomes a false negative. - assert!( - report.frustration.has_frustration || report.follow_up.repair_count >= 1, - "Polite dissatisfaction ('not quite', 'more complicated than expected') should trigger a negative signal" - ); - - assert!( - !matches!(report.overall_quality, InteractionQuality::Excellent), - "Polite disappointment should not be classified as Excellent" - ); - } - - #[test] - fn test_catastrophic_failure_looping_assistant() { - let analyzer = TextBasedSignalAnalyzer::new(); - - // Catastrophic failure: assistant stuck in loop, user increasingly frustrated - let messages = vec![ - // Turn 1: Initial request - create_message(Role::User, "I need to cancel my subscription"), - create_message(Role::Assistant, "I can help you with account management. What would you like to do today?"), - - // Turn 3: First repair - repetition clarification - create_message(Role::User, "I just said I want to cancel my subscription"), - // Turn 4: Exact repetition from assistant - create_message(Role::Assistant, "I can help you with account management. What would you like to do today?"), - - // Turn 5: Repair (negation) + frustration (ALL CAPS) - create_message(Role::User, "No, I meant CANCEL. My. Subscription."), - // Turn 6: Near-duplicate repetition - create_message(Role::Assistant, "I understand you need help with your account. Let me assist you with that."), - - // Turn 7: Severe frustration (complaint + excessive punctuation + ALL CAPS) - create_message(Role::User, "This is not helpful!! I want to CANCEL my subscription!!!"), - // Turn 8: Exact repetition again (looping) - create_message(Role::Assistant, "I can help you with account management. What would you like to do today?"), - - // Turn 9: Repair (explicit rephrase) + frustration (complaint) - create_message(Role::User, "This doesnt work. Let me rephrase: I need to terminate my account and stop being charged"), - // Turn 10: Near-duplicate repetition (looping continues) - create_message(Role::Assistant, "I understand you need help with your account. Let me assist you with that."), - - // Turn 11: Escalation (multiple types) + frustration - create_message(Role::User, "I give up. Can I speak to a real person? This is a waste of time"), - ]; - - let report = analyzer.analyze(&messages); - - // Validate turn count - assert_eq!( - report.turn_count.total_turns, 11, - "Should have 11 total turns" - ); - assert_eq!(report.turn_count.user_turns, 6, "Should have 6 user turns"); - assert_eq!( - report.turn_count.assistant_turns, 5, - "Should have 5 assistant turns" - ); - assert!( - report.turn_count.is_concerning, - "11 turns should be concerning (>7)" - ); - assert!( - !report.turn_count.is_excessive, - "11 turns should not be excessive (<=12)" - ); - assert!( - report.turn_count.efficiency_score < 0.5, - "Efficiency should be low" - ); - - // Validate repair detection (USER signals - query reformulation) - // Detected repairs: - // 1. "I just said I want to cancel..." - pattern: "I just said" - // 2. "No, I meant CANCEL..." - pattern: "No, I meant" - // 3. "Let me rephrase: I need to terminate..." - pattern: "let me rephrase" - // Note: "This is not helpful!!" is frustration (not repair) - // Note: "I give up..." is escalation (not repair) - assert_eq!( - report.follow_up.repair_count, 3, - "Should detect exactly 3 repair attempts from user messages" - ); - assert_eq!( - report.follow_up.repair_ratio, 0.5, - "Repair ratio should be 0.5 (3 repairs / 6 user messages)" - ); - assert!( - report.follow_up.is_concerning, - "50% repair ratio should be highly concerning (threshold is 30%)" - ); - - // Validate frustration detection - assert!( - report.frustration.has_frustration, - "Should detect frustration" - ); - assert!( - report.frustration.frustration_count >= 4, - "Should detect multiple frustration indicators: found {}", - report.frustration.frustration_count - ); - assert!( - report.frustration.severity >= 2, - "Should be at least moderate frustration" - ); - - // Validate repetition/looping detection (ASSISTANT signals - not following instructions) - // The assistant repeats the same unhelpful responses multiple times: - // 1. "I can help you with account management..." appears 3 times (exact repetition) - // 2. "I understand you need help with your account..." appears 2 times (near-duplicate) - assert!( - report.repetition.repetition_count >= 4, - "Should detect at least 4 assistant repetitions (exact + near-duplicates)" - ); - assert!( - report.repetition.has_looping, - "Should detect looping (>2 repetitions indicates stuck agent)" - ); - assert!( - report.repetition.severity >= 2, - "Should be moderate to severe looping (assistant not adapting)" - ); - - // Validate escalation detection - assert!( - report.escalation.escalation_requested, - "Should detect escalation request" - ); - assert!( - report.escalation.escalation_count >= 2, - "Should detect multiple escalation indicators: 'give up' + 'speak to a real person'" - ); - - // Validate overall quality - assert_eq!(report.overall_quality, InteractionQuality::Severe, "Should be classified as Severe due to escalation + excessive frustration + looping + high repair ratio"); + let r = SignalAnalyzer::default().analyze_sharegpt(&convo); + assert!(r.execution.failure.count >= 1); + assert!(r.quality_score < 50.0); } } diff --git a/crates/brightstaff/src/signals/environment/exhaustion.rs b/crates/brightstaff/src/signals/environment/exhaustion.rs new file mode 100644 index 00000000..142e7d6e --- /dev/null +++ b/crates/brightstaff/src/signals/environment/exhaustion.rs @@ -0,0 +1,347 @@ +//! Environment exhaustion detector. Direct port of +//! `signals/environment/exhaustion.py`. + +use std::sync::OnceLock; + +use regex::Regex; +use serde_json::json; + +use crate::signals::analyzer::ShareGptMessage; +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType}; + +pub const API_ERROR_PATTERNS: &[&str] = &[ + r"500\s*(internal\s+)?server\s+error", + r"502\s*bad\s+gateway", + r"503\s*service\s+unavailable", + r"504\s*gateway\s+timeout", + r"internal\s+server\s+error", + r"service\s+unavailable", + r"server\s+error", + r"backend\s+error", + r"upstream\s+error", + r"service\s+temporarily\s+unavailable", + r"maintenance\s+mode", + r"under\s+maintenance", + r"try\s+again\s+later", + r"temporarily\s+unavailable", + r"system\s+error", + r"unexpected\s+error", + r"unhandled\s+exception", +]; + +pub const TIMEOUT_PATTERNS: &[&str] = &[ + r"timeout", + r"timed?\s*out", + r"etimedout", + r"connection\s+timed?\s*out", + r"read\s+timed?\s*out", + r"request\s+timed?\s*out", + r"gateway\s+timeout", + r"deadline\s+exceeded", + r"took\s+too\s+long", + r"operation\s+timed?\s*out", + r"socket\s+timeout", +]; + +pub const RATE_LIMIT_PATTERNS: &[&str] = &[ + r"rate\s+limit", + r"rate.limited", + r"(status|error|http)\s*:?\s*429", + r"429\s+(too\s+many|rate|limit)", + r"too\s+many\s+requests?", + r"quota\s+exceeded", + r"quota\s+limit", + r"throttl(ed|ing)", + r"request\s+limit", + r"api\s+limit", + r"calls?\s+per\s+(second|minute|hour|day)", + r"exceeded\s+.*\s+limit", + r"slow\s+down", + r"retry\s+after", + r"requests?\s+exceeded", +]; + +pub const NETWORK_PATTERNS: &[&str] = &[ + r"connection\s+refused", + r"econnrefused", + r"econnreset", + r"connection\s+reset", + r"enotfound", + r"dns\s+(error|failure|lookup)", + r"host\s+not\s+found", + r"network\s+(error|failure|unreachable)", + r"no\s+route\s+to\s+host", + r"socket\s+error", + r"connection\s+failed", + r"unable\s+to\s+connect", + r"cannot\s+connect", + r"could\s+not\s+connect", + r"connect\s+error", + r"ssl\s+(error|handshake|certificate)", + r"certificate\s+(error|invalid|expired)", +]; + +pub const MALFORMED_PATTERNS: &[&str] = &[ + r"json\s+parse\s+error", + r"invalid\s+json", + r"unexpected\s+token", + r"syntax\s+error.*json", + r"malformed\s+(response|json|data)", + r"unexpected\s+end\s+of", + r"parse\s+error", + r"parsing\s+failed", + r"invalid\s+response", + r"unexpected\s+response", + r"response\s+format", + r"missing\s+field.*response", + r"unexpected\s+schema", + r"schema\s+validation", + r"deserialization\s+error", + r"failed\s+to\s+decode", +]; + +pub const CONTEXT_OVERFLOW_PATTERNS: &[&str] = &[ + r"context\s+(length|limit|overflow|exceeded)", + r"token\s+(limit|overflow|exceeded)", + r"max(imum)?\s+tokens?", + r"input\s+too\s+(long|large)", + r"exceeds?\s+(context|token|character|input)\s+limit", + r"message\s+too\s+(long|large)", + r"content\s+too\s+(long|large)", + r"truncat(ed|ion)\s+(due\s+to|because|for)\s+(length|size|limit)", + r"maximum\s+context", + r"prompt\s+too\s+(long|large)", +]; + +fn compile(patterns: &[&str]) -> Regex { + let combined = patterns + .iter() + .map(|p| format!("({})", p)) + .collect::>() + .join("|"); + Regex::new(&format!("(?i){}", combined)).expect("exhaustion pattern regex must compile") +} + +fn api_error_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(API_ERROR_PATTERNS)) +} +fn timeout_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(TIMEOUT_PATTERNS)) +} +fn rate_limit_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(RATE_LIMIT_PATTERNS)) +} +fn network_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(NETWORK_PATTERNS)) +} +fn malformed_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(MALFORMED_PATTERNS)) +} +fn context_overflow_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(CONTEXT_OVERFLOW_PATTERNS)) +} + +fn snippet_around(text: &str, m: regex::Match<'_>, context: usize) -> String { + let start = m.start().saturating_sub(context); + let end = (m.end() + context).min(text.len()); + let start = align_char_boundary(text, start, false); + let end = align_char_boundary(text, end, true); + let mut snippet = String::new(); + if start > 0 { + snippet.push_str("..."); + } + snippet.push_str(&text[start..end]); + if end < text.len() { + snippet.push_str("..."); + } + snippet +} + +fn align_char_boundary(s: &str, mut idx: usize, forward: bool) -> usize { + if idx >= s.len() { + return s.len(); + } + while !s.is_char_boundary(idx) { + if forward { + idx += 1; + } else if idx == 0 { + break; + } else { + idx -= 1; + } + } + idx +} + +pub fn analyze_exhaustion(messages: &[ShareGptMessage<'_>]) -> SignalGroup { + let mut group = SignalGroup::new("exhaustion"); + + for (i, msg) in messages.iter().enumerate() { + if msg.from != "observation" { + continue; + } + let value = msg.value; + let lower = value.to_lowercase(); + + if let Some(m) = rate_limit_re().find(&lower) { + group.add_signal(emit( + SignalType::EnvironmentExhaustionRateLimit, + i, + snippet_around(value, m, 50), + 0.95, + "rate_limit", + m.as_str(), + )); + continue; + } + + if let Some(m) = api_error_re().find(&lower) { + group.add_signal(emit( + SignalType::EnvironmentExhaustionApiError, + i, + snippet_around(value, m, 50), + 0.9, + "api_error", + m.as_str(), + )); + continue; + } + + if let Some(m) = timeout_re().find(&lower) { + group.add_signal(emit( + SignalType::EnvironmentExhaustionTimeout, + i, + snippet_around(value, m, 50), + 0.9, + "timeout", + m.as_str(), + )); + continue; + } + + if let Some(m) = network_re().find(&lower) { + group.add_signal(emit( + SignalType::EnvironmentExhaustionNetwork, + i, + snippet_around(value, m, 50), + 0.9, + "network", + m.as_str(), + )); + continue; + } + + if let Some(m) = malformed_re().find(&lower) { + group.add_signal(emit( + SignalType::EnvironmentExhaustionMalformed, + i, + snippet_around(value, m, 50), + 0.85, + "malformed_response", + m.as_str(), + )); + continue; + } + + if let Some(m) = context_overflow_re().find(&lower) { + group.add_signal(emit( + SignalType::EnvironmentExhaustionContextOverflow, + i, + snippet_around(value, m, 50), + 0.9, + "context_overflow", + m.as_str(), + )); + } + } + + group +} + +fn emit( + t: SignalType, + idx: usize, + snippet: String, + confidence: f32, + kind: &str, + matched: &str, +) -> SignalInstance { + SignalInstance::new(t, idx, snippet) + .with_confidence(confidence) + .with_metadata(json!({ + "exhaustion_type": kind, + "matched": matched, + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn obs(value: &str) -> ShareGptMessage<'_> { + ShareGptMessage { + from: "observation", + value, + } + } + + #[test] + fn detects_rate_limit() { + let g = analyze_exhaustion(&[obs("HTTP 429: too many requests, retry after 30s")]); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionRateLimit))); + } + + #[test] + fn detects_api_error() { + let g = analyze_exhaustion(&[obs("503 service unavailable - try again later")]); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionApiError))); + } + + #[test] + fn detects_timeout() { + let g = analyze_exhaustion(&[obs("Connection timed out after 30 seconds")]); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionTimeout))); + } + + #[test] + fn detects_network_failure() { + let g = analyze_exhaustion(&[obs("ECONNREFUSED: connection refused by remote host")]); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionNetwork))); + } + + #[test] + fn detects_malformed_response() { + let g = analyze_exhaustion(&[obs("Invalid JSON: unexpected token at position 42")]); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::EnvironmentExhaustionMalformed))); + } + + #[test] + fn detects_context_overflow() { + let g = analyze_exhaustion(&[obs("Maximum context length exceeded for this model")]); + assert!(g.signals.iter().any(|s| matches!( + s.signal_type, + SignalType::EnvironmentExhaustionContextOverflow + ))); + } +} diff --git a/crates/brightstaff/src/signals/environment/mod.rs b/crates/brightstaff/src/signals/environment/mod.rs new file mode 100644 index 00000000..97d9b300 --- /dev/null +++ b/crates/brightstaff/src/signals/environment/mod.rs @@ -0,0 +1,3 @@ +//! Environment signals: exhaustion (external system failures and constraints). + +pub mod exhaustion; diff --git a/crates/brightstaff/src/signals/execution/failure.rs b/crates/brightstaff/src/signals/execution/failure.rs new file mode 100644 index 00000000..3e171446 --- /dev/null +++ b/crates/brightstaff/src/signals/execution/failure.rs @@ -0,0 +1,388 @@ +//! Execution failure detector. Direct port of `signals/execution/failure.py`. + +use std::sync::OnceLock; + +use regex::Regex; +use serde_json::json; + +use crate::signals::analyzer::ShareGptMessage; +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType}; + +pub const INVALID_ARGS_PATTERNS: &[&str] = &[ + r"invalid\s+argument", + r"invalid\s+parameter", + r"invalid\s+type", + r"type\s*error", + r"expected\s+\w+\s*,?\s*got\s+\w+", + r"required\s+field", + r"required\s+parameter", + r"missing\s+required", + r"missing\s+argument", + r"validation\s+failed", + r"validation\s+error", + r"invalid\s+value", + r"invalid\s+format", + r"must\s+be\s+(a|an)\s+\w+", + r"cannot\s+be\s+(null|empty|none)", + r"is\s+not\s+valid", + r"does\s+not\s+match", + r"out\s+of\s+range", + r"invalid\s+date", + r"invalid\s+json", + r"malformed\s+request", +]; + +pub const BAD_QUERY_PATTERNS: &[&str] = &[ + r"invalid\s+query", + r"query\s+syntax\s+error", + r"malformed\s+query", + r"unknown\s+field", + r"invalid\s+field", + r"invalid\s+filter", + r"invalid\s+search", + r"unknown\s+id", + r"invalid\s+id", + r"id\s+format\s+error", + r"invalid\s+identifier", + r"query\s+failed", + r"search\s+error", + r"invalid\s+operator", + r"unsupported\s+query", +]; + +pub const TOOL_NOT_FOUND_PATTERNS: &[&str] = &[ + r"unknown\s+function", + r"unknown\s+tool", + r"function\s+not\s+found", + r"tool\s+not\s+found", + r"no\s+such\s+function", + r"no\s+such\s+tool", + r"undefined\s+function", + r"action\s+not\s+supported", + r"invalid\s+tool", + r"invalid\s+function", + r"unrecognized\s+function", +]; + +pub const AUTH_MISUSE_PATTERNS: &[&str] = &[ + r"\bunauthorized\b", + r"(status|error|http|code)\s*:?\s*401", + r"401\s+unauthorized", + r"403\s+forbidden", + r"permission\s+denied", + r"access\s+denied", + r"authentication\s+required", + r"invalid\s+credentials", + r"invalid\s+token", + r"token\s+expired", + r"missing\s+authorization", + r"\bforbidden\b", + r"not\s+authorized", + r"insufficient\s+permissions?", +]; + +pub const STATE_ERROR_PATTERNS: &[&str] = &[ + r"invalid\s+state", + r"illegal\s+state", + r"must\s+call\s+\w+\s+first", + r"must\s+\w+\s+before", + r"cannot\s+\w+\s+before", + r"already\s+(exists?|created|started|finished)", + r"not\s+initialized", + r"not\s+started", + r"already\s+in\s+progress", + r"operation\s+in\s+progress", + r"sequence\s+error", + r"precondition\s+failed", + r"(status|error|http)\s*:?\s*409", + r"409\s+conflict", + r"\bconflict\b", +]; + +fn compile(patterns: &[&str]) -> Regex { + // Use `(?i)` flag for case-insensitive matching, matching Python's `re.IGNORECASE`. + let combined = patterns + .iter() + .map(|p| format!("({})", p)) + .collect::>() + .join("|"); + Regex::new(&format!("(?i){}", combined)).expect("failure pattern regex must compile") +} + +fn invalid_args_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(INVALID_ARGS_PATTERNS)) +} +fn bad_query_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(BAD_QUERY_PATTERNS)) +} +fn tool_not_found_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(TOOL_NOT_FOUND_PATTERNS)) +} +fn auth_misuse_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(AUTH_MISUSE_PATTERNS)) +} +fn state_error_re() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| compile(STATE_ERROR_PATTERNS)) +} + +/// Pull tool name + args from a `function_call` message. Mirrors +/// `_extract_tool_info` in the reference. +pub(crate) fn extract_tool_info(value: &str) -> (String, String) { + if let Ok(parsed) = serde_json::from_str::(value) { + if let Some(obj) = parsed.as_object() { + let name = obj + .get("name") + .or_else(|| obj.get("function")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let args = match obj.get("arguments").or_else(|| obj.get("args")) { + Some(serde_json::Value::Object(o)) => { + serde_json::to_string(&serde_json::Value::Object(o.clone())).unwrap_or_default() + } + Some(other) => other + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(|| serde_json::to_string(other).unwrap_or_default()), + None => String::new(), + }; + return (name, args); + } + } + let mut snippet: String = value.chars().take(200).collect(); + snippet.shrink_to_fit(); + ("unknown".to_string(), snippet) +} + +/// Build a context-window snippet around a regex match, with leading/trailing +/// ellipses when truncated. Mirrors `_get_snippet`. +fn snippet_around(text: &str, m: regex::Match<'_>, context: usize) -> String { + let start = m.start().saturating_sub(context); + let end = (m.end() + context).min(text.len()); + // Ensure we cut on UTF-8 boundaries. + let start = align_char_boundary(text, start, false); + let end = align_char_boundary(text, end, true); + let mut snippet = String::new(); + if start > 0 { + snippet.push_str("..."); + } + snippet.push_str(&text[start..end]); + if end < text.len() { + snippet.push_str("..."); + } + snippet +} + +fn align_char_boundary(s: &str, mut idx: usize, forward: bool) -> usize { + if idx >= s.len() { + return s.len(); + } + while !s.is_char_boundary(idx) { + if forward { + idx += 1; + } else if idx == 0 { + break; + } else { + idx -= 1; + } + } + idx +} + +pub fn analyze_failure(messages: &[ShareGptMessage<'_>]) -> SignalGroup { + let mut group = SignalGroup::new("failure"); + let mut last_call: Option<(usize, String, String)> = None; + + for (i, msg) in messages.iter().enumerate() { + match msg.from { + "function_call" => { + let (name, args) = extract_tool_info(msg.value); + last_call = Some((i, name, args)); + continue; + } + "observation" => {} + _ => continue, + } + + let value = msg.value; + let lower = value.to_lowercase(); + let (call_index, tool_name) = match &last_call { + Some((idx, name, _)) => (*idx, name.clone()), + None => (i.saturating_sub(1), "unknown".to_string()), + }; + + if let Some(m) = invalid_args_re().find(&lower) { + group.add_signal( + SignalInstance::new( + SignalType::ExecutionFailureInvalidArgs, + i, + snippet_around(value, m, 50), + ) + .with_confidence(0.9) + .with_metadata(json!({ + "tool_name": tool_name, + "call_index": call_index, + "error_type": "invalid_args", + "matched": m.as_str(), + })), + ); + continue; + } + + if let Some(m) = tool_not_found_re().find(&lower) { + group.add_signal( + SignalInstance::new( + SignalType::ExecutionFailureToolNotFound, + i, + snippet_around(value, m, 50), + ) + .with_confidence(0.95) + .with_metadata(json!({ + "tool_name": tool_name, + "call_index": call_index, + "error_type": "tool_not_found", + "matched": m.as_str(), + })), + ); + continue; + } + + if let Some(m) = auth_misuse_re().find(&lower) { + group.add_signal( + SignalInstance::new( + SignalType::ExecutionFailureAuthMisuse, + i, + snippet_around(value, m, 50), + ) + .with_confidence(0.8) + .with_metadata(json!({ + "tool_name": tool_name, + "call_index": call_index, + "error_type": "auth_misuse", + "matched": m.as_str(), + })), + ); + continue; + } + + if let Some(m) = state_error_re().find(&lower) { + group.add_signal( + SignalInstance::new( + SignalType::ExecutionFailureStateError, + i, + snippet_around(value, m, 50), + ) + .with_confidence(0.85) + .with_metadata(json!({ + "tool_name": tool_name, + "call_index": call_index, + "error_type": "state_error", + "matched": m.as_str(), + })), + ); + continue; + } + + if let Some(m) = bad_query_re().find(&lower) { + let confidence = if ["error", "invalid", "failed"] + .iter() + .any(|w| lower.contains(w)) + { + 0.8 + } else { + 0.6 + }; + group.add_signal( + SignalInstance::new( + SignalType::ExecutionFailureBadQuery, + i, + snippet_around(value, m, 50), + ) + .with_confidence(confidence) + .with_metadata(json!({ + "tool_name": tool_name, + "call_index": call_index, + "error_type": "bad_query", + "matched": m.as_str(), + })), + ); + } + } + + group +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fc(value: &str) -> ShareGptMessage<'_> { + ShareGptMessage { + from: "function_call", + value, + } + } + fn obs(value: &str) -> ShareGptMessage<'_> { + ShareGptMessage { + from: "observation", + value, + } + } + + #[test] + fn detects_invalid_args() { + let msgs = vec![ + fc(r#"{"name":"create_user","arguments":{"age":"twelve"}}"#), + obs("Error: validation failed - expected integer got string for field age"), + ]; + let g = analyze_failure(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionFailureInvalidArgs))); + } + + #[test] + fn detects_tool_not_found() { + let msgs = vec![ + fc(r#"{"name":"send_thought","arguments":{}}"#), + obs("Error: unknown function 'send_thought'"), + ]; + let g = analyze_failure(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionFailureToolNotFound))); + } + + #[test] + fn detects_auth_misuse() { + let msgs = vec![ + fc(r#"{"name":"get_secret","arguments":{}}"#), + obs("HTTP 401 Unauthorized"), + ]; + let g = analyze_failure(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionFailureAuthMisuse))); + } + + #[test] + fn detects_state_error() { + let msgs = vec![ + fc(r#"{"name":"commit_tx","arguments":{}}"#), + obs("must call begin_tx first"), + ]; + let g = analyze_failure(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionFailureStateError))); + } +} diff --git a/crates/brightstaff/src/signals/execution/loops.rs b/crates/brightstaff/src/signals/execution/loops.rs new file mode 100644 index 00000000..70b90e83 --- /dev/null +++ b/crates/brightstaff/src/signals/execution/loops.rs @@ -0,0 +1,433 @@ +//! Execution loops detector. Direct port of `signals/execution/loops.py`. + +use serde_json::json; + +use crate::signals::analyzer::ShareGptMessage; +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType}; + +pub const RETRY_THRESHOLD: usize = 3; +pub const PARAMETER_DRIFT_THRESHOLD: usize = 3; +pub const OSCILLATION_CYCLES_THRESHOLD: usize = 3; + +#[derive(Debug, Clone)] +pub struct ToolCall { + pub index: usize, + pub name: String, + /// Canonical JSON string of arguments (sorted keys when parseable). + pub args: String, + pub args_dict: Option>, +} + +impl ToolCall { + pub fn args_equal(&self, other: &ToolCall) -> bool { + match (&self.args_dict, &other.args_dict) { + (Some(a), Some(b)) => a == b, + _ => self.args == other.args, + } + } +} + +fn parse_tool_call(index: usize, msg: &ShareGptMessage<'_>) -> Option { + if msg.from != "function_call" { + return None; + } + let value = msg.value; + + if let Ok(parsed) = serde_json::from_str::(value) { + if let Some(obj) = parsed.as_object() { + let name = obj + .get("name") + .or_else(|| obj.get("function")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let raw_args = obj.get("arguments").or_else(|| obj.get("args")); + let (args_str, args_dict) = match raw_args { + Some(serde_json::Value::Object(o)) => { + let mut keys: Vec<&String> = o.keys().collect(); + keys.sort(); + let mut canon = serde_json::Map::new(); + for k in keys { + canon.insert(k.clone(), o[k].clone()); + } + ( + serde_json::to_string(&serde_json::Value::Object(canon.clone())) + .unwrap_or_default(), + Some(canon), + ) + } + Some(other) => ( + other + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(|| serde_json::to_string(other).unwrap_or_default()), + None, + ), + None => (String::new(), None), + }; + return Some(ToolCall { + index, + name, + args: args_str, + args_dict, + }); + } + } + + if let Some(paren) = value.find('(') { + if paren > 0 { + let name = value[..paren].trim().to_string(); + let args_part = &value[paren..]; + if args_part.starts_with('(') && args_part.ends_with(')') { + let inner = args_part[1..args_part.len() - 1].trim(); + if let Ok(serde_json::Value::Object(o)) = + serde_json::from_str::(inner) + { + let mut keys: Vec<&String> = o.keys().collect(); + keys.sort(); + let mut canon = serde_json::Map::new(); + for k in keys { + canon.insert(k.clone(), o[k].clone()); + } + return Some(ToolCall { + index, + name, + args: serde_json::to_string(&serde_json::Value::Object(canon.clone())) + .unwrap_or_default(), + args_dict: Some(canon), + }); + } + return Some(ToolCall { + index, + name, + args: inner.to_string(), + args_dict: None, + }); + } + return Some(ToolCall { + index, + name, + args: args_part.to_string(), + args_dict: None, + }); + } + } + + Some(ToolCall { + index, + name: value.trim().to_string(), + args: String::new(), + args_dict: None, + }) +} + +fn extract_tool_calls(messages: &[ShareGptMessage<'_>]) -> Vec { + let mut out = Vec::new(); + for (i, msg) in messages.iter().enumerate() { + if let Some(c) = parse_tool_call(i, msg) { + out.push(c); + } + } + out +} + +fn detect_retry(calls: &[ToolCall]) -> Vec<(usize, usize, String)> { + if calls.len() < RETRY_THRESHOLD { + return Vec::new(); + } + let mut patterns = Vec::new(); + let mut i = 0; + while i < calls.len() { + let current = &calls[i]; + let mut j = i + 1; + let mut run_length = 1; + while j < calls.len() { + if calls[j].name == current.name && calls[j].args_equal(current) { + run_length += 1; + j += 1; + } else { + break; + } + } + if run_length >= RETRY_THRESHOLD { + patterns.push((calls[i].index, calls[j - 1].index, current.name.clone())); + i = j; + } else { + i += 1; + } + } + patterns +} + +fn detect_parameter_drift(calls: &[ToolCall]) -> Vec<(usize, usize, String, usize)> { + if calls.len() < PARAMETER_DRIFT_THRESHOLD { + return Vec::new(); + } + let mut patterns = Vec::new(); + let mut i = 0; + while i < calls.len() { + let current_name = calls[i].name.clone(); + let mut seen_args: Vec = vec![calls[i].args.clone()]; + let mut unique_args = 1; + let mut j = i + 1; + while j < calls.len() { + if calls[j].name != current_name { + break; + } + if !seen_args.iter().any(|a| a == &calls[j].args) { + seen_args.push(calls[j].args.clone()); + unique_args += 1; + } + j += 1; + } + let run_length = j - i; + if run_length >= PARAMETER_DRIFT_THRESHOLD && unique_args >= 2 { + patterns.push(( + calls[i].index, + calls[j - 1].index, + current_name, + unique_args, + )); + i = j; + } else { + i += 1; + } + } + patterns +} + +fn detect_oscillation(calls: &[ToolCall]) -> Vec<(usize, usize, Vec, usize)> { + let min_calls = 2 * OSCILLATION_CYCLES_THRESHOLD; + if calls.len() < min_calls { + return Vec::new(); + } + let mut patterns = Vec::new(); + let mut i: usize = 0; + while i + min_calls <= calls.len() { + let max_pat_len = (5usize).min(calls.len() - i); + let mut found_for_i = false; + for pat_len in 2..=max_pat_len { + let pattern_names: Vec = + (0..pat_len).map(|k| calls[i + k].name.clone()).collect(); + let unique: std::collections::HashSet<&String> = pattern_names.iter().collect(); + if unique.len() < 2 { + continue; + } + let mut cycles = 1; + let mut pos = i + pat_len; + while pos + pat_len <= calls.len() { + let mut all_match = true; + for k in 0..pat_len { + if calls[pos + k].name != pattern_names[k] { + all_match = false; + break; + } + } + if all_match { + cycles += 1; + pos += pat_len; + } else { + break; + } + } + if cycles >= OSCILLATION_CYCLES_THRESHOLD { + let end_idx_in_calls = i + (cycles * pat_len) - 1; + patterns.push(( + calls[i].index, + calls[end_idx_in_calls].index, + pattern_names, + cycles, + )); + // Mirror Python: `i = end_idx + 1 - pattern_len`. We set `i` so that + // the next outer iteration begins after we account for overlap. + i = end_idx_in_calls + 1 - pat_len; + found_for_i = true; + break; + } + } + if !found_for_i { + i += 1; + } else { + // Match Python's `i = end_idx + 1 - pattern_len; break` then loop. + // We'll continue; the outer while re-checks i. + } + } + if patterns.len() > 1 { + patterns = deduplicate_patterns(patterns); + } + patterns +} + +fn deduplicate_patterns( + mut patterns: Vec<(usize, usize, Vec, usize)>, +) -> Vec<(usize, usize, Vec, usize)> { + if patterns.is_empty() { + return patterns; + } + patterns.sort_by(|a, b| { + let ord = a.0.cmp(&b.0); + if ord != std::cmp::Ordering::Equal { + ord + } else { + (b.1 - b.0).cmp(&(a.1 - a.0)) + } + }); + let mut result = Vec::new(); + let mut last_end: i64 = -1; + for p in patterns { + if (p.0 as i64) > last_end { + last_end = p.1 as i64; + result.push(p); + } + } + result +} + +pub fn analyze_loops(messages: &[ShareGptMessage<'_>]) -> SignalGroup { + let mut group = SignalGroup::new("loops"); + let calls = extract_tool_calls(messages); + if calls.len() < RETRY_THRESHOLD { + return group; + } + + let retries = detect_retry(&calls); + for (start_idx, end_idx, tool_name) in &retries { + let call_count = calls + .iter() + .filter(|c| *start_idx <= c.index && c.index <= *end_idx) + .count(); + group.add_signal( + SignalInstance::new( + SignalType::ExecutionLoopsRetry, + *start_idx, + format!( + "Tool '{}' called {} times with identical arguments", + tool_name, call_count + ), + ) + .with_confidence(0.95) + .with_metadata(json!({ + "tool_name": tool_name, + "start_index": start_idx, + "end_index": end_idx, + "call_count": call_count, + "loop_type": "retry", + })), + ); + } + + let drifts = detect_parameter_drift(&calls); + for (start_idx, end_idx, tool_name, variation_count) in &drifts { + let overlaps_retry = retries + .iter() + .any(|r| !(*end_idx < r.0 || *start_idx > r.1)); + if overlaps_retry { + continue; + } + let call_count = calls + .iter() + .filter(|c| *start_idx <= c.index && c.index <= *end_idx) + .count(); + group.add_signal( + SignalInstance::new( + SignalType::ExecutionLoopsParameterDrift, + *start_idx, + format!( + "Tool '{}' called {} times with {} different argument variations", + tool_name, call_count, variation_count + ), + ) + .with_confidence(0.85) + .with_metadata(json!({ + "tool_name": tool_name, + "start_index": start_idx, + "end_index": end_idx, + "call_count": call_count, + "variation_count": variation_count, + "loop_type": "parameter_drift", + })), + ); + } + + let oscillations = detect_oscillation(&calls); + for (start_idx, end_idx, tool_names, cycle_count) in &oscillations { + let pattern_str = tool_names.join(" \u{2192} "); + group.add_signal( + SignalInstance::new( + SignalType::ExecutionLoopsOscillation, + *start_idx, + format!( + "Oscillation pattern [{}] repeated {} times", + pattern_str, cycle_count + ), + ) + .with_confidence(0.9) + .with_metadata(json!({ + "pattern": tool_names, + "start_index": start_idx, + "end_index": end_idx, + "cycle_count": cycle_count, + "loop_type": "oscillation", + })), + ); + } + + group +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fc(value: &str) -> ShareGptMessage<'_> { + ShareGptMessage { + from: "function_call", + value, + } + } + + #[test] + fn detects_retry_loop() { + let arg = r#"{"name":"check_status","arguments":{"id":"abc"}}"#; + let msgs = vec![fc(arg), fc(arg), fc(arg), fc(arg)]; + let g = analyze_loops(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionLoopsRetry))); + } + + #[test] + fn detects_parameter_drift() { + let msgs = vec![ + fc(r#"{"name":"search","arguments":{"q":"a"}}"#), + fc(r#"{"name":"search","arguments":{"q":"ab"}}"#), + fc(r#"{"name":"search","arguments":{"q":"abc"}}"#), + fc(r#"{"name":"search","arguments":{"q":"abcd"}}"#), + ]; + let g = analyze_loops(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionLoopsParameterDrift))); + } + + #[test] + fn detects_oscillation() { + let a = r#"{"name":"toolA","arguments":{}}"#; + let b = r#"{"name":"toolB","arguments":{}}"#; + let msgs = vec![fc(a), fc(b), fc(a), fc(b), fc(a), fc(b)]; + let g = analyze_loops(&msgs); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::ExecutionLoopsOscillation))); + } + + #[test] + fn no_signals_when_few_calls() { + let msgs = vec![fc(r#"{"name":"only_once","arguments":{}}"#)]; + let g = analyze_loops(&msgs); + assert!(g.signals.is_empty()); + } +} diff --git a/crates/brightstaff/src/signals/execution/mod.rs b/crates/brightstaff/src/signals/execution/mod.rs new file mode 100644 index 00000000..87dc28c4 --- /dev/null +++ b/crates/brightstaff/src/signals/execution/mod.rs @@ -0,0 +1,5 @@ +//! Execution signals: failure (agent-caused tool errors) and loops +//! (repetitive tool-call behavior). + +pub mod failure; +pub mod loops; diff --git a/crates/brightstaff/src/signals/interaction/constants.rs b/crates/brightstaff/src/signals/interaction/constants.rs new file mode 100644 index 00000000..2301395c --- /dev/null +++ b/crates/brightstaff/src/signals/interaction/constants.rs @@ -0,0 +1,193 @@ +//! Shared constants for the interaction layer detectors. +//! +//! Direct port of `signals/interaction/constants.py`. + +use std::collections::HashSet; +use std::sync::OnceLock; + +pub const POSITIVE_PREFIXES: &[&str] = &[ + "yes", + "yeah", + "yep", + "yup", + "sure", + "ok", + "okay", + "great", + "awesome", + "perfect", + "thanks", + "thank", + "wonderful", + "excellent", + "amazing", + "nice", + "good", + "cool", + "absolutely", + "definitely", + "please", +]; + +pub const CONFIRMATION_PREFIXES: &[&str] = &[ + "yes", + "yeah", + "yep", + "yup", + "correct", + "right", + "that's correct", + "thats correct", + "that's right", + "thats right", + "that is correct", + "that is right", +]; + +const STOPWORD_LIST: &[&str] = &[ + "a", + "about", + "above", + "after", + "again", + "against", + "all", + "am", + "an", + "and", + "any", + "are", + "as", + "at", + "be", + "because", + "been", + "before", + "being", + "below", + "between", + "both", + "but", + "by", + "can", + "could", + "did", + "do", + "does", + "doing", + "down", + "during", + "each", + "few", + "for", + "from", + "further", + "had", + "has", + "have", + "having", + "he", + "her", + "here", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "i", + "if", + "in", + "into", + "is", + "it", + "its", + "itself", + "just", + "me", + "more", + "most", + "my", + "myself", + "no", + "nor", + "not", + "now", + "of", + "off", + "on", + "once", + "only", + "or", + "other", + "our", + "ours", + "ourselves", + "out", + "over", + "own", + "same", + "she", + "should", + "so", + "some", + "such", + "than", + "that", + "the", + "their", + "theirs", + "them", + "themselves", + "then", + "there", + "these", + "they", + "this", + "those", + "through", + "to", + "too", + "under", + "until", + "up", + "very", + "was", + "we", + "were", + "what", + "when", + "where", + "which", + "while", + "who", + "whom", + "why", + "with", + "would", + "you", + "your", + "yours", + "yourself", + "yourselves", +]; + +pub fn stopwords() -> &'static HashSet<&'static str> { + static SET: OnceLock> = OnceLock::new(); + SET.get_or_init(|| STOPWORD_LIST.iter().copied().collect()) +} + +/// Returns true if `text` (case-insensitive, trimmed) starts with any of the +/// given prefixes treated as **whole tokens or token sequences**. This matches +/// the Python's `text_lower.startswith(prefix)` plus the natural intent that +/// `"please"` shouldn't fire on `"pleased"`. +pub fn starts_with_prefix(text: &str, prefixes: &[&str]) -> bool { + let lowered = text.to_lowercase(); + let trimmed = lowered.trim_start(); + for prefix in prefixes { + if trimmed.starts_with(prefix) { + return true; + } + } + false +} diff --git a/crates/brightstaff/src/signals/interaction/disengagement.rs b/crates/brightstaff/src/signals/interaction/disengagement.rs new file mode 100644 index 00000000..28711d18 --- /dev/null +++ b/crates/brightstaff/src/signals/interaction/disengagement.rs @@ -0,0 +1,445 @@ +//! Disengagement signals: escalation, quit, negative stance. +//! +//! Direct port of `signals/interaction/disengagement.py`. + +use std::sync::OnceLock; + +use regex::Regex; +use serde_json::json; + +use super::constants::{starts_with_prefix, POSITIVE_PREFIXES}; +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType}; +use crate::signals::text_processing::{normalize_patterns, NormalizedMessage, NormalizedPattern}; + +const ESCALATION_PATTERN_TEXTS: &[&str] = &[ + // Human requests + "speak to a human", + "talk to a human", + "connect me to a human", + "connect me with a human", + "transfer me to a human", + "get me a human", + "chat with a human", + // Person requests + "speak to a person", + "talk to a person", + "connect me to a person", + "connect me with a person", + "transfer me to a person", + "get me a person", + "chat with a person", + // Real person requests + "speak to a real person", + "talk to a real person", + "connect me to a real person", + "connect me with a real person", + "transfer me to a real person", + "get me a real person", + "chat with a real person", + // Actual person requests + "speak to an actual person", + "talk to an actual person", + "connect me to an actual person", + "connect me with an actual person", + "transfer me to an actual person", + "get me an actual person", + "chat with an actual person", + // Supervisor requests + "speak to a supervisor", + "talk to a supervisor", + "connect me to a supervisor", + "connect me with a supervisor", + "transfer me to a supervisor", + "get me a supervisor", + "chat with a supervisor", + // Manager requests + "speak to a manager", + "talk to a manager", + "connect me to a manager", + "connect me with a manager", + "transfer me to a manager", + "get me a manager", + "chat with a manager", + // Customer service requests + "speak to customer service", + "talk to customer service", + "connect me to customer service", + "connect me with customer service", + "transfer me to customer service", + "get me customer service", + "chat with customer service", + // Customer support requests + "speak to customer support", + "talk to customer support", + "connect me to customer support", + "connect me with customer support", + "transfer me to customer support", + "get me customer support", + "chat with customer support", + // Support requests + "speak to support", + "talk to support", + "connect me to support", + "connect me with support", + "transfer me to support", + "get me support", + "chat with support", + // Tech support requests + "speak to tech support", + "talk to tech support", + "connect me to tech support", + "connect me with tech support", + "transfer me to tech support", + "get me tech support", + "chat with tech support", + // Help desk requests + "speak to help desk", + "talk to help desk", + "connect me to help desk", + "connect me with help desk", + "transfer me to help desk", + "get me help desk", + "chat with help desk", + // Explicit escalation + "escalate this", +]; + +const QUIT_PATTERN_TEXTS: &[&str] = &[ + "i give up", + "i'm giving up", + "im giving up", + "i'm going to quit", + "i quit", + "forget it", + "forget this", + "screw it", + "screw this", + "don't bother trying", + "don't bother with this", + "don't bother with it", + "don't even bother", + "why bother", + "not worth it", + "this is hopeless", + "going elsewhere", + "try somewhere else", + "look elsewhere", +]; + +const NEGATIVE_STANCE_PATTERN_TEXTS: &[&str] = &[ + "this is useless", + "not helpful", + "doesn't help", + "not helping", + "you're not helping", + "youre not helping", + "this doesn't work", + "this doesnt work", + "this isn't working", + "this isnt working", + "still doesn't work", + "still doesnt work", + "still not working", + "still isn't working", + "still isnt working", + "waste of time", + "wasting my time", + "this is ridiculous", + "this is absurd", + "this is insane", + "this is stupid", + "this is dumb", + "this sucks", + "this is frustrating", + "not good enough", + "why can't you", + "why cant you", + "same issue", + "did that already", + "done that already", + "tried that already", + "already tried that", + "i've done that", + "ive done that", + "i've tried that", + "ive tried that", + "i'm disappointed", + "im disappointed", + "disappointed with you", + "disappointed in you", + "useless bot", + "dumb bot", + "stupid bot", +]; + +const AGENT_DIRECTED_PROFANITY_PATTERN_TEXTS: &[&str] = &[ + "this is bullshit", + "what bullshit", + "such bullshit", + "total bullshit", + "complete bullshit", + "this is crap", + "what crap", + "this is shit", + "what the hell is wrong with you", + "what the fuck is wrong with you", + "you're fucking useless", + "youre fucking useless", + "you are fucking useless", + "fucking useless", + "this bot is shit", + "this bot is crap", + "damn bot", + "fucking bot", + "stupid fucking", + "are you fucking kidding", + "wtf is wrong with you", + "wtf is this", + "ffs just", + "for fucks sake", + "for fuck's sake", + "what the f**k", + "what the f*ck", + "what the f***", + "that's bullsh*t", + "thats bullsh*t", + "that's bull***t", + "thats bull***t", + "that's bs", + "thats bs", + "this is bullsh*t", + "this is bull***t", + "this is bs", +]; + +fn escalation_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(ESCALATION_PATTERN_TEXTS)) +} + +fn quit_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(QUIT_PATTERN_TEXTS)) +} + +fn negative_stance_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(NEGATIVE_STANCE_PATTERN_TEXTS)) +} + +fn profanity_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(AGENT_DIRECTED_PROFANITY_PATTERN_TEXTS)) +} + +fn re_consecutive_q() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| Regex::new(r"\?{2,}").unwrap()) +} +fn re_consecutive_e() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| Regex::new(r"!{2,}").unwrap()) +} +fn re_mixed_punct() -> &'static Regex { + static R: OnceLock = OnceLock::new(); + R.get_or_init(|| Regex::new(r"[?!]{3,}").unwrap()) +} + +pub fn analyze_disengagement( + normalized_messages: &[(usize, &str, NormalizedMessage)], + char_ngram_threshold: f32, + token_cosine_threshold: f32, +) -> SignalGroup { + let mut group = SignalGroup::new("disengagement"); + + for (idx, role, norm_msg) in normalized_messages { + if *role != "human" { + continue; + } + + let text = &norm_msg.raw; + + // All-caps shouting check. + let alpha_chars: String = text.chars().filter(|c| c.is_alphabetic()).collect(); + if alpha_chars.chars().count() >= 10 { + let upper_count = alpha_chars.chars().filter(|c| c.is_uppercase()).count(); + let upper_ratio = upper_count as f32 / alpha_chars.chars().count() as f32; + if upper_ratio >= 0.8 { + let snippet: String = text.chars().take(50).collect(); + group.add_signal( + SignalInstance::new(SignalType::DisengagementNegativeStance, *idx, snippet) + .with_metadata(json!({ + "indicator_type": "all_caps", + "upper_ratio": upper_ratio, + })), + ); + } + } + + // Excessive consecutive punctuation. + let starts_with_positive = starts_with_prefix(text, POSITIVE_PREFIXES); + let cq = re_consecutive_q().find_iter(text).count(); + let ce = re_consecutive_e().find_iter(text).count(); + let mixed = re_mixed_punct().find_iter(text).count(); + if !starts_with_positive && (cq >= 1 || ce >= 1 || mixed >= 1) { + let snippet: String = text.chars().take(50).collect(); + group.add_signal( + SignalInstance::new(SignalType::DisengagementNegativeStance, *idx, snippet) + .with_metadata(json!({ + "indicator_type": "excessive_punctuation", + "consecutive_questions": cq, + "consecutive_exclamations": ce, + "mixed_punctuation": mixed, + })), + ); + } + + // Escalation patterns. + let mut found_escalation = false; + for pattern in escalation_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::DisengagementEscalation, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({"pattern_type": "escalation"})), + ); + found_escalation = true; + break; + } + } + + // Quit patterns (independent of escalation). + for pattern in quit_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new(SignalType::DisengagementQuit, *idx, pattern.raw.clone()) + .with_metadata(json!({"pattern_type": "quit"})), + ); + break; + } + } + + // Profanity (more specific) before generic negative stance. + let mut found_profanity = false; + for pattern in profanity_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::DisengagementNegativeStance, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({ + "indicator_type": "profanity", + "pattern": pattern.raw, + })), + ); + found_profanity = true; + break; + } + } + + if !found_escalation && !found_profanity { + for pattern in negative_stance_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::DisengagementNegativeStance, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({ + "indicator_type": "complaint", + "pattern": pattern.raw, + })), + ); + break; + } + } + } + } + + group +} + +#[cfg(test)] +mod tests { + use super::*; + + fn nm(s: &str) -> NormalizedMessage { + NormalizedMessage::from_text(s, 2000) + } + + #[test] + fn detects_human_escalation_request() { + let msgs = vec![( + 0usize, + "human", + nm("This is taking forever, get me a human"), + )]; + let g = analyze_disengagement(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::DisengagementEscalation))); + } + + #[test] + fn detects_quit_intent() { + let msgs = vec![(0usize, "human", nm("Forget it, I give up"))]; + let g = analyze_disengagement(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::DisengagementQuit))); + } + + #[test] + fn detects_negative_stance_complaint() { + let msgs = vec![(0usize, "human", nm("This is useless"))]; + let g = analyze_disengagement(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::DisengagementNegativeStance))); + } + + #[test] + fn detects_excessive_punctuation_as_negative_stance() { + let msgs = vec![(0usize, "human", nm("WHY isn't this working???"))]; + let g = analyze_disengagement(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::DisengagementNegativeStance))); + } + + #[test] + fn positive_excitement_is_not_disengagement() { + let msgs = vec![(0usize, "human", nm("Yes!! That's perfect!!!"))]; + let g = analyze_disengagement(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .all(|s| !matches!(s.signal_type, SignalType::DisengagementNegativeStance))); + } +} diff --git a/crates/brightstaff/src/signals/interaction/misalignment.rs b/crates/brightstaff/src/signals/interaction/misalignment.rs new file mode 100644 index 00000000..3dcf3ddd --- /dev/null +++ b/crates/brightstaff/src/signals/interaction/misalignment.rs @@ -0,0 +1,338 @@ +//! Misalignment signals: corrections, rephrases, clarifications. +//! +//! Direct port of `signals/interaction/misalignment.py`. + +use std::sync::OnceLock; + +use serde_json::json; + +use super::constants::{stopwords, CONFIRMATION_PREFIXES}; +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType}; +use crate::signals::text_processing::{normalize_patterns, NormalizedMessage, NormalizedPattern}; + +const CORRECTION_PATTERN_TEXTS: &[&str] = &[ + "no, i meant", + "no i meant", + "no, i said", + "no i said", + "no, i asked", + "no i asked", + "nah, i meant", + "nope, i meant", + "not what i said", + "not what i asked", + "that's not what i said", + "that's not what i asked", + "that's not what i meant", + "thats not what i said", + "thats not what i asked", + "thats not what i meant", + "that's not what you", + "no that's not what i", + "no, that's not what i", + "you're not quite right", + "youre not quite right", + "you're not exactly right", + "youre not exactly right", + "you're wrong about", + "youre wrong about", + "i just said", + "i already said", + "i already told you", +]; + +const REPHRASE_PATTERN_TEXTS: &[&str] = &[ + "let me rephrase", + "let me explain again", + "what i'm trying to say", + "what i'm saying is", + "in other words", +]; + +const CLARIFICATION_PATTERN_TEXTS: &[&str] = &[ + "i don't understand", + "don't understand", + "not understanding", + "can't understand", + "don't get it", + "don't follow", + "i'm confused", + "so confused", + "makes no sense", + "doesn't make sense", + "not making sense", + "what do you mean", + "what does that mean", + "what are you saying", + "i'm lost", + "totally lost", + "lost me", + "no clue what you", + "no idea what you", + "no clue what that", + "no idea what that", + "come again", + "say that again", + "repeat that", + "trouble following", + "hard to follow", + "can't follow", +]; + +fn correction_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(CORRECTION_PATTERN_TEXTS)) +} + +fn rephrase_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(REPHRASE_PATTERN_TEXTS)) +} + +fn clarification_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(CLARIFICATION_PATTERN_TEXTS)) +} + +fn is_confirmation_message(text: &str) -> bool { + let lowered = text.to_lowercase(); + let trimmed = lowered.trim(); + CONFIRMATION_PREFIXES.iter().any(|p| trimmed.starts_with(p)) +} + +/// Detect whether two user messages appear to be rephrases of each other. +pub fn is_similar_rephrase( + norm_msg1: &NormalizedMessage, + norm_msg2: &NormalizedMessage, + overlap_threshold: f32, + min_meaningful_tokens: usize, + max_new_content_ratio: f32, +) -> bool { + if norm_msg1.tokens.len() < 3 || norm_msg2.tokens.len() < 3 { + return false; + } + if is_confirmation_message(&norm_msg1.raw) { + return false; + } + + let stops = stopwords(); + let tokens1: std::collections::HashSet<&str> = norm_msg1 + .tokens + .iter() + .filter(|t| !stops.contains(t.as_str())) + .map(|s| s.as_str()) + .collect(); + let tokens2: std::collections::HashSet<&str> = norm_msg2 + .tokens + .iter() + .filter(|t| !stops.contains(t.as_str())) + .map(|s| s.as_str()) + .collect(); + + if tokens1.len() < min_meaningful_tokens || tokens2.len() < min_meaningful_tokens { + return false; + } + + let new_tokens: std::collections::HashSet<&&str> = tokens1.difference(&tokens2).collect(); + let new_content_ratio = if tokens1.is_empty() { + 0.0 + } else { + new_tokens.len() as f32 / tokens1.len() as f32 + }; + if new_content_ratio > max_new_content_ratio { + return false; + } + + let intersection = tokens1.intersection(&tokens2).count(); + let min_size = tokens1.len().min(tokens2.len()); + if min_size == 0 { + return false; + } + let overlap_ratio = intersection as f32 / min_size as f32; + overlap_ratio >= overlap_threshold +} + +/// Analyze user messages for misalignment signals. +pub fn analyze_misalignment( + normalized_messages: &[(usize, &str, NormalizedMessage)], + char_ngram_threshold: f32, + token_cosine_threshold: f32, +) -> SignalGroup { + let mut group = SignalGroup::new("misalignment"); + + let mut prev_user_idx: Option = None; + let mut prev_user_msg: Option<&NormalizedMessage> = None; + + for (idx, role, norm_msg) in normalized_messages { + if *role != "human" { + continue; + } + + let mut found_in_turn = false; + + for pattern in correction_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::MisalignmentCorrection, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({"pattern_type": "correction"})), + ); + found_in_turn = true; + break; + } + } + + if found_in_turn { + prev_user_idx = Some(*idx); + prev_user_msg = Some(norm_msg); + continue; + } + + for pattern in rephrase_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::MisalignmentRephrase, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({"pattern_type": "rephrase"})), + ); + found_in_turn = true; + break; + } + } + + if found_in_turn { + prev_user_idx = Some(*idx); + prev_user_msg = Some(norm_msg); + continue; + } + + for pattern in clarification_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::MisalignmentClarification, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({"pattern_type": "clarification"})), + ); + found_in_turn = true; + break; + } + } + + if found_in_turn { + prev_user_idx = Some(*idx); + prev_user_msg = Some(norm_msg); + continue; + } + + // Semantic rephrase vs the previous user message (recent only). + if let (Some(prev_idx), Some(prev_msg)) = (prev_user_idx, prev_user_msg) { + let turns_between = idx.saturating_sub(prev_idx); + if turns_between <= 3 && is_similar_rephrase(norm_msg, prev_msg, 0.75, 4, 0.5) { + group.add_signal( + SignalInstance::new( + SignalType::MisalignmentRephrase, + *idx, + "[similar rephrase detected]", + ) + .with_confidence(0.8) + .with_metadata(json!({ + "pattern_type": "semantic_rephrase", + "compared_to": prev_idx, + })), + ); + } + } + + prev_user_idx = Some(*idx); + prev_user_msg = Some(norm_msg); + } + + group +} + +#[cfg(test)] +mod tests { + use super::*; + + fn nm(s: &str) -> NormalizedMessage { + NormalizedMessage::from_text(s, 2000) + } + + fn make(items: &[(&'static str, &str)]) -> Vec<(usize, &'static str, NormalizedMessage)> { + items + .iter() + .enumerate() + .map(|(i, (role, text))| (i, *role, nm(text))) + .collect() + } + + #[test] + fn detects_explicit_correction() { + let msgs = make(&[ + ("human", "Show me my orders"), + ("gpt", "Sure, here are your invoices"), + ("human", "No, I meant my recent orders"), + ]); + let g = analyze_misalignment(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::MisalignmentCorrection))); + } + + #[test] + fn detects_rephrase_marker() { + let msgs = make(&[ + ("human", "Show me X"), + ("gpt", "Sure"), + ("human", "Let me rephrase: I want X grouped by date"), + ]); + let g = analyze_misalignment(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::MisalignmentRephrase))); + } + + #[test] + fn detects_clarification_request() { + let msgs = make(&[ + ("human", "Run the report"), + ("gpt", "Foobar quux baz."), + ("human", "I don't understand what you mean"), + ]); + let g = analyze_misalignment(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::MisalignmentClarification))); + } + + #[test] + fn confirmation_is_not_a_rephrase() { + let m1 = nm("Yes, that's correct, please proceed with the order"); + let m2 = nm("please proceed with the order for the same product"); + assert!(!is_similar_rephrase(&m1, &m2, 0.75, 4, 0.5)); + } +} diff --git a/crates/brightstaff/src/signals/interaction/mod.rs b/crates/brightstaff/src/signals/interaction/mod.rs new file mode 100644 index 00000000..b60a6748 --- /dev/null +++ b/crates/brightstaff/src/signals/interaction/mod.rs @@ -0,0 +1,10 @@ +//! Interaction signals: misalignment, stagnation, disengagement, satisfaction. +//! +//! These signals capture how the dialogue itself unfolds (semantic alignment, +//! progress, engagement, closure) independent of tool execution outcomes. + +pub mod constants; +pub mod disengagement; +pub mod misalignment; +pub mod satisfaction; +pub mod stagnation; diff --git a/crates/brightstaff/src/signals/interaction/satisfaction.rs b/crates/brightstaff/src/signals/interaction/satisfaction.rs new file mode 100644 index 00000000..ad719960 --- /dev/null +++ b/crates/brightstaff/src/signals/interaction/satisfaction.rs @@ -0,0 +1,177 @@ +//! Satisfaction signals: gratitude, confirmation, success. +//! +//! Direct port of `signals/interaction/satisfaction.py`. + +use std::sync::OnceLock; + +use serde_json::json; + +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType}; +use crate::signals::text_processing::{normalize_patterns, NormalizedMessage, NormalizedPattern}; + +const GRATITUDE_PATTERN_TEXTS: &[&str] = &[ + "that's helpful", + "that helps", + "this helps", + "appreciate it", + "appreciate that", + "that's perfect", + "exactly what i needed", + "just what i needed", + "you're the best", + "you rock", + "you're awesome", + "you're amazing", + "you're great", +]; + +const CONFIRMATION_PATTERN_TEXTS: &[&str] = &[ + "that works", + "this works", + "that's great", + "that's amazing", + "this is great", + "that's awesome", + "love it", + "love this", + "love that", +]; + +const SUCCESS_PATTERN_TEXTS: &[&str] = &[ + "it worked", + "that worked", + "this worked", + "it's working", + "that's working", + "this is working", +]; + +fn gratitude_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(GRATITUDE_PATTERN_TEXTS)) +} + +fn confirmation_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(CONFIRMATION_PATTERN_TEXTS)) +} + +fn success_patterns() -> &'static Vec { + static PATS: OnceLock> = OnceLock::new(); + PATS.get_or_init(|| normalize_patterns(SUCCESS_PATTERN_TEXTS)) +} + +pub fn analyze_satisfaction( + normalized_messages: &[(usize, &str, NormalizedMessage)], + char_ngram_threshold: f32, + token_cosine_threshold: f32, +) -> SignalGroup { + let mut group = SignalGroup::new("satisfaction"); + + for (idx, role, norm_msg) in normalized_messages { + if *role != "human" { + continue; + } + + let mut found = false; + + for pattern in gratitude_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::SatisfactionGratitude, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({"pattern_type": "gratitude"})), + ); + found = true; + break; + } + } + if found { + continue; + } + + for pattern in confirmation_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new( + SignalType::SatisfactionConfirmation, + *idx, + pattern.raw.clone(), + ) + .with_metadata(json!({"pattern_type": "confirmation"})), + ); + found = true; + break; + } + } + if found { + continue; + } + + for pattern in success_patterns() { + if norm_msg.matches_normalized_pattern( + pattern, + char_ngram_threshold, + token_cosine_threshold, + ) { + group.add_signal( + SignalInstance::new(SignalType::SatisfactionSuccess, *idx, pattern.raw.clone()) + .with_metadata(json!({"pattern_type": "success"})), + ); + break; + } + } + } + + group +} + +#[cfg(test)] +mod tests { + use super::*; + + fn nm(s: &str) -> NormalizedMessage { + NormalizedMessage::from_text(s, 2000) + } + + #[test] + fn detects_gratitude() { + let msgs = vec![(0usize, "human", nm("That's perfect, appreciate it!"))]; + let g = analyze_satisfaction(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::SatisfactionGratitude))); + } + + #[test] + fn detects_confirmation() { + let msgs = vec![(0usize, "human", nm("That works for me, thanks"))]; + let g = analyze_satisfaction(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::SatisfactionConfirmation))); + } + + #[test] + fn detects_success() { + let msgs = vec![(0usize, "human", nm("Great, it worked!"))]; + let g = analyze_satisfaction(&msgs, 0.65, 0.6); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::SatisfactionSuccess))); + } +} diff --git a/crates/brightstaff/src/signals/interaction/stagnation.rs b/crates/brightstaff/src/signals/interaction/stagnation.rs new file mode 100644 index 00000000..d7d03c80 --- /dev/null +++ b/crates/brightstaff/src/signals/interaction/stagnation.rs @@ -0,0 +1,241 @@ +//! Stagnation signals: dragging (turn-count efficiency) and repetition. +//! +//! Direct port of `signals/interaction/stagnation.py`. + +use serde_json::json; + +use super::constants::{starts_with_prefix, POSITIVE_PREFIXES}; +use crate::signals::schemas::{SignalGroup, SignalInstance, SignalType, TurnMetrics}; +use crate::signals::text_processing::NormalizedMessage; + +/// Adapter row used by stagnation::dragging detector. Mirrors the ShareGPT +/// `{"from": role, "value": text}` shape used in the Python reference. +pub struct ShareGptMsg<'a> { + pub from: &'a str, +} + +pub fn analyze_dragging( + messages: &[ShareGptMsg<'_>], + baseline_turns: usize, + efficiency_threshold: f32, +) -> (SignalGroup, TurnMetrics) { + let mut group = SignalGroup::new("stagnation"); + + let mut user_turns: usize = 0; + let mut assistant_turns: usize = 0; + for m in messages { + match m.from { + "human" => user_turns += 1, + "gpt" => assistant_turns += 1, + _ => {} + } + } + + let total_turns = user_turns; + let efficiency_score: f32 = if total_turns == 0 || total_turns <= baseline_turns { + 1.0 + } else { + let excess = (total_turns - baseline_turns) as f32; + 1.0 / (1.0 + excess * 0.25) + }; + + let is_dragging = efficiency_score < efficiency_threshold; + let metrics = TurnMetrics { + total_turns, + user_turns, + assistant_turns, + is_dragging, + efficiency_score, + }; + + if is_dragging { + let last_idx = messages.len().saturating_sub(1); + group.add_signal( + SignalInstance::new( + SignalType::StagnationDragging, + last_idx, + format!( + "Conversation dragging: {} turns (efficiency: {:.2})", + total_turns, efficiency_score + ), + ) + .with_confidence(1.0 - efficiency_score) + .with_metadata(json!({ + "total_turns": total_turns, + "efficiency_score": efficiency_score, + "baseline_turns": baseline_turns, + })), + ); + } + + (group, metrics) +} + +pub fn analyze_repetition( + normalized_messages: &[(usize, &str, NormalizedMessage)], + lookback: usize, + exact_threshold: f32, + near_duplicate_threshold: f32, +) -> SignalGroup { + let mut group = SignalGroup::new("stagnation"); + + // We keep references into `normalized_messages`. Since `normalized_messages` + // is borrowed for the whole function, this avoids cloning. + let mut prev_human: Vec<(usize, &NormalizedMessage)> = Vec::new(); + let mut prev_gpt: Vec<(usize, &NormalizedMessage)> = Vec::new(); + + for (idx, role, norm_msg) in normalized_messages { + if *role != "human" && *role != "gpt" { + continue; + } + + // Skip human positive-prefix messages; they're naturally repetitive. + if *role == "human" && starts_with_prefix(&norm_msg.raw, POSITIVE_PREFIXES) { + prev_human.push((*idx, norm_msg)); + continue; + } + + if norm_msg.tokens.len() < 5 { + if *role == "human" { + prev_human.push((*idx, norm_msg)); + } else { + prev_gpt.push((*idx, norm_msg)); + } + continue; + } + + let prev = if *role == "human" { + &prev_human + } else { + &prev_gpt + }; + let start = prev.len().saturating_sub(lookback); + let mut matched = false; + for (prev_idx, prev_msg) in &prev[start..] { + if prev_msg.tokens.len() < 5 { + continue; + } + let similarity = norm_msg.ngram_similarity_with_message(prev_msg); + if similarity >= exact_threshold { + group.add_signal( + SignalInstance::new( + SignalType::StagnationRepetition, + *idx, + format!("Exact repetition with message {}", prev_idx), + ) + .with_confidence(similarity) + .with_metadata(json!({ + "repetition_type": "exact", + "compared_to": prev_idx, + "similarity": similarity, + "role": role, + })), + ); + matched = true; + break; + } else if similarity >= near_duplicate_threshold { + group.add_signal( + SignalInstance::new( + SignalType::StagnationRepetition, + *idx, + format!("Near-duplicate with message {}", prev_idx), + ) + .with_confidence(similarity) + .with_metadata(json!({ + "repetition_type": "near_duplicate", + "compared_to": prev_idx, + "similarity": similarity, + "role": role, + })), + ); + matched = true; + break; + } + } + let _ = matched; + + if *role == "human" { + prev_human.push((*idx, norm_msg)); + } else { + prev_gpt.push((*idx, norm_msg)); + } + } + + group +} + +/// Combined stagnation analyzer: dragging + repetition. +pub fn analyze_stagnation( + messages: &[ShareGptMsg<'_>], + normalized_messages: &[(usize, &str, NormalizedMessage)], + baseline_turns: usize, +) -> (SignalGroup, TurnMetrics) { + let (dragging_group, metrics) = analyze_dragging(messages, baseline_turns, 0.5); + let repetition_group = analyze_repetition(normalized_messages, 2, 0.95, 0.85); + + let mut combined = SignalGroup::new("stagnation"); + for s in dragging_group.signals.iter().cloned() { + combined.add_signal(s); + } + for s in repetition_group.signals.iter().cloned() { + combined.add_signal(s); + } + (combined, metrics) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn nm(s: &str) -> NormalizedMessage { + NormalizedMessage::from_text(s, 2000) + } + + #[test] + fn dragging_after_many_user_turns() { + let msgs: Vec<_> = (0..15) + .flat_map(|_| [ShareGptMsg { from: "human" }, ShareGptMsg { from: "gpt" }]) + .collect(); + let (g, m) = analyze_dragging(&msgs, 5, 0.5); + assert!(m.is_dragging); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::StagnationDragging))); + } + + #[test] + fn no_dragging_below_baseline() { + let msgs = vec![ + ShareGptMsg { from: "human" }, + ShareGptMsg { from: "gpt" }, + ShareGptMsg { from: "human" }, + ShareGptMsg { from: "gpt" }, + ]; + let (g, m) = analyze_dragging(&msgs, 5, 0.5); + assert!(!m.is_dragging); + assert!(g.signals.is_empty()); + } + + #[test] + fn detects_exact_repetition_in_user_messages() { + let n = vec![ + ( + 0usize, + "human", + nm("This widget is broken and needs repair right now"), + ), + (1, "gpt", nm("Sorry to hear that. Let me look into it.")), + ( + 2, + "human", + nm("This widget is broken and needs repair right now"), + ), + ]; + let g = analyze_repetition(&n, 2, 0.95, 0.85); + assert!(g + .signals + .iter() + .any(|s| matches!(s.signal_type, SignalType::StagnationRepetition))); + } +} diff --git a/crates/brightstaff/src/signals/mod.rs b/crates/brightstaff/src/signals/mod.rs index 83db943e..d96d3bf0 100644 --- a/crates/brightstaff/src/signals/mod.rs +++ b/crates/brightstaff/src/signals/mod.rs @@ -1,3 +1,26 @@ -mod analyzer; +//! Plano signals: behavioral quality indicators for agent interactions. +//! +//! This is a Rust port of the paper-aligned Python reference implementation at +//! `https://github.com/katanemo/signals` (or `/Users/shashmi/repos/signals`). +//! +//! Three layers of signals are detected from a conversation transcript: +//! +//! - **Interaction**: misalignment, stagnation, disengagement, satisfaction +//! - **Execution**: failure, loops +//! - **Environment**: exhaustion +//! +//! See `SignalType` for the full hierarchy. -pub use analyzer::*; +pub mod analyzer; +pub mod environment; +pub mod execution; +pub mod interaction; +pub mod otel; +pub mod schemas; +pub mod text_processing; + +pub use analyzer::{SignalAnalyzer, FLAG_MARKER}; +pub use schemas::{ + EnvironmentSignals, ExecutionSignals, InteractionQuality, InteractionSignals, SignalGroup, + SignalInstance, SignalLayer, SignalReport, SignalType, TurnMetrics, +}; diff --git a/crates/brightstaff/src/signals/otel.rs b/crates/brightstaff/src/signals/otel.rs new file mode 100644 index 00000000..deb3c1b5 --- /dev/null +++ b/crates/brightstaff/src/signals/otel.rs @@ -0,0 +1,241 @@ +//! Helpers for emitting `SignalReport` data to OpenTelemetry spans. +//! +//! Two sets of attributes are emitted: +//! +//! - **Legacy** keys under `signals.*` (e.g. `signals.frustration.count`), +//! computed from the new layered counts. Preserved for one release for +//! backward compatibility with existing dashboards. +//! - **New** layered keys (e.g. `signals.interaction.misalignment.count`), +//! one set of `count`/`severity` attributes per category, plus per-instance +//! span events named `signal.`. + +use opentelemetry::trace::SpanRef; +use opentelemetry::KeyValue; + +use crate::signals::schemas::{SignalGroup, SignalReport, SignalType}; + +/// Emit both legacy and layered OTel attributes/events for a `SignalReport`. +/// +/// Returns `true` if any "concerning" signal was found, mirroring the previous +/// behavior used to flag the span operation name. +pub fn emit_signals_to_span(span: &SpanRef<'_>, report: &SignalReport) -> bool { + emit_overall(span, report); + emit_layered_attributes(span, report); + emit_legacy_attributes(span, report); + emit_signal_events(span, report); + + is_concerning(report) +} + +fn emit_overall(span: &SpanRef<'_>, report: &SignalReport) { + span.set_attribute(KeyValue::new( + "signals.quality", + report.overall_quality.as_str().to_string(), + )); + span.set_attribute(KeyValue::new( + "signals.quality_score", + report.quality_score as f64, + )); + span.set_attribute(KeyValue::new( + "signals.turn_count", + report.turn_metrics.total_turns as i64, + )); + span.set_attribute(KeyValue::new( + "signals.efficiency_score", + report.turn_metrics.efficiency_score as f64, + )); +} + +fn emit_group(span: &SpanRef<'_>, prefix: &str, group: &SignalGroup) { + if group.count == 0 { + return; + } + span.set_attribute(KeyValue::new( + format!("{}.count", prefix), + group.count as i64, + )); + span.set_attribute(KeyValue::new( + format!("{}.severity", prefix), + group.severity as i64, + )); +} + +fn emit_layered_attributes(span: &SpanRef<'_>, report: &SignalReport) { + emit_group( + span, + "signals.interaction.misalignment", + &report.interaction.misalignment, + ); + emit_group( + span, + "signals.interaction.stagnation", + &report.interaction.stagnation, + ); + emit_group( + span, + "signals.interaction.disengagement", + &report.interaction.disengagement, + ); + emit_group( + span, + "signals.interaction.satisfaction", + &report.interaction.satisfaction, + ); + emit_group(span, "signals.execution.failure", &report.execution.failure); + emit_group(span, "signals.execution.loops", &report.execution.loops); + emit_group( + span, + "signals.environment.exhaustion", + &report.environment.exhaustion, + ); +} + +fn count_of(report: &SignalReport, t: SignalType) -> usize { + report.iter_signals().filter(|s| s.signal_type == t).count() +} + +/// Emit the legacy attribute keys consumed by existing dashboards. These are +/// derived from the new `SignalReport` so no detector contract is broken. +fn emit_legacy_attributes(span: &SpanRef<'_>, report: &SignalReport) { + use crate::tracing::signals as legacy; + + // signals.follow_up.repair.{count,ratio} - misalignment proxies repairs. + let repair_count = report.interaction.misalignment.count; + let user_turns = report.turn_metrics.user_turns.max(1) as f32; + if repair_count > 0 { + span.set_attribute(KeyValue::new(legacy::REPAIR_COUNT, repair_count as i64)); + let ratio = repair_count as f32 / user_turns; + span.set_attribute(KeyValue::new(legacy::REPAIR_RATIO, format!("{:.3}", ratio))); + } + + // signals.frustration.{count,severity} - disengagement.negative_stance is + // the closest legacy analog of "frustration". + let frustration_count = count_of(report, SignalType::DisengagementNegativeStance); + if frustration_count > 0 { + span.set_attribute(KeyValue::new( + legacy::FRUSTRATION_COUNT, + frustration_count as i64, + )); + let severity = match frustration_count { + 0 => 0, + 1..=2 => 1, + 3..=4 => 2, + _ => 3, + }; + span.set_attribute(KeyValue::new(legacy::FRUSTRATION_SEVERITY, severity as i64)); + } + + // signals.repetition.count - stagnation (repetition + dragging). + if report.interaction.stagnation.count > 0 { + span.set_attribute(KeyValue::new( + legacy::REPETITION_COUNT, + report.interaction.stagnation.count as i64, + )); + } + + // signals.escalation.requested - any escalation/quit signal. + let escalated = report.interaction.disengagement.signals.iter().any(|s| { + matches!( + s.signal_type, + SignalType::DisengagementEscalation | SignalType::DisengagementQuit + ) + }); + if escalated { + span.set_attribute(KeyValue::new(legacy::ESCALATION_REQUESTED, true)); + } + + // signals.positive_feedback.count - satisfaction signals. + if report.interaction.satisfaction.count > 0 { + span.set_attribute(KeyValue::new( + legacy::POSITIVE_FEEDBACK_COUNT, + report.interaction.satisfaction.count as i64, + )); + } +} + +fn emit_signal_events(span: &SpanRef<'_>, report: &SignalReport) { + for sig in report.iter_signals() { + let event_name = format!("signal.{}", sig.signal_type.as_str()); + let mut attrs: Vec = vec![ + KeyValue::new("signal.type", sig.signal_type.as_str().to_string()), + KeyValue::new("signal.message_index", sig.message_index as i64), + KeyValue::new("signal.confidence", sig.confidence as f64), + ]; + if !sig.snippet.is_empty() { + attrs.push(KeyValue::new("signal.snippet", sig.snippet.clone())); + } + if !sig.metadata.is_null() { + attrs.push(KeyValue::new("signal.metadata", sig.metadata.to_string())); + } + span.add_event(event_name, attrs); + } +} + +fn is_concerning(report: &SignalReport) -> bool { + use crate::signals::schemas::InteractionQuality; + if matches!( + report.overall_quality, + InteractionQuality::Poor | InteractionQuality::Severe + ) { + return true; + } + if report.interaction.disengagement.count > 0 { + return true; + } + if report.interaction.stagnation.count > 2 { + return true; + } + if report.execution.failure.count > 0 || report.execution.loops.count > 0 { + return true; + } + false +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::signals::schemas::{ + EnvironmentSignals, ExecutionSignals, InteractionQuality, InteractionSignals, SignalGroup, + SignalInstance, SignalReport, SignalType, TurnMetrics, + }; + + fn report_with_escalation() -> SignalReport { + let mut diseng = SignalGroup::new("disengagement"); + diseng.add_signal(SignalInstance::new( + SignalType::DisengagementEscalation, + 3, + "get me a human", + )); + SignalReport { + interaction: InteractionSignals { + disengagement: diseng, + ..InteractionSignals::default() + }, + execution: ExecutionSignals::default(), + environment: EnvironmentSignals::default(), + overall_quality: InteractionQuality::Severe, + quality_score: 0.0, + turn_metrics: TurnMetrics { + total_turns: 3, + user_turns: 2, + assistant_turns: 1, + is_dragging: false, + efficiency_score: 1.0, + }, + summary: String::new(), + } + } + + #[test] + fn is_concerning_flags_disengagement() { + let r = report_with_escalation(); + assert!(is_concerning(&r)); + } + + #[test] + fn count_of_returns_per_type_count() { + let r = report_with_escalation(); + assert_eq!(count_of(&r, SignalType::DisengagementEscalation), 1); + assert_eq!(count_of(&r, SignalType::DisengagementNegativeStance), 0); + } +} diff --git a/crates/brightstaff/src/signals/schemas.rs b/crates/brightstaff/src/signals/schemas.rs new file mode 100644 index 00000000..47ea0836 --- /dev/null +++ b/crates/brightstaff/src/signals/schemas.rs @@ -0,0 +1,431 @@ +//! Data shapes for the signal analyzer. +//! +//! Mirrors `signals/schemas.py` from the reference implementation. Where the +//! Python library exposes a `Dict[str, SignalGroup]` partitioned by category, +//! the Rust port uses strongly-typed sub-structs (`InteractionSignals`, +//! `ExecutionSignals`, `EnvironmentSignals`) for the same partitioning. + +use serde::{Deserialize, Serialize}; + +/// Hierarchical signal type. The 20 leaf variants mirror the paper taxonomy +/// and the Python reference's `SignalType` string enum. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SignalType { + // Interaction > Misalignment + MisalignmentCorrection, + MisalignmentRephrase, + MisalignmentClarification, + + // Interaction > Stagnation + StagnationDragging, + StagnationRepetition, + + // Interaction > Disengagement + DisengagementEscalation, + DisengagementQuit, + DisengagementNegativeStance, + + // Interaction > Satisfaction + SatisfactionGratitude, + SatisfactionConfirmation, + SatisfactionSuccess, + + // Execution > Failure + ExecutionFailureInvalidArgs, + ExecutionFailureBadQuery, + ExecutionFailureToolNotFound, + ExecutionFailureAuthMisuse, + ExecutionFailureStateError, + + // Execution > Loops + ExecutionLoopsRetry, + ExecutionLoopsParameterDrift, + ExecutionLoopsOscillation, + + // Environment > Exhaustion + EnvironmentExhaustionApiError, + EnvironmentExhaustionTimeout, + EnvironmentExhaustionRateLimit, + EnvironmentExhaustionNetwork, + EnvironmentExhaustionMalformed, + EnvironmentExhaustionContextOverflow, +} + +impl SignalType { + /// Dotted hierarchical string identifier, e.g. + /// `"interaction.misalignment.correction"`. Matches the Python reference's + /// `SignalType` enum *value* strings byte-for-byte. + pub fn as_str(&self) -> &'static str { + match self { + SignalType::MisalignmentCorrection => "interaction.misalignment.correction", + SignalType::MisalignmentRephrase => "interaction.misalignment.rephrase", + SignalType::MisalignmentClarification => "interaction.misalignment.clarification", + SignalType::StagnationDragging => "interaction.stagnation.dragging", + SignalType::StagnationRepetition => "interaction.stagnation.repetition", + SignalType::DisengagementEscalation => "interaction.disengagement.escalation", + SignalType::DisengagementQuit => "interaction.disengagement.quit", + SignalType::DisengagementNegativeStance => "interaction.disengagement.negative_stance", + SignalType::SatisfactionGratitude => "interaction.satisfaction.gratitude", + SignalType::SatisfactionConfirmation => "interaction.satisfaction.confirmation", + SignalType::SatisfactionSuccess => "interaction.satisfaction.success", + SignalType::ExecutionFailureInvalidArgs => "execution.failure.invalid_args", + SignalType::ExecutionFailureBadQuery => "execution.failure.bad_query", + SignalType::ExecutionFailureToolNotFound => "execution.failure.tool_not_found", + SignalType::ExecutionFailureAuthMisuse => "execution.failure.auth_misuse", + SignalType::ExecutionFailureStateError => "execution.failure.state_error", + SignalType::ExecutionLoopsRetry => "execution.loops.retry", + SignalType::ExecutionLoopsParameterDrift => "execution.loops.parameter_drift", + SignalType::ExecutionLoopsOscillation => "execution.loops.oscillation", + SignalType::EnvironmentExhaustionApiError => "environment.exhaustion.api_error", + SignalType::EnvironmentExhaustionTimeout => "environment.exhaustion.timeout", + SignalType::EnvironmentExhaustionRateLimit => "environment.exhaustion.rate_limit", + SignalType::EnvironmentExhaustionNetwork => "environment.exhaustion.network", + SignalType::EnvironmentExhaustionMalformed => { + "environment.exhaustion.malformed_response" + } + SignalType::EnvironmentExhaustionContextOverflow => { + "environment.exhaustion.context_overflow" + } + } + } + + pub fn layer(&self) -> SignalLayer { + match self { + SignalType::MisalignmentCorrection + | SignalType::MisalignmentRephrase + | SignalType::MisalignmentClarification + | SignalType::StagnationDragging + | SignalType::StagnationRepetition + | SignalType::DisengagementEscalation + | SignalType::DisengagementQuit + | SignalType::DisengagementNegativeStance + | SignalType::SatisfactionGratitude + | SignalType::SatisfactionConfirmation + | SignalType::SatisfactionSuccess => SignalLayer::Interaction, + SignalType::ExecutionFailureInvalidArgs + | SignalType::ExecutionFailureBadQuery + | SignalType::ExecutionFailureToolNotFound + | SignalType::ExecutionFailureAuthMisuse + | SignalType::ExecutionFailureStateError + | SignalType::ExecutionLoopsRetry + | SignalType::ExecutionLoopsParameterDrift + | SignalType::ExecutionLoopsOscillation => SignalLayer::Execution, + SignalType::EnvironmentExhaustionApiError + | SignalType::EnvironmentExhaustionTimeout + | SignalType::EnvironmentExhaustionRateLimit + | SignalType::EnvironmentExhaustionNetwork + | SignalType::EnvironmentExhaustionMalformed + | SignalType::EnvironmentExhaustionContextOverflow => SignalLayer::Environment, + } + } + + /// Category name within the layer (e.g. `"misalignment"`, `"failure"`). + pub fn category(&self) -> &'static str { + // Strip the layer prefix and take everything before the next dot. + let s = self.as_str(); + let after_layer = s.split_once('.').map(|(_, rest)| rest).unwrap_or(s); + after_layer + .split_once('.') + .map(|(c, _)| c) + .unwrap_or(after_layer) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SignalLayer { + Interaction, + Execution, + Environment, +} + +impl SignalLayer { + pub fn as_str(&self) -> &'static str { + match self { + SignalLayer::Interaction => "interaction", + SignalLayer::Execution => "execution", + SignalLayer::Environment => "environment", + } + } +} + +/// Overall quality assessment for an agent interaction session. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum InteractionQuality { + Excellent, + Good, + Neutral, + Poor, + Severe, +} + +impl InteractionQuality { + pub fn as_str(&self) -> &'static str { + match self { + InteractionQuality::Excellent => "excellent", + InteractionQuality::Good => "good", + InteractionQuality::Neutral => "neutral", + InteractionQuality::Poor => "poor", + InteractionQuality::Severe => "severe", + } + } +} + +/// A single detected signal instance. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SignalInstance { + pub signal_type: SignalType, + /// Absolute index into the original conversation `Vec`. + pub message_index: usize, + pub snippet: String, + pub confidence: f32, + /// Free-form metadata payload mirroring the Python `Dict[str, Any]`. + /// Stored as a JSON object so we can faithfully reproduce the reference's + /// flexible per-detector metadata. + #[serde(default)] + pub metadata: serde_json::Value, +} + +impl SignalInstance { + pub fn new(signal_type: SignalType, message_index: usize, snippet: impl Into) -> Self { + Self { + signal_type, + message_index, + snippet: snippet.into(), + confidence: 1.0, + metadata: serde_json::Value::Object(serde_json::Map::new()), + } + } + + pub fn with_confidence(mut self, c: f32) -> Self { + self.confidence = c; + self + } + + pub fn with_metadata(mut self, m: serde_json::Value) -> Self { + self.metadata = m; + self + } +} + +/// Aggregated signals for a specific category. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SignalGroup { + pub category: String, + pub count: usize, + pub signals: Vec, + /// Severity level (0-3: none, mild, moderate, severe). + pub severity: u8, +} + +impl SignalGroup { + pub fn new(category: impl Into) -> Self { + Self { + category: category.into(), + count: 0, + signals: Vec::new(), + severity: 0, + } + } + + pub fn add_signal(&mut self, signal: SignalInstance) { + self.signals.push(signal); + self.count = self.signals.len(); + self.update_severity(); + } + + fn update_severity(&mut self) { + self.severity = match self.count { + 0 => 0, + 1..=2 => 1, + 3..=4 => 2, + _ => 3, + }; + } +} + +/// Turn count and efficiency metrics, used by stagnation.dragging. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct TurnMetrics { + pub total_turns: usize, + pub user_turns: usize, + pub assistant_turns: usize, + pub is_dragging: bool, + pub efficiency_score: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InteractionSignals { + pub misalignment: SignalGroup, + pub stagnation: SignalGroup, + pub disengagement: SignalGroup, + pub satisfaction: SignalGroup, +} + +impl Default for InteractionSignals { + fn default() -> Self { + Self { + misalignment: SignalGroup::new("misalignment"), + stagnation: SignalGroup::new("stagnation"), + disengagement: SignalGroup::new("disengagement"), + satisfaction: SignalGroup::new("satisfaction"), + } + } +} + +impl InteractionSignals { + /// Ratio of misalignment instances to user turns. Used as a quality + /// scoring input and as a threshold for the "high misalignment rate" + /// summary callout. Mirrors `misalignment.count / max(user_turns, 1)` + /// from the Python reference's `_assess_quality` and `_generate_summary`. + pub fn misalignment_ratio(&self, user_turns: usize) -> f32 { + let denom = user_turns.max(1) as f32; + self.misalignment.count as f32 / denom + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExecutionSignals { + pub failure: SignalGroup, + pub loops: SignalGroup, +} + +impl Default for ExecutionSignals { + fn default() -> Self { + Self { + failure: SignalGroup::new("failure"), + loops: SignalGroup::new("loops"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnvironmentSignals { + pub exhaustion: SignalGroup, +} + +impl Default for EnvironmentSignals { + fn default() -> Self { + Self { + exhaustion: SignalGroup::new("exhaustion"), + } + } +} + +/// Complete signal analysis report for a conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SignalReport { + pub interaction: InteractionSignals, + pub execution: ExecutionSignals, + pub environment: EnvironmentSignals, + pub overall_quality: InteractionQuality, + pub quality_score: f32, + pub turn_metrics: TurnMetrics, + pub summary: String, +} + +impl Default for SignalReport { + fn default() -> Self { + Self { + interaction: InteractionSignals::default(), + execution: ExecutionSignals::default(), + environment: EnvironmentSignals::default(), + overall_quality: InteractionQuality::Neutral, + quality_score: 50.0, + turn_metrics: TurnMetrics::default(), + summary: String::new(), + } + } +} + +impl SignalReport { + /// Iterate over every `SignalInstance` across all layers and groups. + pub fn iter_signals(&self) -> impl Iterator { + self.interaction + .misalignment + .signals + .iter() + .chain(self.interaction.stagnation.signals.iter()) + .chain(self.interaction.disengagement.signals.iter()) + .chain(self.interaction.satisfaction.signals.iter()) + .chain(self.execution.failure.signals.iter()) + .chain(self.execution.loops.signals.iter()) + .chain(self.environment.exhaustion.signals.iter()) + } + + pub fn has_signal_type(&self, t: SignalType) -> bool { + self.iter_signals().any(|s| s.signal_type == t) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn signal_type_strings_match_paper_taxonomy() { + assert_eq!( + SignalType::MisalignmentCorrection.as_str(), + "interaction.misalignment.correction" + ); + assert_eq!( + SignalType::ExecutionFailureInvalidArgs.as_str(), + "execution.failure.invalid_args" + ); + assert_eq!( + SignalType::EnvironmentExhaustionMalformed.as_str(), + "environment.exhaustion.malformed_response" + ); + } + + #[test] + fn signal_type_layer_and_category() { + assert_eq!( + SignalType::MisalignmentRephrase.layer(), + SignalLayer::Interaction + ); + assert_eq!(SignalType::MisalignmentRephrase.category(), "misalignment"); + assert_eq!( + SignalType::ExecutionLoopsRetry.layer(), + SignalLayer::Execution + ); + assert_eq!(SignalType::ExecutionLoopsRetry.category(), "loops"); + assert_eq!( + SignalType::EnvironmentExhaustionTimeout.layer(), + SignalLayer::Environment + ); + assert_eq!( + SignalType::EnvironmentExhaustionTimeout.category(), + "exhaustion" + ); + } + + #[test] + fn signal_group_severity_buckets_match_python() { + let mut g = SignalGroup::new("misalignment"); + assert_eq!(g.severity, 0); + for n in 1..=2 { + g.add_signal(SignalInstance::new( + SignalType::MisalignmentCorrection, + n, + "x", + )); + } + assert_eq!(g.severity, 1); + for n in 3..=4 { + g.add_signal(SignalInstance::new( + SignalType::MisalignmentCorrection, + n, + "x", + )); + } + assert_eq!(g.severity, 2); + for n in 5..=6 { + g.add_signal(SignalInstance::new( + SignalType::MisalignmentCorrection, + n, + "x", + )); + } + assert_eq!(g.severity, 3); + } +} diff --git a/crates/brightstaff/src/signals/text_processing.rs b/crates/brightstaff/src/signals/text_processing.rs new file mode 100644 index 00000000..a1d463cc --- /dev/null +++ b/crates/brightstaff/src/signals/text_processing.rs @@ -0,0 +1,401 @@ +//! Text normalization and similarity primitives. +//! +//! Direct Rust port of `signals/text_processing.py` from the reference. The +//! shapes (`NormalizedMessage`, `NormalizedPattern`) and similarity formulas +//! match the Python implementation exactly so that pattern matching produces +//! the same results on the same inputs. + +use std::collections::{HashMap, HashSet}; + +/// Size of character n-grams used for fuzzy similarity (3 = trigrams). +pub const NGRAM_SIZE: usize = 3; + +const PUNCT_TRIM: &[char] = &[ + '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', + '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', +]; + +/// Pre-processed message with normalized text and tokens for efficient matching. +#[derive(Debug, Clone, Default)] +pub struct NormalizedMessage { + pub raw: String, + pub tokens: Vec, + pub token_set: HashSet, + pub bigram_set: HashSet, + pub char_ngram_set: HashSet, + pub token_frequency: HashMap, +} + +impl NormalizedMessage { + /// Create a normalized message from raw text. Mirrors + /// `NormalizedMessage.from_text` in the reference, including the + /// head-20%/tail-80% truncation strategy when text exceeds `max_length`. + pub fn from_text(text: &str, max_length: usize) -> Self { + let char_count = text.chars().count(); + + let raw: String = if char_count <= max_length { + text.to_string() + } else { + let head_len = max_length / 5; + // Reserve one char for the joining space. + let tail_len = max_length.saturating_sub(head_len + 1); + let head: String = text.chars().take(head_len).collect(); + let tail: String = text + .chars() + .skip(char_count.saturating_sub(tail_len)) + .collect(); + format!("{} {}", head, tail) + }; + + // Normalize unicode punctuation to ASCII equivalents. + let normalized_unicode = raw + .replace(['\u{2019}', '\u{2018}'], "'") + .replace(['\u{201c}', '\u{201d}'], "\"") + .replace(['\u{2013}', '\u{2014}'], "-"); + + // Lowercase + collapse whitespace (matches Python's `" ".join(s.split())`). + let normalized: String = normalized_unicode + .to_lowercase() + .split_whitespace() + .collect::>() + .join(" "); + + let mut tokens: Vec = Vec::new(); + for word in normalized.split_whitespace() { + let stripped: String = word.trim_matches(PUNCT_TRIM).to_string(); + if !stripped.is_empty() { + tokens.push(stripped); + } + } + + let token_set: HashSet = tokens.iter().cloned().collect(); + + let mut bigram_set: HashSet = HashSet::new(); + for i in 0..tokens.len().saturating_sub(1) { + bigram_set.insert(format!("{} {}", tokens[i], tokens[i + 1])); + } + + let tokens_text = tokens.join(" "); + let char_ngram_set = char_ngrams(&tokens_text, NGRAM_SIZE); + + let mut token_frequency: HashMap = HashMap::new(); + for t in &tokens { + *token_frequency.entry(t.clone()).or_insert(0) += 1; + } + + Self { + raw, + tokens, + token_set, + bigram_set, + char_ngram_set, + token_frequency, + } + } + + pub fn contains_token(&self, token: &str) -> bool { + self.token_set.contains(token) + } + + pub fn contains_phrase(&self, phrase: &str) -> bool { + let phrase_tokens: Vec<&str> = phrase.split_whitespace().collect(); + if phrase_tokens.is_empty() { + return false; + } + if phrase_tokens.len() == 1 { + return self.contains_token(phrase_tokens[0]); + } + if phrase_tokens.len() > self.tokens.len() { + return false; + } + let n = phrase_tokens.len(); + for i in 0..=self.tokens.len() - n { + if self.tokens[i..i + n] + .iter() + .zip(phrase_tokens.iter()) + .all(|(a, b)| a == b) + { + return true; + } + } + false + } + + /// Character n-gram (Jaccard) similarity vs another normalized message. + pub fn ngram_similarity_with_message(&self, other: &NormalizedMessage) -> f32 { + jaccard(&self.char_ngram_set, &other.char_ngram_set) + } + + /// Character n-gram (Jaccard) similarity vs a raw pattern string. + pub fn ngram_similarity_with_pattern(&self, pattern: &str) -> f32 { + let normalized = strip_non_word_chars(&pattern.to_lowercase()); + let pattern_ngrams = char_ngrams(&normalized, NGRAM_SIZE); + jaccard(&self.char_ngram_set, &pattern_ngrams) + } + + /// Fraction of pattern's ngrams contained in this message's ngram set. + pub fn char_ngram_containment(&self, pattern: &str) -> f32 { + let normalized = strip_non_word_chars(&pattern.to_lowercase()); + let pattern_ngrams = char_ngrams(&normalized, NGRAM_SIZE); + if pattern_ngrams.is_empty() { + return 0.0; + } + let contained = pattern_ngrams + .iter() + .filter(|ng| self.char_ngram_set.contains(*ng)) + .count(); + contained as f32 / pattern_ngrams.len() as f32 + } + + /// Token-frequency cosine similarity vs a raw pattern string. + pub fn token_cosine_similarity(&self, pattern: &str) -> f32 { + let mut pattern_freq: HashMap = HashMap::new(); + for word in pattern.to_lowercase().split_whitespace() { + let stripped = word.trim_matches(PUNCT_TRIM); + if !stripped.is_empty() { + *pattern_freq.entry(stripped.to_string()).or_insert(0) += 1; + } + } + cosine_freq(&self.token_frequency, &pattern_freq) + } + + /// Layered match against a pre-normalized pattern. Mirrors + /// `matches_normalized_pattern` from the reference: exact phrase -> + /// char-ngram Jaccard -> token cosine. + pub fn matches_normalized_pattern( + &self, + pattern: &NormalizedPattern, + char_ngram_threshold: f32, + token_cosine_threshold: f32, + ) -> bool { + // Layer 0: exact phrase match using pre-tokenized message. + let plen = pattern.tokens.len(); + let slen = self.tokens.len(); + if plen > 0 && plen <= slen { + for i in 0..=slen - plen { + if self.tokens[i..i + plen] == pattern.tokens[..] { + return true; + } + } + } + + // Layer 1: character n-gram Jaccard similarity. + if !self.char_ngram_set.is_empty() && !pattern.char_ngram_set.is_empty() { + let inter = self + .char_ngram_set + .intersection(&pattern.char_ngram_set) + .count(); + let union = self.char_ngram_set.union(&pattern.char_ngram_set).count(); + if union > 0 { + let sim = inter as f32 / union as f32; + if sim >= char_ngram_threshold { + return true; + } + } + } + + // Layer 2: token frequency cosine similarity. + if !self.token_frequency.is_empty() && !pattern.token_frequency.is_empty() { + let sim = cosine_freq(&self.token_frequency, &pattern.token_frequency); + if sim >= token_cosine_threshold { + return true; + } + } + + false + } +} + +/// Pre-processed pattern with normalized text and pre-computed n-grams/tokens. +#[derive(Debug, Clone, Default)] +pub struct NormalizedPattern { + pub raw: String, + pub tokens: Vec, + pub char_ngram_set: HashSet, + pub token_frequency: HashMap, +} + +impl NormalizedPattern { + pub fn from_text(pattern: &str) -> Self { + let normalized = pattern + .to_lowercase() + .replace(['\u{2019}', '\u{2018}'], "'") + .replace(['\u{201c}', '\u{201d}'], "\"") + .replace(['\u{2013}', '\u{2014}'], "-"); + let normalized: String = normalized.split_whitespace().collect::>().join(" "); + + // Tokenize the same way as NormalizedMessage (trim boundary punctuation, + // keep internal punctuation). + let mut tokens: Vec = Vec::new(); + for word in normalized.split_whitespace() { + let stripped = word.trim_matches(PUNCT_TRIM); + if !stripped.is_empty() { + tokens.push(stripped.to_string()); + } + } + + // For ngrams + cosine, strip ALL punctuation (matches Python's + // `re.sub(r"[^\w\s]", "", normalized)`). + let normalized_for_ngrams = strip_non_word_chars(&normalized); + let char_ngram_set = char_ngrams(&normalized_for_ngrams, NGRAM_SIZE); + + let tokens_no_punct: Vec<&str> = normalized_for_ngrams.split_whitespace().collect(); + let mut token_frequency: HashMap = HashMap::new(); + for t in &tokens_no_punct { + *token_frequency.entry((*t).to_string()).or_insert(0) += 1; + } + + Self { + raw: pattern.to_string(), + tokens, + char_ngram_set, + token_frequency, + } + } +} + +/// Convenience: normalize a list of raw pattern strings into `NormalizedPattern`s. +pub fn normalize_patterns(patterns: &[&str]) -> Vec { + patterns + .iter() + .map(|p| NormalizedPattern::from_text(p)) + .collect() +} + +// --------------------------------------------------------------------------- +// Similarity primitives +// --------------------------------------------------------------------------- + +fn char_ngrams(s: &str, n: usize) -> HashSet { + // Python iterates by character index, not byte; mirror that with .chars(). + let chars: Vec = s.chars().collect(); + let mut out: HashSet = HashSet::new(); + if chars.len() < n { + return out; + } + for i in 0..=chars.len() - n { + out.insert(chars[i..i + n].iter().collect()); + } + out +} + +fn jaccard(a: &HashSet, b: &HashSet) -> f32 { + if a.is_empty() && b.is_empty() { + return 1.0; + } + if a.is_empty() || b.is_empty() { + return 0.0; + } + let inter = a.intersection(b).count(); + let union = a.union(b).count(); + if union == 0 { + 0.0 + } else { + inter as f32 / union as f32 + } +} + +fn cosine_freq(a: &HashMap, b: &HashMap) -> f32 { + if a.is_empty() && b.is_empty() { + return 1.0; + } + if a.is_empty() || b.is_empty() { + return 0.0; + } + let mut dot: f64 = 0.0; + let mut n1_sq: f64 = 0.0; + let mut n2_sq: f64 = 0.0; + for (token, &freq2) in b { + let freq1 = *a.get(token).unwrap_or(&0); + dot += (freq1 * freq2) as f64; + n2_sq += (freq2 * freq2) as f64; + } + for &freq1 in a.values() { + n1_sq += (freq1 * freq1) as f64; + } + let n1 = n1_sq.sqrt(); + let n2 = n2_sq.sqrt(); + if n1 == 0.0 || n2 == 0.0 { + 0.0 + } else { + (dot / (n1 * n2)) as f32 + } +} + +/// Python equivalent: `re.sub(r"[^\w\s]", "", text)` followed by whitespace +/// collapse. Python's `\w` is `[A-Za-z0-9_]` plus unicode word characters; we +/// use Rust's `char::is_alphanumeric()` plus `_` for an equivalent definition. +fn strip_non_word_chars(text: &str) -> String { + let mut out = String::with_capacity(text.len()); + for c in text.chars() { + if c.is_alphanumeric() || c == '_' || c.is_whitespace() { + out.push(c); + } + } + out.split_whitespace().collect::>().join(" ") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normalize_lowercases_and_strips_punctuation() { + let m = NormalizedMessage::from_text("Hello, World!", 2000); + assert_eq!(m.tokens, vec!["hello".to_string(), "world".to_string()]); + } + + #[test] + fn normalizes_smart_quotes() { + let m = NormalizedMessage::from_text("don\u{2019}t", 2000); + assert!(m.tokens.contains(&"don't".to_string())); + } + + #[test] + fn truncates_long_text_with_head_tail() { + let long = "a".repeat(3000); + let m = NormalizedMessage::from_text(&long, 2000); + // raw should be ~ 2000 chars (head + space + tail) + assert!(m.raw.chars().count() <= 2001); + assert!(m.raw.starts_with("aa")); + assert!(m.raw.ends_with("aa")); + } + + #[test] + fn contains_phrase_matches_consecutive_tokens() { + let m = NormalizedMessage::from_text("I think this is great work", 2000); + assert!(m.contains_phrase("this is great")); + assert!(!m.contains_phrase("great this")); + } + + #[test] + fn matches_pattern_via_exact_phrase() { + let m = NormalizedMessage::from_text("No, I meant the second one", 2000); + let p = NormalizedPattern::from_text("no i meant"); + assert!(m.matches_normalized_pattern(&p, 0.65, 0.6)); + } + + #[test] + fn matches_pattern_via_char_ngram_fuzziness() { + // Typo in "meant" -> "ment" so layer 0 (exact phrase) cannot match, + // forcing the matcher to fall back to layer 1 (char n-gram Jaccard). + let m = NormalizedMessage::from_text("No I ment", 2000); + let p = NormalizedPattern::from_text("no i meant"); + assert!(m.matches_normalized_pattern(&p, 0.4, 0.6)); + } + + #[test] + fn jaccard_identical_sets_is_one() { + let a: HashSet = ["abc", "bcd"].iter().map(|s| s.to_string()).collect(); + assert!((jaccard(&a, &a) - 1.0).abs() < 1e-6); + } + + #[test] + fn cosine_freq_orthogonal_is_zero() { + let mut a: HashMap = HashMap::new(); + a.insert("hello".to_string(), 1); + let mut b: HashMap = HashMap::new(); + b.insert("world".to_string(), 1); + assert_eq!(cosine_freq(&a, &b), 0.0); + } +} diff --git a/crates/brightstaff/src/streaming.rs b/crates/brightstaff/src/streaming.rs index 40cbbe7c..26af8672 100644 --- a/crates/brightstaff/src/streaming.rs +++ b/crates/brightstaff/src/streaming.rs @@ -20,8 +20,11 @@ const STREAM_BUFFER_SIZE: usize = 16; /// Most chat responses are well under this; pathological ones are dropped without /// affecting pass-through streaming to the client. const USAGE_BUFFER_MAX: usize = 2 * 1024 * 1024; -use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER}; -use crate::tracing::{llm, set_service_name, signals as signal_constants}; +use crate::metrics as bs_metrics; +use crate::metrics::labels as metric_labels; +use crate::signals::otel::emit_signals_to_span; +use crate::signals::{SignalAnalyzer, FLAG_MARKER}; +use crate::tracing::{llm, set_service_name}; use hermesllm::apis::openai::Message; /// Parsed usage + resolved-model details from a provider response. @@ -172,6 +175,18 @@ impl StreamProcessor for Box { } } +/// Optional Prometheus-metric context for an LLM upstream call. When present, +/// [`ObservableStreamProcessor`] emits `brightstaff_llm_*` metrics at +/// first-byte / complete / error callbacks. +#[derive(Debug, Clone)] +pub struct LlmMetricsCtx { + pub provider: String, + pub model: String, + /// HTTP status of the upstream response. Used to pick `status_class` and + /// `error_class` on `on_complete`. + pub upstream_status: u16, +} + /// A processor that tracks streaming metrics pub struct ObservableStreamProcessor { service_name: String, @@ -185,6 +200,8 @@ pub struct ObservableStreamProcessor { /// on `on_complete`. Capped at `USAGE_BUFFER_MAX`; excess chunks are dropped /// from the buffer (they still pass through to the client). response_buffer: Vec, + llm_metrics: Option, + metrics_recorded: bool, } impl ObservableStreamProcessor { @@ -219,8 +236,17 @@ impl ObservableStreamProcessor { time_to_first_token: None, messages, response_buffer: Vec::new(), + llm_metrics: None, + metrics_recorded: false, } } + + /// Attach LLM upstream metric context so the processor emits + /// `brightstaff_llm_*` metrics on first-byte / complete / error. + pub fn with_llm_metrics(mut self, ctx: LlmMetricsCtx) -> Self { + self.llm_metrics = Some(ctx); + self + } } impl StreamProcessor for ObservableStreamProcessor { @@ -240,7 +266,11 @@ impl StreamProcessor for ObservableStreamProcessor { fn on_first_bytes(&mut self) { // Record time to first token (only for streaming) if self.time_to_first_token.is_none() { - self.time_to_first_token = Some(self.start_time.elapsed().as_millis()); + let elapsed = self.start_time.elapsed(); + self.time_to_first_token = Some(elapsed.as_millis()); + if let Some(ref ctx) = self.llm_metrics { + bs_metrics::record_llm_ttft(&ctx.provider, &ctx.model, elapsed); + } } } @@ -299,81 +329,56 @@ impl StreamProcessor for ObservableStreamProcessor { otel_span.set_attribute(KeyValue::new(llm::MODEL_NAME, resolved)); } } + + // Emit LLM upstream prometheus metrics (duration + tokens) if wired. + // The upstream responded (we have a status), so status_class alone + // carries the non-2xx signal — error_class stays "none". + if let Some(ref ctx) = self.llm_metrics { + bs_metrics::record_llm_upstream( + &ctx.provider, + &ctx.model, + ctx.upstream_status, + metric_labels::LLM_ERR_NONE, + self.start_time.elapsed(), + ); + if let Some(v) = usage.prompt_tokens { + bs_metrics::record_llm_tokens( + &ctx.provider, + &ctx.model, + metric_labels::TOKEN_KIND_PROMPT, + v.max(0) as u64, + ); + } + if let Some(v) = usage.completion_tokens { + bs_metrics::record_llm_tokens( + &ctx.provider, + &ctx.model, + metric_labels::TOKEN_KIND_COMPLETION, + v.max(0) as u64, + ); + } + if usage.prompt_tokens.is_none() && usage.completion_tokens.is_none() { + bs_metrics::record_llm_tokens_usage_missing(&ctx.provider, &ctx.model); + } + self.metrics_recorded = true; + } // Release the buffered bytes early; nothing downstream needs them. self.response_buffer.clear(); self.response_buffer.shrink_to_fit(); - // Analyze signals if messages are available and record as span attributes + // Analyze signals if messages are available and record as span + // attributes + per-signal events. We dual-emit legacy aggregate keys + // and the new layered taxonomy so existing dashboards keep working + // while new consumers can opt into the richer hierarchy. if let Some(ref messages) = self.messages { - let analyzer: Box = Box::new(TextBasedSignalAnalyzer::new()); - let report = analyzer.analyze(messages); + let analyzer = SignalAnalyzer::default(); + let report = analyzer.analyze_openai(messages); - // Get the current OTel span to set signal attributes let span = tracing::Span::current(); let otel_context = span.context(); let otel_span = otel_context.span(); - // Add overall quality - otel_span.set_attribute(KeyValue::new( - signal_constants::QUALITY, - format!("{:?}", report.overall_quality), - )); - - // Add repair/follow-up metrics if concerning - if report.follow_up.is_concerning || report.follow_up.repair_count > 0 { - otel_span.set_attribute(KeyValue::new( - signal_constants::REPAIR_COUNT, - report.follow_up.repair_count as i64, - )); - otel_span.set_attribute(KeyValue::new( - signal_constants::REPAIR_RATIO, - format!("{:.3}", report.follow_up.repair_ratio), - )); - } - - // Add frustration metrics - if report.frustration.has_frustration { - otel_span.set_attribute(KeyValue::new( - signal_constants::FRUSTRATION_COUNT, - report.frustration.frustration_count as i64, - )); - otel_span.set_attribute(KeyValue::new( - signal_constants::FRUSTRATION_SEVERITY, - report.frustration.severity as i64, - )); - } - - // Add repetition metrics - if report.repetition.has_looping { - otel_span.set_attribute(KeyValue::new( - signal_constants::REPETITION_COUNT, - report.repetition.repetition_count as i64, - )); - } - - // Add escalation metrics - if report.escalation.escalation_requested { - otel_span - .set_attribute(KeyValue::new(signal_constants::ESCALATION_REQUESTED, true)); - } - - // Add positive feedback metrics - if report.positive_feedback.has_positive_feedback { - otel_span.set_attribute(KeyValue::new( - signal_constants::POSITIVE_FEEDBACK_COUNT, - report.positive_feedback.positive_count as i64, - )); - } - - // Flag the span name if any concerning signal is detected - let should_flag = report.frustration.has_frustration - || report.repetition.has_looping - || report.escalation.escalation_requested - || matches!( - report.overall_quality, - InteractionQuality::Poor | InteractionQuality::Severe - ); - + let should_flag = emit_signals_to_span(&otel_span, &report); if should_flag { otel_span.update_name(format!("{} {}", self.operation_name, FLAG_MARKER)); } @@ -396,6 +401,18 @@ impl StreamProcessor for ObservableStreamProcessor { duration_ms = self.start_time.elapsed().as_millis(), "stream error" ); + if let Some(ref ctx) = self.llm_metrics { + if !self.metrics_recorded { + bs_metrics::record_llm_upstream( + &ctx.provider, + &ctx.model, + ctx.upstream_status, + metric_labels::LLM_ERR_STREAM, + self.start_time.elapsed(), + ); + self.metrics_recorded = true; + } + } } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 028c8046..86aa331d 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -234,6 +234,7 @@ pub struct Overrides { pub llm_routing_model: Option, pub agent_orchestration_model: Option, pub orchestrator_model_context_length: Option, + pub disable_signals: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -391,6 +392,8 @@ pub enum LlmProviderType { AmazonBedrock, #[serde(rename = "plano")] Plano, + #[serde(rename = "chatgpt")] + ChatGPT, #[serde(rename = "digitalocean")] DigitalOcean, } @@ -414,6 +417,7 @@ impl Display for LlmProviderType { LlmProviderType::Qwen => write!(f, "qwen"), LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"), LlmProviderType::Plano => write!(f, "plano"), + LlmProviderType::ChatGPT => write!(f, "chatgpt"), LlmProviderType::DigitalOcean => write!(f, "digitalocean"), } } @@ -481,6 +485,7 @@ pub struct LlmProvider { pub base_url_path_prefix: Option, pub internal: Option, pub passthrough_auth: Option, + pub headers: Option>, } pub trait IntoModels { @@ -524,6 +529,7 @@ impl Default for LlmProvider { base_url_path_prefix: None, internal: None, passthrough_auth: None, + headers: None, } } } @@ -750,4 +756,29 @@ mod test { assert!(model_ids.contains(&"openai-gpt4".to_string())); assert!(!model_ids.contains(&"plano-orchestrator".to_string())); } + + #[test] + fn test_overrides_disable_signals_default_none() { + let overrides = super::Overrides::default(); + assert_eq!(overrides.disable_signals, None); + } + + #[test] + fn test_overrides_disable_signals_deserialize() { + let yaml = r#" +disable_signals: true +"#; + let overrides: super::Overrides = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(overrides.disable_signals, Some(true)); + + let yaml_false = r#" +disable_signals: false +"#; + let overrides: super::Overrides = serde_yaml::from_str(yaml_false).unwrap(); + assert_eq!(overrides.disable_signals, Some(false)); + + let yaml_missing = "{}"; + let overrides: super::Overrides = serde_yaml::from_str(yaml_missing).unwrap(); + assert_eq!(overrides.disable_signals, None); + } } diff --git a/crates/common/src/llm_providers.rs b/crates/common/src/llm_providers.rs index b5c03b30..b4355a2f 100644 --- a/crates/common/src/llm_providers.rs +++ b/crates/common/src/llm_providers.rs @@ -277,6 +277,7 @@ mod tests { internal: None, stream: None, passthrough_auth: None, + headers: None, } } diff --git a/crates/hermesllm/src/bin/provider_models.yaml b/crates/hermesllm/src/bin/provider_models.yaml index d07e265d..2e9e0a9b 100644 --- a/crates/hermesllm/src/bin/provider_models.yaml +++ b/crates/hermesllm/src/bin/provider_models.yaml @@ -329,6 +329,10 @@ providers: - xiaomi/mimo-v2-flash - xiaomi/mimo-v2-omni - xiaomi/mimo-v2-pro + chatgpt: + - chatgpt/gpt-5.4 + - chatgpt/gpt-5.3-codex + - chatgpt/gpt-5.2 digitalocean: - digitalocean/openai-gpt-4.1 - digitalocean/openai-gpt-4o @@ -376,6 +380,6 @@ providers: - digitalocean/qwen3-embedding-0.6b - digitalocean/router:software-engineering metadata: - total_providers: 12 - total_models: 361 - last_updated: 2026-04-16T00:00:00.000000+00:00 + total_providers: 13 + total_models: 364 + last_updated: 2026-04-20T00:00:00.000000+00:00 diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 67a60def..eeef8856 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -194,9 +194,10 @@ impl SupportedAPIsFromClient { // For Responses API, check if provider supports it, otherwise translate to chat/completions match provider_id { // Providers that support /v1/responses natively - ProviderId::OpenAI | ProviderId::XAI | ProviderId::Vercel => { - route_by_provider("/responses") - } + ProviderId::OpenAI + | ProviderId::XAI + | ProviderId::ChatGPT + | ProviderId::Vercel => route_by_provider("/responses"), // All other providers: translate to /chat/completions _ => route_by_provider("/chat/completions"), } @@ -722,4 +723,36 @@ mod tests { "/v1/responses" ); } + + #[test] + fn test_responses_api_targets_chatgpt_native_responses_endpoint() { + let api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::ChatGPT, + "/v1/responses", + "gpt-5.4", + false, + None, + false + ), + "/v1/responses" + ); + } + + #[test] + fn test_responses_api_targets_vercel_native_responses_endpoint() { + let api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + assert_eq!( + api.target_endpoint_for_provider( + &ProviderId::Vercel, + "/v1/responses", + "gpt-5.4", + false, + None, + false + ), + "/v1/responses" + ); + } } diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 1b90ae53..4fa7d19d 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -44,6 +44,7 @@ pub enum ProviderId { Zhipu, Qwen, AmazonBedrock, + ChatGPT, DigitalOcean, Vercel, OpenRouter, @@ -74,6 +75,7 @@ impl TryFrom<&str> for ProviderId { "qwen" => Ok(ProviderId::Qwen), "amazon_bedrock" => Ok(ProviderId::AmazonBedrock), "amazon" => Ok(ProviderId::AmazonBedrock), // alias + "chatgpt" => Ok(ProviderId::ChatGPT), "digitalocean" => Ok(ProviderId::DigitalOcean), "do" => Ok(ProviderId::DigitalOcean), // alias "do_ai" => Ok(ProviderId::DigitalOcean), // alias @@ -103,6 +105,7 @@ impl ProviderId { ProviderId::Moonshotai => "moonshotai", ProviderId::Zhipu => "z-ai", ProviderId::Qwen => "qwen", + ProviderId::ChatGPT => "chatgpt", ProviderId::DigitalOcean => "digitalocean", _ => return Vec::new(), }; @@ -170,7 +173,8 @@ impl ProviderId { | ProviderId::Zhipu | ProviderId::Qwen | ProviderId::DigitalOcean - | ProviderId::OpenRouter, + | ProviderId::OpenRouter + | ProviderId::ChatGPT, SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), @@ -191,13 +195,14 @@ impl ProviderId { | ProviderId::Zhipu | ProviderId::Qwen | ProviderId::DigitalOcean - | ProviderId::OpenRouter, + | ProviderId::OpenRouter + | ProviderId::ChatGPT, SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - // OpenAI Responses API - OpenAI and xAI support this natively + // OpenAI Responses API - OpenAI, xAI, and ChatGPT support this natively ( - ProviderId::OpenAI | ProviderId::XAI, + ProviderId::OpenAI | ProviderId::XAI | ProviderId::ChatGPT, SupportedAPIsFromClient::OpenAIResponsesAPI(_), ) => SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses), @@ -258,6 +263,7 @@ impl Display for ProviderId { ProviderId::Zhipu => write!(f, "zhipu"), ProviderId::Qwen => write!(f, "qwen"), ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"), + ProviderId::ChatGPT => write!(f, "chatgpt"), ProviderId::DigitalOcean => write!(f, "digitalocean"), ProviderId::Vercel => write!(f, "vercel"), ProviderId::OpenRouter => write!(f, "openrouter"), @@ -447,4 +453,16 @@ mod tests { SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses) )); } + + #[test] + fn test_chatgpt_uses_responses_api_for_responses_clients() { + use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; + + let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + let upstream = ProviderId::ChatGPT.compatible_api_for_client(&client_api, false); + assert!(matches!( + upstream, + SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses) + )); + } } diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 92688133..aa100a17 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -77,7 +77,7 @@ impl ProviderRequestType { &mut self, provider_id: ProviderId, upstream_api: &SupportedUpstreamAPIs, - ) { + ) -> Result<(), ProviderRequestError> { if provider_id == ProviderId::XAI && matches!( upstream_api, @@ -89,6 +89,48 @@ impl ProviderRequestType { req.web_search_options = None; } } + + // ChatGPT requires instructions, store=false, and input as a list + if provider_id == ProviderId::ChatGPT { + if let Self::ResponsesAPIRequest(req) = self { + use crate::apis::openai_responses::{ + InputItem, InputMessage, InputParam, MessageContent, MessageRole, + }; + + const CHATGPT_BASE_INSTRUCTIONS: &str = + "You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer."; + match &req.instructions { + Some(existing) if existing.contains(CHATGPT_BASE_INSTRUCTIONS) => {} + Some(existing) => { + req.instructions = + Some(format!("{}\n\n{}", CHATGPT_BASE_INSTRUCTIONS, existing)); + } + None => { + req.instructions = Some(CHATGPT_BASE_INSTRUCTIONS.to_string()); + } + } + req.store = Some(false); + if req.stream == Some(false) { + return Err(ProviderRequestError { + message: "Non-streaming requests are not supported for the ChatGPT Codex provider. Set stream=true or omit the stream field.".to_string(), + source: None, + }); + } + req.stream = Some(true); + + // ChatGPT backend requires input to be a list, not a plain string + if let InputParam::Text(text) = &req.input { + req.input = InputParam::Items(vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Text(text.clone()), + })]); + } + if let InputParam::SingleItem(item) = &req.input { + req.input = InputParam::Items(vec![item.clone()]); + } + } + } + Ok(()) } } @@ -824,10 +866,12 @@ mod tests { ..Default::default() }); - request.normalize_for_upstream( - ProviderId::XAI, - &SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - ); + request + .normalize_for_upstream( + ProviderId::XAI, + &SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ) + .unwrap(); let ProviderRequestType::ChatCompletionsRequest(req) = request else { panic!("expected chat request"); @@ -852,10 +896,12 @@ mod tests { ..Default::default() }); - request.normalize_for_upstream( - ProviderId::OpenAI, - &SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - ); + request + .normalize_for_upstream( + ProviderId::OpenAI, + &SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ) + .unwrap(); let ProviderRequestType::ChatCompletionsRequest(req) = request else { panic!("expected chat request"); diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs index 66ccc735..8d06dfcf 100644 --- a/crates/hermesllm/src/providers/streaming_response.rs +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -346,12 +346,10 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S ( SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) => { + ) if transformed_event.is_event_only() && transformed_event.event.is_some() => { // OpenAI clients don't expect separate event: lines // Suppress upstream Anthropic event-only lines - if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transformed_lines = "\n".to_string(); - } + transformed_event.sse_transformed_lines = "\n".to_string(); } _ => { // Other cross-API combinations can be handled here as needed @@ -371,12 +369,10 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S | ( SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_), - ) => { - if transformed_event.is_event_only() && transformed_event.event.is_some() { - // Mark as should-skip by clearing sse_transformed_lines - // The event line is already included when the data line is transformed - transformed_event.sse_transformed_lines = String::new(); - } + ) if transformed_event.is_event_only() && transformed_event.event.is_some() => { + // Mark as should-skip by clearing sse_transformed_lines + // The event line is already included when the data line is transformed + transformed_event.sse_transformed_lines = String::new(); } _ => { // Other passthrough combinations (OpenAI ChatCompletions, etc.) don't have this issue diff --git a/crates/hermesllm/src/transforms/lib.rs b/crates/hermesllm/src/transforms/lib.rs index 115f061c..5308cc47 100644 --- a/crates/hermesllm/src/transforms/lib.rs +++ b/crates/hermesllm/src/transforms/lib.rs @@ -188,14 +188,13 @@ pub fn convert_openai_message_to_anthropic_content( // Handle regular content match &message.content { - Some(MessageContent::Text(text)) => { - if !text.is_empty() { - blocks.push(MessagesContentBlock::Text { - text: text.clone(), - cache_control: None, - }); - } + Some(MessageContent::Text(text)) if !text.is_empty() => { + blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); } + Some(MessageContent::Text(_)) => {} Some(MessageContent::Parts(parts)) => { for part in parts { match part { diff --git a/crates/hermesllm/src/transforms/request/from_anthropic.rs b/crates/hermesllm/src/transforms/request/from_anthropic.rs index 82dbe547..dba17dde 100644 --- a/crates/hermesllm/src/transforms/request/from_anthropic.rs +++ b/crates/hermesllm/src/transforms/request/from_anthropic.rs @@ -354,10 +354,10 @@ impl TryFrom for BedrockMessage { MessagesMessageContent::Blocks(blocks) => { for block in blocks { match block { - crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { - if !text.is_empty() { - content_blocks.push(ContentBlock::Text { text }); - } + crate::apis::anthropic::MessagesContentBlock::Text { text, .. } + if !text.is_empty() => + { + content_blocks.push(ContentBlock::Text { text }); } crate::apis::anthropic::MessagesContentBlock::ToolUse { id, diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs index 70e69cb8..b673af38 100644 --- a/crates/hermesllm/src/transforms/request/from_openai.rs +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -317,11 +317,10 @@ impl TryFrom for BedrockMessage { Role::User => { // Convert user message content to content blocks match message.content { - Some(MessageContent::Text(text)) => { - if !text.is_empty() { - content_blocks.push(ContentBlock::Text { text }); - } + Some(MessageContent::Text(text)) if !text.is_empty() => { + content_blocks.push(ContentBlock::Text { text }); } + Some(MessageContent::Text(_)) => {} Some(MessageContent::Parts(parts)) => { // Convert OpenAI content parts to Bedrock ContentBlocks for part in parts { diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index e7763ee0..fa9964dd 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -241,6 +241,14 @@ impl StreamContext { } } + // Apply any extra headers configured on the provider (e.g., ChatGPT-Account-Id, originator) + let headers = self.llm_provider().headers.clone(); + if let Some(headers) = headers { + for (key, value) in &headers { + self.set_http_request_header(key, Some(value)); + } + } + Ok(()) } @@ -1060,7 +1068,20 @@ impl HttpContext for StreamContext { match ProviderRequestType::try_from((deserialized_client_request, upstream)) { Ok(mut request) => { - request.normalize_for_upstream(self.get_provider_id(), upstream); + if let Err(e) = + request.normalize_for_upstream(self.get_provider_id(), upstream) + { + warn!( + "request_id={}: normalize_for_upstream failed: {}", + self.request_identifier(), + e + ); + self.send_server_error( + ServerError::LogicError(e.message), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } debug!( "request_id={}: upstream request payload: {}", self.request_identifier(), diff --git a/demos/llm_routing/chatgpt_subscription/README.md b/demos/llm_routing/chatgpt_subscription/README.md new file mode 100644 index 00000000..d091155a --- /dev/null +++ b/demos/llm_routing/chatgpt_subscription/README.md @@ -0,0 +1,61 @@ +# ChatGPT Subscription Routing + +Route requests through your ChatGPT Plus/Pro subscription using Plano. Uses the OpenAI Responses API under the hood, targeting `chatgpt.com/backend-api/codex/responses`. + +## Setup + +### 1. Authenticate with ChatGPT + +```bash +planoai chatgpt login +``` + +This opens a device code flow — visit the URL shown and enter the code. Tokens are saved to `~/.plano/chatgpt/auth.json`. + +### 2. Start Plano + +```bash +planoai up config.yaml +``` + +### 3. Send a request + +```bash +curl http://localhost:12000/v1/responses \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5.2", + "input": "Hello, what model are you?" + }' +``` + +Or use the test script: + +```bash +bash test_chatgpt.sh +``` + +## How it works + +- `chatgpt/gpt-5.2` in the config tells Plano to use the ChatGPT subscription provider +- Plano reads OAuth tokens from `~/.plano/chatgpt/auth.json` (auto-refreshes if expired) +- Requests are proxied to `https://chatgpt.com/backend-api/codex/responses` with the required headers: + - `Authorization: Bearer ` + - `ChatGPT-Account-Id: ` + - `originator: codex_cli_rs` + - `session_id: ` + +## Available models + +``` +chatgpt/gpt-5.4 +chatgpt/gpt-5.3-codex +chatgpt/gpt-5.2 +``` + +## Managing credentials + +```bash +planoai chatgpt status # Check auth status +planoai chatgpt logout # Remove stored credentials +``` diff --git a/demos/llm_routing/chatgpt_subscription/chat.py b/demos/llm_routing/chatgpt_subscription/chat.py new file mode 100644 index 00000000..3c6b8ae3 --- /dev/null +++ b/demos/llm_routing/chatgpt_subscription/chat.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""Interactive chat with a model through Plano using the OpenAI SDK.""" + +import sys +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:12000/v1", api_key="unused") + + +def run_chat(model): + print(f"Chatting with {model} via Plano (Ctrl+C to quit)\n") + history = [] + while True: + try: + user_input = input("you> ") + except (KeyboardInterrupt, EOFError): + print("\nbye") + break + if not user_input.strip(): + continue + + history.append({"role": "user", "content": user_input}) + + stream = client.responses.create(model=model, input=history, stream=True) + print(f"{model}> ", end="", flush=True) + full = "" + for event in stream: + if event.type == "response.output_text.delta": + print(event.delta, end="", flush=True) + full += event.delta + print() + + history.append({"role": "assistant", "content": full}) + + +if __name__ == "__main__": + model = sys.argv[1] if len(sys.argv) > 1 else "gpt-5.2" + run_chat(model) diff --git a/demos/llm_routing/chatgpt_subscription/config.yaml b/demos/llm_routing/chatgpt_subscription/config.yaml new file mode 100644 index 00000000..a7137b3d --- /dev/null +++ b/demos/llm_routing/chatgpt_subscription/config.yaml @@ -0,0 +1,9 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: chatgpt/* diff --git a/demos/llm_routing/chatgpt_subscription/test_chatgpt.sh b/demos/llm_routing/chatgpt_subscription/test_chatgpt.sh new file mode 100755 index 00000000..5544049d --- /dev/null +++ b/demos/llm_routing/chatgpt_subscription/test_chatgpt.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Test ChatGPT subscription routing through Plano +# Prerequisites: planoai chatgpt login && planoai up config.yaml + +set -e + +echo "Testing ChatGPT subscription via Plano Responses API..." +echo "" + +curl -s http://localhost:12000/v1/responses \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5.2", + "input": "What is 2 + 2? Reply in one word." + }' | python3 -m json.tool + +echo "" +echo "Done." diff --git a/docs/source/guides/observability/monitoring.rst b/docs/source/guides/observability/monitoring.rst index 736e0a64..d28d25ca 100644 --- a/docs/source/guides/observability/monitoring.rst +++ b/docs/source/guides/observability/monitoring.rst @@ -75,3 +75,54 @@ are some sample configuration files for both, respectively. isDefault: true access: proxy editable: true + +Brightstaff metrics +~~~~~~~~~~~~~~~~~~~ + +In addition to Envoy's stats on ``:9901``, the brightstaff dataplane +process exposes its own Prometheus endpoint on ``0.0.0.0:9092`` (override +with ``METRICS_BIND_ADDRESS``). It publishes: + +* HTTP RED — ``brightstaff_http_requests_total``, + ``brightstaff_http_request_duration_seconds``, + ``brightstaff_http_in_flight_requests`` (labels: ``handler``, ``method``, + ``status_class``). +* LLM upstream — ``brightstaff_llm_upstream_requests_total``, + ``brightstaff_llm_upstream_duration_seconds``, + ``brightstaff_llm_time_to_first_token_seconds``, + ``brightstaff_llm_tokens_total`` (labels: ``provider``, ``model``, + ``error_class``, ``kind``). +* Routing — ``brightstaff_router_decisions_total``, + ``brightstaff_router_decision_duration_seconds``, + ``brightstaff_routing_service_requests_total``, + ``brightstaff_session_cache_events_total``. +* Process & build — ``process_resident_memory_bytes``, + ``process_cpu_seconds_total``, ``brightstaff_build_info``. + +A self-contained Prometheus + Grafana stack is shipped under +``config/grafana/``. With Plano already running on the host, bring it up +with one command: + +.. code-block:: bash + + cd config/grafana + docker compose up -d + open http://localhost:3000 # admin / admin (anonymous viewer also enabled) + +Grafana auto-loads the Prometheus datasource and the brightstaff +dashboard (look under the *Plano* folder). Prometheus scrapes the host's +``:9092`` and ``:9901`` via ``host.docker.internal``. + +Files: + +* ``config/grafana/docker-compose.yaml`` — one-command Prom + Grafana + stack with provisioning. +* ``config/grafana/prometheus_scrape.yaml`` — complete Prometheus config + with ``envoy`` and ``brightstaff`` scrape jobs (mounted by the + compose). +* ``config/grafana/brightstaff_dashboard.json`` — 19-panel dashboard + across HTTP RED, LLM upstream, Routing service, and Process & Envoy + link rows. Auto-provisioned by the compose; can also be imported by + hand via *Dashboards → New → Import*. +* ``config/grafana/provisioning/`` — Grafana provisioning files for the + datasource and dashboard provider. diff --git a/docs/source/resources/includes/plano_config_full_reference.yaml b/docs/source/resources/includes/plano_config_full_reference.yaml index 1d544727..808d0a98 100644 --- a/docs/source/resources/includes/plano_config_full_reference.yaml +++ b/docs/source/resources/includes/plano_config_full_reference.yaml @@ -173,6 +173,9 @@ overrides: llm_routing_model: Plano-Orchestrator # Model used for agent orchestration (must be listed in model_providers) agent_orchestration_model: Plano-Orchestrator + # Disable agentic signal analysis (frustration, repetition, escalation, etc.) + # on LLM responses to save CPU. Default: false. + disable_signals: false # Model affinity — pin routing decisions for agentic loops routing: diff --git a/docs/source/resources/includes/plano_config_full_reference_rendered.yaml b/docs/source/resources/includes/plano_config_full_reference_rendered.yaml index 4992ce3b..a0603221 100644 --- a/docs/source/resources/includes/plano_config_full_reference_rendered.yaml +++ b/docs/source/resources/includes/plano_config_full_reference_rendered.yaml @@ -170,6 +170,7 @@ model_providers: provider_interface: plano overrides: agent_orchestration_model: Plano-Orchestrator + disable_signals: false llm_routing_model: Plano-Orchestrator optimize_context_window: true prompt_target_intent_matching_threshold: 0.7 diff --git a/tests/parity/signals/.gitignore b/tests/parity/signals/.gitignore new file mode 100644 index 00000000..3a7e0d4f --- /dev/null +++ b/tests/parity/signals/.gitignore @@ -0,0 +1,4 @@ +out/ +.venv/ +__pycache__/ +*.pyc diff --git a/tests/parity/signals/README.md b/tests/parity/signals/README.md new file mode 100644 index 00000000..67193d60 --- /dev/null +++ b/tests/parity/signals/README.md @@ -0,0 +1,98 @@ +# Signals Parity Harness + +Validates that `crates/brightstaff/src/signals/` (Rust port) produces the same +`SignalReport` as the Python reference at +on a fixed sample of `lmsys/lmsys-chat-1m` conversations. + +This harness is **not** part of normal CI. It downloads several GB and is run +on demand to gate releases of the signals subsystem (or to investigate +regressions reported in production). + +## What gets compared + +For each conversation, both analyzers emit a `SignalReport`. The comparator +classifies any divergence into three tiers: + +| Tier | Field | Action on divergence | +|------|------------------------------------------------|----------------------| +| A | set of `SignalType` present, per-type counts, `overall_quality` | Fail the run | +| B | per-instance `message_index`, instance counts per type | Log + collect, do not fail | +| C | metadata, snippet text, summary | Information only | + +Quality buckets are compared by string (`excellent` / `good` / ...). + +## What this harness does *not* cover + +`lmsys-chat-1m` is plain user/assistant chat. It exercises the **interaction** +layer well (misalignment, stagnation, disengagement, satisfaction) but does +**not** exercise: + +- `execution.failure.*` +- `execution.loops.*` +- `environment.exhaustion.*` + +Those signals require `function_call` / `observation` ShareGPT roles. They are +covered by the Rust unit tests and the Python repo's own test fixtures, both +of which run on every PR. A synthetic tool-trace dataset for full coverage is +deferred to a follow-up. + +## One-time setup + +```bash +# 1. Build the Rust replay binary. +cd ../../../crates && cargo build --release -p brightstaff --bin signals_replay + +# 2. Set up the Python environment for the harness driver. +cd ../tests/parity/signals +python3 -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt + +# 3. Install the Python signals reference. +# Either point at a local checkout: +pip install -e /path/to/signals +# or pull from git: +pip install 'signals @ git+https://github.com/katanemo/signals@' +``` + +## Running + +```bash +source .venv/bin/activate + +python run_parity.py \ + --num-samples 2000 \ + --seed 42 \ + --dataset-revision \ + --rust-binary ../../../crates/target/release/signals_replay \ + --output-dir out/ + +python compare.py --output-dir out/ +``` + +`run_parity.py` will: + +1. Download `lmsys/lmsys-chat-1m` (cached in `~/.cache/huggingface`). +2. Pick `--num-samples` rows under `--seed`. +3. Convert each to ShareGPT, write `out/conversations.jsonl`. +4. Run the Rust binary as a subprocess → `out/rust_reports.jsonl`. +5. Run the Python analyzer in-process → `out/python_reports.jsonl`. + +`compare.py` reads both report files and writes: + +- `out/diffs.jsonl` — one record per mismatched conversation, with tier + structural diff +- `out/metrics.json` — agreement %, per-`SignalType` confusion matrix, quality-bucket confusion matrix +- `out/summary.md` — human-readable PR-ready report + +Exit code is non-zero iff any Tier-A divergence is observed. + +## Reproducibility + +Every run pins: + +- `dataset_revision` — the HF dataset commit +- `seed` — RNG seed for sampling +- `signals_python_version` — `pip show signals` version +- `plano_git_sha` — `git rev-parse HEAD` of this repo +- `signals_replay_binary_sha256` — the hash of the Rust bin + +All are stamped into `metrics.json`. diff --git a/tests/parity/signals/_smoke_test.py b/tests/parity/signals/_smoke_test.py new file mode 100644 index 00000000..68c6e879 --- /dev/null +++ b/tests/parity/signals/_smoke_test.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +Local smoke test for the parity harness — runs both runners on a tiny +hand-picked set of conversations without touching the lmsys dataset. + +Run from this directory: + python _smoke_test.py --rust-binary +""" + +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +from pathlib import Path + +from signals.analyzer import SignalAnalyzer + +SAMPLES = [ + { + "id": "smoke-gratitude", + "messages": [ + {"from": "human", "value": "What is the weather in Istanbul?"}, + {"from": "gpt", "value": "Istanbul is 14C and partly cloudy."}, + {"from": "human", "value": "That worked, exactly what I needed. Thanks!"}, + ], + }, + { + "id": "smoke-escalation", + "messages": [ + {"from": "human", "value": "This isn't helpful at all"}, + {"from": "gpt", "value": "I'm sorry, can you tell me more?"}, + {"from": "human", "value": "Get me a human, this is useless"}, + ], + }, + { + "id": "smoke-correction", + "messages": [ + {"from": "human", "value": "Book me a flight to NYC for tomorrow"}, + {"from": "gpt", "value": "Sure, here are flights to NYC for Friday."}, + { + "from": "human", + "value": "No, I meant flights for Saturday, not tomorrow", + }, + ], + }, + { + "id": "smoke-clean", + "messages": [ + {"from": "human", "value": "Hi"}, + {"from": "gpt", "value": "Hello, how can I help?"}, + ], + }, + { + "id": "smoke-rephrase", + "messages": [ + {"from": "human", "value": "Can you summarize the news please"}, + {"from": "gpt", "value": "Sure, here is a summary."}, + {"from": "human", "value": "Could you please summarize the news"}, + ], + }, +] + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--rust-binary", required=True, type=Path) + args = p.parse_args() + + out_dir = Path("out_smoke") + out_dir.mkdir(exist_ok=True) + conv_path = out_dir / "conversations.jsonl" + rust_path = out_dir / "rust_reports.jsonl" + py_path = out_dir / "python_reports.jsonl" + + with conv_path.open("w") as f: + for s in SAMPLES: + f.write(json.dumps(s) + "\n") + + with conv_path.open("rb") as fin, rust_path.open("wb") as fout: + proc = subprocess.run( + [str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE + ) + if proc.returncode != 0: + sys.stderr.write(proc.stderr.decode("utf-8", errors="replace")) + return 2 + + analyzer = SignalAnalyzer() + with conv_path.open() as fin, py_path.open("w") as fout: + for line in fin: + obj = json.loads(line) + r = analyzer.analyze(obj["messages"]) + fout.write(json.dumps({"id": obj["id"], "report": r.to_dict()}) + "\n") + + rc = subprocess.call( + [sys.executable, "compare.py", "--output-dir", str(out_dir)], + ) + return rc + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/parity/signals/compare.py b/tests/parity/signals/compare.py new file mode 100644 index 00000000..80f56295 --- /dev/null +++ b/tests/parity/signals/compare.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +Diff Rust vs Python signal reports produced by run_parity.py. + +See README.md for the tier definitions. Exits non-zero iff any Tier-A +divergence is found. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter, defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +CATEGORIES_BY_LAYER = { + "interaction_signals": [ + "misalignment", + "stagnation", + "disengagement", + "satisfaction", + ], + "execution_signals": ["failure", "loops"], + "environment_signals": ["exhaustion"], +} + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--output-dir", type=Path, default=Path("out")) + return p.parse_args() + + +def load_jsonl(path: Path) -> Dict[str, Dict[str, Any]]: + """Load a JSONL file keyed by `id`. Lines with errors are still indexed.""" + out: Dict[str, Dict[str, Any]] = {} + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + obj = json.loads(line) + out[str(obj.get("id"))] = obj + return out + + +def per_type_counts(report: Dict[str, Any]) -> Dict[str, int]: + """Return {signal_type: count} across all groups in a report dict.""" + counts: Counter[str] = Counter() + for layer in CATEGORIES_BY_LAYER: + groups = report.get(layer, {}) or {} + for category in CATEGORIES_BY_LAYER[layer]: + group = groups.get(category) + if not group: + continue + for sig in group.get("signals", []) or []: + counts[sig["signal_type"]] += 1 + return dict(counts) + + +def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]: + out: Dict[str, List[int]] = defaultdict(list) + for layer in CATEGORIES_BY_LAYER: + groups = report.get(layer, {}) or {} + for category in CATEGORIES_BY_LAYER[layer]: + group = groups.get(category) + if not group: + continue + for sig in group.get("signals", []) or []: + out[sig["signal_type"]].append(sig.get("message_index")) + for k in out: + out[k].sort(key=lambda x: (x is None, x)) + return dict(out) + + +def diff_counts(a: Dict[str, int], b: Dict[str, int]) -> List[Tuple[str, int, int]]: + """Return [(signal_type, a_count, b_count)] for entries that differ.""" + keys = set(a) | set(b) + out = [] + for k in sorted(keys): + ac = a.get(k, 0) + bc = b.get(k, 0) + if ac != bc: + out.append((k, ac, bc)) + return out + + +def diff_indices( + a: Dict[str, List[int]], b: Dict[str, List[int]] +) -> List[Tuple[str, List[int], List[int]]]: + keys = set(a) | set(b) + out = [] + for k in sorted(keys): + ai = a.get(k, []) + bi = b.get(k, []) + if ai != bi: + out.append((k, ai, bi)) + return out + + +def compare_one( + convo_id: str, py: Dict[str, Any], rust: Dict[str, Any] +) -> Dict[str, Any] | None: + """Compare a single conversation. Return diff record, or None if identical.""" + if "error" in py or "error" in rust: + return { + "id": convo_id, + "tier": "A", + "kind": "error_in_runner", + "python_error": py.get("error"), + "rust_error": rust.get("error"), + } + py_report = py["report"] + rust_report = rust["report"] + + py_counts = per_type_counts(py_report) + rust_counts = per_type_counts(rust_report) + count_diff = diff_counts(py_counts, rust_counts) + + py_quality = py_report.get("overall_quality") + rust_quality = rust_report.get("overall_quality") + quality_mismatch = py_quality != rust_quality + + if count_diff or quality_mismatch: + return { + "id": convo_id, + "tier": "A", + "kind": "signal_or_quality_mismatch", + "quality": {"python": py_quality, "rust": rust_quality}, + "count_diff": [ + {"signal_type": st, "python": pc, "rust": rc} + for (st, pc, rc) in count_diff + ], + } + + py_idx = per_type_indices(py_report) + rust_idx = per_type_indices(rust_report) + idx_diff = diff_indices(py_idx, rust_idx) + if idx_diff: + return { + "id": convo_id, + "tier": "B", + "kind": "instance_index_mismatch", + "diff": [ + {"signal_type": st, "python_indices": pi, "rust_indices": ri} + for (st, pi, ri) in idx_diff + ], + } + + return None + + +def confusion_matrix( + pairs: List[Tuple[str, str]], labels: List[str] +) -> Dict[str, Dict[str, int]]: + cm: Dict[str, Dict[str, int]] = {a: {b: 0 for b in labels} for a in labels} + for py, rust in pairs: + if py not in cm: + cm[py] = {b: 0 for b in labels} + if rust not in cm[py]: + cm[py][rust] = 0 + cm[py][rust] += 1 + return cm + + +def main() -> int: + args = parse_args() + out_dir = args.output_dir + + py_reports = load_jsonl(out_dir / "python_reports.jsonl") + rust_reports = load_jsonl(out_dir / "rust_reports.jsonl") + + common_ids = sorted(set(py_reports) & set(rust_reports)) + only_py = sorted(set(py_reports) - set(rust_reports)) + only_rust = sorted(set(rust_reports) - set(py_reports)) + + diffs: List[Dict[str, Any]] = [] + quality_pairs: List[Tuple[str, str]] = [] + per_type_total = Counter() + per_type_disagree = Counter() + + tier_a = 0 + tier_b = 0 + for cid in common_ids: + d = compare_one(cid, py_reports[cid], rust_reports[cid]) + if d is None: + quality_pairs.append( + ( + py_reports[cid]["report"]["overall_quality"], + rust_reports[cid]["report"]["overall_quality"], + ) + ) + for st, _ in per_type_counts(py_reports[cid]["report"]).items(): + per_type_total[st] += 1 + else: + diffs.append(d) + if d["tier"] == "A": + tier_a += 1 + elif d["tier"] == "B": + tier_b += 1 + if "report" in py_reports[cid] and "report" in rust_reports[cid]: + quality_pairs.append( + ( + py_reports[cid]["report"].get("overall_quality", "?"), + rust_reports[cid]["report"].get("overall_quality", "?"), + ) + ) + for cd in d.get("count_diff", []) or []: + per_type_disagree[cd["signal_type"]] += 1 + per_type_total[cd["signal_type"]] += 1 + + n_total = len(common_ids) + n_match = n_total - len(diffs) + agreement = (n_match / n_total) if n_total else 0.0 + + quality_labels = ["excellent", "good", "neutral", "poor", "severe"] + cm = confusion_matrix(quality_pairs, quality_labels) + + metrics = { + "n_python_reports": len(py_reports), + "n_rust_reports": len(rust_reports), + "n_common": n_total, + "n_only_python": len(only_py), + "n_only_rust": len(only_rust), + "n_full_match": n_match, + "agreement_pct": round(100.0 * agreement, 4), + "tier_a_divergences": tier_a, + "tier_b_divergences": tier_b, + "quality_confusion_matrix": cm, + "per_signal_type_total": dict(per_type_total), + "per_signal_type_disagree": dict(per_type_disagree), + } + + # Pull in run metadata if present. + rm_path = out_dir / "run_metadata.json" + if rm_path.exists(): + metrics["run_metadata"] = json.loads(rm_path.read_text()) + + (out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2)) + with (out_dir / "diffs.jsonl").open("w") as f: + for d in diffs: + f.write(json.dumps(d, ensure_ascii=False)) + f.write("\n") + + write_summary_md(out_dir / "summary.md", metrics, diffs[:20]) + + print( + json.dumps( + {k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, + indent=2, + ) + ) + print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}") + print(f"summary: {out_dir / 'summary.md'}") + + if tier_a > 0: + print(f"\nFAIL: {tier_a} Tier-A divergence(s) detected.", file=sys.stderr) + return 1 + return 0 + + +def write_summary_md( + path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]] +) -> None: + lines: List[str] = [] + lines.append("# Signals Parity Report") + lines.append("") + rm = metrics.get("run_metadata", {}) + if rm: + lines.append("## Run metadata") + lines.append("") + for k in ( + "dataset_name", + "dataset_revision", + "seed", + "num_samples_actual", + "plano_git_sha", + "signals_python_version", + "rust_binary_sha256", + ): + if k in rm: + lines.append(f"- **{k}**: `{rm[k]}`") + lines.append("") + + lines.append("## Summary") + lines.append("") + lines.append(f"- Conversations compared: **{metrics['n_common']}**") + lines.append(f"- Full matches: **{metrics['n_full_match']}**") + lines.append(f"- Agreement: **{metrics['agreement_pct']}%**") + lines.append(f"- Tier-A divergences: **{metrics['tier_a_divergences']}**") + lines.append(f"- Tier-B divergences: **{metrics['tier_b_divergences']}**") + lines.append("") + + lines.append("## Per-signal-type disagreement") + lines.append("") + lines.append("| Signal type | Total reports | Disagreements |") + lines.append("|---|---:|---:|") + totals = metrics["per_signal_type_total"] + disagrees = metrics["per_signal_type_disagree"] + for k in sorted(set(totals) | set(disagrees)): + lines.append(f"| `{k}` | {totals.get(k, 0)} | {disagrees.get(k, 0)} |") + lines.append("") + + lines.append("## Quality bucket confusion matrix (rows = python, cols = rust)") + lines.append("") + cm = metrics["quality_confusion_matrix"] + labels = list(cm.keys()) + lines.append("| | " + " | ".join(labels) + " |") + lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|") + for r in labels: + lines.append( + f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |" + ) + lines.append("") + + if sample_diffs: + lines.append("## Sample divergences (first 20)") + lines.append("") + for d in sample_diffs: + lines.append(f"### `{d['id']}` — tier {d['tier']} — {d['kind']}") + lines.append("") + lines.append("```json") + lines.append(json.dumps(d, indent=2)) + lines.append("```") + lines.append("") + + path.write_text("\n".join(lines)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/parity/signals/requirements.txt b/tests/parity/signals/requirements.txt new file mode 100644 index 00000000..7b25f179 --- /dev/null +++ b/tests/parity/signals/requirements.txt @@ -0,0 +1,3 @@ +huggingface_hub>=0.25 +pyarrow>=15 +tqdm>=4.65 diff --git a/tests/parity/signals/run_parity.py b/tests/parity/signals/run_parity.py new file mode 100644 index 00000000..1d14630e --- /dev/null +++ b/tests/parity/signals/run_parity.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +""" +Parity harness driver. + +Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python +reference analyzer (in-process) and the Rust port (subprocess), writes both +reports to disk for `compare.py` to diff. + +Usage: + python run_parity.py \\ + --num-samples 2000 \\ + --seed 42 \\ + --dataset-revision \\ + --rust-binary ../../../crates/target/release/signals_replay \\ + --output-dir out/ +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import random +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, Iterator, List + +try: + import pyarrow.parquet as pq + from huggingface_hub import hf_hub_download, list_repo_files +except ImportError: + print( + "error: install dependencies first: pip install -r requirements.txt", + file=sys.stderr, + ) + sys.exit(2) + +try: + from signals.analyzer import SignalAnalyzer +except ImportError: + print( + "error: the python `signals` package is not installed. " + "install it from your local checkout: pip install -e /path/to/signals", + file=sys.stderr, + ) + sys.exit(2) + +try: + from tqdm import tqdm +except ImportError: + + def tqdm(it, **_kwargs): # type: ignore[no-redef] + return it + + +DATASET_NAME = "lmsys/lmsys-chat-1m" + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--num-samples", type=int, default=2000) + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--dataset-revision", + default=None, + help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)", + ) + p.add_argument( + "--rust-binary", + type=Path, + required=True, + help="path to the `signals_replay` binary built from crates/brightstaff", + ) + p.add_argument( + "--output-dir", + type=Path, + default=Path("out"), + help="directory to write the conversations + both runners' outputs", + ) + p.add_argument( + "--max-conv-messages", + type=int, + default=200, + help="drop conversations with more than this many messages (the analyzer " + "truncates to last 100 anyway; this is a sanity cap on input parsing)", + ) + return p.parse_args() + + +def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`. + + lmsys uses `user` / `assistant` (no tools, no system role in `conversation`). + """ + out = [] + for m in conversation: + role = m.get("role", "") + content = m.get("content", "") + if not isinstance(content, str): + content = str(content) if content is not None else "" + if role == "user": + from_ = "human" + elif role == "assistant": + from_ = "gpt" + else: + # lmsys is human/assistant only; skip anything else defensively. + continue + out.append({"from": from_, "value": content}) + return out + + +def _list_parquet_files(revision: str | None) -> List[str]: + """Return the list of parquet shard paths in the dataset repo.""" + files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision) + return sorted(f for f in files if f.endswith(".parquet")) + + +def _download_shards(paths: List[str], revision: str | None) -> List[Path]: + """Download each parquet shard to the HF cache, return local paths.""" + local: List[Path] = [] + for rel in tqdm(paths, desc="downloading shards", unit="shard"): + p = hf_hub_download( + DATASET_NAME, + filename=rel, + repo_type="dataset", + revision=revision, + ) + local.append(Path(p)) + return local + + +def sample_conversations( + *, + num_samples: int, + seed: int, + revision: str | None, + max_conv_messages: int, +) -> Iterator[Dict[str, Any]]: + """Yield `num_samples` conversations sampled uniformly across the dataset. + + We bypass the `datasets` loader (which has a Python 3.14 pickle issue) + and read the parquet shards directly via pyarrow. + """ + print( + f"listing {DATASET_NAME}" + f"{' @ ' + revision if revision else ' (no revision pinned!)'}", + file=sys.stderr, + ) + shard_paths = _list_parquet_files(revision) + if not shard_paths: + raise SystemExit(f"no parquet shards found for {DATASET_NAME}") + local_paths = _download_shards(shard_paths, revision) + + # Collect row counts without reading data. + shard_row_counts: List[int] = [] + for p in local_paths: + pf = pq.ParquetFile(str(p)) + shard_row_counts.append(pf.metadata.num_rows) + total_rows = sum(shard_row_counts) + print( + f"dataset has {total_rows:,} rows across {len(local_paths)} shards", + file=sys.stderr, + ) + + rng = random.Random(seed) + global_indices = sorted(rng.sample(range(total_rows), num_samples)) + + # Bucket indices by shard. + by_shard: Dict[int, List[int]] = {} + cumulative = 0 + shard_offsets = [] + for c in shard_row_counts: + shard_offsets.append(cumulative) + cumulative += c + for gi in global_indices: + # Find which shard this index belongs to. + for si, off in enumerate(shard_offsets): + if gi < off + shard_row_counts[si]: + by_shard.setdefault(si, []).append(gi - off) + break + + yielded = 0 + for si in sorted(by_shard.keys()): + local_rows = by_shard[si] + pf = pq.ParquetFile(str(local_paths[si])) + table = pf.read(columns=["conversation"]) + conv_col = table.column("conversation") + for local_idx in local_rows: + raw = conv_col[local_idx].as_py() + if not raw: + continue + conversation = raw if isinstance(raw, list) else raw.get("conversation", []) + if len(conversation) > max_conv_messages: + continue + messages = lmsys_to_sharegpt(conversation) + if not messages: + continue + global_idx = shard_offsets[si] + local_idx + yield { + "id": f"lmsys-{global_idx}", + "messages": messages, + } + yielded += 1 + print(f"yielded {yielded} conversations after filtering", file=sys.stderr) + + +def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int: + n = 0 + with out_path.open("w") as f: + for s in tqdm(samples, desc="sampling", unit="convo"): + f.write(json.dumps(s, ensure_ascii=False)) + f.write("\n") + n += 1 + return n + + +def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None: + print(f"running rust analyzer: {rust_binary}", file=sys.stderr) + t0 = time.monotonic() + with conv_path.open("rb") as fin, out_path.open("wb") as fout: + proc = subprocess.run( + [str(rust_binary)], + stdin=fin, + stdout=fout, + stderr=subprocess.PIPE, + check=False, + ) + if proc.returncode != 0: + sys.stderr.write(proc.stderr.decode("utf-8", errors="replace")) + raise SystemExit(f"rust runner exited {proc.returncode}") + elapsed = time.monotonic() - t0 + print(f" rust runner: {elapsed:.1f}s", file=sys.stderr) + + +def run_python(conv_path: Path, out_path: Path) -> None: + print("running python analyzer...", file=sys.stderr) + t0 = time.monotonic() + analyzer = SignalAnalyzer() + with conv_path.open() as fin, out_path.open("w") as fout: + for line in tqdm(fin, desc="python", unit="convo"): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + report = analyzer.analyze(obj["messages"]) + fout.write( + json.dumps( + {"id": obj["id"], "report": report.to_dict()}, + ensure_ascii=False, + ) + ) + except Exception as e: + fout.write(json.dumps({"id": obj.get("id"), "error": str(e)})) + fout.write("\n") + elapsed = time.monotonic() - t0 + print(f" python runner: {elapsed:.1f}s", file=sys.stderr) + + +def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None: + """Write the input metadata so compare.py can include it in the report.""" + binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest() + try: + plano_sha = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent + ) + .decode() + .strip() + ) + except Exception: + plano_sha = "unknown" + try: + signals_version = subprocess.check_output( + [sys.executable, "-m", "pip", "show", "signals"] + ).decode() + signals_version = next( + ( + l.split(":", 1)[1].strip() + for l in signals_version.splitlines() + if l.startswith("Version") + ), + "unknown", + ) + except Exception: + signals_version = "unknown" + + meta = { + "dataset_name": DATASET_NAME, + "dataset_revision": args.dataset_revision, + "seed": args.seed, + "num_samples_requested": args.num_samples, + "num_samples_actual": n_samples, + "rust_binary": str(args.rust_binary.resolve()), + "rust_binary_sha256": binary_sha, + "plano_git_sha": plano_sha, + "signals_python_version": signals_version, + "max_conv_messages": args.max_conv_messages, + } + (output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2)) + print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr) + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + if not args.rust_binary.exists(): + raise SystemExit(f"rust binary not found at {args.rust_binary}") + + conv_path = args.output_dir / "conversations.jsonl" + rust_path = args.output_dir / "rust_reports.jsonl" + py_path = args.output_dir / "python_reports.jsonl" + + samples = sample_conversations( + num_samples=args.num_samples, + seed=args.seed, + revision=args.dataset_revision, + max_conv_messages=args.max_conv_messages, + ) + n = write_conversations(conv_path, samples) + print(f"wrote {n} conversations to {conv_path}", file=sys.stderr) + + run_rust(args.rust_binary, conv_path, rust_path) + run_python(conv_path, py_path) + stamp_metadata(args, args.output_dir, n) + print("done. now run: python compare.py --output-dir " + str(args.output_dir)) + + +if __name__ == "__main__": + main()