adding ollama embeddings conversion calls to openai endpoint
This commit is contained in:
parent
bd21906687
commit
6c9ffad834
1 changed files with 40 additions and 6 deletions
46
router.py
46
router.py
|
|
@ -302,7 +302,25 @@ class rechunk:
|
||||||
"response": chunk.choices[0].text
|
"response": chunk.choices[0].text
|
||||||
}
|
}
|
||||||
return rechunk
|
return rechunk
|
||||||
|
|
||||||
|
def openai_embeddings2ollama(chunk: dict):
|
||||||
|
rechunk = {"embedding": chunk.data[0].embedding}
|
||||||
|
return rechunk
|
||||||
|
|
||||||
|
def openai_embed2ollama(chunk: dict, model: str):
|
||||||
|
rechunk = { "model": model,
|
||||||
|
"created_at": iso8601_ns(),
|
||||||
|
"done": None,
|
||||||
|
"done_reason": None,
|
||||||
|
"total_duration": None,
|
||||||
|
"load_duration": None,
|
||||||
|
"prompt_eval_count": None,
|
||||||
|
"prompt_eval_duration": None,
|
||||||
|
"eval_count": None,
|
||||||
|
"eval_duration": None,
|
||||||
|
"embeddings": [chunk.data[0].embedding]
|
||||||
|
}
|
||||||
|
return rechunk
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# SSE Helpser
|
# SSE Helpser
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -639,13 +657,21 @@ async def embedding_proxy(request: Request):
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
|
if is_openai_endpoint:
|
||||||
|
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||||
|
else:
|
||||||
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
|
||||||
# 3. Async generator that streams embedding data and decrements the counter
|
# 3. Async generator that streams embedding data and decrements the counter
|
||||||
async def stream_embedding_response():
|
async def stream_embedding_response():
|
||||||
try:
|
try:
|
||||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
if is_openai_endpoint:
|
||||||
|
async_gen = await client.embeddings.create(input=[prompt], model=model)
|
||||||
|
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
||||||
|
else:
|
||||||
|
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||||
if hasattr(async_gen, "model_dump_json"):
|
if hasattr(async_gen, "model_dump_json"):
|
||||||
json_line = async_gen.model_dump_json()
|
json_line = async_gen.model_dump_json()
|
||||||
else:
|
else:
|
||||||
|
|
@ -676,7 +702,7 @@ async def embed_proxy(request: Request):
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
input = payload.get("input")
|
_input = payload.get("input")
|
||||||
truncate = payload.get("truncate")
|
truncate = payload.get("truncate")
|
||||||
options = payload.get("options")
|
options = payload.get("options")
|
||||||
keep_alive = payload.get("keep_alive")
|
keep_alive = payload.get("keep_alive")
|
||||||
|
|
@ -685,7 +711,7 @@ async def embed_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'model'"
|
status_code=400, detail="Missing required field 'model'"
|
||||||
)
|
)
|
||||||
if not input:
|
if not _input:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'input'"
|
status_code=400, detail="Missing required field 'input'"
|
||||||
)
|
)
|
||||||
|
|
@ -695,13 +721,21 @@ async def embed_proxy(request: Request):
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
|
if is_openai_endpoint:
|
||||||
|
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||||
|
else:
|
||||||
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
|
||||||
# 3. Async generator that streams embed data and decrements the counter
|
# 3. Async generator that streams embed data and decrements the counter
|
||||||
async def stream_embedding_response():
|
async def stream_embedding_response():
|
||||||
try:
|
try:
|
||||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||||
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive)
|
if is_openai_endpoint:
|
||||||
|
async_gen = await client.embeddings.create(input=[_input], model=model)
|
||||||
|
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
||||||
|
else:
|
||||||
|
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||||
if hasattr(async_gen, "model_dump_json"):
|
if hasattr(async_gen, "model_dump_json"):
|
||||||
json_line = async_gen.model_dump_json()
|
json_line = async_gen.model_dump_json()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue