Merge pull request #12 from nomyo-ai/dev-v0.4.x

token usage counter for non-stream openai ollama endpoints and improvements
This commit is contained in:
Alpha Nerd 2025-11-08 11:54:33 +01:00 committed by GitHub
commit c6c1059ede
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 84 additions and 17 deletions

View file

@ -41,7 +41,7 @@ source .venv/router/bin/activate
pip3 install -r requirements.txt pip3 install -r requirements.txt
``` ```
on the shell do: [optional] on the shell do:
``` ```
export OPENAI_KEY=YOUR_SECRET_API_KEY export OPENAI_KEY=YOUR_SECRET_API_KEY

View file

@ -123,6 +123,7 @@ default_headers={
# 3. Global state: perendpoint permodel active connection counters # 3. Global state: perendpoint permodel active connection counters
# ------------------------------------------------------------- # -------------------------------------------------------------
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
usage_lock = asyncio.Lock() # protects access to usage_counts usage_lock = asyncio.Lock() # protects access to usage_counts
# ------------------------------------------------------------- # -------------------------------------------------------------
@ -191,6 +192,13 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
return True # It's an external OpenAI endpoint return True # It's an external OpenAI endpoint
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None:
async def _record():
async with usage_lock: # reuse the same lock that protects usage_counts
token_usage_counts[endpoint][model] += (prompt + completion)
await publish_snapshot() # immediately broadcast the new totals
asyncio.create_task(_record())
class fetch: class fetch:
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
""" """
@ -336,15 +344,14 @@ async def decrement_usage(endpoint: str, model: str) -> None:
await publish_snapshot() await publish_snapshot()
def iso8601_ns(): def iso8601_ns():
ns_since_epoch = time.time_ns() ns = time.time_ns()
dt = datetime.datetime.fromtimestamp( sec, ns_rem = divmod(ns, 1_000_000_000)
ns_since_epoch / 1_000_000_000, # seconds dt = datetime.datetime.fromtimestamp(sec, tz=datetime.timezone.utc)
tz=datetime.timezone.utc return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
f"{ns_rem:09d}Z"
) )
iso8601_with_ns = (
dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z"
)
return iso8601_with_ns
def is_base64(image_string): def is_base64(image_string):
try: try:
@ -507,7 +514,9 @@ class rechunk:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def publish_snapshot(): async def publish_snapshot():
async with usage_lock: async with usage_lock:
snapshot = json.dumps({"usage_counts": usage_counts}, sort_keys=True) snapshot = json.dumps({"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts,
}, sort_keys=True)
async with _subscribers_lock: async with _subscribers_lock:
for q in _subscribers: for q in _subscribers:
# If the queue is full, drop the message to avoid backpressure. # If the queue is full, drop the message to avoid backpressure.
@ -710,6 +719,9 @@ async def proxy(request: Request):
async for chunk in async_gen: async for chunk in async_gen:
if is_openai_endpoint: if is_openai_endpoint:
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
if hasattr(chunk, "model_dump_json"): if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json() json_line = chunk.model_dump_json()
else: else:
@ -721,6 +733,9 @@ async def proxy(request: Request):
response = response.model_dump_json() response = response.model_dump_json()
else: else:
response = async_gen.model_dump_json() response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
json_line = ( json_line = (
response response
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
@ -791,7 +806,7 @@ async def chat_proxy(request: Request):
optional_params = { optional_params = {
"tools": tools, "tools": tools,
"stream": stream, "stream": stream,
"stream_options": {"include_usage": True} if stream is not None else None, "stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None, "max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
@ -820,6 +835,9 @@ async def chat_proxy(request: Request):
if is_openai_endpoint: if is_openai_endpoint:
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
# `chunk` can be a dict or a pydantic model dump to JSON safely # `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
if hasattr(chunk, "model_dump_json"): if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json() json_line = chunk.model_dump_json()
else: else:
@ -831,6 +849,9 @@ async def chat_proxy(request: Request):
response = response.model_dump_json() response = response.model_dump_json()
else: else:
response = async_gen.model_dump_json() response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
json_line = ( json_line = (
response response
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
@ -1315,7 +1336,8 @@ async def usage_proxy(request: Request):
Return a snapshot of the usage counter for each endpoint. Return a snapshot of the usage counter for each endpoint.
Useful for debugging / monitoring. Useful for debugging / monitoring.
""" """
return {"usage_counts": usage_counts} return {"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts}
# ------------------------------------------------------------- # -------------------------------------------------------------
# 20. Proxy config route for monitoring and frontent usage # 20. Proxy config route for monitoring and frontent usage
@ -1485,6 +1507,9 @@ async def openai_chat_completions_proxy(request: Request):
yield f"data: {data}\n\n".encode("utf-8") yield f"data: {data}\n\n".encode("utf-8")
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
else: else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
json_line = ( json_line = (
async_gen.model_dump_json() async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
@ -1588,6 +1613,9 @@ async def openai_completions_proxy(request: Request):
# Final DONE event # Final DONE event
yield b"data: [DONE]\n\n" yield b"data: [DONE]\n\n"
else: else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
json_line = ( json_line = (
async_gen.model_dump_json() async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")

View file

@ -267,11 +267,12 @@
<th>Quant</th> <th>Quant</th>
<th>Ctx</th> <th>Ctx</th>
<th>Digest</th> <th>Digest</th>
<th>Token</th>
</tr> </tr>
</thead> </thead>
<tbody id="ps-body"> <tbody id="ps-body">
<tr> <tr>
<td colspan="5" class="loading">Loading…</td> <td colspan="6" class="loading">Loading…</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
@ -299,6 +300,7 @@
</div> </div>
<script> <script>
let psRows = new Map();
/* ---------- Utility ---------- */ /* ---------- Utility ---------- */
async function fetchJSON(url) { async function fetchJSON(url) {
const resp = await fetch(url); const resp = await fetch(url);
@ -435,11 +437,28 @@
const data = await fetchJSON("/api/ps"); const data = await fetchJSON("/api/ps");
const body = document.getElementById("ps-body"); const body = document.getElementById("ps-body");
body.innerHTML = data.models body.innerHTML = data.models
.map( .map(m => {
(m) => const existingRow = psRows.get(m.name);
`<tr><td class="model">${m.name}</td><td>${m.details.parameter_size}</td><td>${m.details.quantization_level}</td><td>${m.context_length}</td><td>${m.digest}</td></tr>`, const tokenValue = existingRow
) ? existingRow.querySelector(".token-usage")?.textContent ?? 0
: 0;
return `<tr data-model="${m.name}">
<td class="model">${m.name}</td>
<td>${m.details.parameter_size}</td>
<td>${m.details.quantization_level}</td>
<td>${m.context_length}</td>
<td>${m.digest}</td>
<td class="token-usage">${tokenValue}</td>
</tr>`;
})
.join(""); .join("");
psRows.clear();
document
.querySelectorAll("#ps-body tr[data-model]")
.forEach((row) => {
const model = row.dataset.model;
if (model) psRows.set(model, row);
});
} catch (e) { } catch (e) {
console.error(e); console.error(e);
} }
@ -502,6 +521,26 @@
try { try {
const payload = JSON.parse(e.data); // SSE sends plain text const payload = JSON.parse(e.data); // SSE sends plain text
renderChart(payload); renderChart(payload);
const usage = payload.usage_counts || {};
const tokens = payload.token_usage_counts || {};
psRows.forEach((row, model) => {
/* regular usage count optional if you want to keep it */
let total = 0;
for (const ep in usage) {
total += usage[ep][model] || 0;
}
const usageCell = row.querySelector(".usage");
if (usageCell) usageCell.textContent = total;
/* token usage */
let tokenTotal = 0;
for (const ep in tokens) {
tokenTotal += tokens[ep][model] || 0;
}
const tokenCell = row.querySelector(".token-usage");
if (tokenCell) tokenCell.textContent = tokenTotal;
});
} catch (err) { } catch (err) {
console.error("Failed to parse SSE payload", err); console.error("Failed to parse SSE payload", err);
} }