adding fetch class and ollama client completions on openai endpoints
This commit is contained in:
parent
0a7fd8ca52
commit
9ea852f154
1 changed files with 151 additions and 120 deletions
271
router.py
271
router.py
|
|
@ -119,111 +119,112 @@ async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
|||
text = await resp.text()
|
||||
raise HTTPException(status_code=resp.status, detail=text)
|
||||
|
||||
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
class fetch:
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
else:
|
||||
# stale entry – drop it
|
||||
del _models_cache[endpoint]
|
||||
|
||||
if endpoint in _error_cache:
|
||||
if _is_fresh(_error_cache[endpoint], 1):
|
||||
# Still within the short error TTL – pretend nothing is available
|
||||
return set()
|
||||
else:
|
||||
# Error expired – remove it
|
||||
del _error_cache[endpoint]
|
||||
|
||||
if "/v1" in endpoint:
|
||||
endpoint_url = f"{endpoint}/models"
|
||||
key = "data"
|
||||
else:
|
||||
endpoint_url = f"{endpoint}/api/tags"
|
||||
key = "models"
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
print(f"[fetch_available_models] {endpoint} error: {e}")
|
||||
_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
# stale entry – drop it
|
||||
del _models_cache[endpoint]
|
||||
|
||||
if endpoint in _error_cache:
|
||||
if _is_fresh(_error_cache[endpoint], 1):
|
||||
# Still within the short error TTL – pretend nothing is available
|
||||
return set()
|
||||
else:
|
||||
# Error expired – remove it
|
||||
del _error_cache[endpoint]
|
||||
|
||||
if "/v1" in endpoint:
|
||||
endpoint_url = f"{endpoint}/models"
|
||||
key = "data"
|
||||
else:
|
||||
endpoint_url = f"{endpoint}/api/tags"
|
||||
key = "models"
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
print(f"[fetch.available_models] {endpoint} error: {e}")
|
||||
_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
try:
|
||||
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return []
|
||||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
try:
|
||||
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return []
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
base_url = ep
|
||||
else:
|
||||
base_url = ep+"/v1"
|
||||
return base_url
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
base_url = ep
|
||||
else:
|
||||
base_url = ep+"/v1"
|
||||
return base_url
|
||||
|
||||
def dedupe_on_keys(dicts, key_fields):
|
||||
"""
|
||||
|
|
@ -288,6 +289,19 @@ class rechunk:
|
|||
else:
|
||||
rechunk["message"] = {"role": chunk.choices[0].message.role, "content": chunk.choices[0].message.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None}
|
||||
return rechunk
|
||||
|
||||
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float):
|
||||
rechunk = { "model": chunk.model,
|
||||
"created_at": iso8601_ns(),
|
||||
"load_duration": None,
|
||||
"done_reason": chunk.choices[0].finish_reason,
|
||||
"total_duration": None,
|
||||
"eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None),
|
||||
"thinking": None,
|
||||
"context": None,
|
||||
"response": chunk.choices[0].text
|
||||
}
|
||||
return rechunk
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
|
|
@ -350,8 +364,8 @@ async def choose_endpoint(model: str) -> str:
|
|||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
|
||||
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if "/v1" not in ep]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
|
|
@ -369,7 +383,7 @@ async def choose_endpoint(model: str) -> str:
|
|||
|
||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||
# (concurrently, but only for the filtered list)
|
||||
load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints]
|
||||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
async with usage_lock:
|
||||
|
|
@ -409,7 +423,6 @@ async def proxy(request: Request):
|
|||
"""
|
||||
Proxy a generate request to Ollama and stream the response back to the client.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
|
@ -439,29 +452,50 @@ async def proxy(request: Request):
|
|||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Decide which endpoint to use
|
||||
|
||||
endpoint = await choose_endpoint(model)
|
||||
|
||||
# Increment usage counter for this endpoint‑model pair
|
||||
await increment_usage(endpoint, model)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
# 3. Create Ollama client instance
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
optional_params = {
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
async def stream_generate_response():
|
||||
try:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if is_openai_endpoint:
|
||||
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
||||
response = json.dumps(response)
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
|
|
@ -513,7 +547,8 @@ async def chat_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
if "/v1" in endpoint:
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
|
|
@ -530,7 +565,6 @@ async def chat_proxy(request: Request):
|
|||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
|
|
@ -560,7 +594,6 @@ async def chat_proxy(request: Request):
|
|||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
print(json_line)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
finally:
|
||||
|
|
@ -941,7 +974,7 @@ async def version_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for version
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
all_versions = await asyncio.gather(*tasks)
|
||||
|
||||
def version_key(v):
|
||||
|
|
@ -964,8 +997,8 @@ async def tags_proxy(request: Request):
|
|||
"""
|
||||
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -988,7 +1021,7 @@ async def ps_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for running models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
loaded_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -1300,8 +1333,8 @@ async def openai_models_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'data': []}
|
||||
|
|
@ -1351,9 +1384,7 @@ async def health_proxy(request: Request):
|
|||
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
|
||||
"""
|
||||
# Run all health checks in parallel
|
||||
tasks = [
|
||||
fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints
|
||||
]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue