Merge remote-tracking branch 'upstream/dev' into feat/unified-etl-pipeline

This commit is contained in:
Anish Sarkar 2026-04-06 22:04:51 +05:30
commit 63a75052ca
76 changed files with 3041 additions and 376 deletions

View file

@ -1351,6 +1351,9 @@ class SearchSpace(BaseModel, TimestampMixin):
image_generation_config_id = Column(
Integer, nullable=True, default=0
) # For image generation, defaults to Auto mode
vision_llm_id = Column(
Integer, nullable=True, default=0
) # For vision/screenshot analysis, defaults to Auto mode
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False

View file

@ -3,6 +3,7 @@ from fastapi import APIRouter
from .airtable_add_connector_route import (
router as airtable_add_connector_router,
)
from .autocomplete_routes import router as autocomplete_router
from .chat_comments_routes import router as chat_comments_router
from .circleback_webhook_route import router as circleback_webhook_router
from .clickup_add_connector_route import router as clickup_add_connector_router
@ -95,3 +96,4 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
router.include_router(stripe_router) # Stripe checkout for additional page packs
router.include_router(youtube_router) # YouTube playlist resolution
router.include_router(prompts_router)
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context

View file

@ -1,7 +1,5 @@
import base64
import hashlib
import logging
import secrets
from datetime import UTC, datetime, timedelta
from uuid import UUID
@ -26,7 +24,11 @@ from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.oauth_security import (
OAuthStateManager,
TokenEncryption,
generate_pkce_pair,
)
logger = logging.getLogger(__name__)
@ -75,28 +77,6 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
return f"Basic {b64}"
def generate_pkce_pair() -> tuple[str, str]:
"""
Generate PKCE code verifier and code challenge.
Returns:
Tuple of (code_verifier, code_challenge)
"""
# Generate code verifier (43-128 characters)
code_verifier = (
base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
)
# Generate code challenge (SHA256 hash of verifier, base64url encoded)
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
.decode("utf-8")
.rstrip("=")
)
return code_verifier, code_challenge
@router.get("/auth/airtable/connector/add")
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
"""

View file

@ -0,0 +1,42 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.services.new_streaming_service import VercelStreamingService
from app.services.vision_autocomplete_service import stream_vision_autocomplete
from app.users import current_active_user
from app.utils.rbac import check_search_space_access
router = APIRouter(prefix="/autocomplete", tags=["autocomplete"])
MAX_SCREENSHOT_SIZE = 20 * 1024 * 1024 # 20 MB base64 ceiling
class VisionAutocompleteRequest(BaseModel):
screenshot: str = Field(..., max_length=MAX_SCREENSHOT_SIZE)
search_space_id: int
app_name: str = ""
window_title: str = ""
@router.post("/vision/stream")
async def vision_autocomplete_stream(
body: VisionAutocompleteRequest,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
await check_search_space_access(session, user, body.search_space_id)
return StreamingResponse(
stream_vision_autocomplete(
body.screenshot, body.search_space_id, session,
app_name=body.app_name, window_title=body.window_title,
),
media_type="text/event-stream",
headers={
**VercelStreamingService.get_response_headers(),
"X-Accel-Buffering": "no",
},
)

View file

@ -28,7 +28,11 @@ from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.oauth_security import (
OAuthStateManager,
TokenEncryption,
generate_code_verifier,
)
logger = logging.getLogger(__name__)
@ -96,9 +100,14 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
flow = get_google_flow()
# Generate secure state parameter with HMAC signature
code_verifier = generate_code_verifier()
flow.code_verifier = code_verifier
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
state_manager = get_state_manager()
state_encoded = state_manager.generate_secure_state(space_id, user.id)
state_encoded = state_manager.generate_secure_state(
space_id, user.id, code_verifier=code_verifier
)
auth_url, _ = flow.authorization_url(
access_type="offline",
@ -146,8 +155,11 @@ async def reauth_calendar(
flow = get_google_flow()
code_verifier = generate_code_verifier()
flow.code_verifier = code_verifier
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
@ -225,6 +237,7 @@ async def calendar_callback(
user_id = UUID(data["user_id"])
space_id = data["space_id"]
code_verifier = data.get("code_verifier")
# Validate redirect URI (security: ensure it matches configured value)
if not config.GOOGLE_CALENDAR_REDIRECT_URI:
@ -233,6 +246,7 @@ async def calendar_callback(
)
flow = get_google_flow()
flow.code_verifier = code_verifier
flow.fetch_token(code=code)
creds = flow.credentials

View file

@ -41,7 +41,11 @@ from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.oauth_security import (
OAuthStateManager,
TokenEncryption,
generate_code_verifier,
)
# Relax token scope validation for Google OAuth
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
@ -127,14 +131,19 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
flow = get_google_flow()
# Generate secure state parameter with HMAC signature
code_verifier = generate_code_verifier()
flow.code_verifier = code_verifier
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
state_manager = get_state_manager()
state_encoded = state_manager.generate_secure_state(space_id, user.id)
state_encoded = state_manager.generate_secure_state(
space_id, user.id, code_verifier=code_verifier
)
# Generate authorization URL
auth_url, _ = flow.authorization_url(
access_type="offline", # Get refresh token
prompt="consent", # Force consent screen to get refresh token
access_type="offline",
prompt="consent",
include_granted_scopes="true",
state=state_encoded,
)
@ -193,8 +202,11 @@ async def reauth_drive(
flow = get_google_flow()
code_verifier = generate_code_verifier()
flow.code_verifier = code_verifier
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
@ -285,6 +297,7 @@ async def drive_callback(
space_id = data["space_id"]
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
code_verifier = data.get("code_verifier")
logger.info(
f"Processing Google Drive callback for user {user_id}, space {space_id}"
@ -296,8 +309,9 @@ async def drive_callback(
status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
)
# Exchange authorization code for tokens
# Exchange authorization code for tokens (restore PKCE code_verifier from state)
flow = get_google_flow()
flow.code_verifier = code_verifier
flow.fetch_token(code=code)
creds = flow.credentials

View file

@ -28,7 +28,11 @@ from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.oauth_security import (
OAuthStateManager,
TokenEncryption,
generate_code_verifier,
)
logger = logging.getLogger(__name__)
@ -109,9 +113,14 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
flow = get_google_flow()
# Generate secure state parameter with HMAC signature
code_verifier = generate_code_verifier()
flow.code_verifier = code_verifier
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
state_manager = get_state_manager()
state_encoded = state_manager.generate_secure_state(space_id, user.id)
state_encoded = state_manager.generate_secure_state(
space_id, user.id, code_verifier=code_verifier
)
auth_url, _ = flow.authorization_url(
access_type="offline",
@ -164,8 +173,11 @@ async def reauth_gmail(
flow = get_google_flow()
code_verifier = generate_code_verifier()
flow.code_verifier = code_verifier
state_manager = get_state_manager()
extra: dict = {"connector_id": connector_id}
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
@ -256,6 +268,7 @@ async def gmail_callback(
user_id = UUID(data["user_id"])
space_id = data["space_id"]
code_verifier = data.get("code_verifier")
# Validate redirect URI (security: ensure it matches configured value)
if not config.GOOGLE_GMAIL_REDIRECT_URI:
@ -264,6 +277,7 @@ async def gmail_callback(
)
flow = get_google_flow()
flow.code_verifier = code_verifier
flow.fetch_token(code=code)
creds = flow.credentials

View file

@ -522,14 +522,17 @@ async def get_llm_preferences(
image_generation_config = await _get_image_gen_config_by_id(
session, search_space.image_generation_config_id
)
vision_llm = await _get_llm_config_by_id(session, search_space.vision_llm_id)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
document_summary_llm_id=search_space.document_summary_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
vision_llm_id=search_space.vision_llm_id,
agent_llm=agent_llm,
document_summary_llm=document_summary_llm,
image_generation_config=image_generation_config,
vision_llm=vision_llm,
)
except HTTPException:
@ -589,14 +592,17 @@ async def update_llm_preferences(
image_generation_config = await _get_image_gen_config_by_id(
session, search_space.image_generation_config_id
)
vision_llm = await _get_llm_config_by_id(session, search_space.vision_llm_id)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
document_summary_llm_id=search_space.document_summary_llm_id,
image_generation_config_id=search_space.image_generation_config_id,
vision_llm_id=search_space.vision_llm_id,
agent_llm=agent_llm,
document_summary_llm=document_summary_llm,
image_generation_config=image_generation_config,
vision_llm=vision_llm,
)
except HTTPException:

View file

@ -182,6 +182,9 @@ class LLMPreferencesRead(BaseModel):
image_generation_config_id: int | None = Field(
None, description="ID of the image generation config to use"
)
vision_llm_id: int | None = Field(
None, description="ID of the LLM config to use for vision/screenshot analysis"
)
agent_llm: dict[str, Any] | None = Field(
None, description="Full config for agent LLM"
)
@ -191,6 +194,9 @@ class LLMPreferencesRead(BaseModel):
image_generation_config: dict[str, Any] | None = Field(
None, description="Full config for image generation"
)
vision_llm: dict[str, Any] | None = Field(
None, description="Full config for vision LLM"
)
model_config = ConfigDict(from_attributes=True)
@ -207,3 +213,6 @@ class LLMPreferencesUpdate(BaseModel):
image_generation_config_id: int | None = Field(
None, description="ID of the image generation config to use"
)
vision_llm_id: int | None = Field(
None, description="ID of the LLM config to use for vision/screenshot analysis"
)

View file

@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
class LLMRole:
AGENT = "agent" # For agent/chat operations
DOCUMENT_SUMMARY = "document_summary" # For document summarization
VISION = "vision" # For vision/screenshot analysis
def get_global_llm_config(llm_config_id: int) -> dict | None:
@ -187,7 +188,7 @@ async def get_search_space_llm_instance(
Args:
session: Database session
search_space_id: Search Space ID
role: LLM role ('agent' or 'document_summary')
role: LLM role ('agent', 'document_summary', or 'vision')
Returns:
ChatLiteLLM or ChatLiteLLMRouter instance, or None if not found
@ -209,6 +210,8 @@ async def get_search_space_llm_instance(
llm_config_id = search_space.agent_llm_id
elif role == LLMRole.DOCUMENT_SUMMARY:
llm_config_id = search_space.document_summary_llm_id
elif role == LLMRole.VISION:
llm_config_id = search_space.vision_llm_id
else:
logger.error(f"Invalid LLM role: {role}")
return None
@ -405,6 +408,13 @@ async def get_document_summary_llm(
)
async def get_vision_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Get the search space's vision LLM instance for screenshot analysis."""
return await get_search_space_llm_instance(session, search_space_id, LLMRole.VISION)
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
async def get_user_long_context_llm(
session: AsyncSession,

View file

@ -3,7 +3,7 @@ Service for managing user page limits for ETL services.
"""
import os
from pathlib import Path
from pathlib import Path, PurePosixPath
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@ -223,10 +223,155 @@ class PageLimitService:
# Estimate ~2000 characters per page
return max(1, content_length // 2000)
@staticmethod
def estimate_pages_from_metadata(
file_name_or_ext: str, file_size: int | str | None = None
) -> int:
"""Size-based page estimation from file name/extension and byte size.
Pure function no file I/O, no database access. Used by cloud
connectors (which only have API metadata) and as the internal
fallback for :meth:`estimate_pages_before_processing`.
``file_name_or_ext`` can be a full filename (``"report.pdf"``) or
a bare extension (``".pdf"``). ``file_size`` may be an int, a
stringified int from a cloud API, or *None*.
"""
if file_size is not None:
try:
file_size = int(file_size)
except (ValueError, TypeError):
file_size = 0
else:
file_size = 0
if file_size <= 0:
return 1
ext = PurePosixPath(file_name_or_ext).suffix.lower() if file_name_or_ext else ""
if not ext and file_name_or_ext.startswith("."):
ext = file_name_or_ext.lower()
file_ext = ext
if file_ext == ".pdf":
return max(1, file_size // (100 * 1024))
if file_ext in {
".doc",
".docx",
".docm",
".dot",
".dotm",
".odt",
".ott",
".sxw",
".stw",
".uot",
".rtf",
".pages",
".wpd",
".wps",
".abw",
".zabw",
".cwk",
".hwp",
".lwp",
".mcw",
".mw",
".sdw",
".vor",
}:
return max(1, file_size // (50 * 1024))
if file_ext in {
".ppt",
".pptx",
".pptm",
".pot",
".potx",
".odp",
".otp",
".sxi",
".sti",
".uop",
".key",
".sda",
".sdd",
".sdp",
}:
return max(1, file_size // (200 * 1024))
if file_ext in {
".xls",
".xlsx",
".xlsm",
".xlsb",
".xlw",
".xlr",
".ods",
".ots",
".fods",
".numbers",
".123",
".wk1",
".wk2",
".wk3",
".wk4",
".wks",
".wb1",
".wb2",
".wb3",
".wq1",
".wq2",
".csv",
".tsv",
".slk",
".sylk",
".dif",
".dbf",
".prn",
".qpw",
".602",
".et",
".eth",
}:
return max(1, file_size // (100 * 1024))
if file_ext in {".epub"}:
return max(1, file_size // (50 * 1024))
if file_ext in {".txt", ".log", ".md", ".markdown", ".htm", ".html", ".xml"}:
return max(1, file_size // 3000)
if file_ext in {
".jpg",
".jpeg",
".png",
".gif",
".bmp",
".tiff",
".webp",
".svg",
".cgm",
".odg",
".pbd",
}:
return 1
if file_ext in {".mp3", ".m4a", ".wav", ".mpga"}:
return max(1, file_size // (1024 * 1024))
if file_ext in {".mp4", ".mpeg", ".webm"}:
return max(1, file_size // (5 * 1024 * 1024))
return max(1, file_size // (80 * 1024))
def estimate_pages_before_processing(self, file_path: str) -> int:
"""
Estimate page count from file before processing (to avoid unnecessary API calls).
This is called BEFORE sending to ETL services to prevent cost on rejected files.
Estimate page count from a local file before processing.
For PDFs, attempts to read the actual page count via pypdf.
For everything else, delegates to :meth:`estimate_pages_from_metadata`.
Args:
file_path: Path to the file
@ -240,7 +385,6 @@ class PageLimitService:
file_ext = Path(file_path).suffix.lower()
file_size = os.path.getsize(file_path)
# PDF files - try to get actual page count
if file_ext == ".pdf":
try:
import pypdf
@ -249,153 +393,6 @@ class PageLimitService:
pdf_reader = pypdf.PdfReader(f)
return len(pdf_reader.pages)
except Exception:
# If PDF reading fails, fall back to size estimation
# Typical PDF: ~100KB per page (conservative estimate)
return max(1, file_size // (100 * 1024))
pass # fall through to size-based estimation
# Word Processing Documents
# Microsoft Word, LibreOffice Writer, WordPerfect, Pages, etc.
elif file_ext in [
".doc",
".docx",
".docm",
".dot",
".dotm", # Microsoft Word
".odt",
".ott",
".sxw",
".stw",
".uot", # OpenDocument/StarOffice Writer
".rtf", # Rich Text Format
".pages", # Apple Pages
".wpd",
".wps", # WordPerfect, Microsoft Works
".abw",
".zabw", # AbiWord
".cwk",
".hwp",
".lwp",
".mcw",
".mw",
".sdw",
".vor", # Other word processors
]:
# Typical word document: ~50KB per page (conservative)
return max(1, file_size // (50 * 1024))
# Presentation Documents
# PowerPoint, Impress, Keynote, etc.
elif file_ext in [
".ppt",
".pptx",
".pptm",
".pot",
".potx", # Microsoft PowerPoint
".odp",
".otp",
".sxi",
".sti",
".uop", # OpenDocument/StarOffice Impress
".key", # Apple Keynote
".sda",
".sdd",
".sdp", # StarOffice Draw/Impress
]:
# Typical presentation: ~200KB per slide (conservative)
return max(1, file_size // (200 * 1024))
# Spreadsheet Documents
# Excel, Calc, Numbers, Lotus, etc.
elif file_ext in [
".xls",
".xlsx",
".xlsm",
".xlsb",
".xlw",
".xlr", # Microsoft Excel
".ods",
".ots",
".fods", # OpenDocument Spreadsheet
".numbers", # Apple Numbers
".123",
".wk1",
".wk2",
".wk3",
".wk4",
".wks", # Lotus 1-2-3
".wb1",
".wb2",
".wb3",
".wq1",
".wq2", # Quattro Pro
".csv",
".tsv",
".slk",
".sylk",
".dif",
".dbf",
".prn",
".qpw", # Data formats
".602",
".et",
".eth", # Other spreadsheets
]:
# Spreadsheets typically have 1 sheet = 1 page for ETL
# Conservative: ~100KB per sheet
return max(1, file_size // (100 * 1024))
# E-books
elif file_ext in [".epub"]:
# E-books vary widely, estimate by size
# Typical e-book: ~50KB per page
return max(1, file_size // (50 * 1024))
# Plain Text and Markup Files
elif file_ext in [
".txt",
".log", # Plain text
".md",
".markdown", # Markdown
".htm",
".html",
".xml", # Markup
]:
# Plain text: ~3000 bytes per page
return max(1, file_size // 3000)
# Image Files
# Each image is typically processed as 1 page
elif file_ext in [
".jpg",
".jpeg", # JPEG
".png", # PNG
".gif", # GIF
".bmp", # Bitmap
".tiff", # TIFF
".webp", # WebP
".svg", # SVG
".cgm", # Computer Graphics Metafile
".odg",
".pbd", # OpenDocument Graphics
]:
# Each image = 1 page
return 1
# Audio Files (transcription = typically 1 page per minute)
# Note: These should be handled by audio transcription flow, not ETL
elif file_ext in [".mp3", ".m4a", ".wav", ".mpga"]:
# Audio files: estimate based on duration
# Fallback: ~1MB per minute of audio, 1 page per minute transcript
return max(1, file_size // (1024 * 1024))
# Video Files (typically not processed for pages, but just in case)
elif file_ext in [".mp4", ".mpeg", ".webm"]:
# Video files: very rough estimate
# Typically wouldn't be page-based, but use conservative estimate
return max(1, file_size // (5 * 1024 * 1024))
# Other/Unknown Document Types
else:
# Conservative estimate: ~80KB per page
# This catches: .sgl, .sxg, .uof, .uos1, .uos2, .web, and any future formats
return max(1, file_size // (80 * 1024))
return self.estimate_pages_from_metadata(file_ext, file_size)

View file

@ -0,0 +1,225 @@
import logging
from typing import AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from sqlalchemy.ext.asyncio import AsyncSession
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.services.llm_service import get_vision_llm
from app.services.new_streaming_service import VercelStreamingService
logger = logging.getLogger(__name__)
KB_TOP_K = 5
KB_MAX_CHARS = 4000
EXTRACT_QUERY_PROMPT = """Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
EXTRACT_QUERY_PROMPT_WITH_APP = """The user is currently in the application "{app_name}" with the window titled "{window_title}".
Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
VISION_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
You will receive a screenshot of the user's screen. Your job:
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
2. Identify the text area where the user will type.
3. Based on the full visual context, generate the text the user most likely wants to write.
Key behavior:
- If the text area is EMPTY, draft a full response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
- If the text area already has text, continue it naturally.
Rules:
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
- Be concise but complete a full thought, not a fragment.
- Match the tone and formality of the surrounding context.
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
- Do NOT describe the screenshot or explain your reasoning.
- If you cannot determine what to write, output nothing."""
APP_CONTEXT_BLOCK = """
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
KB_CONTEXT_BLOCK = """
You also have access to the user's knowledge base documents below. Use them to write more accurate, informed, and contextually relevant text. Do NOT cite or reference the documents explicitly — just let the knowledge inform your writing naturally.
<knowledge_base>
{kb_context}
</knowledge_base>"""
def _build_system_prompt(app_name: str, window_title: str, kb_context: str) -> str:
"""Assemble the system prompt from optional context blocks."""
prompt = VISION_SYSTEM_PROMPT
if app_name:
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
if kb_context:
prompt += KB_CONTEXT_BLOCK.format(kb_context=kb_context)
return prompt
def _is_vision_unsupported_error(e: Exception) -> bool:
"""Check if an exception indicates the model doesn't support vision/images."""
msg = str(e).lower()
return "content must be a string" in msg or "does not support image" in msg
async def _extract_query_from_screenshot(
llm, screenshot_data_url: str,
app_name: str = "", window_title: str = "",
) -> str | None:
"""Ask the Vision LLM to describe what the user is working on.
Raises vision-unsupported errors so the caller can return a
friendly message immediately instead of retrying with astream.
"""
if app_name:
prompt_text = EXTRACT_QUERY_PROMPT_WITH_APP.format(
app_name=app_name, window_title=window_title,
)
else:
prompt_text = EXTRACT_QUERY_PROMPT
try:
response = await llm.ainvoke([
HumanMessage(content=[
{"type": "text", "text": prompt_text},
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
]),
])
query = response.content.strip() if hasattr(response, "content") else ""
return query if query else None
except Exception as e:
if _is_vision_unsupported_error(e):
raise
logger.warning(f"Failed to extract query from screenshot: {e}")
return None
async def _search_knowledge_base(
session: AsyncSession, search_space_id: int, query: str
) -> str:
"""Search the KB and return formatted context string."""
try:
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text=query,
top_k=KB_TOP_K,
search_space_id=search_space_id,
)
if not results:
return ""
parts: list[str] = []
char_count = 0
for doc in results:
title = doc.get("document", {}).get("title", "Untitled")
for chunk in doc.get("chunks", []):
content = chunk.get("content", "").strip()
if not content:
continue
entry = f"[{title}]\n{content}"
if char_count + len(entry) > KB_MAX_CHARS:
break
parts.append(entry)
char_count += len(entry)
if char_count >= KB_MAX_CHARS:
break
return "\n\n---\n\n".join(parts)
except Exception as e:
logger.warning(f"KB search failed, proceeding without context: {e}")
return ""
async def stream_vision_autocomplete(
screenshot_data_url: str,
search_space_id: int,
session: AsyncSession,
*,
app_name: str = "",
window_title: str = "",
) -> AsyncGenerator[str, None]:
"""Analyze a screenshot with the vision LLM and stream a text completion.
Pipeline:
1. Extract a search query from the screenshot (non-streaming)
2. Search the knowledge base for relevant context
3. Stream the final completion with screenshot + KB + app context
"""
streaming = VercelStreamingService()
vision_error_msg = (
"The selected model does not support vision. "
"Please set a vision-capable model (e.g. GPT-4o, Gemini) in your search space settings."
)
llm = await get_vision_llm(session, search_space_id)
if not llm:
yield streaming.format_message_start()
yield streaming.format_error("No Vision LLM configured for this search space")
yield streaming.format_done()
return
kb_context = ""
try:
query = await _extract_query_from_screenshot(
llm, screenshot_data_url, app_name=app_name, window_title=window_title,
)
except Exception as e:
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
yield streaming.format_message_start()
yield streaming.format_error(vision_error_msg)
yield streaming.format_done()
return
if query:
kb_context = await _search_knowledge_base(session, search_space_id, query)
system_prompt = _build_system_prompt(app_name, window_title, kb_context)
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=[
{
"type": "text",
"text": "Analyze this screenshot. Understand the full context of what the user is working on, then generate the text they most likely want to write in the active text area.",
},
{
"type": "image_url",
"image_url": {"url": screenshot_data_url},
},
]),
]
text_started = False
text_id = ""
try:
yield streaming.format_message_start()
text_id = streaming.generate_text_id()
yield streaming.format_text_start(text_id)
text_started = True
async for chunk in llm.astream(messages):
token = chunk.content if hasattr(chunk, "content") else str(chunk)
if token:
yield streaming.format_text_delta(text_id, token)
yield streaming.format_text_end(text_id)
yield streaming.format_finish()
yield streaming.format_done()
except Exception as e:
if text_started:
yield streaming.format_text_end(text_id)
if _is_vision_unsupported_error(e):
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
yield streaming.format_error(vision_error_msg)
else:
logger.error(f"Vision autocomplete streaming error: {e}", exc_info=True)
yield streaming.format_error("Autocomplete failed. Please try again.")
yield streaming.format_done()

View file

@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
@ -396,6 +397,12 @@ async def _index_full_scan(
},
)
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
page_limit_reached = False
renamed_count = 0
skipped = 0
files_to_download: list[dict] = []
@ -425,6 +432,21 @@ async def _index_full_scan(
elif skip_item(file):
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
if not page_limit_reached:
logger.warning(
"Page limit reached during Dropbox full scan, "
"skipping remaining files"
)
page_limit_reached = True
skipped += 1
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
batch_indexed, failed = await _download_and_index(
@ -438,6 +460,14 @@ async def _index_full_scan(
on_heartbeat=on_heartbeat_callback,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
indexed = renamed_count + batch_indexed
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
@ -458,6 +488,11 @@ async def _index_selected_files(
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index user-selected files using the parallel pipeline."""
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
files_to_download: list[dict] = []
errors: list[str] = []
renamed_count = 0
@ -482,6 +517,15 @@ async def _index_selected_files(
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
display = file_name or file_path
errors.append(f"File '{display}': page limit would be exceeded")
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
batch_indexed, _failed = await _download_and_index(
@ -495,6 +539,14 @@ async def _index_selected_files(
on_heartbeat=on_heartbeat,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
return renamed_count + batch_indexed, skipped, errors

View file

@ -34,6 +34,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
@ -327,6 +328,12 @@ async def _process_single_file(
return 1, 0, 0
return 0, 1, 0
page_limit_service = PageLimitService(session)
estimated_pages = PageLimitService.estimate_pages_from_metadata(
file_name, file.get("size")
)
await page_limit_service.check_page_limit(user_id, estimated_pages)
markdown, drive_metadata, error = await download_and_extract_content(
drive_client, file
)
@ -363,6 +370,9 @@ async def _process_single_file(
)
await pipeline.index(document, connector_doc, user_llm)
await page_limit_service.update_page_usage(
user_id, estimated_pages, allow_exceed=True
)
logger.info(f"Successfully indexed Google Drive file: {file_name}")
return 1, 0, 0
@ -466,6 +476,11 @@ async def _index_selected_files(
Returns (indexed_count, skipped_count, errors).
"""
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
files_to_download: list[dict] = []
errors: list[str] = []
renamed_count = 0
@ -486,6 +501,15 @@ async def _index_selected_files(
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
display = file_name or file_id
errors.append(f"File '{display}': page limit would be exceeded")
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
await _create_drive_placeholders(
@ -507,6 +531,14 @@ async def _index_selected_files(
on_heartbeat=on_heartbeat,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
return renamed_count + batch_indexed, skipped, errors
@ -545,6 +577,12 @@ async def _index_full_scan(
# ------------------------------------------------------------------
# Phase 1 (serial): collect files, run skip checks, track renames
# ------------------------------------------------------------------
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
page_limit_reached = False
renamed_count = 0
skipped = 0
files_processed = 0
@ -593,6 +631,20 @@ async def _index_full_scan(
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
if not page_limit_reached:
logger.warning(
"Page limit reached during Google Drive full scan, "
"skipping remaining files"
)
page_limit_reached = True
skipped += 1
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
page_token = next_token
@ -636,6 +688,14 @@ async def _index_full_scan(
on_heartbeat=on_heartbeat_callback,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
indexed = renamed_count + batch_indexed
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
@ -686,6 +746,12 @@ async def _index_with_delta_sync(
# ------------------------------------------------------------------
# Phase 1 (serial): handle removals, collect files for download
# ------------------------------------------------------------------
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
page_limit_reached = False
renamed_count = 0
skipped = 0
files_to_download: list[dict] = []
@ -715,6 +781,20 @@ async def _index_with_delta_sync(
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
if not page_limit_reached:
logger.warning(
"Page limit reached during Google Drive delta sync, "
"skipping remaining files"
)
page_limit_reached = True
skipped += 1
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
# ------------------------------------------------------------------
@ -742,6 +822,14 @@ async def _index_with_delta_sync(
on_heartbeat=on_heartbeat_callback,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
indexed = renamed_count + batch_indexed
logger.info(
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"

View file

@ -79,6 +79,7 @@ def _compute_final_pages(
actual = page_limit_service.estimate_pages_from_content_length(content_length)
return max(estimated_pages, actual)
DEFAULT_EXCLUDE_PATTERNS = [
".git",
"node_modules",

View file

@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm
from app.services.page_limit_service import PageLimitService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
@ -291,6 +292,11 @@ async def _index_selected_files(
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
"""Index user-selected files using the parallel pipeline."""
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
files_to_download: list[dict] = []
errors: list[str] = []
renamed_count = 0
@ -311,6 +317,15 @@ async def _index_selected_files(
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
display = file_name or file_id
errors.append(f"File '{display}': page limit would be exceeded")
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
batch_indexed, _failed = await _download_and_index(
@ -324,6 +339,14 @@ async def _index_selected_files(
on_heartbeat=on_heartbeat,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
return renamed_count + batch_indexed, skipped, errors
@ -358,6 +381,12 @@ async def _index_full_scan(
},
)
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
page_limit_reached = False
renamed_count = 0
skipped = 0
files_to_download: list[dict] = []
@ -383,6 +412,21 @@ async def _index_full_scan(
else:
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
if not page_limit_reached:
logger.warning(
"Page limit reached during OneDrive full scan, "
"skipping remaining files"
)
page_limit_reached = True
skipped += 1
continue
batch_estimated_pages += file_pages
files_to_download.append(file)
batch_indexed, failed = await _download_and_index(
@ -396,6 +440,14 @@ async def _index_full_scan(
on_heartbeat=on_heartbeat_callback,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
indexed = renamed_count + batch_indexed
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
@ -441,6 +493,12 @@ async def _index_with_delta_sync(
logger.info(f"Processing {len(changes)} delta changes")
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
remaining_quota = pages_limit - pages_used
batch_estimated_pages = 0
page_limit_reached = False
renamed_count = 0
skipped = 0
files_to_download: list[dict] = []
@ -471,6 +529,20 @@ async def _index_with_delta_sync(
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
change.get("name", ""), change.get("size")
)
if batch_estimated_pages + file_pages > remaining_quota:
if not page_limit_reached:
logger.warning(
"Page limit reached during OneDrive delta sync, "
"skipping remaining files"
)
page_limit_reached = True
skipped += 1
continue
batch_estimated_pages += file_pages
files_to_download.append(change)
batch_indexed, failed = await _download_and_index(
@ -484,6 +556,14 @@ async def _index_with_delta_sync(
on_heartbeat=on_heartbeat_callback,
)
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
pages_to_deduct = max(
1, batch_estimated_pages * batch_indexed // len(files_to_download)
)
await page_limit_service.update_page_usage(
user_id, pages_to_deduct, allow_exceed=True
)
indexed = renamed_count + batch_indexed
logger.info(
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"

View file

@ -11,6 +11,8 @@ import hmac
import json
import logging
import time
from random import SystemRandom
from string import ascii_letters, digits
from uuid import UUID
from cryptography.fernet import Fernet
@ -18,6 +20,25 @@ from fastapi import HTTPException
logger = logging.getLogger(__name__)
_PKCE_CHARS = ascii_letters + digits + "-._~"
_PKCE_RNG = SystemRandom()
def generate_code_verifier(length: int = 128) -> str:
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
def generate_pkce_pair(length: int = 128) -> tuple[str, str]:
"""Generate a PKCE code_verifier and its S256 code_challenge."""
verifier = generate_code_verifier(length)
challenge = (
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
.decode()
.rstrip("=")
)
return verifier, challenge
class OAuthStateManager:
"""Manages secure OAuth state parameters with HMAC signatures."""