starting an openai2ollama client translation layer with rechunking class
This commit is contained in:
parent
25b287eba6
commit
6381dd09c3
1 changed files with 50 additions and 4 deletions
54
router.py
54
router.py
|
|
@ -6,7 +6,7 @@ version: 0.3
|
|||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl
|
||||
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
|
|
@ -257,6 +257,31 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# usage_counts.pop(endpoint, None)
|
||||
await publish_snapshot()
|
||||
|
||||
def iso8601_ns():
|
||||
ns_since_epoch = time.time_ns()
|
||||
dt = datetime.datetime.fromtimestamp(
|
||||
ns_since_epoch / 1_000_000_000, # seconds
|
||||
tz=datetime.timezone.utc
|
||||
)
|
||||
iso8601_with_ns = (
|
||||
dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z"
|
||||
)
|
||||
return iso8601_with_ns
|
||||
|
||||
class rechunk:
|
||||
def openai_chat_completion2ollama(chunk):
|
||||
chunk = { "model": chunk.model,
|
||||
"created_at": iso8601_ns() ,
|
||||
"done_reason": chunk.choices[0].finish_reason,
|
||||
"load_duration": None,
|
||||
"prompt_eval_count": None,
|
||||
"prompt_eval_duration": None,
|
||||
"eval_count": None,
|
||||
"eval_duration": None,
|
||||
"message": {"role": chunk.choices[0].delta.role, "content": chunk.choices[0].delta.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None},
|
||||
}
|
||||
return chunk
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -473,7 +498,7 @@ async def chat_proxy(request: Request):
|
|||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing or invalid 'message' field (must be a list)"
|
||||
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
|
@ -481,20 +506,41 @@ async def chat_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
if "/v1" in endpoint:
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if "/v1" in endpoint:
|
||||
async_gen = await oclient.chat.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if "/v1" in endpoint:
|
||||
print(chunk)
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
print(json_line)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
json_line = (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue