Adding OpenAI compatibility
New Endpoints New Requirements
This commit is contained in:
parent
cdb4485334
commit
516ec8b102
2 changed files with 203 additions and 6 deletions
|
|
@ -2,13 +2,16 @@ annotated-types==0.7.0
|
|||
anyio==4.10.0
|
||||
certifi==2025.8.3
|
||||
click==8.2.1
|
||||
distro==1.9.0
|
||||
exceptiongroup==1.3.0
|
||||
fastapi==0.116.1
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
idna==3.10
|
||||
jiter==0.10.0
|
||||
ollama==0.5.3
|
||||
openai==1.102.0
|
||||
pydantic==2.11.7
|
||||
pydantic-settings==2.10.1
|
||||
pydantic_core==2.33.2
|
||||
|
|
@ -16,6 +19,7 @@ python-dotenv==1.1.1
|
|||
PyYAML==6.0.2
|
||||
sniffio==1.3.1
|
||||
starlette==0.47.2
|
||||
tqdm==4.67.1
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.14.1
|
||||
uvicorn==0.35.0
|
||||
|
|
|
|||
205
router.py
205
router.py
|
|
@ -6,7 +6,7 @@ version: 0.1
|
|||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, random, asyncio, yaml, httpx, ollama
|
||||
import json, random, asyncio, yaml, httpx, ollama, openai
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
|
|
@ -314,7 +314,7 @@ async def chat_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 8. API route – Embedding
|
||||
# 8. API route – Embedding - deprecated
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/embeddings")
|
||||
async def embedding_proxy(request: Request):
|
||||
|
|
@ -742,12 +742,205 @@ async def ps_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 18. OpenAI API compatible endpoints #ToDo
|
||||
# 18. API route – OpenAI compatible Embedding
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/embeddings")
|
||||
async def openai_embedding_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not input:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async_gen = await oclient.embeddings.create(input = [input], model=model)
|
||||
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return async_gen
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 19. API route – OpenAI compatible Chat Completions
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
messages = payload.get("messages")
|
||||
frequency_penalty = payload.get("frequency_penalty")
|
||||
presence_penalty = payload.get("presence_penalty")
|
||||
response_format = payload.get("response_format")
|
||||
seed = payload.get("seed")
|
||||
stop = payload.get("stop")
|
||||
stream = payload.get("stream")
|
||||
stream_options = payload.get("stream_options")
|
||||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
tools =payload.get("tools")
|
||||
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ochat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else json.dumps(chunk)
|
||||
)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ochat_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 20. API route – OpenAI compatible Completions
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
frequency_penalty = payload.get("frequency_penalty")
|
||||
presence_penalty = payload.get("presence_penalty")
|
||||
seed = payload.get("seed")
|
||||
stop = payload.get("stop")
|
||||
stream = payload.get("stream")
|
||||
stream_options = payload.get("stream_options")
|
||||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
suffix =payload.get("suffix")
|
||||
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
async def stream_ocompletions_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else json.dumps(chunk)
|
||||
)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ocompletions_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 21. OpenAI API compatible endpoints #ToDo
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/models")
|
||||
@app.post("/v1/embeddings")
|
||||
async def not_implemented_yet(request: Request):
|
||||
|
||||
return Response(
|
||||
|
|
@ -755,7 +948,7 @@ async def not_implemented_yet(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 19. FastAPI startup event – load configuration
|
||||
# 22. FastAPI startup event – load configuration
|
||||
# -------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
|
|
@ -763,4 +956,4 @@ async def startup_event() -> None:
|
|||
# Load YAML config (or use defaults if not present)
|
||||
config = Config.from_yaml(Path("config.yaml"))
|
||||
print(f"Loaded configuration:\n endpoints={config.endpoints},\n "
|
||||
f"max_concurrent_connections={config.max_concurrent_connections}")
|
||||
f"max_concurrent_connections={config.max_concurrent_connections}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue