Fix request closures during long-running streaming

This commit is contained in:
adilhafeez 2026-04-18 18:10:23 -07:00
parent ffea891dba
commit c8c6b87d1e
7 changed files with 637 additions and 149 deletions

View file

@ -7,6 +7,7 @@ Single-source: one fetch at startup, cached for the life of the process.
from __future__ import annotations
import logging
import re
import threading
from dataclasses import dataclass
from typing import Any
@ -123,13 +124,28 @@ class PricingCatalog:
return round(cost, 6)
_DATE_SUFFIX_RE = re.compile(r"-\d{8}$")
_PROVIDER_PREFIXES = ("anthropic", "openai", "google", "meta", "cohere", "mistral")
_ANTHROPIC_FAMILIES = {"opus", "sonnet", "haiku"}
def _model_key_candidates(model_name: str) -> list[str]:
"""Lookup-side variants of a Plano-emitted model name.
Plano resolves names like ``claude-haiku-4-5-20251001``; the catalog stores
them as ``anthropic-claude-haiku-4.5``. We strip the date suffix and the
``provider/`` prefix here; the catalog itself registers the dash/dot and
family-order aliases at parse time (see :func:`_expand_aliases`).
"""
base = model_name.strip()
out = [base]
if "/" in base:
out.append(base.split("/", 1)[1])
for k in list(out):
stripped = _DATE_SUFFIX_RE.sub("", k)
if stripped != k:
out.append(stripped)
out.extend([v.lower() for v in list(out)])
# Dedup while preserving order.
seen: set[str] = set()
uniq = []
for key in out:
@ -139,6 +155,54 @@ def _model_key_candidates(model_name: str) -> list[str]:
return uniq
def _expand_aliases(model_id: str) -> set[str]:
"""Catalog-side variants of a DO model id.
DO publishes Anthropic models under ids like ``anthropic-claude-opus-4.7``
or ``anthropic-claude-4.6-sonnet`` while Plano emits ``claude-opus-4-7`` /
``claude-sonnet-4-6``. Generate a set covering provider-prefix stripping,
dashdot in version segments, and familyversion word order so a single
catalog entry matches every name shape we'll see at lookup.
"""
aliases: set[str] = set()
def add(name: str) -> None:
if not name:
return
aliases.add(name)
aliases.add(name.lower())
add(model_id)
base = model_id
head, _, rest = base.partition("-")
if head.lower() in _PROVIDER_PREFIXES and rest:
add(rest)
base = rest
for key in list(aliases):
if "." in key:
add(key.replace(".", "-"))
parts = base.split("-")
if len(parts) >= 3 and parts[0].lower() == "claude":
rest_parts = parts[1:]
for i, p in enumerate(rest_parts):
if p.lower() in _ANTHROPIC_FAMILIES:
others = rest_parts[:i] + rest_parts[i + 1 :]
if not others:
break
family_last = "claude-" + "-".join(others) + "-" + p
family_first = "claude-" + p + "-" + "-".join(others)
add(family_last)
add(family_first)
add(family_last.replace(".", "-"))
add(family_first.replace(".", "-"))
break
return aliases
def _parse_do_pricing(data: Any) -> dict[str, ModelPrice]:
"""Parse DO catalog response into a ModelPrice map keyed by model id.
@ -204,11 +268,13 @@ def _parse_do_pricing(data: Any) -> dict[str, ModelPrice]:
# rates for promo/open-weight models.
if input_rate == 0 and output_rate == 0:
continue
prices[str(model_id)] = ModelPrice(
price = ModelPrice(
input_per_token_usd=input_rate,
output_per_token_usd=output_rate,
cached_input_per_token_usd=cached_rate,
)
for alias in _expand_aliases(str(model_id)):
prices.setdefault(alias, price)
return prices