Adding OpenAI compatibility
New Endpoints New Requirements
This commit is contained in:
parent
cdb4485334
commit
516ec8b102
2 changed files with 203 additions and 6 deletions
|
|
@ -2,13 +2,16 @@ annotated-types==0.7.0
|
||||||
anyio==4.10.0
|
anyio==4.10.0
|
||||||
certifi==2025.8.3
|
certifi==2025.8.3
|
||||||
click==8.2.1
|
click==8.2.1
|
||||||
|
distro==1.9.0
|
||||||
exceptiongroup==1.3.0
|
exceptiongroup==1.3.0
|
||||||
fastapi==0.116.1
|
fastapi==0.116.1
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
httpcore==1.0.9
|
httpcore==1.0.9
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
idna==3.10
|
idna==3.10
|
||||||
|
jiter==0.10.0
|
||||||
ollama==0.5.3
|
ollama==0.5.3
|
||||||
|
openai==1.102.0
|
||||||
pydantic==2.11.7
|
pydantic==2.11.7
|
||||||
pydantic-settings==2.10.1
|
pydantic-settings==2.10.1
|
||||||
pydantic_core==2.33.2
|
pydantic_core==2.33.2
|
||||||
|
|
@ -16,6 +19,7 @@ python-dotenv==1.1.1
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
starlette==0.47.2
|
starlette==0.47.2
|
||||||
|
tqdm==4.67.1
|
||||||
typing-inspection==0.4.1
|
typing-inspection==0.4.1
|
||||||
typing_extensions==4.14.1
|
typing_extensions==4.14.1
|
||||||
uvicorn==0.35.0
|
uvicorn==0.35.0
|
||||||
|
|
|
||||||
203
router.py
203
router.py
|
|
@ -6,7 +6,7 @@ version: 0.1
|
||||||
license: AGPL
|
license: AGPL
|
||||||
"""
|
"""
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
import json, random, asyncio, yaml, httpx, ollama
|
import json, random, asyncio, yaml, httpx, ollama, openai
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Set, List
|
from typing import Dict, Set, List
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
|
|
@ -314,7 +314,7 @@ async def chat_proxy(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 8. API route – Embedding
|
# 8. API route – Embedding - deprecated
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
@app.post("/api/embeddings")
|
@app.post("/api/embeddings")
|
||||||
async def embedding_proxy(request: Request):
|
async def embedding_proxy(request: Request):
|
||||||
|
|
@ -742,12 +742,205 @@ async def ps_proxy(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 18. OpenAI API compatible endpoints #ToDo
|
# 18. API route – OpenAI compatible Embedding
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
async def openai_embedding_proxy(request: Request):
|
||||||
|
"""
|
||||||
|
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 1. Parse and validate request
|
||||||
|
try:
|
||||||
|
body_bytes = await request.body()
|
||||||
|
payload = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
model = payload.get("model")
|
||||||
|
input = payload.get("input")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field 'model'"
|
||||||
|
)
|
||||||
|
if not input:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field 'input'"
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
|
# 2. Endpoint logic
|
||||||
|
endpoint = await choose_endpoint(model)
|
||||||
|
await increment_usage(endpoint, model)
|
||||||
|
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||||
|
|
||||||
|
# 3. Async generator that streams embedding data and decrements the counter
|
||||||
|
async_gen = await oclient.embeddings.create(input = [input], model=model)
|
||||||
|
|
||||||
|
await decrement_usage(endpoint, model)
|
||||||
|
|
||||||
|
# 5. Return a StreamingResponse backed by the generator
|
||||||
|
return async_gen
|
||||||
|
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
# 19. API route – OpenAI compatible Chat Completions
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
|
async def openai_chat_completions_proxy(request: Request):
|
||||||
|
"""
|
||||||
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 1. Parse and validate request
|
||||||
|
try:
|
||||||
|
body_bytes = await request.body()
|
||||||
|
payload = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
model = payload.get("model")
|
||||||
|
messages = payload.get("messages")
|
||||||
|
frequency_penalty = payload.get("frequency_penalty")
|
||||||
|
presence_penalty = payload.get("presence_penalty")
|
||||||
|
response_format = payload.get("response_format")
|
||||||
|
seed = payload.get("seed")
|
||||||
|
stop = payload.get("stop")
|
||||||
|
stream = payload.get("stream")
|
||||||
|
stream_options = payload.get("stream_options")
|
||||||
|
temperature = payload.get("temperature")
|
||||||
|
top_p = payload.get("top_p")
|
||||||
|
max_tokens = payload.get("max_tokens")
|
||||||
|
tools =payload.get("tools")
|
||||||
|
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field 'model'"
|
||||||
|
)
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
|
# 2. Endpoint logic
|
||||||
|
endpoint = await choose_endpoint(model)
|
||||||
|
await increment_usage(endpoint, model)
|
||||||
|
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||||
|
|
||||||
|
# 3. Async generator that streams completions data and decrements the counter
|
||||||
|
async def stream_ochat_response():
|
||||||
|
try:
|
||||||
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||||
|
async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools)
|
||||||
|
if stream == True:
|
||||||
|
async for chunk in async_gen:
|
||||||
|
data = (
|
||||||
|
chunk.model_dump_json()
|
||||||
|
if hasattr(chunk, "model_dump_json")
|
||||||
|
else json.dumps(chunk)
|
||||||
|
)
|
||||||
|
yield f"data: {data}\n\n".encode("utf-8")
|
||||||
|
# Final DONE event
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
else:
|
||||||
|
json_line = (
|
||||||
|
async_gen.model_dump_json()
|
||||||
|
if hasattr(async_gen, "model_dump_json")
|
||||||
|
else json.dumps(async_gen)
|
||||||
|
)
|
||||||
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure counter is decremented even if an exception occurs
|
||||||
|
await decrement_usage(endpoint, model)
|
||||||
|
|
||||||
|
# 4. Return a StreamingResponse backed by the generator
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_ochat_response(),
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
# 20. API route – OpenAI compatible Completions
|
||||||
|
# -------------------------------------------------------------
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
|
async def openai_completions_proxy(request: Request):
|
||||||
|
"""
|
||||||
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 1. Parse and validate request
|
||||||
|
try:
|
||||||
|
body_bytes = await request.body()
|
||||||
|
payload = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
model = payload.get("model")
|
||||||
|
prompt = payload.get("prompt")
|
||||||
|
frequency_penalty = payload.get("frequency_penalty")
|
||||||
|
presence_penalty = payload.get("presence_penalty")
|
||||||
|
seed = payload.get("seed")
|
||||||
|
stop = payload.get("stop")
|
||||||
|
stream = payload.get("stream")
|
||||||
|
stream_options = payload.get("stream_options")
|
||||||
|
temperature = payload.get("temperature")
|
||||||
|
top_p = payload.get("top_p")
|
||||||
|
max_tokens = payload.get("max_tokens")
|
||||||
|
suffix =payload.get("suffix")
|
||||||
|
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field 'model'"
|
||||||
|
)
|
||||||
|
if not prompt:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field 'prompt'"
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
|
# 2. Endpoint logic
|
||||||
|
endpoint = await choose_endpoint(model)
|
||||||
|
await increment_usage(endpoint, model)
|
||||||
|
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||||
|
|
||||||
|
# 3. Async generator that streams completions data and decrements the counter
|
||||||
|
async def stream_ocompletions_response():
|
||||||
|
try:
|
||||||
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||||
|
async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix)
|
||||||
|
if stream == True:
|
||||||
|
async for chunk in async_gen:
|
||||||
|
data = (
|
||||||
|
chunk.model_dump_json()
|
||||||
|
if hasattr(chunk, "model_dump_json")
|
||||||
|
else json.dumps(chunk)
|
||||||
|
)
|
||||||
|
yield f"data: {data}\n\n".encode("utf-8")
|
||||||
|
# Final DONE event
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
else:
|
||||||
|
json_line = (
|
||||||
|
async_gen.model_dump_json()
|
||||||
|
if hasattr(async_gen, "model_dump_json")
|
||||||
|
else json.dumps(async_gen)
|
||||||
|
)
|
||||||
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure counter is decremented even if an exception occurs
|
||||||
|
await decrement_usage(endpoint, model)
|
||||||
|
|
||||||
|
# 4. Return a StreamingResponse backed by the generator
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_ocompletions_response(),
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
# 21. OpenAI API compatible endpoints #ToDo
|
||||||
|
# -------------------------------------------------------------
|
||||||
@app.post("/v1/models")
|
@app.post("/v1/models")
|
||||||
@app.post("/v1/embeddings")
|
|
||||||
async def not_implemented_yet(request: Request):
|
async def not_implemented_yet(request: Request):
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
|
|
@ -755,7 +948,7 @@ async def not_implemented_yet(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 19. FastAPI startup event – load configuration
|
# 22. FastAPI startup event – load configuration
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event() -> None:
|
async def startup_event() -> None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue