add session pinning to llm_chat handler and rewrite session pinning demo

- extend brightstaff llm_chat_inner to extract X-Session-Id, check the
  session cache before routing, and cache the result afterward — same
  pattern as routing_service.rs
- replace old urllib-based demo with a real FastAPI research agent that
  runs 3 independent tool-calling tasks with alternating intents so
  Plano routes to different models; demo.py is a pure httpx client that
  shows the routing trace side-by-side with and without session pinning
This commit is contained in:
Adil Hafeez 2026-03-26 16:44:05 -07:00
parent 71437d2b2c
commit 0105897692
7 changed files with 771 additions and 200 deletions

View file

@ -1,150 +1,174 @@
#!/usr/bin/env python3
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = ["httpx>=0.27"]
# ///
"""
Session Pinning Demo Iterative Research Agent
Session Pinning Demo Research Agent client
Demonstrates how session pinning ensures consistent model selection
across multiple iterations of an agentic loop. Runs the same 5-step
research workflow twice:
Sends the same query to the Research Agent twice once without a session ID
and once with one and compares the routing trace to show how session pinning
keeps the model consistent across the LLM's tool-calling loop.
1) Without session pinning models may switch between iterations
2) With session pinning first iteration pins the model for all subsequent ones
Requires the agent to already be running (start it with ./start_agents.sh).
Uses the /routing/v1/chat/completions endpoint (routing decisions only, no LLM calls).
Usage:
uv run demo.py
AGENT_URL=http://localhost:8000 uv run demo.py
"""
import json
import asyncio
import os
import urllib.request
import uuid
PLANO_URL = os.environ.get("PLANO_URL", "http://localhost:12000")
import httpx
# Simulates an iterative research agent building a task management app.
# Prompts deliberately alternate between code_generation and complex_reasoning
# intents so that without pinning, different models get selected per step.
RESEARCH_STEPS = [
"Design a REST API schema for a task management app with users, projects, and tasks",
"Analyze the trade-offs between SQL and NoSQL databases for this task management system",
"Write the database models and ORM setup in Python using SQLAlchemy",
"Review the API design for security vulnerabilities and suggest improvements",
"Implement the authentication middleware with JWT tokens",
]
AGENT_URL = os.environ.get("AGENT_URL", "http://localhost:8000")
QUERY = (
"Should we use PostgreSQL or MongoDB for a high-traffic e-commerce backend "
"that needs strong consistency for orders but flexible schemas for products?"
)
STEP_LABELS = [
"Design REST API schema",
"Analyze SQL vs NoSQL trade-offs",
"Write SQLAlchemy database models",
"Review API security vulnerabilities",
"Implement JWT auth middleware",
]
# ---------------------------------------------------------------------------
# Client helpers
# ---------------------------------------------------------------------------
def run_research_loop(session_id=None):
"""Run the research agent loop, optionally with session pinning."""
results = []
async def wait_for_agent(timeout: int = 30) -> bool:
async with httpx.AsyncClient() as client:
for _ in range(timeout * 2):
try:
r = await client.get(f"{AGENT_URL}/health", timeout=1.0)
if r.status_code == 200:
return True
except Exception:
pass
await asyncio.sleep(0.5)
return False
for i, prompt in enumerate(RESEARCH_STEPS, 1):
headers = {"Content-Type": "application/json"}
if session_id:
headers["X-Session-Id"] = session_id
payload = {
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": prompt}],
}
async def ask_agent(query: str, session_id: str | None = None) -> dict:
headers: dict[str, str] = {}
if session_id:
headers["X-Session-Id"] = session_id
resp = urllib.request.urlopen(
urllib.request.Request(
f"{PLANO_URL}/routing/v1/chat/completions",
data=json.dumps(payload).encode(),
headers=headers,
),
timeout=10,
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(
f"{AGENT_URL}/v1/chat/completions",
headers=headers,
json={"messages": [{"role": "user", "content": query}]},
)
data = json.loads(resp.read())
model = data.get("model", "unknown")
route = data.get("route") or "none"
pinned = data.get("pinned")
results.append({"step": i, "model": model, "route": route, "pinned": pinned})
return results
r.raise_for_status()
return r.json()
def print_results_table(results):
"""Print results as a compact aligned table."""
label_width = max(len(l) for l in STEP_LABELS)
for r in results:
step = r["step"]
label = STEP_LABELS[step - 1]
model = r["model"]
pinned = r["pinned"]
# Shorten model names for readability
short_model = model.replace("anthropic/", "").replace("openai/", "")
pin_indicator = ""
if pinned is True:
pin_indicator = " ◀ pinned"
elif pinned is False:
pin_indicator = " ◀ routed"
print(f" {step}. {label:<{label_width}}{short_model}{pin_indicator}")
# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------
def print_summary(label, results):
"""Print a one-line summary of model consistency."""
models = [r["model"] for r in results]
unique = sorted(set(models))
def _short(model: str) -> str:
return model.split("/")[-1] if "/" in model else model
def _print_trace(result: dict) -> None:
trace = result.get("routing_trace", [])
if not trace:
print(" (no trace)")
return
prev: str | None = None
for t in trace:
short = _short(t["model"])
switch = " ← switched" if (prev and t["model"] != prev) else ""
prev = t["model"]
print(f" {t['task']:<26} [{short}]{switch}")
def _print_summary(label: str, result: dict) -> None:
models = [t["model"] for t in result.get("routing_trace", [])]
if not models:
print(f" ? {label}: no routing data")
return
unique = set(models)
if len(unique) == 1:
short = models[0].replace("anthropic/", "").replace("openai/", "")
print(f"{label}: All 5 steps → {short}")
print(f"{label}: {_short(next(iter(unique)))} for all {len(models)} turns")
else:
short = [m.replace("anthropic/", "").replace("openai/", "") for m in unique]
print(f"{label}: Models varied → {', '.join(short)}")
switched = sum(1 for a, b in zip(models, models[1:]) if a != b)
names = ", ".join(sorted(_short(m) for m in unique))
print(f"{label}: model switched {switched} time(s) — {names}")
def main():
# ---------------------------------------------------------------------------
# Demo
# ---------------------------------------------------------------------------
async def main() -> None:
print()
print(" ╔══════════════════════════════════════════════════════════════╗")
print(" ║ Session Pinning Demo — Iterative Research Agent ║")
print("Session Pinning Demo — Research Agent ")
print(" ╚══════════════════════════════════════════════════════════════╝")
print()
print(" An agent builds a task management app in 5 steps.")
print(" Each step asks Plano's router which model to use.")
print(f" Agent : {AGENT_URL}")
print(f" Query : \"{QUERY[:72]}\"")
print()
print(" The agent uses a tool-calling loop (get_db_benchmarks,")
print(" get_case_studies, check_feature_support) to research the")
print(" question. Each LLM turn hits Plano's preference-based router.")
print()
# --- Run 1: Without session pinning ---
print(" ┌──────────────────────────────────────────────────────────────┐")
print(" │ Run 1: WITHOUT Session Pinning │")
print(" └──────────────────────────────────────────────────────────────┘")
print()
results_no_pin = run_research_loop(session_id=None)
print_results_table(results_no_pin)
print(f" Waiting for agent at {AGENT_URL}", end=" ", flush=True)
if not await wait_for_agent():
print("FAILED — agent did not respond within 30 s")
return
print("ready.")
print()
# --- Run 2: With session pinning ---
session_id = str(uuid.uuid4())
short_sid = session_id[:8]
print(f" ┌──────────────────────────────────────────────────────────────┐")
print(f" │ Run 2: WITH Session Pinning (session: {short_sid}…) │")
print(f" └──────────────────────────────────────────────────────────────┘")
sid = str(uuid.uuid4())
print(" Sending queries (running concurrently)…")
print()
results_pinned = run_research_loop(session_id=session_id)
print_results_table(results_pinned)
without, with_pin = await asyncio.gather(
ask_agent(QUERY, session_id=None),
ask_agent(QUERY, session_id=sid),
)
# ── Run 1 ────────────────────────────────────────────────────────────
print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
print(" Run 1: WITHOUT Session Pinning")
print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
print()
print(" LLM turns inside the agent loop:")
print()
_print_trace(without)
print()
_print_summary("Without pinning", without)
print()
# --- Summary ---
print(" ┌──────────────────────────────────────────────────────────────┐")
print(" │ Summary │")
print(" └──────────────────────────────────────────────────────────────┘")
# ── Run 2 ────────────────────────────────────────────────────────────
print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
print(f" Run 2: WITH Session Pinning (X-Session-Id: {sid[:8]}…)")
print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
print()
print_summary("Without pinning", results_no_pin)
print_summary("With pinning ", results_pinned)
print(" LLM turns inside the agent loop:")
print()
_print_trace(with_pin)
print()
_print_summary("With pinning ", with_pin)
print()
# ── Final answer ─────────────────────────────────────────────────────
answer = with_pin["choices"][0]["message"]["content"]
print(" ══ Agent recommendation (pinned session) ═════════════════════")
print()
for line in answer.splitlines():
print(f" {line}")
print()
print(" ══════════════════════════════════════════════════════════════")
print()
if __name__ == "__main__":
main()
asyncio.run(main())