record and display total token usage on ollama endpoints using ollama client
This commit is contained in:
parent
9007f686c2
commit
4c9ec5b1b2
2 changed files with 70 additions and 8 deletions
29
router.py
29
router.py
|
|
@ -110,6 +110,7 @@ default_headers={
|
|||
# 3. Global state: per‑endpoint per‑model active connection counters
|
||||
# -------------------------------------------------------------
|
||||
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
usage_lock = asyncio.Lock() # protects access to usage_counts
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -137,6 +138,13 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
|
|||
|
||||
return True # It's an external OpenAI endpoint
|
||||
|
||||
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None:
|
||||
async def _record():
|
||||
async with usage_lock: # reuse the same lock that protects usage_counts
|
||||
token_usage_counts[endpoint][model] += (prompt + completion)
|
||||
await publish_snapshot() # immediately broadcast the new totals
|
||||
asyncio.create_task(_record())
|
||||
|
||||
class fetch:
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -447,7 +455,9 @@ class rechunk:
|
|||
# ------------------------------------------------------------------
|
||||
async def publish_snapshot():
|
||||
async with usage_lock:
|
||||
snapshot = json.dumps({"usage_counts": usage_counts}, sort_keys=True)
|
||||
snapshot = json.dumps({"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts,
|
||||
}, sort_keys=True)
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
# If the queue is full, drop the message to avoid back‑pressure.
|
||||
|
|
@ -650,6 +660,9 @@ async def proxy(request: Request):
|
|||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
|
|
@ -661,6 +674,9 @@ async def proxy(request: Request):
|
|||
response = response.model_dump_json()
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -731,7 +747,7 @@ async def chat_proxy(request: Request):
|
|||
optional_params = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
"stream_options": {"include_usage": True} if stream is not None else None,
|
||||
"stream_options": {"include_usage": True} if stream else None,
|
||||
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
||||
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
||||
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
||||
|
|
@ -760,6 +776,9 @@ async def chat_proxy(request: Request):
|
|||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
|
|
@ -771,6 +790,9 @@ async def chat_proxy(request: Request):
|
|||
response = response.model_dump_json()
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -1255,7 +1277,8 @@ async def usage_proxy(request: Request):
|
|||
Return a snapshot of the usage counter for each endpoint.
|
||||
Useful for debugging / monitoring.
|
||||
"""
|
||||
return {"usage_counts": usage_counts}
|
||||
return {"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 20. Proxy config route – for monitoring and frontent usage
|
||||
|
|
|
|||
|
|
@ -267,11 +267,12 @@
|
|||
<th>Quant</th>
|
||||
<th>Ctx</th>
|
||||
<th>Digest</th>
|
||||
<th>Token</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="ps-body">
|
||||
<tr>
|
||||
<td colspan="5" class="loading">Loading…</td>
|
||||
<td colspan="6" class="loading">Loading…</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
|
@ -299,6 +300,7 @@
|
|||
</div>
|
||||
|
||||
<script>
|
||||
let psRows = new Map();
|
||||
/* ---------- Utility ---------- */
|
||||
async function fetchJSON(url) {
|
||||
const resp = await fetch(url);
|
||||
|
|
@ -435,11 +437,28 @@
|
|||
const data = await fetchJSON("/api/ps");
|
||||
const body = document.getElementById("ps-body");
|
||||
body.innerHTML = data.models
|
||||
.map(
|
||||
(m) =>
|
||||
`<tr><td class="model">${m.name}</td><td>${m.details.parameter_size}</td><td>${m.details.quantization_level}</td><td>${m.context_length}</td><td>${m.digest}</td></tr>`,
|
||||
)
|
||||
.map(m => {
|
||||
const existingRow = psRows.get(m.name);
|
||||
const tokenValue = existingRow
|
||||
? existingRow.querySelector(".token-usage")?.textContent ?? 0
|
||||
: 0;
|
||||
return `<tr data-model="${m.name}">
|
||||
<td class="model">${m.name}</td>
|
||||
<td>${m.details.parameter_size}</td>
|
||||
<td>${m.details.quantization_level}</td>
|
||||
<td>${m.context_length}</td>
|
||||
<td>${m.digest}</td>
|
||||
<td class="token-usage">${tokenValue}</td>
|
||||
</tr>`;
|
||||
})
|
||||
.join("");
|
||||
psRows.clear();
|
||||
document
|
||||
.querySelectorAll("#ps-body tr[data-model]")
|
||||
.forEach((row) => {
|
||||
const model = row.dataset.model;
|
||||
if (model) psRows.set(model, row);
|
||||
});
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
|
|
@ -502,6 +521,26 @@
|
|||
try {
|
||||
const payload = JSON.parse(e.data); // SSE sends plain text
|
||||
renderChart(payload);
|
||||
const usage = payload.usage_counts || {};
|
||||
const tokens = payload.token_usage_counts || {};
|
||||
|
||||
psRows.forEach((row, model) => {
|
||||
/* regular usage count – optional if you want to keep it */
|
||||
let total = 0;
|
||||
for (const ep in usage) {
|
||||
total += usage[ep][model] || 0;
|
||||
}
|
||||
const usageCell = row.querySelector(".usage");
|
||||
if (usageCell) usageCell.textContent = total;
|
||||
|
||||
/* token usage */
|
||||
let tokenTotal = 0;
|
||||
for (const ep in tokens) {
|
||||
tokenTotal += tokens[ep][model] || 0;
|
||||
}
|
||||
const tokenCell = row.querySelector(".token-usage");
|
||||
if (tokenCell) tokenCell.textContent = tokenTotal;
|
||||
});
|
||||
} catch (err) {
|
||||
console.error("Failed to parse SSE payload", err);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue