proposal: use global truststore ctx for all connections
This commit is contained in:
parent
dd30ab9422
commit
b649dcd8d6
1 changed files with 38 additions and 35 deletions
73
router.py
73
router.py
|
|
@ -6,8 +6,11 @@ version: 0.6
|
||||||
license: AGPL
|
license: AGPL
|
||||||
"""
|
"""
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, truststore
|
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets
|
||||||
from openai import DefaultAsyncHttpxClient
|
try:
|
||||||
|
import truststore; truststore.inject_into_ssl()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -747,7 +750,7 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
||||||
Handles endpoint selection, client creation, usage tracking, and request execution.
|
Handles endpoint selection, client creation, usage tracking, and request execution.
|
||||||
"""
|
"""
|
||||||
is_openai_endpoint = "/v1" in endpoint
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")[0]
|
model = model.split(":latest")[0]
|
||||||
|
|
@ -771,9 +774,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
||||||
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
|
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
|
||||||
}
|
}
|
||||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||||
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
|
|
||||||
|
|
@ -1264,7 +1267,7 @@ async def proxy(request: Request):
|
||||||
|
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
is_openai_endpoint = "/v1" in endpoint
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
|
|
@ -1286,9 +1289,9 @@ async def proxy(request: Request):
|
||||||
"suffix": suffix,
|
"suffix": suffix,
|
||||||
}
|
}
|
||||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||||
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
|
|
||||||
# 4. Async generator that streams data and decrements the counter
|
# 4. Async generator that streams data and decrements the counter
|
||||||
|
|
@ -1384,7 +1387,7 @@ async def chat_proxy(request: Request):
|
||||||
opt = False
|
opt = False
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
is_openai_endpoint = "/v1" in endpoint
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
|
|
@ -1409,9 +1412,9 @@ async def chat_proxy(request: Request):
|
||||||
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
|
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
|
||||||
}
|
}
|
||||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||||
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
# 3. Async generator that streams chat data and decrements the counter
|
# 3. Async generator that streams chat data and decrements the counter
|
||||||
async def stream_chat_response():
|
async def stream_chat_response():
|
||||||
|
|
@ -1501,14 +1504,14 @@ async def embedding_proxy(request: Request):
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
is_openai_endpoint = "/v1" in endpoint
|
is_openai_endpoint = "/v1" in endpoint
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
model = model[0]
|
model = model[0]
|
||||||
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
# 3. Async generator that streams embedding data and decrements the counter
|
# 3. Async generator that streams embedding data and decrements the counter
|
||||||
async def stream_embedding_response():
|
async def stream_embedding_response():
|
||||||
|
|
@ -1568,14 +1571,14 @@ async def embed_proxy(request: Request):
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
is_openai_endpoint = is_ext_openai_endpoint(endpoint) #"/v1" in endpoint
|
is_openai_endpoint = is_ext_openai_endpoint(endpoint) #"/v1" in endpoint
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
model = model[0]
|
model = model[0]
|
||||||
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
# 3. Async generator that streams embed data and decrements the counter
|
# 3. Async generator that streams embed data and decrements the counter
|
||||||
async def stream_embedding_response():
|
async def stream_embedding_response():
|
||||||
|
|
@ -1636,9 +1639,9 @@ async def create_proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
status_lists = []
|
status_lists = []
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
for endpoint in config.endpoints:
|
for endpoint in config.endpoints:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
||||||
status_lists.append(create)
|
status_lists.append(create)
|
||||||
|
|
||||||
|
|
@ -1676,8 +1679,8 @@ async def show_proxy(request: Request, model: Optional[str] = None):
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
#await increment_usage(endpoint, model)
|
#await increment_usage(endpoint, model)
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
|
||||||
# 3. Proxy a simple show request
|
# 3. Proxy a simple show request
|
||||||
show = await client.show(model=model)
|
show = await client.show(model=model)
|
||||||
|
|
@ -1810,10 +1813,10 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
|
||||||
|
|
||||||
# 3. Iterate over all endpoints to copy the model on each endpoint
|
# 3. Iterate over all endpoints to copy the model on each endpoint
|
||||||
status_list = []
|
status_list = []
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
for endpoint in config.endpoints:
|
for endpoint in config.endpoints:
|
||||||
if "/v1" not in endpoint:
|
if "/v1" not in endpoint:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
# 4. Proxy a simple copy request
|
# 4. Proxy a simple copy request
|
||||||
copy = await client.copy(source=src, destination=dst)
|
copy = await client.copy(source=src, destination=dst)
|
||||||
status_list.append(copy.status)
|
status_list.append(copy.status)
|
||||||
|
|
@ -1847,10 +1850,10 @@ async def delete_proxy(request: Request, model: Optional[str] = None):
|
||||||
|
|
||||||
# 2. Iterate over all endpoints to delete the model on each endpoint
|
# 2. Iterate over all endpoints to delete the model on each endpoint
|
||||||
status_list = []
|
status_list = []
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
for endpoint in config.endpoints:
|
for endpoint in config.endpoints:
|
||||||
if "/v1" not in endpoint:
|
if "/v1" not in endpoint:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
# 3. Proxy a simple copy request
|
# 3. Proxy a simple copy request
|
||||||
copy = await client.delete(model=model)
|
copy = await client.delete(model=model)
|
||||||
status_list.append(copy.status)
|
status_list.append(copy.status)
|
||||||
|
|
@ -1886,10 +1889,10 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
|
||||||
|
|
||||||
# 2. Iterate over all endpoints to pull the model
|
# 2. Iterate over all endpoints to pull the model
|
||||||
status_list = []
|
status_list = []
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
for endpoint in config.endpoints:
|
for endpoint in config.endpoints:
|
||||||
if "/v1" not in endpoint:
|
if "/v1" not in endpoint:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
# 3. Proxy a simple pull request
|
# 3. Proxy a simple pull request
|
||||||
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
||||||
status_list.append(pull)
|
status_list.append(pull)
|
||||||
|
|
@ -1928,9 +1931,9 @@ async def push_proxy(request: Request):
|
||||||
|
|
||||||
# 2. Iterate over all endpoints
|
# 2. Iterate over all endpoints
|
||||||
status_list = []
|
status_list = []
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
for endpoint in config.endpoints:
|
for endpoint in config.endpoints:
|
||||||
client = ollama.AsyncClient(host=endpoint, verify=ctx)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
# 3. Proxy a simple push request
|
# 3. Proxy a simple push request
|
||||||
push = await client.push(model=model, insecure=insecure, stream=False)
|
push = await client.push(model=model, insecure=insecure, stream=False)
|
||||||
status_list.append(push)
|
status_list.append(push)
|
||||||
|
|
@ -2131,8 +2134,8 @@ async def openai_embedding_proxy(request: Request):
|
||||||
else:
|
else:
|
||||||
api_key = "ollama"
|
api_key = "ollama"
|
||||||
base_url = ep2base(endpoint)
|
base_url = ep2base(endpoint)
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key, http_client=DefaultAsyncHttpxClient(verify=ctx))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
|
||||||
|
|
||||||
# 3. Async generator that streams embedding data and decrements the counter
|
# 3. Async generator that streams embedding data and decrements the counter
|
||||||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||||||
|
|
@ -2212,8 +2215,8 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
base_url = ep2base(endpoint)
|
base_url = ep2base(endpoint)
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint])
|
||||||
# 3. Async generator that streams completions data and decrements the counter
|
# 3. Async generator that streams completions data and decrements the counter
|
||||||
async def stream_ochat_response():
|
async def stream_ochat_response():
|
||||||
try:
|
try:
|
||||||
|
|
@ -2338,8 +2341,8 @@ async def openai_completions_proxy(request: Request):
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint = await choose_endpoint(model)
|
||||||
await increment_usage(endpoint, model)
|
await increment_usage(endpoint, model)
|
||||||
base_url = ep2base(endpoint)
|
base_url = ep2base(endpoint)
|
||||||
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint], http_client=DefaultAsyncHttpxClient(verify=ctx))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint])
|
||||||
|
|
||||||
# 3. Async generator that streams completions data and decrements the counter
|
# 3. Async generator that streams completions data and decrements the counter
|
||||||
async def stream_ocompletions_response(model=model):
|
async def stream_ocompletions_response(model=model):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue