mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
chore: made generate_image more agnostic
This commit is contained in:
parent
19e2857343
commit
f85adefe5e
15 changed files with 176 additions and 112 deletions
|
|
@ -60,8 +60,15 @@ def upgrade() -> None:
|
|||
sa.Column(
|
||||
"provider",
|
||||
sa.Enum(
|
||||
"OPENAI", "AZURE_OPENAI", "GOOGLE", "VERTEX_AI", "BEDROCK",
|
||||
"RECRAFT", "OPENROUTER", "XINFERENCE", "NSCALE",
|
||||
"OPENAI",
|
||||
"AZURE_OPENAI",
|
||||
"GOOGLE",
|
||||
"VERTEX_AI",
|
||||
"BEDROCK",
|
||||
"RECRAFT",
|
||||
"OPENROUTER",
|
||||
"XINFERENCE",
|
||||
"NSCALE",
|
||||
name="imagegenprovider",
|
||||
create_type=False,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -106,8 +106,6 @@ You have access to the following tools:
|
|||
- Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork"
|
||||
- Args:
|
||||
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
|
||||
- size: Image size. Options: "1024x1024" (square, default), "1536x1024" (landscape), "1024x1536" (portrait), "1792x1024" (wide)
|
||||
- quality: Image quality. Options: "auto" (default), "high", "medium", "low"
|
||||
- n: Number of images to generate (1-4, default: 1)
|
||||
- Returns: A dictionary with the generated image URL in the "src" field, along with metadata.
|
||||
- CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL
|
||||
|
|
@ -300,19 +298,19 @@ You have access to the following tools:
|
|||
- Then provide your explanation, referencing the displayed image
|
||||
|
||||
- User: "Generate an image of a cat"
|
||||
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere", size="1024x1024", quality="auto")`
|
||||
- Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||
- Step 2: Use the returned "src" URL to display it: `display_image(src="<returned_url>", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")`
|
||||
|
||||
- User: "Create a landscape painting of mountains"
|
||||
- Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement", size="1536x1024", quality="high")`
|
||||
- Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Mountain landscape painting", title="Generated Image")`
|
||||
|
||||
- User: "Draw me a logo for a coffee shop called Bean Dream"
|
||||
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding", size="1024x1024", quality="high")`
|
||||
- Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="Bean Dream coffee shop logo", title="Generated Image")`
|
||||
|
||||
- User: "Make a wide banner image for my blog about AI"
|
||||
- Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional", size="1792x1024", quality="high")`
|
||||
- Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional")`
|
||||
- Step 2: `display_image(src="<returned_url>", alt="AI blog banner", title="Generated Image")`
|
||||
</tool_call_examples>
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -21,12 +21,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.config import config
|
||||
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -76,8 +76,6 @@ def create_generate_image_tool(
|
|||
@tool
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
size: str = "1024x1024",
|
||||
quality: str = "auto",
|
||||
n: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -89,10 +87,6 @@ def create_generate_image_tool(
|
|||
Args:
|
||||
prompt: A detailed text description of the image to generate.
|
||||
Be specific about subject, style, colors, composition, and mood.
|
||||
size: Image size. Options: "1024x1024" (square), "1536x1024" (landscape),
|
||||
"1024x1536" (portrait), "1792x1024" (wide). Default: "1024x1024"
|
||||
quality: Image quality. Options: "auto" (default), "high", "medium", "low".
|
||||
Default: "auto"
|
||||
n: Number of images to generate (1-4). Default: 1
|
||||
|
||||
Returns:
|
||||
|
|
@ -112,18 +106,14 @@ def create_generate_image_tool(
|
|||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: 'style' is intentionally excluded from gen_kwargs because
|
||||
# it is only supported by DALL-E 3 and causes errors with other
|
||||
# models (e.g. gpt-image-1 rejects it as an unknown parameter).
|
||||
# Since we can't predict which model auto-mode will route to,
|
||||
# it's safest to omit it.
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
if quality:
|
||||
gen_kwargs["quality"] = quality
|
||||
if size:
|
||||
gen_kwargs["size"] = size
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
|
|
@ -199,8 +189,6 @@ def create_generate_image_tool(
|
|||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
quality=quality,
|
||||
size=size,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
|
|
@ -108,7 +108,9 @@ class AirtableHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Airtable access token is invalid or empty. "
|
||||
"Please reconnect your Airtable account."
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ class ConfluenceHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Confluence access token is invalid or empty. "
|
||||
"Please reconnect your Confluence account."
|
||||
|
|
|
|||
|
|
@ -129,7 +129,9 @@ class JiraHistoryConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Jira access token is invalid or empty. "
|
||||
"Please reconnect your Jira account."
|
||||
|
|
|
|||
|
|
@ -153,7 +153,9 @@ class LinearConnector:
|
|||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
if not final_token or (
|
||||
isinstance(final_token, str) and not final_token.strip()
|
||||
):
|
||||
raise ValueError(
|
||||
"Linear access token is invalid or empty. "
|
||||
"Please reconnect your Linear account."
|
||||
|
|
|
|||
|
|
@ -100,7 +100,6 @@ class PodcastStatus(str, Enum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
"""
|
||||
Enum for LLM providers supported by LiteLLM.
|
||||
|
|
@ -941,7 +940,9 @@ class ImageGenerationConfig(BaseModel, TimestampMixin):
|
|||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="image_generation_configs")
|
||||
search_space = relationship(
|
||||
"SearchSpace", back_populates="image_generation_configs"
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneration(BaseModel, TimestampMixin):
|
||||
|
|
@ -973,9 +974,7 @@ class ImageGeneration(BaseModel, TimestampMixin):
|
|||
String(50), nullable=True
|
||||
) # "1024x1024", "1536x1024", "1024x1536", etc.
|
||||
style = Column(String(50), nullable=True) # Model-specific style parameter
|
||||
response_format = Column(
|
||||
String(50), nullable=True
|
||||
) # "url" or "b64_json"
|
||||
response_format = Column(String(50), nullable=True) # "url" or "b64_json"
|
||||
|
||||
# Image generation config reference
|
||||
# 0 = Auto mode (router), negative IDs = global configs from YAML,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from .confluence_add_connector_route import router as confluence_add_connector_r
|
|||
from .discord_add_connector_route import router as discord_add_connector_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .editor_routes import router as editor_router
|
||||
from .image_generation_routes import router as image_generation_router
|
||||
from .google_calendar_add_connector_route import (
|
||||
router as google_calendar_add_connector_router,
|
||||
)
|
||||
|
|
@ -21,6 +20,7 @@ from .google_drive_add_connector_route import (
|
|||
from .google_gmail_add_connector_route import (
|
||||
router as google_gmail_add_connector_router,
|
||||
)
|
||||
from .image_generation_routes import router as image_generation_router
|
||||
from .incentive_tasks_routes import router as incentive_tasks_router
|
||||
from .jira_add_connector_route import router as jira_add_connector_router
|
||||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
|
|
|
|||
|
|
@ -54,9 +54,9 @@ logger = logging.getLogger(__name__)
|
|||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
|
|
@ -82,7 +82,9 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def _build_model_string(provider: str, model_name: str, custom_provider: str | None) -> str:
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
|
|
@ -210,38 +212,44 @@ async def get_global_image_gen_configs(
|
|||
safe_configs = []
|
||||
|
||||
if global_configs and len(global_configs) > 0:
|
||||
safe_configs.append({
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes across available image generation providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
})
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": 0,
|
||||
"name": "Auto (Load Balanced)",
|
||||
"description": "Automatically routes across available image generation providers.",
|
||||
"provider": "AUTO",
|
||||
"custom_provider": None,
|
||||
"model_name": "auto",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"is_global": True,
|
||||
"is_auto_mode": True,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
safe_configs.append({
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
})
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
}
|
||||
)
|
||||
|
||||
return safe_configs
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch global image generation configs")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch configs: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -258,7 +266,9 @@ async def create_image_gen_config(
|
|||
"""Create a new image generation config for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session, user, config_data.search_space_id,
|
||||
session,
|
||||
user,
|
||||
config_data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generation configs in this search space",
|
||||
)
|
||||
|
|
@ -274,7 +284,9 @@ async def create_image_gen_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create ImageGenerationConfig")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create config: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
||||
|
|
@ -288,7 +300,9 @@ async def list_image_gen_configs(
|
|||
"""List image generation configs for a search space."""
|
||||
try:
|
||||
await check_permission(
|
||||
session, user, search_space_id,
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
|
|
@ -306,10 +320,14 @@ async def list_image_gen_configs(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list ImageGenerationConfigs")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch configs: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead)
|
||||
@router.get(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def get_image_gen_config(
|
||||
config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -325,7 +343,9 @@ async def get_image_gen_config(
|
|||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, db_config.search_space_id,
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to view image generation configs in this search space",
|
||||
)
|
||||
|
|
@ -335,10 +355,14 @@ async def get_image_gen_config(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get ImageGenerationConfig")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch config: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead)
|
||||
@router.put(
|
||||
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
||||
)
|
||||
async def update_image_gen_config(
|
||||
config_id: int,
|
||||
update_data: ImageGenerationConfigUpdate,
|
||||
|
|
@ -355,7 +379,9 @@ async def update_image_gen_config(
|
|||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, db_config.search_space_id,
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to update image generation configs in this search space",
|
||||
)
|
||||
|
|
@ -372,7 +398,9 @@ async def update_image_gen_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to update ImageGenerationConfig")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update config: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
|
||||
|
|
@ -391,21 +419,28 @@ async def delete_image_gen_config(
|
|||
raise HTTPException(status_code=404, detail="Config not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, db_config.search_space_id,
|
||||
session,
|
||||
user,
|
||||
db_config.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generation configs in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_config)
|
||||
await session.commit()
|
||||
return {"message": "Image generation config deleted successfully", "id": config_id}
|
||||
return {
|
||||
"message": "Image generation config deleted successfully",
|
||||
"id": config_id,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to delete ImageGenerationConfig")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete config: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete config: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -422,7 +457,9 @@ async def create_image_generation(
|
|||
"""Create and execute an image generation request."""
|
||||
try:
|
||||
await check_permission(
|
||||
session, user, data.search_space_id,
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_CREATE.value,
|
||||
"You don't have permission to create image generations in this search space",
|
||||
)
|
||||
|
|
@ -463,11 +500,15 @@ async def create_image_generation(
|
|||
raise
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error during image generation") from None
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error during image generation"
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to create image generation")
|
||||
raise HTTPException(status_code=500, detail=f"Image generation failed: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Image generation failed: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
|
||||
|
|
@ -487,7 +528,9 @@ async def list_image_generations(
|
|||
try:
|
||||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session, user, search_space_id,
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
)
|
||||
|
|
@ -495,7 +538,8 @@ async def list_image_generations(
|
|||
select(ImageGeneration)
|
||||
.filter(ImageGeneration.search_space_id == search_space_id)
|
||||
.order_by(ImageGeneration.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
|
|
@ -504,15 +548,21 @@ async def list_image_generations(
|
|||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.order_by(ImageGeneration.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
return [ImageGenerationListRead.from_orm_with_count(img) for img in result.scalars().all()]
|
||||
return [
|
||||
ImageGenerationListRead.from_orm_with_count(img)
|
||||
for img in result.scalars().all()
|
||||
]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error fetching image generations") from None
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error fetching image generations"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
|
||||
|
|
@ -531,7 +581,9 @@ async def get_image_generation(
|
|||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, image_gen.search_space_id,
|
||||
session,
|
||||
user,
|
||||
image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_READ.value,
|
||||
"You don't have permission to read image generations in this search space",
|
||||
)
|
||||
|
|
@ -540,7 +592,9 @@ async def get_image_generation(
|
|||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error fetching image generation") from None
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error fetching image generation"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
|
||||
|
|
@ -559,7 +613,9 @@ async def delete_image_generation(
|
|||
raise HTTPException(status_code=404, detail="Image generation not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, db_image_gen.search_space_id,
|
||||
session,
|
||||
user,
|
||||
db_image_gen.search_space_id,
|
||||
Permission.IMAGE_GENERATIONS_DELETE.value,
|
||||
"You don't have permission to delete image generations in this search space",
|
||||
)
|
||||
|
|
@ -572,13 +628,16 @@ async def delete_image_generation(
|
|||
raise
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error deleting image generation") from None
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error deleting image generation"
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Serving (serves generated images from DB, protected by signed tokens)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/image-generations/{image_gen_id}/image")
|
||||
async def serve_generated_image(
|
||||
image_gen_id: int,
|
||||
|
|
@ -616,13 +675,16 @@ async def serve_generated_image(
|
|||
|
||||
images = image_gen.response_data.get("data", [])
|
||||
if not images or index >= len(images):
|
||||
raise HTTPException(status_code=404, detail="Image not found at the specified index")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Image not found at the specified index"
|
||||
)
|
||||
|
||||
image_entry = images[index]
|
||||
|
||||
# If there's a URL, redirect to it
|
||||
if image_entry.get("url"):
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url=image_entry["url"])
|
||||
|
||||
# If there's b64_json data, decode and serve it
|
||||
|
|
@ -643,4 +705,6 @@ async def serve_generated_image(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to serve generated image")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to serve image: {e!s}") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to serve image: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -113,10 +113,12 @@ __all__ = [
|
|||
"DriveItem",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"GlobalImageGenConfigRead",
|
||||
"GlobalNewLLMConfigRead",
|
||||
"GoogleDriveIndexRequest",
|
||||
"GoogleDriveIndexingOptions",
|
||||
"GlobalImageGenConfigRead",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# Image Generation Config schemas
|
||||
"ImageGenerationConfigCreate",
|
||||
"ImageGenerationConfigPublic",
|
||||
|
|
@ -126,8 +128,6 @@ __all__ = [
|
|||
"ImageGenerationCreate",
|
||||
"ImageGenerationListRead",
|
||||
"ImageGenerationRead",
|
||||
# Base schemas
|
||||
"IDModel",
|
||||
# RBAC schemas
|
||||
"InviteAcceptRequest",
|
||||
"InviteAcceptResponse",
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
|
||||
from app.db import ImageGenProvider
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ImageGenerationConfig CRUD Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ IMAGE_GEN_AUTO_MODE_ID = 0
|
|||
IMAGE_GEN_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
|
|
@ -156,9 +156,7 @@ class ImageGenRouterService:
|
|||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(
|
||||
provider, provider.lower()
|
||||
)
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
|
|
@ -194,9 +192,7 @@ class ImageGenRouterService:
|
|||
return deployment
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to convert image gen config to deployment: {e}"
|
||||
)
|
||||
logger.warning(f"Failed to convert image gen config to deployment: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ from app.agents.new_chat.llm_config import (
|
|||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.db import Document, SurfsenseDocsDocument
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.schemas.new_chat import ChatAttachment
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
|
@ -1211,9 +1211,10 @@ async def stream_new_chat(
|
|||
|
||||
# Generate LLM title for new chats after first response
|
||||
# Check if this is the first assistant response by counting existing assistant messages
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
|
|
@ -1231,10 +1232,12 @@ async def stream_new_chat(
|
|||
# Truncate inputs to avoid context length issues
|
||||
truncated_query = user_query[:500]
|
||||
truncated_response = accumulated_text[:1000]
|
||||
title_result = await title_chain.ainvoke({
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
})
|
||||
title_result = await title_chain.ainvoke(
|
||||
{
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract and clean the title
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
|
|
@ -1242,7 +1245,7 @@ async def stream_new_chat(
|
|||
# Validate the title (reasonable length)
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
# Remove any quotes or extra formatting
|
||||
generated_title = raw_title.strip('"\'')
|
||||
generated_title = raw_title.strip("\"'")
|
||||
except Exception:
|
||||
generated_title = None
|
||||
|
||||
|
|
|
|||
|
|
@ -219,7 +219,9 @@ class CustomBearerTransport(BearerTransport):
|
|||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
|
||||
payload = jwt.decode(
|
||||
token, SECRET, algorithms=["HS256"], options={"verify_aud": False}
|
||||
)
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue