adding ollama embeddings conversion calls to openai endpoint
This commit is contained in:
parent
bd21906687
commit
6c9ffad834
1 changed files with 40 additions and 6 deletions
46
router.py
46
router.py
|
|
@ -302,7 +302,25 @@ class rechunk:
|
|||
"response": chunk.choices[0].text
|
||||
}
|
||||
return rechunk
|
||||
|
||||
def openai_embeddings2ollama(chunk: dict):
|
||||
rechunk = {"embedding": chunk.data[0].embedding}
|
||||
return rechunk
|
||||
|
||||
def openai_embed2ollama(chunk: dict, model: str):
|
||||
rechunk = { "model": model,
|
||||
"created_at": iso8601_ns(),
|
||||
"done": None,
|
||||
"done_reason": None,
|
||||
"total_duration": None,
|
||||
"load_duration": None,
|
||||
"prompt_eval_count": None,
|
||||
"prompt_eval_duration": None,
|
||||
"eval_count": None,
|
||||
"eval_duration": None,
|
||||
"embeddings": [chunk.data[0].embedding]
|
||||
}
|
||||
return rechunk
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -639,13 +657,21 @@ async def embedding_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
async_gen = await client.embeddings.create(input=[prompt], model=model)
|
||||
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
||||
else:
|
||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
|
|
@ -676,7 +702,7 @@ async def embed_proxy(request: Request):
|
|||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
_input = payload.get("input")
|
||||
truncate = payload.get("truncate")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
|
@ -685,7 +711,7 @@ async def embed_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not input:
|
||||
if not _input:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
|
|
@ -695,13 +721,21 @@ async def embed_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
async_gen = await client.embeddings.create(input=[_input], model=model)
|
||||
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
||||
else:
|
||||
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue