comliance with ollama naming conventions and openai model['id']
This commit is contained in:
parent
733e215be2
commit
4b5834d7df
1 changed files with 21 additions and 7 deletions
28
router.py
28
router.py
|
|
@ -394,10 +394,19 @@ async def choose_endpoint(model: str) -> str:
|
||||||
|
|
||||||
# 6️⃣
|
# 6️⃣
|
||||||
if not candidate_endpoints:
|
if not candidate_endpoints:
|
||||||
raise RuntimeError(
|
if ":latest" in model: #ollama naming convention not applicable to openai
|
||||||
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
|
model = model.split(":")
|
||||||
f"advertise the model '{model}'."
|
model = model[0]
|
||||||
)
|
print(model)
|
||||||
|
candidate_endpoints = [
|
||||||
|
ep for ep, models in zip(config.endpoints, advertised_sets)
|
||||||
|
if model in models
|
||||||
|
]
|
||||||
|
if not candidate_endpoints:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
|
||||||
|
f"advertise the model '{model}'."
|
||||||
|
)
|
||||||
|
|
||||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||||
# (concurrently, but only for the filtered list)
|
# (concurrently, but only for the filtered list)
|
||||||
|
|
@ -472,9 +481,11 @@ async def proxy(request: Request):
|
||||||
|
|
||||||
|
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
await increment_usage(endpoint, model)
|
|
||||||
is_openai_endpoint = "/v1" in endpoint
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
|
if ":latest" in model:
|
||||||
|
model = model.split(":")
|
||||||
|
model = model[0]
|
||||||
params = {
|
params = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
@ -488,6 +499,7 @@ async def proxy(request: Request):
|
||||||
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
await increment_usage(endpoint, model)
|
||||||
|
|
||||||
# 4. Async generator that streams data and decrements the counter
|
# 4. Async generator that streams data and decrements the counter
|
||||||
async def stream_generate_response():
|
async def stream_generate_response():
|
||||||
|
|
@ -564,9 +576,11 @@ async def chat_proxy(request: Request):
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
await increment_usage(endpoint, model)
|
|
||||||
is_openai_endpoint = "/v1" in endpoint
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
|
if ":latest" in model:
|
||||||
|
model = model.split(":")
|
||||||
|
model = model[0]
|
||||||
params = {
|
params = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
@ -581,7 +595,7 @@ async def chat_proxy(request: Request):
|
||||||
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
await increment_usage(endpoint, model)
|
||||||
# 3. Async generator that streams chat data and decrements the counter
|
# 3. Async generator that streams chat data and decrements the counter
|
||||||
async def stream_chat_response():
|
async def stream_chat_response():
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue