diff --git a/router.py b/router.py index d7a4cbd..51e0137 100644 --- a/router.py +++ b/router.py @@ -302,7 +302,25 @@ class rechunk: "response": chunk.choices[0].text } return rechunk + + def openai_embeddings2ollama(chunk: dict): + rechunk = {"embedding": chunk.data[0].embedding} + return rechunk + def openai_embed2ollama(chunk: dict, model: str): + rechunk = { "model": model, + "created_at": iso8601_ns(), + "done": None, + "done_reason": None, + "total_duration": None, + "load_duration": None, + "prompt_eval_count": None, + "prompt_eval_duration": None, + "eval_count": None, + "eval_duration": None, + "embeddings": [chunk.data[0].embedding] + } + return rechunk # ------------------------------------------------------------------ # SSE Helpser # ------------------------------------------------------------------ @@ -639,13 +657,21 @@ async def embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - client = ollama.AsyncClient(host=endpoint) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + async_gen = await client.embeddings.create(input=[prompt], model=model) + async_gen = rechunk.openai_embeddings2ollama(async_gen) + else: + async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: @@ -676,7 +702,7 @@ async def embed_proxy(request: Request): payload = json.loads(body_bytes.decode("utf-8")) model = payload.get("model") - input = payload.get("input") + _input = payload.get("input") truncate = payload.get("truncate") options = payload.get("options") keep_alive = payload.get("keep_alive") @@ -685,7 +711,7 @@ async def embed_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - if not input: + if not _input: raise HTTPException( status_code=400, detail="Missing required field 'input'" ) @@ -695,13 +721,21 @@ async def embed_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - client = ollama.AsyncClient(host=endpoint) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + async_gen = await client.embeddings.create(input=[_input], model=model) + async_gen = rechunk.openai_embed2ollama(async_gen, model) + else: + async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive) if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: