diff --git a/router.py b/router.py index e38d98b..5169ce2 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.3 license: AGPL """ # ------------------------------------------------------------- -import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random +import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64 from pathlib import Path from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException @@ -274,6 +274,38 @@ def iso8601_ns(): ) return iso8601_with_ns +def is_base64(image_string): + try: + if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode(): + return True + except Exception as e: + return False + +def transform_images_to_data_urls(message_list): + for message in message_list: + if "images" in message: + images = message.pop("images") + if not isinstance(images, list): + continue + new_content = [] + for image in images: #TODO: quality downsize if images are too big to fit into model context window size + if not is_base64(image): + raise ValueError(f"Image string is not a valid base64 encoded string.") + data_url = f"data:image/jpeg;base64,{image}" + #new_content.append({ + # "type": "text", + # "text": "" + #}) + new_content.append({ + "type": "image_url", + "image_url": { + "url": data_url + } + }) + message["content"] = new_content + + return message_list + class rechunk: def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse: if chunk.choices == [] and chunk.usage is not None: @@ -608,7 +640,7 @@ async def chat_proxy(request: Request): _format = payload.get("format") keep_alive = payload.get("keep_alive") options = payload.get("options") - + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -631,6 +663,8 @@ async def chat_proxy(request: Request): if ":latest" in model: model = model.split(":latest") model = model[0] + if messages: + messages = transform_images_to_data_urls(messages) params = { "messages": messages, "model": model, @@ -1314,7 +1348,6 @@ async def openai_chat_completions_proxy(request: Request): await increment_usage(endpoint, model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint]) - # 3. Async generator that streams completions data and decrements the counter async def stream_ochat_response(): try: @@ -1329,8 +1362,7 @@ async def openai_chat_completions_proxy(request: Request): ) if chunk.choices[0].delta.content is not None: yield f"data: {data}\n\n".encode("utf-8") - # Final DONE event - #yield b"data: [DONE]\n\n" + yield b"data: [DONE]\n\n" else: json_line = ( async_gen.model_dump_json()