diff --git a/surfsense_backend/alembic/versions/93_add_image_generations_table.py b/surfsense_backend/alembic/versions/93_add_image_generations_table.py index 151208229..f24adc68f 100644 --- a/surfsense_backend/alembic/versions/93_add_image_generations_table.py +++ b/surfsense_backend/alembic/versions/93_add_image_generations_table.py @@ -60,8 +60,15 @@ def upgrade() -> None: sa.Column( "provider", sa.Enum( - "OPENAI", "AZURE_OPENAI", "GOOGLE", "VERTEX_AI", "BEDROCK", - "RECRAFT", "OPENROUTER", "XINFERENCE", "NSCALE", + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", name="imagegenprovider", create_type=False, ), diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 38bae230d..01c762197 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -106,8 +106,6 @@ You have access to the following tools: - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" - Args: - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. - - size: Image size. Options: "1024x1024" (square, default), "1536x1024" (landscape), "1024x1536" (portrait), "1792x1024" (wide) - - quality: Image quality. Options: "auto" (default), "high", "medium", "low" - n: Number of images to generate (1-4, default: 1) - Returns: A dictionary with the generated image URL in the "src" field, along with metadata. - CRITICAL: After calling generate_image, you MUST call `display_image` with the returned "src" URL @@ -300,19 +298,19 @@ You have access to the following tools: - Then provide your explanation, referencing the displayed image - User: "Generate an image of a cat" - - Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere", size="1024x1024", quality="auto")` + - Step 1: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` - Step 2: Use the returned "src" URL to display it: `display_image(src="", alt="A fluffy orange tabby cat on a windowsill", title="Generated Image")` - User: "Create a landscape painting of mountains" - - Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement", size="1536x1024", quality="high")` + - Step 1: `generate_image(prompt="Majestic snow-capped mountain range at sunset, dramatic orange and purple sky, alpine meadow with wildflowers in the foreground, oil painting style with visible brushstrokes, inspired by the Hudson River School art movement")` - Step 2: `display_image(src="", alt="Mountain landscape painting", title="Generated Image")` - User: "Draw me a logo for a coffee shop called Bean Dream" - - Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding", size="1024x1024", quality="high")` + - Step 1: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` - Step 2: `display_image(src="", alt="Bean Dream coffee shop logo", title="Generated Image")` - User: "Make a wide banner image for my blog about AI" - - Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional", size="1792x1024", quality="high")` + - Step 1: `generate_image(prompt="Wide banner illustration for an AI technology blog, featuring abstract neural network patterns, glowing blue and purple connections, modern futuristic aesthetic, digital art style, clean and professional")` - Step 2: `display_image(src="", alt="AI blog banner", title="Generated Image")` """ diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index 091fb122f..8ffa4ecde 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -21,12 +21,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import config from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace -from app.utils.signed_image_urls import generate_image_token from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, is_image_gen_auto_mode, ) +from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -76,8 +76,6 @@ def create_generate_image_tool( @tool async def generate_image( prompt: str, - size: str = "1024x1024", - quality: str = "auto", n: int = 1, ) -> dict[str, Any]: """ @@ -89,10 +87,6 @@ def create_generate_image_tool( Args: prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. - size: Image size. Options: "1024x1024" (square), "1536x1024" (landscape), - "1024x1536" (portrait), "1792x1024" (wide). Default: "1024x1024" - quality: Image quality. Options: "auto" (default), "high", "medium", "low". - Default: "auto" n: Number of images to generate (1-4). Default: 1 Returns: @@ -112,18 +106,14 @@ def create_generate_image_tool( ) # Build generation kwargs - # NOTE: 'style' is intentionally excluded from gen_kwargs because - # it is only supported by DALL-E 3 and causes errors with other - # models (e.g. gpt-image-1 rejects it as an unknown parameter). - # Since we can't predict which model auto-mode will route to, - # it's safest to omit it. + # NOTE: size, quality, and style are intentionally NOT passed. + # Different models support different values for these params + # (e.g. DALL-E 3 wants "hd"/"standard" for quality while + # gpt-image-1 wants "high"/"medium"/"low"; size options also + # differ). Letting the model use its own defaults avoids errors. gen_kwargs: dict[str, Any] = {} if n is not None and n > 1: gen_kwargs["n"] = n - if quality: - gen_kwargs["quality"] = quality - if size: - gen_kwargs["size"] = size # Call litellm based on config type if is_image_gen_auto_mode(config_id): @@ -199,8 +189,6 @@ def create_generate_image_tool( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), n=n, - quality=quality, - size=size, image_generation_config_id=config_id, response_data=response_dict, search_space_id=search_space_id, diff --git a/surfsense_backend/app/connectors/airtable_history.py b/surfsense_backend/app/connectors/airtable_history.py index 092485f77..49c2fcbdd 100644 --- a/surfsense_backend/app/connectors/airtable_history.py +++ b/surfsense_backend/app/connectors/airtable_history.py @@ -108,7 +108,9 @@ class AirtableHistoryConnector: # Final validation after decryption final_token = config_data.get("access_token") - if not final_token or (isinstance(final_token, str) and not final_token.strip()): + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): raise ValueError( "Airtable access token is invalid or empty. " "Please reconnect your Airtable account." diff --git a/surfsense_backend/app/connectors/confluence_history.py b/surfsense_backend/app/connectors/confluence_history.py index 908f532db..5d19edc54 100644 --- a/surfsense_backend/app/connectors/confluence_history.py +++ b/surfsense_backend/app/connectors/confluence_history.py @@ -128,7 +128,9 @@ class ConfluenceHistoryConnector: # Final validation after decryption final_token = config_data.get("access_token") - if not final_token or (isinstance(final_token, str) and not final_token.strip()): + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): raise ValueError( "Confluence access token is invalid or empty. " "Please reconnect your Confluence account." diff --git a/surfsense_backend/app/connectors/jira_history.py b/surfsense_backend/app/connectors/jira_history.py index 46a28324d..e9f28a2c4 100644 --- a/surfsense_backend/app/connectors/jira_history.py +++ b/surfsense_backend/app/connectors/jira_history.py @@ -129,7 +129,9 @@ class JiraHistoryConnector: # Final validation after decryption final_token = config_data.get("access_token") - if not final_token or (isinstance(final_token, str) and not final_token.strip()): + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): raise ValueError( "Jira access token is invalid or empty. " "Please reconnect your Jira account." diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index 6500b9027..534d70b89 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -153,7 +153,9 @@ class LinearConnector: # Final validation after decryption final_token = config_data.get("access_token") - if not final_token or (isinstance(final_token, str) and not final_token.strip()): + if not final_token or ( + isinstance(final_token, str) and not final_token.strip() + ): raise ValueError( "Linear access token is invalid or empty. " "Please reconnect your Linear account." diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 320ff6d8d..a82c18470 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -100,7 +100,6 @@ class PodcastStatus(str, Enum): FAILED = "failed" - class LiteLLMProvider(str, Enum): """ Enum for LLM providers supported by LiteLLM. @@ -941,7 +940,9 @@ class ImageGenerationConfig(BaseModel, TimestampMixin): search_space_id = Column( Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False ) - search_space = relationship("SearchSpace", back_populates="image_generation_configs") + search_space = relationship( + "SearchSpace", back_populates="image_generation_configs" + ) class ImageGeneration(BaseModel, TimestampMixin): @@ -973,9 +974,7 @@ class ImageGeneration(BaseModel, TimestampMixin): String(50), nullable=True ) # "1024x1024", "1536x1024", "1024x1536", etc. style = Column(String(50), nullable=True) # Model-specific style parameter - response_format = Column( - String(50), nullable=True - ) # "url" or "b64_json" + response_format = Column(String(50), nullable=True) # "url" or "b64_json" # Image generation config reference # 0 = Auto mode (router), negative IDs = global configs from YAML, diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 683f3548b..d9353284c 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -11,7 +11,6 @@ from .confluence_add_connector_route import router as confluence_add_connector_r from .discord_add_connector_route import router as discord_add_connector_router from .documents_routes import router as documents_router from .editor_routes import router as editor_router -from .image_generation_routes import router as image_generation_router from .google_calendar_add_connector_route import ( router as google_calendar_add_connector_router, ) @@ -21,6 +20,7 @@ from .google_drive_add_connector_route import ( from .google_gmail_add_connector_route import ( router as google_gmail_add_connector_router, ) +from .image_generation_routes import router as image_generation_router from .incentive_tasks_routes import router as incentive_tasks_router from .jira_add_connector_route import router as jira_add_connector_router from .linear_add_connector_route import router as linear_add_connector_router diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 9b79771eb..9406867c6 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -54,9 +54,9 @@ logger = logging.getLogger(__name__) _PROVIDER_MAP = { "OPENAI": "openai", "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio + "GOOGLE": "gemini", # Google AI Studio "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock + "BEDROCK": "bedrock", # AWS Bedrock "RECRAFT": "recraft", "OPENROUTER": "openrouter", "XINFERENCE": "xinference", @@ -82,7 +82,9 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None -def _build_model_string(provider: str, model_name: str, custom_provider: str | None) -> str: +def _build_model_string( + provider: str, model_name: str, custom_provider: str | None +) -> str: """Build a litellm model string from provider + model_name.""" if custom_provider: return f"{custom_provider}/{model_name}" @@ -210,38 +212,44 @@ async def get_global_image_gen_configs( safe_configs = [] if global_configs and len(global_configs) > 0: - safe_configs.append({ - "id": 0, - "name": "Auto (Load Balanced)", - "description": "Automatically routes across available image generation providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - }) + safe_configs.append( + { + "id": 0, + "name": "Auto (Load Balanced)", + "description": "Automatically routes across available image generation providers.", + "provider": "AUTO", + "custom_provider": None, + "model_name": "auto", + "api_base": None, + "api_version": None, + "litellm_params": {}, + "is_global": True, + "is_auto_mode": True, + } + ) for cfg in global_configs: - safe_configs.append({ - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - }) + safe_configs.append( + { + "id": cfg.get("id"), + "name": cfg.get("name"), + "description": cfg.get("description"), + "provider": cfg.get("provider"), + "custom_provider": cfg.get("custom_provider"), + "model_name": cfg.get("model_name"), + "api_base": cfg.get("api_base") or None, + "api_version": cfg.get("api_version") or None, + "litellm_params": cfg.get("litellm_params", {}), + "is_global": True, + } + ) return safe_configs except Exception as e: logger.exception("Failed to fetch global image generation configs") - raise HTTPException(status_code=500, detail=f"Failed to fetch configs: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e # ============================================================================= @@ -258,7 +266,9 @@ async def create_image_gen_config( """Create a new image generation config for a search space.""" try: await check_permission( - session, user, config_data.search_space_id, + session, + user, + config_data.search_space_id, Permission.IMAGE_GENERATIONS_CREATE.value, "You don't have permission to create image generation configs in this search space", ) @@ -274,7 +284,9 @@ async def create_image_gen_config( except Exception as e: await session.rollback() logger.exception("Failed to create ImageGenerationConfig") - raise HTTPException(status_code=500, detail=f"Failed to create config: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to create config: {e!s}" + ) from e @router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) @@ -288,7 +300,9 @@ async def list_image_gen_configs( """List image generation configs for a search space.""" try: await check_permission( - session, user, search_space_id, + session, + user, + search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to view image generation configs in this search space", ) @@ -306,10 +320,14 @@ async def list_image_gen_configs( raise except Exception as e: logger.exception("Failed to list ImageGenerationConfigs") - raise HTTPException(status_code=500, detail=f"Failed to fetch configs: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to fetch configs: {e!s}" + ) from e -@router.get("/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead) +@router.get( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) async def get_image_gen_config( config_id: int, session: AsyncSession = Depends(get_async_session), @@ -325,7 +343,9 @@ async def get_image_gen_config( raise HTTPException(status_code=404, detail="Config not found") await check_permission( - session, user, db_config.search_space_id, + session, + user, + db_config.search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to view image generation configs in this search space", ) @@ -335,10 +355,14 @@ async def get_image_gen_config( raise except Exception as e: logger.exception("Failed to get ImageGenerationConfig") - raise HTTPException(status_code=500, detail=f"Failed to fetch config: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to fetch config: {e!s}" + ) from e -@router.put("/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead) +@router.put( + "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead +) async def update_image_gen_config( config_id: int, update_data: ImageGenerationConfigUpdate, @@ -355,7 +379,9 @@ async def update_image_gen_config( raise HTTPException(status_code=404, detail="Config not found") await check_permission( - session, user, db_config.search_space_id, + session, + user, + db_config.search_space_id, Permission.IMAGE_GENERATIONS_CREATE.value, "You don't have permission to update image generation configs in this search space", ) @@ -372,7 +398,9 @@ async def update_image_gen_config( except Exception as e: await session.rollback() logger.exception("Failed to update ImageGenerationConfig") - raise HTTPException(status_code=500, detail=f"Failed to update config: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to update config: {e!s}" + ) from e @router.delete("/image-generation-configs/{config_id}", response_model=dict) @@ -391,21 +419,28 @@ async def delete_image_gen_config( raise HTTPException(status_code=404, detail="Config not found") await check_permission( - session, user, db_config.search_space_id, + session, + user, + db_config.search_space_id, Permission.IMAGE_GENERATIONS_DELETE.value, "You don't have permission to delete image generation configs in this search space", ) await session.delete(db_config) await session.commit() - return {"message": "Image generation config deleted successfully", "id": config_id} + return { + "message": "Image generation config deleted successfully", + "id": config_id, + } except HTTPException: raise except Exception as e: await session.rollback() logger.exception("Failed to delete ImageGenerationConfig") - raise HTTPException(status_code=500, detail=f"Failed to delete config: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to delete config: {e!s}" + ) from e # ============================================================================= @@ -422,7 +457,9 @@ async def create_image_generation( """Create and execute an image generation request.""" try: await check_permission( - session, user, data.search_space_id, + session, + user, + data.search_space_id, Permission.IMAGE_GENERATIONS_CREATE.value, "You don't have permission to create image generations in this search space", ) @@ -463,11 +500,15 @@ async def create_image_generation( raise except SQLAlchemyError: await session.rollback() - raise HTTPException(status_code=500, detail="Database error during image generation") from None + raise HTTPException( + status_code=500, detail="Database error during image generation" + ) from None except Exception as e: await session.rollback() logger.exception("Failed to create image generation") - raise HTTPException(status_code=500, detail=f"Image generation failed: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Image generation failed: {e!s}" + ) from e @router.get("/image-generations", response_model=list[ImageGenerationListRead]) @@ -487,7 +528,9 @@ async def list_image_generations( try: if search_space_id is not None: await check_permission( - session, user, search_space_id, + session, + user, + search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to read image generations in this search space", ) @@ -495,7 +538,8 @@ async def list_image_generations( select(ImageGeneration) .filter(ImageGeneration.search_space_id == search_space_id) .order_by(ImageGeneration.created_at.desc()) - .offset(skip).limit(limit) + .offset(skip) + .limit(limit) ) else: result = await session.execute( @@ -504,15 +548,21 @@ async def list_image_generations( .join(SearchSpaceMembership) .filter(SearchSpaceMembership.user_id == user.id) .order_by(ImageGeneration.created_at.desc()) - .offset(skip).limit(limit) + .offset(skip) + .limit(limit) ) - return [ImageGenerationListRead.from_orm_with_count(img) for img in result.scalars().all()] + return [ + ImageGenerationListRead.from_orm_with_count(img) + for img in result.scalars().all() + ] except HTTPException: raise except SQLAlchemyError: - raise HTTPException(status_code=500, detail="Database error fetching image generations") from None + raise HTTPException( + status_code=500, detail="Database error fetching image generations" + ) from None @router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead) @@ -531,7 +581,9 @@ async def get_image_generation( raise HTTPException(status_code=404, detail="Image generation not found") await check_permission( - session, user, image_gen.search_space_id, + session, + user, + image_gen.search_space_id, Permission.IMAGE_GENERATIONS_READ.value, "You don't have permission to read image generations in this search space", ) @@ -540,7 +592,9 @@ async def get_image_generation( except HTTPException: raise except SQLAlchemyError: - raise HTTPException(status_code=500, detail="Database error fetching image generation") from None + raise HTTPException( + status_code=500, detail="Database error fetching image generation" + ) from None @router.delete("/image-generations/{image_gen_id}", response_model=dict) @@ -559,7 +613,9 @@ async def delete_image_generation( raise HTTPException(status_code=404, detail="Image generation not found") await check_permission( - session, user, db_image_gen.search_space_id, + session, + user, + db_image_gen.search_space_id, Permission.IMAGE_GENERATIONS_DELETE.value, "You don't have permission to delete image generations in this search space", ) @@ -572,13 +628,16 @@ async def delete_image_generation( raise except SQLAlchemyError: await session.rollback() - raise HTTPException(status_code=500, detail="Database error deleting image generation") from None + raise HTTPException( + status_code=500, detail="Database error deleting image generation" + ) from None # ============================================================================= # Image Serving (serves generated images from DB, protected by signed tokens) # ============================================================================= + @router.get("/image-generations/{image_gen_id}/image") async def serve_generated_image( image_gen_id: int, @@ -616,13 +675,16 @@ async def serve_generated_image( images = image_gen.response_data.get("data", []) if not images or index >= len(images): - raise HTTPException(status_code=404, detail="Image not found at the specified index") + raise HTTPException( + status_code=404, detail="Image not found at the specified index" + ) image_entry = images[index] # If there's a URL, redirect to it if image_entry.get("url"): from fastapi.responses import RedirectResponse + return RedirectResponse(url=image_entry["url"]) # If there's b64_json data, decode and serve it @@ -643,4 +705,6 @@ async def serve_generated_image( raise except Exception as e: logger.exception("Failed to serve generated image") - raise HTTPException(status_code=500, detail=f"Failed to serve image: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to serve image: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 332af55fd..ad5abf777 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -113,10 +113,12 @@ __all__ = [ "DriveItem", "ExtensionDocumentContent", "ExtensionDocumentMetadata", + "GlobalImageGenConfigRead", "GlobalNewLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", - "GlobalImageGenConfigRead", + # Base schemas + "IDModel", # Image Generation Config schemas "ImageGenerationConfigCreate", "ImageGenerationConfigPublic", @@ -126,8 +128,6 @@ __all__ = [ "ImageGenerationCreate", "ImageGenerationListRead", "ImageGenerationRead", - # Base schemas - "IDModel", # RBAC schemas "InviteAcceptRequest", "InviteAcceptResponse", diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 367a35a77..6ef4feff8 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -13,7 +13,6 @@ from pydantic import BaseModel, ConfigDict, Field from app.db import ImageGenProvider - # ============================================================================= # ImageGenerationConfig CRUD Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index 3b8a15d2a..eb6936efd 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -31,9 +31,9 @@ IMAGE_GEN_AUTO_MODE_ID = 0 IMAGE_GEN_PROVIDER_MAP = { "OPENAI": "openai", "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio + "GOOGLE": "gemini", # Google AI Studio "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock + "BEDROCK": "bedrock", # AWS Bedrock "RECRAFT": "recraft", "OPENROUTER": "openrouter", "XINFERENCE": "xinference", @@ -156,9 +156,7 @@ class ImageGenRouterService: model_string = f"{config['custom_provider']}/{config['model_name']}" else: provider = config.get("provider", "").upper() - provider_prefix = IMAGE_GEN_PROVIDER_MAP.get( - provider, provider.lower() - ) + provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{config['model_name']}" # Build litellm params @@ -194,9 +192,7 @@ class ImageGenRouterService: return deployment except Exception as e: - logger.warning( - f"Failed to convert image gen config to deployment: {e}" - ) + logger.warning(f"Failed to convert image gen config to deployment: {e}") return None @classmethod diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index a9751e5d1..685f77e39 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -27,12 +27,12 @@ from app.agents.new_chat.llm_config import ( load_llm_config_from_yaml, ) from app.db import Document, SurfsenseDocsDocument +from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.schemas.new_chat import ChatAttachment from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, ) -from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db @@ -1211,9 +1211,10 @@ async def stream_new_chat( # Generate LLM title for new chats after first response # Check if this is the first assistant response by counting existing assistant messages - from app.db import NewChatMessage, NewChatThread from sqlalchemy import func + from app.db import NewChatMessage, NewChatThread + assistant_count_result = await session.execute( select(func.count(NewChatMessage.id)).filter( NewChatMessage.thread_id == chat_id, @@ -1231,10 +1232,12 @@ async def stream_new_chat( # Truncate inputs to avoid context length issues truncated_query = user_query[:500] truncated_response = accumulated_text[:1000] - title_result = await title_chain.ainvoke({ - "user_query": truncated_query, - "assistant_response": truncated_response, - }) + title_result = await title_chain.ainvoke( + { + "user_query": truncated_query, + "assistant_response": truncated_response, + } + ) # Extract and clean the title if title_result and hasattr(title_result, "content"): @@ -1242,7 +1245,7 @@ async def stream_new_chat( # Validate the title (reasonable length) if raw_title and len(raw_title) <= 100: # Remove any quotes or extra formatting - generated_title = raw_title.strip('"\'') + generated_title = raw_title.strip("\"'") except Exception: generated_title = None diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index 696cdf25e..ee07ba88f 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -219,7 +219,9 @@ class CustomBearerTransport(BearerTransport): # Decode JWT to get user_id for refresh token creation try: - payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False}) + payload = jwt.decode( + token, SECRET, algorithms=["HS256"], options={"verify_aud": False} + ) user_id = uuid.UUID(payload.get("sub")) refresh_token = await create_refresh_token(user_id) except Exception as e: