adding ollama embeddings conversion calls to openai endpoint

This commit is contained in:
Alpha Nerd 2025-09-15 11:47:55 +02:00
parent bd21906687
commit 6c9ffad834

View file

@ -302,7 +302,25 @@ class rechunk:
"response": chunk.choices[0].text "response": chunk.choices[0].text
} }
return rechunk return rechunk
def openai_embeddings2ollama(chunk: dict):
rechunk = {"embedding": chunk.data[0].embedding}
return rechunk
def openai_embed2ollama(chunk: dict, model: str):
rechunk = { "model": model,
"created_at": iso8601_ns(),
"done": None,
"done_reason": None,
"total_duration": None,
"load_duration": None,
"prompt_eval_count": None,
"prompt_eval_duration": None,
"eval_count": None,
"eval_duration": None,
"embeddings": [chunk.data[0].embedding]
}
return rechunk
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# SSE Helpser # SSE Helpser
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -639,13 +657,21 @@ async def embedding_proxy(request: Request):
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model) await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint) is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams embedding data and decrements the counter # 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response(): async def stream_embedding_response():
try: try:
# The chat method returns a generator of dicts (or GenerateResponse) # The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) if is_openai_endpoint:
async_gen = await client.embeddings.create(input=[prompt], model=model)
async_gen = rechunk.openai_embeddings2ollama(async_gen)
else:
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
if hasattr(async_gen, "model_dump_json"): if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json() json_line = async_gen.model_dump_json()
else: else:
@ -676,7 +702,7 @@ async def embed_proxy(request: Request):
payload = json.loads(body_bytes.decode("utf-8")) payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
input = payload.get("input") _input = payload.get("input")
truncate = payload.get("truncate") truncate = payload.get("truncate")
options = payload.get("options") options = payload.get("options")
keep_alive = payload.get("keep_alive") keep_alive = payload.get("keep_alive")
@ -685,7 +711,7 @@ async def embed_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'model'" status_code=400, detail="Missing required field 'model'"
) )
if not input: if not _input:
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'input'" status_code=400, detail="Missing required field 'input'"
) )
@ -695,13 +721,21 @@ async def embed_proxy(request: Request):
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model) await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint) is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams embed data and decrements the counter # 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response(): async def stream_embedding_response():
try: try:
# The chat method returns a generator of dicts (or GenerateResponse) # The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive) if is_openai_endpoint:
async_gen = await client.embeddings.create(input=[_input], model=model)
async_gen = rechunk.openai_embed2ollama(async_gen, model)
else:
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
if hasattr(async_gen, "model_dump_json"): if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json() json_line = async_gen.model_dump_json()
else: else: