Merge pull request #12 from nomyo-ai/dev-v0.4.x
token usage counter for non-stream openai ollama endpoints and improvements
This commit is contained in:
commit
c6c1059ede
3 changed files with 84 additions and 17 deletions
50
router.py
50
router.py
|
|
@ -123,6 +123,7 @@ default_headers={
|
|||
# 3. Global state: per‑endpoint per‑model active connection counters
|
||||
# -------------------------------------------------------------
|
||||
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
usage_lock = asyncio.Lock() # protects access to usage_counts
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -191,6 +192,13 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
|
|||
|
||||
return True # It's an external OpenAI endpoint
|
||||
|
||||
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None:
|
||||
async def _record():
|
||||
async with usage_lock: # reuse the same lock that protects usage_counts
|
||||
token_usage_counts[endpoint][model] += (prompt + completion)
|
||||
await publish_snapshot() # immediately broadcast the new totals
|
||||
asyncio.create_task(_record())
|
||||
|
||||
class fetch:
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -336,15 +344,14 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
await publish_snapshot()
|
||||
|
||||
def iso8601_ns():
|
||||
ns_since_epoch = time.time_ns()
|
||||
dt = datetime.datetime.fromtimestamp(
|
||||
ns_since_epoch / 1_000_000_000, # seconds
|
||||
tz=datetime.timezone.utc
|
||||
ns = time.time_ns()
|
||||
sec, ns_rem = divmod(ns, 1_000_000_000)
|
||||
dt = datetime.datetime.fromtimestamp(sec, tz=datetime.timezone.utc)
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
|
||||
f"{ns_rem:09d}Z"
|
||||
)
|
||||
iso8601_with_ns = (
|
||||
dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z"
|
||||
)
|
||||
return iso8601_with_ns
|
||||
|
||||
def is_base64(image_string):
|
||||
try:
|
||||
|
|
@ -507,7 +514,9 @@ class rechunk:
|
|||
# ------------------------------------------------------------------
|
||||
async def publish_snapshot():
|
||||
async with usage_lock:
|
||||
snapshot = json.dumps({"usage_counts": usage_counts}, sort_keys=True)
|
||||
snapshot = json.dumps({"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts,
|
||||
}, sort_keys=True)
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
# If the queue is full, drop the message to avoid back‑pressure.
|
||||
|
|
@ -710,6 +719,9 @@ async def proxy(request: Request):
|
|||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
|
|
@ -721,6 +733,9 @@ async def proxy(request: Request):
|
|||
response = response.model_dump_json()
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -791,7 +806,7 @@ async def chat_proxy(request: Request):
|
|||
optional_params = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
"stream_options": {"include_usage": True} if stream is not None else None,
|
||||
"stream_options": {"include_usage": True} if stream else None,
|
||||
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
||||
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
||||
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
||||
|
|
@ -820,6 +835,9 @@ async def chat_proxy(request: Request):
|
|||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
|
|
@ -831,6 +849,9 @@ async def chat_proxy(request: Request):
|
|||
response = response.model_dump_json()
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -1315,7 +1336,8 @@ async def usage_proxy(request: Request):
|
|||
Return a snapshot of the usage counter for each endpoint.
|
||||
Useful for debugging / monitoring.
|
||||
"""
|
||||
return {"usage_counts": usage_counts}
|
||||
return {"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 20. Proxy config route – for monitoring and frontent usage
|
||||
|
|
@ -1485,6 +1507,9 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -1588,6 +1613,9 @@ async def openai_completions_proxy(request: Request):
|
|||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue