mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 08:46:22 +02:00
278 lines
9.1 KiB
Python
278 lines
9.1 KiB
Python
"""
|
|
Image Generation Router Service for Load Balancing
|
|
|
|
This module provides a singleton LiteLLM Router for automatic load balancing
|
|
across multiple image generation deployments. It uses litellm.Router which
|
|
natively supports aimage_generation() for async image generation.
|
|
|
|
The router handles:
|
|
- Rate limit management with automatic cooldowns
|
|
- Automatic failover and retries
|
|
- Usage-based routing to distribute load evenly
|
|
|
|
Supported providers: OpenAI, Azure, Google AI Studio, Vertex AI,
|
|
AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from litellm import Router
|
|
from litellm.utils import ImageResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Special ID for Auto mode - uses router for load balancing
|
|
IMAGE_GEN_AUTO_MODE_ID = 0
|
|
|
|
# Provider mapping for LiteLLM model string construction.
|
|
# Only includes providers that support image generation.
|
|
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
|
IMAGE_GEN_PROVIDER_MAP = {
|
|
"OPENAI": "openai",
|
|
"AZURE_OPENAI": "azure",
|
|
"GOOGLE": "gemini", # Google AI Studio
|
|
"VERTEX_AI": "vertex_ai",
|
|
"BEDROCK": "bedrock", # AWS Bedrock
|
|
"RECRAFT": "recraft",
|
|
"OPENROUTER": "openrouter",
|
|
"XINFERENCE": "xinference",
|
|
"NSCALE": "nscale",
|
|
}
|
|
|
|
|
|
class ImageGenRouterService:
|
|
"""
|
|
Singleton service for managing LiteLLM Router for image generation.
|
|
|
|
The router provides automatic load balancing, failover, and rate limit
|
|
handling across multiple image generation deployments.
|
|
Uses Router.aimage_generation() for async image generation calls.
|
|
"""
|
|
|
|
_instance = None
|
|
_router: Router | None = None
|
|
_model_list: list[dict] = []
|
|
_router_settings: dict = {}
|
|
_initialized: bool = False
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> "ImageGenRouterService":
|
|
"""Get the singleton instance of the router service."""
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def initialize(
|
|
cls,
|
|
global_configs: list[dict],
|
|
router_settings: dict | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the router with global image generation configurations.
|
|
|
|
Args:
|
|
global_configs: List of global image gen config dictionaries from YAML
|
|
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
|
"""
|
|
instance = cls.get_instance()
|
|
|
|
if instance._initialized:
|
|
logger.debug("Image Generation Router already initialized, skipping")
|
|
return
|
|
|
|
# Build model list from global configs
|
|
model_list = []
|
|
for config in global_configs:
|
|
deployment = cls._config_to_deployment(config)
|
|
if deployment:
|
|
model_list.append(deployment)
|
|
|
|
if not model_list:
|
|
logger.warning(
|
|
"No valid image generation configs found for router initialization"
|
|
)
|
|
return
|
|
|
|
instance._model_list = model_list
|
|
instance._router_settings = router_settings or {}
|
|
|
|
# Default router settings optimized for rate limit handling
|
|
default_settings = {
|
|
"routing_strategy": "usage-based-routing",
|
|
"num_retries": 3,
|
|
"allowed_fails": 3,
|
|
"cooldown_time": 60,
|
|
"retry_after": 5,
|
|
}
|
|
|
|
# Merge with provided settings
|
|
final_settings = {**default_settings, **instance._router_settings}
|
|
|
|
try:
|
|
instance._router = Router(
|
|
model_list=model_list,
|
|
routing_strategy=final_settings.get(
|
|
"routing_strategy", "usage-based-routing"
|
|
),
|
|
num_retries=final_settings.get("num_retries", 3),
|
|
allowed_fails=final_settings.get("allowed_fails", 3),
|
|
cooldown_time=final_settings.get("cooldown_time", 60),
|
|
set_verbose=False,
|
|
)
|
|
instance._initialized = True
|
|
logger.info(
|
|
f"Image Generation Router initialized with {len(model_list)} deployments, "
|
|
f"strategy: {final_settings.get('routing_strategy')}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Image Generation Router: {e}")
|
|
instance._router = None
|
|
|
|
@classmethod
|
|
def _config_to_deployment(cls, config: dict) -> dict | None:
|
|
"""
|
|
Convert a global image gen config to a router deployment entry.
|
|
|
|
Args:
|
|
config: Global image gen config dictionary
|
|
|
|
Returns:
|
|
Router deployment dictionary or None if invalid
|
|
"""
|
|
try:
|
|
# Skip if essential fields are missing
|
|
if not config.get("model_name") or not config.get("api_key"):
|
|
return None
|
|
|
|
# Build model string
|
|
if config.get("custom_provider"):
|
|
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
|
else:
|
|
provider = config.get("provider", "").upper()
|
|
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
|
|
|
# Build litellm params
|
|
litellm_params: dict[str, Any] = {
|
|
"model": model_string,
|
|
"api_key": config.get("api_key"),
|
|
}
|
|
|
|
# Add optional api_base
|
|
if config.get("api_base"):
|
|
litellm_params["api_base"] = config["api_base"]
|
|
|
|
# Add api_version (required for Azure)
|
|
if config.get("api_version"):
|
|
litellm_params["api_version"] = config["api_version"]
|
|
|
|
# Add any additional litellm parameters
|
|
if config.get("litellm_params"):
|
|
litellm_params.update(config["litellm_params"])
|
|
|
|
# All configs use same alias "auto" for unified routing
|
|
deployment: dict[str, Any] = {
|
|
"model_name": "auto",
|
|
"litellm_params": litellm_params,
|
|
}
|
|
|
|
# Add RPM rate limit from config if available
|
|
# Note: TPM (tokens per minute) is not applicable for image generation
|
|
# since image APIs are rate-limited by requests, not tokens.
|
|
if config.get("rpm"):
|
|
deployment["rpm"] = config["rpm"]
|
|
|
|
return deployment
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to convert image gen config to deployment: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def get_router(cls) -> Router | None:
|
|
"""Get the initialized router instance."""
|
|
instance = cls.get_instance()
|
|
return instance._router
|
|
|
|
@classmethod
|
|
def is_initialized(cls) -> bool:
|
|
"""Check if the router has been initialized."""
|
|
instance = cls.get_instance()
|
|
return instance._initialized and instance._router is not None
|
|
|
|
@classmethod
|
|
def get_model_count(cls) -> int:
|
|
"""Get the number of models in the router."""
|
|
instance = cls.get_instance()
|
|
return len(instance._model_list)
|
|
|
|
@classmethod
|
|
async def aimage_generation(
|
|
cls,
|
|
prompt: str,
|
|
model: str = "auto",
|
|
n: int | None = None,
|
|
timeout: int = 600,
|
|
**kwargs,
|
|
) -> ImageResponse:
|
|
"""
|
|
Generate images using the router for load balancing.
|
|
|
|
Uses Router.aimage_generation() which distributes requests
|
|
across configured image generation deployments.
|
|
|
|
Parameters like size, quality, style, and response_format are intentionally
|
|
omitted to keep the interface model-agnostic. Providers use their own
|
|
sensible defaults. If needed, pass them via **kwargs.
|
|
|
|
Args:
|
|
prompt: Text description of the desired image(s)
|
|
model: Model alias (default "auto" for router routing)
|
|
n: Number of images to generate
|
|
timeout: Request timeout in seconds
|
|
**kwargs: Additional provider-specific params (size, quality, etc.)
|
|
|
|
Returns:
|
|
ImageResponse from litellm
|
|
|
|
Raises:
|
|
ValueError: If router is not initialized
|
|
"""
|
|
instance = cls.get_instance()
|
|
if not instance._router:
|
|
raise ValueError(
|
|
"Image Generation Router not initialized. "
|
|
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
|
)
|
|
|
|
# Build kwargs for aimage_generation
|
|
gen_kwargs: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"timeout": timeout,
|
|
}
|
|
if n is not None:
|
|
gen_kwargs["n"] = n
|
|
gen_kwargs.update(kwargs)
|
|
|
|
return await instance._router.aimage_generation(**gen_kwargs)
|
|
|
|
|
|
def is_image_gen_auto_mode(config_id: int | None) -> bool:
|
|
"""
|
|
Check if the given config ID represents Image Generation Auto mode.
|
|
|
|
Args:
|
|
config_id: The config ID to check
|
|
|
|
Returns:
|
|
True if this is Auto mode, False otherwise
|
|
"""
|
|
return config_id == IMAGE_GEN_AUTO_MODE_ID
|