mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Compare commits
43 commits
92d75ad622
...
ee043df942
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee043df942 | ||
|
|
d95bce1bf1 | ||
|
|
1e0cf21d4a | ||
|
|
4231efc076 | ||
|
|
3d2e25cdf3 | ||
|
|
74bf3df880 | ||
|
|
13e302676e | ||
|
|
6b5b45d08d | ||
|
|
a180bf5576 | ||
|
|
d7315e7f27 | ||
|
|
18103417bb | ||
|
|
46e8134b23 | ||
|
|
82d4d3e272 | ||
|
|
e814540727 | ||
|
|
8e6b1c77ea | ||
|
|
09008c8f1a | ||
|
|
a2b3541046 | ||
|
|
0d2acc665d | ||
|
|
ce40da80ea | ||
|
|
0cd997f673 | ||
|
|
960b8fc012 | ||
|
|
080acf5e0a | ||
|
|
c5aa869adb | ||
|
|
af5977691b | ||
|
|
aeb3f13f91 | ||
|
|
ced7f7562a | ||
|
|
339ff7fdf4 | ||
|
|
482238e5d4 | ||
|
|
bb84bb25a3 | ||
|
|
fb20b0444f | ||
|
|
8ba571566d | ||
|
|
fc84dcffb0 | ||
|
|
c2bd2bc935 | ||
|
|
a99d999a36 | ||
|
|
3e68d4aa3e | ||
|
|
9c1d9357c4 | ||
|
|
6899134a20 | ||
|
|
b2706b00a1 | ||
|
|
eaabad38fc | ||
|
|
ec2b7851b6 | ||
|
|
bcc227a4dd | ||
|
|
fbd033d0a4 | ||
|
|
f4d197f702 |
60 changed files with 2763 additions and 260 deletions
|
|
@ -0,0 +1,39 @@
|
|||
"""119_add_vision_llm_id_to_search_spaces
|
||||
|
||||
Revision ID: 119
|
||||
Revises: 118
|
||||
|
||||
Adds vision_llm_id column to search_spaces for vision/screenshot analysis
|
||||
LLM role assignment. Defaults to 0 (Auto mode), same convention as
|
||||
agent_llm_id and document_summary_llm_id.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "119"
|
||||
down_revision: str | None = "118"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(conn).get_columns("searchspaces")
|
||||
]
|
||||
|
||||
if "vision_llm_id" not in existing_columns:
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_id", sa.Integer(), nullable=True, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_llm_id")
|
||||
|
|
@ -1351,6 +1351,9 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi import APIRouter
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from .autocomplete_routes import router as autocomplete_router
|
||||
from .chat_comments_routes import router as chat_comments_router
|
||||
from .circleback_webhook_route import router as circleback_webhook_router
|
||||
from .clickup_add_connector_route import router as clickup_add_connector_router
|
||||
|
|
@ -95,3 +96,4 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
|
|||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -26,7 +24,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_pkce_pair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -75,28 +77,6 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
|||
return f"Basic {b64}"
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""
|
||||
Generate PKCE code verifier and code challenge.
|
||||
|
||||
Returns:
|
||||
Tuple of (code_verifier, code_challenge)
|
||||
"""
|
||||
# Generate code verifier (43-128 characters)
|
||||
code_verifier = (
|
||||
base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
|
||||
)
|
||||
|
||||
# Generate code challenge (SHA256 hash of verifier, base64url encoded)
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
@router.get("/auth/airtable/connector/add")
|
||||
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""
|
||||
|
|
|
|||
42
surfsense_backend/app/routes/autocomplete_routes.py
Normal file
42
surfsense_backend/app/routes/autocomplete_routes.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.services.vision_autocomplete_service import stream_vision_autocomplete
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
router = APIRouter(prefix="/autocomplete", tags=["autocomplete"])
|
||||
|
||||
MAX_SCREENSHOT_SIZE = 20 * 1024 * 1024 # 20 MB base64 ceiling
|
||||
|
||||
|
||||
class VisionAutocompleteRequest(BaseModel):
|
||||
screenshot: str = Field(..., max_length=MAX_SCREENSHOT_SIZE)
|
||||
search_space_id: int
|
||||
app_name: str = ""
|
||||
window_title: str = ""
|
||||
|
||||
|
||||
@router.post("/vision/stream")
|
||||
async def vision_autocomplete_stream(
|
||||
body: VisionAutocompleteRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_vision_autocomplete(
|
||||
body.screenshot, body.search_space_id, session,
|
||||
app_name=body.app_name, window_title=body.window_title,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
**VercelStreamingService.get_response_headers(),
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
|
@ -28,7 +28,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -96,9 +100,14 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -146,8 +155,11 @@ async def reauth_calendar(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -225,6 +237,7 @@ async def calendar_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_CALENDAR_REDIRECT_URI:
|
||||
|
|
@ -233,6 +246,7 @@ async def calendar_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
# Relax token scope validation for Google OAuth
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
|
@ -127,14 +131,19 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline", # Get refresh token
|
||||
prompt="consent", # Force consent screen to get refresh token
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
|
@ -193,8 +202,11 @@ async def reauth_drive(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -285,6 +297,7 @@ async def drive_callback(
|
|||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
logger.info(
|
||||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
|
|
@ -296,8 +309,9 @@ async def drive_callback(
|
|||
status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Exchange authorization code for tokens (restore PKCE code_verifier from state)
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -109,9 +113,14 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -164,8 +173,11 @@ async def reauth_gmail(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -256,6 +268,7 @@ async def gmail_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_GMAIL_REDIRECT_URI:
|
||||
|
|
@ -264,6 +277,7 @@ async def gmail_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -522,14 +522,17 @@ async def get_llm_preferences(
|
|||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm = await _get_llm_config_by_id(session, search_space.vision_llm_id)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_id=search_space.vision_llm_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -589,14 +592,17 @@ async def update_llm_preferences(
|
|||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm = await _get_llm_config_by_id(session, search_space.vision_llm_id)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_id=search_space.vision_llm_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -182,6 +182,9 @@ class LLMPreferencesRead(BaseModel):
|
|||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for vision/screenshot analysis"
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for agent LLM"
|
||||
)
|
||||
|
|
@ -191,6 +194,9 @@ class LLMPreferencesRead(BaseModel):
|
|||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
vision_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for vision LLM"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -207,3 +213,6 @@ class LLMPreferencesUpdate(BaseModel):
|
|||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for vision/screenshot analysis"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
class LLMRole:
|
||||
AGENT = "agent" # For agent/chat operations
|
||||
DOCUMENT_SUMMARY = "document_summary" # For document summarization
|
||||
VISION = "vision" # For vision/screenshot analysis
|
||||
|
||||
|
||||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||
|
|
@ -187,7 +188,7 @@ async def get_search_space_llm_instance(
|
|||
Args:
|
||||
session: Database session
|
||||
search_space_id: Search Space ID
|
||||
role: LLM role ('agent' or 'document_summary')
|
||||
role: LLM role ('agent', 'document_summary', or 'vision')
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None if not found
|
||||
|
|
@ -209,6 +210,8 @@ async def get_search_space_llm_instance(
|
|||
llm_config_id = search_space.agent_llm_id
|
||||
elif role == LLMRole.DOCUMENT_SUMMARY:
|
||||
llm_config_id = search_space.document_summary_llm_id
|
||||
elif role == LLMRole.VISION:
|
||||
llm_config_id = search_space.vision_llm_id
|
||||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
|
@ -405,6 +408,13 @@ async def get_document_summary_llm(
|
|||
)
|
||||
|
||||
|
||||
async def get_vision_llm(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's vision LLM instance for screenshot analysis."""
|
||||
return await get_search_space_llm_instance(session, search_space_id, LLMRole.VISION)
|
||||
|
||||
|
||||
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Service for managing user page limits for ETL services.
|
|||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -223,10 +223,155 @@ class PageLimitService:
|
|||
# Estimate ~2000 characters per page
|
||||
return max(1, content_length // 2000)
|
||||
|
||||
@staticmethod
|
||||
def estimate_pages_from_metadata(
|
||||
file_name_or_ext: str, file_size: int | str | None = None
|
||||
) -> int:
|
||||
"""Size-based page estimation from file name/extension and byte size.
|
||||
|
||||
Pure function — no file I/O, no database access. Used by cloud
|
||||
connectors (which only have API metadata) and as the internal
|
||||
fallback for :meth:`estimate_pages_before_processing`.
|
||||
|
||||
``file_name_or_ext`` can be a full filename (``"report.pdf"``) or
|
||||
a bare extension (``".pdf"``). ``file_size`` may be an int, a
|
||||
stringified int from a cloud API, or *None*.
|
||||
"""
|
||||
if file_size is not None:
|
||||
try:
|
||||
file_size = int(file_size)
|
||||
except (ValueError, TypeError):
|
||||
file_size = 0
|
||||
else:
|
||||
file_size = 0
|
||||
|
||||
if file_size <= 0:
|
||||
return 1
|
||||
|
||||
ext = PurePosixPath(file_name_or_ext).suffix.lower() if file_name_or_ext else ""
|
||||
if not ext and file_name_or_ext.startswith("."):
|
||||
ext = file_name_or_ext.lower()
|
||||
file_ext = ext
|
||||
|
||||
if file_ext == ".pdf":
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".doc",
|
||||
".docx",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm",
|
||||
".odt",
|
||||
".ott",
|
||||
".sxw",
|
||||
".stw",
|
||||
".uot",
|
||||
".rtf",
|
||||
".pages",
|
||||
".wpd",
|
||||
".wps",
|
||||
".abw",
|
||||
".zabw",
|
||||
".cwk",
|
||||
".hwp",
|
||||
".lwp",
|
||||
".mcw",
|
||||
".mw",
|
||||
".sdw",
|
||||
".vor",
|
||||
}:
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".ppt",
|
||||
".pptx",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx",
|
||||
".odp",
|
||||
".otp",
|
||||
".sxi",
|
||||
".sti",
|
||||
".uop",
|
||||
".key",
|
||||
".sda",
|
||||
".sdd",
|
||||
".sdp",
|
||||
}:
|
||||
return max(1, file_size // (200 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".xls",
|
||||
".xlsx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".xlr",
|
||||
".ods",
|
||||
".ots",
|
||||
".fods",
|
||||
".numbers",
|
||||
".123",
|
||||
".wk1",
|
||||
".wk2",
|
||||
".wk3",
|
||||
".wk4",
|
||||
".wks",
|
||||
".wb1",
|
||||
".wb2",
|
||||
".wb3",
|
||||
".wq1",
|
||||
".wq2",
|
||||
".csv",
|
||||
".tsv",
|
||||
".slk",
|
||||
".sylk",
|
||||
".dif",
|
||||
".dbf",
|
||||
".prn",
|
||||
".qpw",
|
||||
".602",
|
||||
".et",
|
||||
".eth",
|
||||
}:
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
if file_ext in {".epub"}:
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
if file_ext in {".txt", ".log", ".md", ".markdown", ".htm", ".html", ".xml"}:
|
||||
return max(1, file_size // 3000)
|
||||
|
||||
if file_ext in {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".webp",
|
||||
".svg",
|
||||
".cgm",
|
||||
".odg",
|
||||
".pbd",
|
||||
}:
|
||||
return 1
|
||||
|
||||
if file_ext in {".mp3", ".m4a", ".wav", ".mpga"}:
|
||||
return max(1, file_size // (1024 * 1024))
|
||||
|
||||
if file_ext in {".mp4", ".mpeg", ".webm"}:
|
||||
return max(1, file_size // (5 * 1024 * 1024))
|
||||
|
||||
return max(1, file_size // (80 * 1024))
|
||||
|
||||
def estimate_pages_before_processing(self, file_path: str) -> int:
|
||||
"""
|
||||
Estimate page count from file before processing (to avoid unnecessary API calls).
|
||||
This is called BEFORE sending to ETL services to prevent cost on rejected files.
|
||||
Estimate page count from a local file before processing.
|
||||
|
||||
For PDFs, attempts to read the actual page count via pypdf.
|
||||
For everything else, delegates to :meth:`estimate_pages_from_metadata`.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
|
@ -240,7 +385,6 @@ class PageLimitService:
|
|||
file_ext = Path(file_path).suffix.lower()
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
# PDF files - try to get actual page count
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import pypdf
|
||||
|
|
@ -249,153 +393,6 @@ class PageLimitService:
|
|||
pdf_reader = pypdf.PdfReader(f)
|
||||
return len(pdf_reader.pages)
|
||||
except Exception:
|
||||
# If PDF reading fails, fall back to size estimation
|
||||
# Typical PDF: ~100KB per page (conservative estimate)
|
||||
return max(1, file_size // (100 * 1024))
|
||||
pass # fall through to size-based estimation
|
||||
|
||||
# Word Processing Documents
|
||||
# Microsoft Word, LibreOffice Writer, WordPerfect, Pages, etc.
|
||||
elif file_ext in [
|
||||
".doc",
|
||||
".docx",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm", # Microsoft Word
|
||||
".odt",
|
||||
".ott",
|
||||
".sxw",
|
||||
".stw",
|
||||
".uot", # OpenDocument/StarOffice Writer
|
||||
".rtf", # Rich Text Format
|
||||
".pages", # Apple Pages
|
||||
".wpd",
|
||||
".wps", # WordPerfect, Microsoft Works
|
||||
".abw",
|
||||
".zabw", # AbiWord
|
||||
".cwk",
|
||||
".hwp",
|
||||
".lwp",
|
||||
".mcw",
|
||||
".mw",
|
||||
".sdw",
|
||||
".vor", # Other word processors
|
||||
]:
|
||||
# Typical word document: ~50KB per page (conservative)
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
# Presentation Documents
|
||||
# PowerPoint, Impress, Keynote, etc.
|
||||
elif file_ext in [
|
||||
".ppt",
|
||||
".pptx",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx", # Microsoft PowerPoint
|
||||
".odp",
|
||||
".otp",
|
||||
".sxi",
|
||||
".sti",
|
||||
".uop", # OpenDocument/StarOffice Impress
|
||||
".key", # Apple Keynote
|
||||
".sda",
|
||||
".sdd",
|
||||
".sdp", # StarOffice Draw/Impress
|
||||
]:
|
||||
# Typical presentation: ~200KB per slide (conservative)
|
||||
return max(1, file_size // (200 * 1024))
|
||||
|
||||
# Spreadsheet Documents
|
||||
# Excel, Calc, Numbers, Lotus, etc.
|
||||
elif file_ext in [
|
||||
".xls",
|
||||
".xlsx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".xlr", # Microsoft Excel
|
||||
".ods",
|
||||
".ots",
|
||||
".fods", # OpenDocument Spreadsheet
|
||||
".numbers", # Apple Numbers
|
||||
".123",
|
||||
".wk1",
|
||||
".wk2",
|
||||
".wk3",
|
||||
".wk4",
|
||||
".wks", # Lotus 1-2-3
|
||||
".wb1",
|
||||
".wb2",
|
||||
".wb3",
|
||||
".wq1",
|
||||
".wq2", # Quattro Pro
|
||||
".csv",
|
||||
".tsv",
|
||||
".slk",
|
||||
".sylk",
|
||||
".dif",
|
||||
".dbf",
|
||||
".prn",
|
||||
".qpw", # Data formats
|
||||
".602",
|
||||
".et",
|
||||
".eth", # Other spreadsheets
|
||||
]:
|
||||
# Spreadsheets typically have 1 sheet = 1 page for ETL
|
||||
# Conservative: ~100KB per sheet
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
# E-books
|
||||
elif file_ext in [".epub"]:
|
||||
# E-books vary widely, estimate by size
|
||||
# Typical e-book: ~50KB per page
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
# Plain Text and Markup Files
|
||||
elif file_ext in [
|
||||
".txt",
|
||||
".log", # Plain text
|
||||
".md",
|
||||
".markdown", # Markdown
|
||||
".htm",
|
||||
".html",
|
||||
".xml", # Markup
|
||||
]:
|
||||
# Plain text: ~3000 bytes per page
|
||||
return max(1, file_size // 3000)
|
||||
|
||||
# Image Files
|
||||
# Each image is typically processed as 1 page
|
||||
elif file_ext in [
|
||||
".jpg",
|
||||
".jpeg", # JPEG
|
||||
".png", # PNG
|
||||
".gif", # GIF
|
||||
".bmp", # Bitmap
|
||||
".tiff", # TIFF
|
||||
".webp", # WebP
|
||||
".svg", # SVG
|
||||
".cgm", # Computer Graphics Metafile
|
||||
".odg",
|
||||
".pbd", # OpenDocument Graphics
|
||||
]:
|
||||
# Each image = 1 page
|
||||
return 1
|
||||
|
||||
# Audio Files (transcription = typically 1 page per minute)
|
||||
# Note: These should be handled by audio transcription flow, not ETL
|
||||
elif file_ext in [".mp3", ".m4a", ".wav", ".mpga"]:
|
||||
# Audio files: estimate based on duration
|
||||
# Fallback: ~1MB per minute of audio, 1 page per minute transcript
|
||||
return max(1, file_size // (1024 * 1024))
|
||||
|
||||
# Video Files (typically not processed for pages, but just in case)
|
||||
elif file_ext in [".mp4", ".mpeg", ".webm"]:
|
||||
# Video files: very rough estimate
|
||||
# Typically wouldn't be page-based, but use conservative estimate
|
||||
return max(1, file_size // (5 * 1024 * 1024))
|
||||
|
||||
# Other/Unknown Document Types
|
||||
else:
|
||||
# Conservative estimate: ~80KB per page
|
||||
# This catches: .sgl, .sxg, .uof, .uos1, .uos2, .web, and any future formats
|
||||
return max(1, file_size // (80 * 1024))
|
||||
return self.estimate_pages_from_metadata(file_ext, file_size)
|
||||
|
|
|
|||
225
surfsense_backend/app/services/vision_autocomplete_service.py
Normal file
225
surfsense_backend/app/services/vision_autocomplete_service.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.services.llm_service import get_vision_llm
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 5
|
||||
KB_MAX_CHARS = 4000
|
||||
|
||||
EXTRACT_QUERY_PROMPT = """Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
|
||||
|
||||
EXTRACT_QUERY_PROMPT_WITH_APP = """The user is currently in the application "{app_name}" with the window titled "{window_title}".
|
||||
|
||||
Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
|
||||
|
||||
VISION_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||
|
||||
You will receive a screenshot of the user's screen. Your job:
|
||||
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||
2. Identify the text area where the user will type.
|
||||
3. Based on the full visual context, generate the text the user most likely wants to write.
|
||||
|
||||
Key behavior:
|
||||
- If the text area is EMPTY, draft a full response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||
- If the text area already has text, continue it naturally.
|
||||
|
||||
Rules:
|
||||
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
|
||||
- Be concise but complete — a full thought, not a fragment.
|
||||
- Match the tone and formality of the surrounding context.
|
||||
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||
- Do NOT describe the screenshot or explain your reasoning.
|
||||
- If you cannot determine what to write, output nothing."""
|
||||
|
||||
APP_CONTEXT_BLOCK = """
|
||||
|
||||
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
|
||||
|
||||
KB_CONTEXT_BLOCK = """
|
||||
|
||||
You also have access to the user's knowledge base documents below. Use them to write more accurate, informed, and contextually relevant text. Do NOT cite or reference the documents explicitly — just let the knowledge inform your writing naturally.
|
||||
|
||||
<knowledge_base>
|
||||
{kb_context}
|
||||
</knowledge_base>"""
|
||||
|
||||
|
||||
def _build_system_prompt(app_name: str, window_title: str, kb_context: str) -> str:
|
||||
"""Assemble the system prompt from optional context blocks."""
|
||||
prompt = VISION_SYSTEM_PROMPT
|
||||
if app_name:
|
||||
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
|
||||
if kb_context:
|
||||
prompt += KB_CONTEXT_BLOCK.format(kb_context=kb_context)
|
||||
return prompt
|
||||
|
||||
|
||||
def _is_vision_unsupported_error(e: Exception) -> bool:
|
||||
"""Check if an exception indicates the model doesn't support vision/images."""
|
||||
msg = str(e).lower()
|
||||
return "content must be a string" in msg or "does not support image" in msg
|
||||
|
||||
|
||||
async def _extract_query_from_screenshot(
|
||||
llm, screenshot_data_url: str,
|
||||
app_name: str = "", window_title: str = "",
|
||||
) -> str | None:
|
||||
"""Ask the Vision LLM to describe what the user is working on.
|
||||
|
||||
Raises vision-unsupported errors so the caller can return a
|
||||
friendly message immediately instead of retrying with astream.
|
||||
"""
|
||||
if app_name:
|
||||
prompt_text = EXTRACT_QUERY_PROMPT_WITH_APP.format(
|
||||
app_name=app_name, window_title=window_title,
|
||||
)
|
||||
else:
|
||||
prompt_text = EXTRACT_QUERY_PROMPT
|
||||
|
||||
try:
|
||||
response = await llm.ainvoke([
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": prompt_text},
|
||||
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
|
||||
]),
|
||||
])
|
||||
query = response.content.strip() if hasattr(response, "content") else ""
|
||||
return query if query else None
|
||||
except Exception as e:
|
||||
if _is_vision_unsupported_error(e):
|
||||
raise
|
||||
logger.warning(f"Failed to extract query from screenshot: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _search_knowledge_base(
|
||||
session: AsyncSession, search_space_id: int, query: str
|
||||
) -> str:
|
||||
"""Search the KB and return formatted context string."""
|
||||
try:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text=query,
|
||||
top_k=KB_TOP_K,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
char_count = 0
|
||||
for doc in results:
|
||||
title = doc.get("document", {}).get("title", "Untitled")
|
||||
for chunk in doc.get("chunks", []):
|
||||
content = chunk.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
entry = f"[{title}]\n{content}"
|
||||
if char_count + len(entry) > KB_MAX_CHARS:
|
||||
break
|
||||
parts.append(entry)
|
||||
char_count += len(entry)
|
||||
if char_count >= KB_MAX_CHARS:
|
||||
break
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"KB search failed, proceeding without context: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def stream_vision_autocomplete(
|
||||
screenshot_data_url: str,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Analyze a screenshot with the vision LLM and stream a text completion.
|
||||
|
||||
Pipeline:
|
||||
1. Extract a search query from the screenshot (non-streaming)
|
||||
2. Search the knowledge base for relevant context
|
||||
3. Stream the final completion with screenshot + KB + app context
|
||||
"""
|
||||
streaming = VercelStreamingService()
|
||||
vision_error_msg = (
|
||||
"The selected model does not support vision. "
|
||||
"Please set a vision-capable model (e.g. GPT-4o, Gemini) in your search space settings."
|
||||
)
|
||||
|
||||
llm = await get_vision_llm(session, search_space_id)
|
||||
if not llm:
|
||||
yield streaming.format_message_start()
|
||||
yield streaming.format_error("No Vision LLM configured for this search space")
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
kb_context = ""
|
||||
try:
|
||||
query = await _extract_query_from_screenshot(
|
||||
llm, screenshot_data_url, app_name=app_name, window_title=window_title,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
|
||||
yield streaming.format_message_start()
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
if query:
|
||||
kb_context = await _search_knowledge_base(session, search_space_id, query)
|
||||
|
||||
system_prompt = _build_system_prompt(app_name, window_title, kb_context)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Analyze this screenshot. Understand the full context of what the user is working on, then generate the text they most likely want to write in the active text area.",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": screenshot_data_url},
|
||||
},
|
||||
]),
|
||||
]
|
||||
|
||||
text_started = False
|
||||
text_id = ""
|
||||
try:
|
||||
yield streaming.format_message_start()
|
||||
text_id = streaming.generate_text_id()
|
||||
yield streaming.format_text_start(text_id)
|
||||
text_started = True
|
||||
|
||||
async for chunk in llm.astream(messages):
|
||||
token = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||
if token:
|
||||
yield streaming.format_text_delta(text_id, token)
|
||||
|
||||
yield streaming.format_text_end(text_id)
|
||||
yield streaming.format_finish()
|
||||
yield streaming.format_done()
|
||||
|
||||
except Exception as e:
|
||||
if text_started:
|
||||
yield streaming.format_text_end(text_id)
|
||||
|
||||
if _is_vision_unsupported_error(e):
|
||||
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
else:
|
||||
logger.error(f"Vision autocomplete streaming error: {e}", exc_info=True)
|
||||
yield streaming.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming.format_done()
|
||||
|
|
@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -278,6 +279,12 @@ async def _index_full_scan(
|
|||
},
|
||||
)
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -307,6 +314,21 @@ async def _index_full_scan(
|
|||
elif skip_item(file):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Dropbox full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -320,6 +342,14 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
@ -340,6 +370,11 @@ async def _index_selected_files(
|
|||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
|
|
@ -364,6 +399,15 @@ async def _index_selected_files(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_path
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
|
|
@ -377,6 +421,14 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
|||
PlaceholderInfo,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -327,6 +328,12 @@ async def _process_single_file(
|
|||
return 1, 0, 0
|
||||
return 0, 1, 0
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
estimated_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file_name, file.get("size")
|
||||
)
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
markdown, drive_metadata, error = await download_and_extract_content(
|
||||
drive_client, file
|
||||
)
|
||||
|
|
@ -363,6 +370,9 @@ async def _process_single_file(
|
|||
)
|
||||
await pipeline.index(document, connector_doc, user_llm)
|
||||
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
logger.info(f"Successfully indexed Google Drive file: {file_name}")
|
||||
return 1, 0, 0
|
||||
|
||||
|
|
@ -466,6 +476,11 @@ async def _index_selected_files(
|
|||
|
||||
Returns (indexed_count, skipped_count, errors).
|
||||
"""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
|
|
@ -486,6 +501,15 @@ async def _index_selected_files(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
await _create_drive_placeholders(
|
||||
|
|
@ -507,6 +531,14 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
|
|
@ -545,6 +577,12 @@ async def _index_full_scan(
|
|||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): collect files, run skip checks, track renames
|
||||
# ------------------------------------------------------------------
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_processed = 0
|
||||
|
|
@ -593,6 +631,20 @@ async def _index_full_scan(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Google Drive full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
page_token = next_token
|
||||
|
|
@ -636,6 +688,14 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
@ -686,6 +746,12 @@ async def _index_with_delta_sync(
|
|||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): handle removals, collect files for download
|
||||
# ------------------------------------------------------------------
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -715,6 +781,20 @@ async def _index_with_delta_sync(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Google Drive delta sync, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -742,6 +822,14 @@ async def _index_with_delta_sync(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
|
|||
|
|
@ -205,6 +205,7 @@ def _compute_final_pages(
|
|||
actual = page_limit_service.estimate_pages_from_content_length(content_length)
|
||||
return max(estimated_pages, actual)
|
||||
|
||||
|
||||
DEFAULT_EXCLUDE_PATTERNS = [
|
||||
".git",
|
||||
"node_modules",
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -291,6 +292,11 @@ async def _index_selected_files(
|
|||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
|
|
@ -311,6 +317,15 @@ async def _index_selected_files(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
|
|
@ -324,6 +339,14 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
|
|
@ -358,6 +381,12 @@ async def _index_full_scan(
|
|||
},
|
||||
)
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -383,6 +412,21 @@ async def _index_full_scan(
|
|||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during OneDrive full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -396,6 +440,14 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
@ -441,6 +493,12 @@ async def _index_with_delta_sync(
|
|||
|
||||
logger.info(f"Processing {len(changes)} delta changes")
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -471,6 +529,20 @@ async def _index_with_delta_sync(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
change.get("name", ""), change.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during OneDrive delta sync, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(change)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -484,6 +556,14 @@ async def _index_with_delta_sync(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import hmac
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from random import SystemRandom
|
||||
from string import ascii_letters, digits
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
|
@ -18,6 +20,25 @@ from fastapi import HTTPException
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PKCE_CHARS = ascii_letters + digits + "-._~"
|
||||
_PKCE_RNG = SystemRandom()
|
||||
|
||||
|
||||
def generate_code_verifier(length: int = 128) -> str:
|
||||
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
|
||||
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
|
||||
|
||||
|
||||
def generate_pkce_pair(length: int = 128) -> tuple[str, str]:
|
||||
"""Generate a PKCE code_verifier and its S256 code_challenge."""
|
||||
verifier = generate_code_verifier(length)
|
||||
challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages secure OAuth state parameters with HMAC signatures."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
Prerequisites: PostgreSQL + pgvector only.
|
||||
|
||||
External system boundaries are mocked:
|
||||
- ETL parsing — LlamaParse (external API) and Docling (heavy library)
|
||||
- LLM summarization, text embedding, text chunking (external APIs)
|
||||
- Redis heartbeat (external infrastructure)
|
||||
- Task dispatch is swapped via DI (InlineTaskDispatcher)
|
||||
|
|
@ -11,6 +12,7 @@ External system boundaries are mocked:
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
|
@ -298,3 +300,67 @@ def _mock_redis_heartbeat(monkeypatch):
|
|||
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
|
||||
_MOCK_ETL_MARKDOWN = "# Mocked Document\n\nThis is mocked ETL content."
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_etl_parsing(monkeypatch):
|
||||
"""Mock ETL parsing services — LlamaParse and Docling are external boundaries.
|
||||
|
||||
Preserves the real contract: empty/corrupt files raise an error just like
|
||||
the actual services would, so tests covering failure paths keep working.
|
||||
"""
|
||||
|
||||
def _reject_empty(file_path: str) -> None:
|
||||
if os.path.getsize(file_path) == 0:
|
||||
raise RuntimeError(f"Cannot parse empty file: {file_path}")
|
||||
|
||||
# -- LlamaParse mock (external API) --------------------------------
|
||||
|
||||
class _FakeMarkdownDoc:
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
class _FakeLlamaParseResult:
|
||||
async def aget_markdown_documents(self, *, split_by_page=False):
|
||||
return [_FakeMarkdownDoc(_MOCK_ETL_MARKDOWN)]
|
||||
|
||||
async def _fake_llamacloud_parse(**kwargs):
|
||||
_reject_empty(kwargs["file_path"])
|
||||
return _FakeLlamaParseResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.document_processors.file_processors.parse_with_llamacloud_retry",
|
||||
_fake_llamacloud_parse,
|
||||
)
|
||||
|
||||
# -- Docling mock (heavy library boundary) -------------------------
|
||||
|
||||
async def _fake_docling_parse(file_path: str, filename: str):
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.document_processors.file_processors.parse_with_docling",
|
||||
_fake_docling_parse,
|
||||
)
|
||||
|
||||
class _FakeDoclingResult:
|
||||
class Document:
|
||||
@staticmethod
|
||||
def export_to_markdown():
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
document = Document()
|
||||
|
||||
class _FakeDocumentConverter:
|
||||
def convert(self, file_path):
|
||||
_reject_empty(file_path)
|
||||
return _FakeDoclingResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"docling.document_converter.DocumentConverter",
|
||||
_FakeDocumentConverter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1015,7 +1015,7 @@ class TestPageLimits:
|
|||
|
||||
(tmp_path / "note.md").write_text("# Hello World\n\nContent here.")
|
||||
|
||||
count, _skipped, _root_folder_id, err = await index_local_folder(
|
||||
count, _skipped, _root_folder_id, _err = await index_local_folder(
|
||||
session=db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
|
|
|
|||
|
|
@ -248,12 +248,33 @@ def _folder_dict(file_id: str, name: str) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def _make_page_limit_session(pages_used=0, pages_limit=999_999):
|
||||
"""Build a mock DB session that real PageLimitService can operate against."""
|
||||
|
||||
class _FakeUser:
|
||||
def __init__(self, pu, pl):
|
||||
self.pages_used = pu
|
||||
self.pages_limit = pl
|
||||
|
||||
fake_user = _FakeUser(pages_used, pages_limit)
|
||||
session = AsyncMock()
|
||||
|
||||
def _make_result(*_a, **_kw):
|
||||
r = MagicMock()
|
||||
r.first.return_value = (fake_user.pages_used, fake_user.pages_limit)
|
||||
r.unique.return_value.scalar_one_or_none.return_value = fake_user
|
||||
return r
|
||||
|
||||
session.execute = AsyncMock(side_effect=_make_result)
|
||||
return session, fake_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up all mocks needed to call _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
mock_connector = MagicMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
|
@ -472,7 +493,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
|
|
@ -512,7 +533,7 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
|
|||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,680 @@
|
|||
"""Tests for page limit enforcement in connector indexers.
|
||||
|
||||
Covers:
|
||||
A) PageLimitService.estimate_pages_from_metadata — pure function (no mocks)
|
||||
B) Page-limit quota gating in _index_selected_files tested through the
|
||||
real PageLimitService with a mock DB session (system boundary).
|
||||
Google Drive is the primary, with OneDrive/Dropbox smoke tests.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# A) PageLimitService.estimate_pages_from_metadata — pure function
|
||||
# No mocks: it's a staticmethod with no I/O.
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestEstimatePagesFromMetadata:
|
||||
"""Vertical slices for the page estimation staticmethod."""
|
||||
|
||||
def test_pdf_100kb_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1
|
||||
|
||||
def test_pdf_500kb_returns_5(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5
|
||||
|
||||
def test_pdf_1mb(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10
|
||||
|
||||
def test_docx_50kb_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1
|
||||
|
||||
def test_docx_200kb(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4
|
||||
|
||||
def test_pptx_uses_200kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3
|
||||
|
||||
def test_xlsx_uses_100kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3
|
||||
|
||||
def test_txt_uses_3000_bytes_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".txt", 9000) == 3
|
||||
|
||||
def test_image_always_returns_1(self):
|
||||
for ext in (".jpg", ".png", ".gif", ".webp"):
|
||||
assert PageLimitService.estimate_pages_from_metadata(ext, 5_000_000) == 1
|
||||
|
||||
def test_audio_uses_1mb_per_page(self):
|
||||
assert (
|
||||
PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3
|
||||
)
|
||||
|
||||
def test_video_uses_5mb_per_page(self):
|
||||
assert (
|
||||
PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3
|
||||
)
|
||||
|
||||
def test_unknown_ext_uses_80kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2
|
||||
|
||||
def test_zero_size_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 0) == 1
|
||||
|
||||
def test_negative_size_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", -500) == 1
|
||||
|
||||
def test_minimum_is_always_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 50) == 1
|
||||
|
||||
def test_epub_uses_50kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# B) Page-limit enforcement in connector indexers
|
||||
# System boundary mocked: DB session (for PageLimitService)
|
||||
# System boundary mocked: external API clients, download/ETL
|
||||
# NOT mocked: PageLimitService itself (our own code)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Stands in for the User ORM model at the DB boundary."""
|
||||
|
||||
def __init__(self, pages_used: int = 0, pages_limit: int = 100):
|
||||
self.pages_used = pages_used
|
||||
self.pages_limit = pages_limit
|
||||
|
||||
|
||||
def _make_page_limit_session(pages_used: int = 0, pages_limit: int = 100):
|
||||
"""Build a mock DB session that real PageLimitService can operate against.
|
||||
|
||||
Every ``session.execute()`` returns a result compatible with both
|
||||
``get_page_usage`` (.first() → tuple) and ``update_page_usage``
|
||||
(.unique().scalar_one_or_none() → User-like).
|
||||
"""
|
||||
fake_user = _FakeUser(pages_used, pages_limit)
|
||||
session = AsyncMock()
|
||||
|
||||
def _make_result(*_args, **_kwargs):
|
||||
result = MagicMock()
|
||||
result.first.return_value = (fake_user.pages_used, fake_user.pages_limit)
|
||||
result.unique.return_value.scalar_one_or_none.return_value = fake_user
|
||||
return result
|
||||
|
||||
session.execute = AsyncMock(side_effect=_make_result)
|
||||
return session, fake_user
|
||||
|
||||
|
||||
def _make_gdrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict:
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
"mimeType": "application/octet-stream",
|
||||
"size": str(size),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Drive: _index_selected_files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gdrive_selected_mocks(monkeypatch):
|
||||
"""Mocks for Google Drive _index_selected_files — only system boundaries."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_id):
|
||||
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
|
||||
return {
|
||||
"mod": _mod,
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"get_file_results": get_file_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_gdrive_selected(mocks, file_ids):
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
_index_selected_files,
|
||||
)
|
||||
|
||||
return await _index_selected_files(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
file_ids,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks):
|
||||
"""Files whose cumulative estimated pages fit within remaining quota
|
||||
are sent to _download_and_index."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2", "f3"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (3, 0)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
|
||||
)
|
||||
|
||||
assert indexed == 3
|
||||
assert errors == []
|
||||
call_files = m["download_and_index_mock"].call_args[0][2]
|
||||
assert len(call_files) == 3
|
||||
|
||||
|
||||
async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks):
|
||||
"""Files whose pages would exceed remaining quota are rejected."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 98
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["big"] = (
|
||||
_make_gdrive_file("big", "huge.pdf", size=500 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(m, [("big", "huge.pdf")])
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
assert "page limit" in errors[0].lower()
|
||||
|
||||
|
||||
async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks):
|
||||
"""3rd file pushes over quota → only first two indexed."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 2
|
||||
|
||||
for fid in ("f1", "f2", "f3"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert len(errors) == 1
|
||||
call_files = m["download_and_index_mock"].call_args[0][2]
|
||||
assert {f["id"] for f in call_files} == {"f1", "f2"}
|
||||
|
||||
|
||||
async def test_gdrive_proportional_page_deduction(gdrive_selected_mocks):
|
||||
"""Pages deducted are proportional to successfully indexed files."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2", "f3", "f4"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 2)
|
||||
|
||||
await _run_gdrive_selected(
|
||||
m,
|
||||
[("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz"), ("f4", "f4.xyz")],
|
||||
)
|
||||
|
||||
assert m["fake_user"].pages_used == 2
|
||||
|
||||
|
||||
async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks):
|
||||
"""If batch_indexed == 0, user's pages_used stays unchanged."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 5
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["f1"] = (
|
||||
_make_gdrive_file("f1", "f1.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (0, 1)
|
||||
|
||||
await _run_gdrive_selected(m, [("f1", "f1.xyz")])
|
||||
|
||||
assert m["fake_user"].pages_used == 5
|
||||
|
||||
|
||||
async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks):
|
||||
"""When pages_used == pages_limit, every file is rejected."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 100
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]
|
||||
)
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Drive: _index_full_scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gdrive_full_scan_mocks(monkeypatch):
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=([], 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 0, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock())
|
||||
)
|
||||
|
||||
return {
|
||||
"mod": _mod,
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"task_logger": mock_task_logger,
|
||||
"download_mock": download_mock,
|
||||
"batch_mock": batch_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_gdrive_full_scan(mocks, max_files=500):
|
||||
from app.tasks.connector_indexers.google_drive_indexer import _index_full_scan
|
||||
|
||||
return await _index_full_scan(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
MagicMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"My Folder",
|
||||
mocks["task_logger"],
|
||||
MagicMock(),
|
||||
max_files,
|
||||
include_subfolders=False,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeypatch):
|
||||
m = gdrive_full_scan_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 2
|
||||
|
||||
page_files = [
|
||||
_make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(5)
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
m["mod"],
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
m["download_mock"].return_value = ([], 0)
|
||||
m["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
_indexed, skipped = await _run_gdrive_full_scan(m)
|
||||
|
||||
call_files = m["download_mock"].call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
assert skipped == 3
|
||||
|
||||
|
||||
async def test_gdrive_full_scan_deducts_after_indexing(
|
||||
gdrive_full_scan_mocks, monkeypatch
|
||||
):
|
||||
m = gdrive_full_scan_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
page_files = [
|
||||
_make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(3)
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
m["mod"],
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
mock_docs = [MagicMock() for _ in range(3)]
|
||||
m["download_mock"].return_value = (mock_docs, 0)
|
||||
m["batch_mock"].return_value = ([], 3, 0)
|
||||
|
||||
await _run_gdrive_full_scan(m)
|
||||
|
||||
assert m["fake_user"].pages_used == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Drive: _index_with_delta_sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_gdrive_delta_sync_skips_over_quota(monkeypatch):
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
session, _ = _make_page_limit_session(0, 2)
|
||||
|
||||
changes = [
|
||||
{
|
||||
"fileId": f"mod{i}",
|
||||
"file": _make_gdrive_file(f"mod{i}", f"mod{i}.xyz", size=80 * 1024),
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"fetch_all_changes",
|
||||
AsyncMock(return_value=(changes, "new-token", None)),
|
||||
)
|
||||
monkeypatch.setattr(_mod, "categorize_change", lambda change: "modified")
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=([], 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 2, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock())
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
_indexed, skipped = await _mod._index_with_delta_sync(
|
||||
MagicMock(),
|
||||
session,
|
||||
MagicMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"start-token",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
call_files = download_mock.call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
assert skipped == 3
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# C) OneDrive smoke tests — verify page limit wiring
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def _make_onedrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict:
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
"file": {"mimeType": "application/octet-stream"},
|
||||
"size": str(size),
|
||||
"lastModifiedDateTime": "2026-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def onedrive_selected_mocks(monkeypatch):
|
||||
import app.tasks.connector_indexers.onedrive_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_id):
|
||||
return get_file_results.get(file_id, (None, f"Not found: {file_id}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"get_file_results": get_file_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_onedrive_selected(mocks, file_ids):
|
||||
from app.tasks.connector_indexers.onedrive_indexer import _index_selected_files
|
||||
|
||||
return await _index_selected_files(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
file_ids,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_onedrive_over_quota_rejected(onedrive_selected_mocks):
|
||||
"""OneDrive: files exceeding quota produce errors, not downloads."""
|
||||
m = onedrive_selected_mocks
|
||||
m["fake_user"].pages_used = 99
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["big"] = (
|
||||
_make_onedrive_file("big", "huge.pdf", size=500 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_onedrive_selected(m, [("big", "huge.pdf")])
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
assert "page limit" in errors[0].lower()
|
||||
|
||||
|
||||
async def test_onedrive_deducts_after_success(onedrive_selected_mocks):
|
||||
"""OneDrive: pages_used increases after successful indexing."""
|
||||
m = onedrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_onedrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
await _run_onedrive_selected(m, [("f1", "f1.xyz"), ("f2", "f2.xyz")])
|
||||
|
||||
assert m["fake_user"].pages_used == 2
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# D) Dropbox smoke tests — verify page limit wiring
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def _make_dropbox_file(file_path: str, name: str, size: int = 80 * 1024) -> dict:
|
||||
return {
|
||||
"id": f"id:{file_path}",
|
||||
"name": name,
|
||||
".tag": "file",
|
||||
"path_lower": file_path,
|
||||
"size": str(size),
|
||||
"server_modified": "2026-01-01T00:00:00Z",
|
||||
"content_hash": f"hash_{name}",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dropbox_selected_mocks(monkeypatch):
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_path):
|
||||
return get_file_results.get(file_path, (None, f"Not found: {file_path}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_path", _fake_get_file)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"get_file_results": get_file_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_dropbox_selected(mocks, file_paths):
|
||||
from app.tasks.connector_indexers.dropbox_indexer import _index_selected_files
|
||||
|
||||
return await _index_selected_files(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
file_paths,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_dropbox_over_quota_rejected(dropbox_selected_mocks):
|
||||
"""Dropbox: files exceeding quota produce errors, not downloads."""
|
||||
m = dropbox_selected_mocks
|
||||
m["fake_user"].pages_used = 99
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["/huge.pdf"] = (
|
||||
_make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_dropbox_selected(
|
||||
m, [("/huge.pdf", "huge.pdf")]
|
||||
)
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
assert "page limit" in errors[0].lower()
|
||||
|
||||
|
||||
async def test_dropbox_deducts_after_success(dropbox_selected_mocks):
|
||||
"""Dropbox: pages_used increases after successful indexing."""
|
||||
m = dropbox_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for name in ("f1.xyz", "f2.xyz"):
|
||||
path = f"/{name}"
|
||||
m["get_file_results"][path] = (
|
||||
_make_dropbox_file(path, name, size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
await _run_dropbox_selected(m, [("/f1.xyz", "f1.xyz"), ("/f2.xyz", "f2.xyz")])
|
||||
|
||||
assert m["fake_user"].pages_used == 2
|
||||
1
surfsense_desktop/.npmrc
Normal file
1
surfsense_desktop/.npmrc
Normal file
|
|
@ -0,0 +1 @@
|
|||
node-linker=hoisted
|
||||
|
|
@ -9,6 +9,12 @@ directories:
|
|||
files:
|
||||
- dist/**/*
|
||||
- "!node_modules"
|
||||
- node_modules/node-gyp-build/**/*
|
||||
- node_modules/bindings/**/*
|
||||
- node_modules/file-uri-to-path/**/*
|
||||
- node_modules/node-mac-permissions/**/*
|
||||
- "!node_modules/node-mac-permissions/src"
|
||||
- "!node_modules/node-mac-permissions/binding.gyp"
|
||||
- "!src"
|
||||
- "!scripts"
|
||||
- "!release"
|
||||
|
|
@ -29,12 +35,20 @@ extraResources:
|
|||
filter: ["**/*"]
|
||||
asarUnpack:
|
||||
- "**/*.node"
|
||||
- "node_modules/node-gyp-build/**/*"
|
||||
- "node_modules/bindings/**/*"
|
||||
- "node_modules/file-uri-to-path/**/*"
|
||||
- "node_modules/node-mac-permissions/**/*"
|
||||
mac:
|
||||
icon: assets/icon.icns
|
||||
category: public.app-category.productivity
|
||||
artifactName: "${productName}-${version}-${arch}.${ext}"
|
||||
hardenedRuntime: true
|
||||
hardenedRuntime: false
|
||||
gatekeeperAssess: false
|
||||
extendInfo:
|
||||
NSAccessibilityUsageDescription: "SurfSense uses accessibility features to insert suggestions into the active application."
|
||||
NSScreenCaptureUsageDescription: "SurfSense uses screen capture to analyze your screen and provide context-aware writing suggestions."
|
||||
NSAppleEventsUsageDescription: "SurfSense uses Apple Events to interact with the active application."
|
||||
target:
|
||||
- target: dmg
|
||||
arch: [x64, arm64]
|
||||
|
|
|
|||
|
|
@ -11,12 +11,14 @@
|
|||
"dist:mac": "pnpm build && electron-builder --mac --config electron-builder.yml",
|
||||
"dist:win": "pnpm build && electron-builder --win --config electron-builder.yml",
|
||||
"dist:linux": "pnpm build && electron-builder --linux --config electron-builder.yml",
|
||||
"typecheck": "tsc --noEmit"
|
||||
"typecheck": "tsc --noEmit",
|
||||
"postinstall": "electron-rebuild"
|
||||
},
|
||||
"author": "MODSetter",
|
||||
"license": "MIT",
|
||||
"packageManager": "pnpm@10.24.0",
|
||||
"devDependencies": {
|
||||
"@electron/rebuild": "^4.0.3",
|
||||
"@types/node": "^25.5.0",
|
||||
"concurrently": "^9.2.1",
|
||||
"dotenv": "^17.3.1",
|
||||
|
|
@ -27,9 +29,11 @@
|
|||
"wait-on": "^9.0.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"bindings": "^1.5.0",
|
||||
"chokidar": "^5.0.0",
|
||||
"electron-store": "^11.0.2",
|
||||
"electron-updater": "^6.8.3",
|
||||
"get-port-please": "^3.2.0"
|
||||
"get-port-please": "^3.2.0",
|
||||
"node-mac-permissions": "^2.5.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
36
surfsense_desktop/pnpm-lock.yaml
generated
36
surfsense_desktop/pnpm-lock.yaml
generated
|
|
@ -8,6 +8,9 @@ importers:
|
|||
|
||||
.:
|
||||
dependencies:
|
||||
bindings:
|
||||
specifier: ^1.5.0
|
||||
version: 1.5.0
|
||||
chokidar:
|
||||
specifier: ^5.0.0
|
||||
version: 5.0.0
|
||||
|
|
@ -20,7 +23,13 @@ importers:
|
|||
get-port-please:
|
||||
specifier: ^3.2.0
|
||||
version: 3.2.0
|
||||
node-mac-permissions:
|
||||
specifier: ^2.5.0
|
||||
version: 2.5.0
|
||||
devDependencies:
|
||||
'@electron/rebuild':
|
||||
specifier: ^4.0.3
|
||||
version: 4.0.3
|
||||
'@types/node':
|
||||
specifier: ^25.5.0
|
||||
version: 25.5.0
|
||||
|
|
@ -349,6 +358,7 @@ packages:
|
|||
'@xmldom/xmldom@0.8.11':
|
||||
resolution: {integrity: sha512-cQzWCtO6C8TQiYl1ruKNn2U6Ao4o4WBBcbL61yJl84x+j5sOWWFU9X7DpND8XZG3daDppSsigMdfAIl2upQBRw==}
|
||||
engines: {node: '>=10.0.0'}
|
||||
deprecated: this version has critical issues, please update to the latest version
|
||||
|
||||
abbrev@3.0.1:
|
||||
resolution: {integrity: sha512-AO2ac6pjRB3SJmGJo+v5/aK6Omggp6fsLrs6wN9bd35ulu4cCwaAU9+7ZhXjeqHVkaHThLuzH0nZr0YpCDhygg==}
|
||||
|
|
@ -444,6 +454,9 @@ packages:
|
|||
base64-js@1.5.1:
|
||||
resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==}
|
||||
|
||||
bindings@1.5.0:
|
||||
resolution: {integrity: sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ==}
|
||||
|
||||
bl@4.1.0:
|
||||
resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==}
|
||||
|
||||
|
|
@ -785,6 +798,9 @@ packages:
|
|||
picomatch:
|
||||
optional: true
|
||||
|
||||
file-uri-to-path@1.0.0:
|
||||
resolution: {integrity: sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw==}
|
||||
|
||||
filelist@1.0.6:
|
||||
resolution: {integrity: sha512-5giy2PkLYY1cP39p17Ech+2xlpTRL9HLspOfEgm0L6CwBXBTgsK5ou0JtzYuepxkaQ/tvhCFIJ5uXo0OrM2DxA==}
|
||||
|
||||
|
|
@ -1163,6 +1179,9 @@ packages:
|
|||
node-addon-api@1.7.2:
|
||||
resolution: {integrity: sha512-ibPK3iA+vaY1eEjESkQkM0BbCqFOaZMiXRTtdB0u7b4djtY6JnsjvPdUHVMg6xQt3B8fpTTWHI9A+ADjM9frzg==}
|
||||
|
||||
node-addon-api@7.1.1:
|
||||
resolution: {integrity: sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==}
|
||||
|
||||
node-api-version@0.2.1:
|
||||
resolution: {integrity: sha512-2xP/IGGMmmSQpI1+O/k72jF/ykvZ89JeuKX3TLJAYPDVLUalrshrLHkeVcCCZqG/eEa635cr8IBYzgnDvM2O8Q==}
|
||||
|
||||
|
|
@ -1171,6 +1190,10 @@ packages:
|
|||
engines: {node: ^18.17.0 || >=20.5.0}
|
||||
hasBin: true
|
||||
|
||||
node-mac-permissions@2.5.0:
|
||||
resolution: {integrity: sha512-zR8SVCaN3WqV1xwWd04XVAdzm3UTdjbxciLrZtB0Cc7F2Kd34AJfhPD4hm1HU0YH3oGUZO4X9OBLY5ijSTHsGw==}
|
||||
os: [darwin]
|
||||
|
||||
nopt@8.1.0:
|
||||
resolution: {integrity: sha512-ieGu42u/Qsa4TFktmaKEwM6MQH0pOWnaB3htzh0JRtx84+Mebc0cbZYN5bC+6WTZ4+77xrL9Pn5m7CV6VIkV7A==}
|
||||
engines: {node: ^18.17.0 || >=20.5.0}
|
||||
|
|
@ -2028,6 +2051,10 @@ snapshots:
|
|||
|
||||
base64-js@1.5.1: {}
|
||||
|
||||
bindings@1.5.0:
|
||||
dependencies:
|
||||
file-uri-to-path: 1.0.0
|
||||
|
||||
bl@4.1.0:
|
||||
dependencies:
|
||||
buffer: 5.7.1
|
||||
|
|
@ -2486,6 +2513,8 @@ snapshots:
|
|||
optionalDependencies:
|
||||
picomatch: 4.0.3
|
||||
|
||||
file-uri-to-path@1.0.0: {}
|
||||
|
||||
filelist@1.0.6:
|
||||
dependencies:
|
||||
minimatch: 5.1.9
|
||||
|
|
@ -2885,6 +2914,8 @@ snapshots:
|
|||
node-addon-api@1.7.2:
|
||||
optional: true
|
||||
|
||||
node-addon-api@7.1.1: {}
|
||||
|
||||
node-api-version@0.2.1:
|
||||
dependencies:
|
||||
semver: 7.7.4
|
||||
|
|
@ -2904,6 +2935,11 @@ snapshots:
|
|||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
node-mac-permissions@2.5.0:
|
||||
dependencies:
|
||||
bindings: 1.5.0
|
||||
node-addon-api: 7.1.1
|
||||
|
||||
nopt@8.1.0:
|
||||
dependencies:
|
||||
abbrev: 3.0.1
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ async function buildElectron() {
|
|||
bundle: true,
|
||||
platform: 'node',
|
||||
target: 'node18',
|
||||
external: ['electron'],
|
||||
external: ['electron', 'node-mac-permissions', 'bindings', 'file-uri-to-path'],
|
||||
sourcemap: true,
|
||||
minify: false,
|
||||
define: {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,17 @@ export const IPC_CHANNELS = {
|
|||
SET_QUICK_ASK_MODE: 'set-quick-ask-mode',
|
||||
GET_QUICK_ASK_MODE: 'get-quick-ask-mode',
|
||||
REPLACE_TEXT: 'replace-text',
|
||||
// Permissions
|
||||
GET_PERMISSIONS_STATUS: 'get-permissions-status',
|
||||
REQUEST_ACCESSIBILITY: 'request-accessibility',
|
||||
REQUEST_SCREEN_RECORDING: 'request-screen-recording',
|
||||
RESTART_APP: 'restart-app',
|
||||
// Autocomplete
|
||||
AUTOCOMPLETE_CONTEXT: 'autocomplete-context',
|
||||
ACCEPT_SUGGESTION: 'accept-suggestion',
|
||||
DISMISS_SUGGESTION: 'dismiss-suggestion',
|
||||
SET_AUTOCOMPLETE_ENABLED: 'set-autocomplete-enabled',
|
||||
GET_AUTOCOMPLETE_ENABLED: 'get-autocomplete-enabled',
|
||||
// Folder sync channels
|
||||
FOLDER_SYNC_SELECT_FOLDER: 'folder-sync:select-folder',
|
||||
FOLDER_SYNC_ADD_FOLDER: 'folder-sync:add-folder',
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
import { app, ipcMain, shell } from 'electron';
|
||||
import { IPC_CHANNELS } from './channels';
|
||||
import {
|
||||
getPermissionsStatus,
|
||||
requestAccessibility,
|
||||
requestScreenRecording,
|
||||
restartApp,
|
||||
} from '../modules/permissions';
|
||||
import {
|
||||
selectFolder,
|
||||
addWatchedFolder,
|
||||
|
|
@ -31,6 +37,22 @@ export function registerIpcHandlers(): void {
|
|||
return app.getVersion();
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_PERMISSIONS_STATUS, () => {
|
||||
return getPermissionsStatus();
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.REQUEST_ACCESSIBILITY, () => {
|
||||
requestAccessibility();
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.REQUEST_SCREEN_RECORDING, () => {
|
||||
requestScreenRecording();
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.RESTART_APP, () => {
|
||||
restartApp();
|
||||
});
|
||||
|
||||
// Folder sync handlers
|
||||
ipcMain.handle(IPC_CHANNELS.FOLDER_SYNC_SELECT_FOLDER, () => selectFolder());
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import { setupDeepLinks, handlePendingDeepLink } from './modules/deep-links';
|
|||
import { setupAutoUpdater } from './modules/auto-updater';
|
||||
import { setupMenu } from './modules/menu';
|
||||
import { registerQuickAsk, unregisterQuickAsk } from './modules/quick-ask';
|
||||
import { registerAutocomplete, unregisterAutocomplete } from './modules/autocomplete';
|
||||
import { registerFolderWatcher, unregisterFolderWatcher } from './modules/folder-watcher';
|
||||
import { registerIpcHandlers } from './ipc/handlers';
|
||||
|
||||
|
|
@ -17,7 +18,6 @@ if (!setupDeepLinks()) {
|
|||
|
||||
registerIpcHandlers();
|
||||
|
||||
// App lifecycle
|
||||
app.whenReady().then(async () => {
|
||||
setupMenu();
|
||||
try {
|
||||
|
|
@ -27,8 +27,10 @@ app.whenReady().then(async () => {
|
|||
setTimeout(() => app.quit(), 0);
|
||||
return;
|
||||
}
|
||||
createMainWindow();
|
||||
|
||||
createMainWindow('/dashboard');
|
||||
registerQuickAsk();
|
||||
registerAutocomplete();
|
||||
registerFolderWatcher();
|
||||
setupAutoUpdater();
|
||||
|
||||
|
|
@ -36,7 +38,7 @@ app.whenReady().then(async () => {
|
|||
|
||||
app.on('activate', () => {
|
||||
if (BrowserWindow.getAllWindows().length === 0) {
|
||||
createMainWindow();
|
||||
createMainWindow('/dashboard');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
@ -49,5 +51,6 @@ app.on('window-all-closed', () => {
|
|||
|
||||
app.on('will-quit', () => {
|
||||
unregisterQuickAsk();
|
||||
unregisterAutocomplete();
|
||||
unregisterFolderWatcher();
|
||||
});
|
||||
|
|
|
|||
132
surfsense_desktop/src/modules/autocomplete/index.ts
Normal file
132
surfsense_desktop/src/modules/autocomplete/index.ts
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import { clipboard, globalShortcut, ipcMain, screen } from 'electron';
|
||||
import { IPC_CHANNELS } from '../../ipc/channels';
|
||||
import { getFrontmostApp, getWindowTitle, hasAccessibilityPermission, simulatePaste } from '../platform';
|
||||
import { hasScreenRecordingPermission, requestAccessibility, requestScreenRecording } from '../permissions';
|
||||
import { getMainWindow } from '../window';
|
||||
import { captureScreen } from './screenshot';
|
||||
import { createSuggestionWindow, destroySuggestion, getSuggestionWindow } from './suggestion-window';
|
||||
|
||||
const SHORTCUT = 'CommandOrControl+Shift+Space';
|
||||
|
||||
let autocompleteEnabled = true;
|
||||
let savedClipboard = '';
|
||||
let sourceApp = '';
|
||||
let lastSearchSpaceId: string | null = null;
|
||||
|
||||
function isSurfSenseWindow(): boolean {
|
||||
const app = getFrontmostApp();
|
||||
return app === 'Electron' || app === 'SurfSense' || app === 'surfsense-desktop';
|
||||
}
|
||||
|
||||
async function triggerAutocomplete(): Promise<void> {
|
||||
if (!autocompleteEnabled) return;
|
||||
if (isSurfSenseWindow()) return;
|
||||
|
||||
if (!hasScreenRecordingPermission()) {
|
||||
requestScreenRecording();
|
||||
return;
|
||||
}
|
||||
|
||||
sourceApp = getFrontmostApp();
|
||||
const windowTitle = getWindowTitle();
|
||||
savedClipboard = clipboard.readText();
|
||||
|
||||
const screenshot = await captureScreen();
|
||||
if (!screenshot) {
|
||||
console.error('[autocomplete] Screenshot capture failed');
|
||||
return;
|
||||
}
|
||||
|
||||
const mainWin = getMainWindow();
|
||||
if (mainWin && !mainWin.isDestroyed()) {
|
||||
const mainUrl = mainWin.webContents.getURL();
|
||||
const match = mainUrl.match(/\/dashboard\/(\d+)/);
|
||||
if (match) {
|
||||
lastSearchSpaceId = match[1];
|
||||
}
|
||||
}
|
||||
|
||||
if (!lastSearchSpaceId) {
|
||||
console.warn('[autocomplete] No active search space. Open a search space first.');
|
||||
return;
|
||||
}
|
||||
|
||||
const searchSpaceId = lastSearchSpaceId;
|
||||
const cursor = screen.getCursorScreenPoint();
|
||||
const win = createSuggestionWindow(cursor.x, cursor.y);
|
||||
|
||||
win.webContents.once('did-finish-load', () => {
|
||||
const sw = getSuggestionWindow();
|
||||
setTimeout(() => {
|
||||
if (sw && !sw.isDestroyed()) {
|
||||
sw.webContents.send(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, {
|
||||
screenshot,
|
||||
searchSpaceId,
|
||||
appName: sourceApp,
|
||||
windowTitle,
|
||||
});
|
||||
}
|
||||
}, 300);
|
||||
});
|
||||
}
|
||||
|
||||
async function acceptAndInject(text: string): Promise<void> {
|
||||
if (!sourceApp) return;
|
||||
|
||||
if (!hasAccessibilityPermission()) {
|
||||
requestAccessibility();
|
||||
return;
|
||||
}
|
||||
|
||||
clipboard.writeText(text);
|
||||
destroySuggestion();
|
||||
|
||||
try {
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
simulatePaste();
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
clipboard.writeText(savedClipboard);
|
||||
} catch {
|
||||
clipboard.writeText(savedClipboard);
|
||||
}
|
||||
}
|
||||
|
||||
function registerIpcHandlers(): void {
|
||||
ipcMain.handle(IPC_CHANNELS.ACCEPT_SUGGESTION, async (_event, text: string) => {
|
||||
await acceptAndInject(text);
|
||||
});
|
||||
ipcMain.handle(IPC_CHANNELS.DISMISS_SUGGESTION, () => {
|
||||
destroySuggestion();
|
||||
});
|
||||
ipcMain.handle(IPC_CHANNELS.SET_AUTOCOMPLETE_ENABLED, (_event, enabled: boolean) => {
|
||||
autocompleteEnabled = enabled;
|
||||
if (!enabled) {
|
||||
destroySuggestion();
|
||||
}
|
||||
});
|
||||
ipcMain.handle(IPC_CHANNELS.GET_AUTOCOMPLETE_ENABLED, () => autocompleteEnabled);
|
||||
}
|
||||
|
||||
export function registerAutocomplete(): void {
|
||||
registerIpcHandlers();
|
||||
|
||||
const ok = globalShortcut.register(SHORTCUT, () => {
|
||||
const sw = getSuggestionWindow();
|
||||
if (sw && !sw.isDestroyed()) {
|
||||
destroySuggestion();
|
||||
return;
|
||||
}
|
||||
triggerAutocomplete();
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
console.error(`[autocomplete] Failed to register shortcut ${SHORTCUT}`);
|
||||
} else {
|
||||
console.log(`[autocomplete] Registered shortcut ${SHORTCUT}`);
|
||||
}
|
||||
}
|
||||
|
||||
export function unregisterAutocomplete(): void {
|
||||
globalShortcut.unregister(SHORTCUT);
|
||||
destroySuggestion();
|
||||
}
|
||||
27
surfsense_desktop/src/modules/autocomplete/screenshot.ts
Normal file
27
surfsense_desktop/src/modules/autocomplete/screenshot.ts
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import { desktopCapturer, screen } from 'electron';
|
||||
|
||||
/**
|
||||
* Captures the primary display as a base64-encoded PNG data URL.
|
||||
* Uses the display's actual size for full-resolution capture.
|
||||
*/
|
||||
export async function captureScreen(): Promise<string | null> {
|
||||
try {
|
||||
const primaryDisplay = screen.getPrimaryDisplay();
|
||||
const { width, height } = primaryDisplay.size;
|
||||
|
||||
const sources = await desktopCapturer.getSources({
|
||||
types: ['screen'],
|
||||
thumbnailSize: { width, height },
|
||||
});
|
||||
|
||||
if (!sources.length) {
|
||||
console.error('[screenshot] No screen sources found');
|
||||
return null;
|
||||
}
|
||||
|
||||
return sources[0].thumbnail.toDataURL();
|
||||
} catch (err) {
|
||||
console.error('[screenshot] Failed to capture screen:', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
112
surfsense_desktop/src/modules/autocomplete/suggestion-window.ts
Normal file
112
surfsense_desktop/src/modules/autocomplete/suggestion-window.ts
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import { BrowserWindow, screen, shell } from 'electron';
|
||||
import path from 'path';
|
||||
import { getServerPort } from '../server';
|
||||
|
||||
const TOOLTIP_WIDTH = 420;
|
||||
const TOOLTIP_HEIGHT = 38;
|
||||
const MAX_HEIGHT = 400;
|
||||
|
||||
let suggestionWindow: BrowserWindow | null = null;
|
||||
let resizeTimer: ReturnType<typeof setInterval> | null = null;
|
||||
let cursorOrigin = { x: 0, y: 0 };
|
||||
|
||||
const CURSOR_GAP = 20;
|
||||
|
||||
function positionOnScreen(cursorX: number, cursorY: number, w: number, h: number): { x: number; y: number } {
|
||||
const display = screen.getDisplayNearestPoint({ x: cursorX, y: cursorY });
|
||||
const { x: dx, y: dy, width: dw, height: dh } = display.workArea;
|
||||
|
||||
const x = Math.max(dx, Math.min(cursorX, dx + dw - w));
|
||||
|
||||
const spaceBelow = (dy + dh) - (cursorY + CURSOR_GAP);
|
||||
const y = spaceBelow >= h
|
||||
? cursorY + CURSOR_GAP
|
||||
: cursorY - h - CURSOR_GAP;
|
||||
|
||||
return { x, y: Math.max(dy, y) };
|
||||
}
|
||||
|
||||
function stopResizePolling(): void {
|
||||
if (resizeTimer) { clearInterval(resizeTimer); resizeTimer = null; }
|
||||
}
|
||||
|
||||
function startResizePolling(win: BrowserWindow): void {
|
||||
stopResizePolling();
|
||||
let lastH = 0;
|
||||
resizeTimer = setInterval(async () => {
|
||||
if (!win || win.isDestroyed()) { stopResizePolling(); return; }
|
||||
try {
|
||||
const h: number = await win.webContents.executeJavaScript(
|
||||
`document.body.scrollHeight`
|
||||
);
|
||||
if (h > 0 && h !== lastH) {
|
||||
lastH = h;
|
||||
const clamped = Math.min(h, MAX_HEIGHT);
|
||||
const pos = positionOnScreen(cursorOrigin.x, cursorOrigin.y, TOOLTIP_WIDTH, clamped);
|
||||
win.setBounds({ x: pos.x, y: pos.y, width: TOOLTIP_WIDTH, height: clamped });
|
||||
}
|
||||
} catch {}
|
||||
}, 150);
|
||||
}
|
||||
|
||||
export function getSuggestionWindow(): BrowserWindow | null {
|
||||
return suggestionWindow;
|
||||
}
|
||||
|
||||
export function destroySuggestion(): void {
|
||||
stopResizePolling();
|
||||
if (suggestionWindow && !suggestionWindow.isDestroyed()) {
|
||||
suggestionWindow.close();
|
||||
}
|
||||
suggestionWindow = null;
|
||||
}
|
||||
|
||||
export function createSuggestionWindow(x: number, y: number): BrowserWindow {
|
||||
destroySuggestion();
|
||||
cursorOrigin = { x, y };
|
||||
|
||||
const pos = positionOnScreen(x, y, TOOLTIP_WIDTH, TOOLTIP_HEIGHT);
|
||||
|
||||
suggestionWindow = new BrowserWindow({
|
||||
width: TOOLTIP_WIDTH,
|
||||
height: TOOLTIP_HEIGHT,
|
||||
x: pos.x,
|
||||
y: pos.y,
|
||||
frame: false,
|
||||
transparent: true,
|
||||
focusable: false,
|
||||
alwaysOnTop: true,
|
||||
skipTaskbar: true,
|
||||
hasShadow: true,
|
||||
type: 'panel',
|
||||
webPreferences: {
|
||||
preload: path.join(__dirname, 'preload.js'),
|
||||
contextIsolation: true,
|
||||
nodeIntegration: false,
|
||||
sandbox: true,
|
||||
},
|
||||
show: false,
|
||||
});
|
||||
|
||||
suggestionWindow.loadURL(`http://localhost:${getServerPort()}/desktop/suggestion?t=${Date.now()}`);
|
||||
|
||||
suggestionWindow.once('ready-to-show', () => {
|
||||
suggestionWindow?.showInactive();
|
||||
if (suggestionWindow) startResizePolling(suggestionWindow);
|
||||
});
|
||||
|
||||
suggestionWindow.webContents.setWindowOpenHandler(({ url }) => {
|
||||
if (url.startsWith('http://localhost')) {
|
||||
return { action: 'allow' };
|
||||
}
|
||||
shell.openExternal(url);
|
||||
return { action: 'deny' };
|
||||
});
|
||||
|
||||
suggestionWindow.on('closed', () => {
|
||||
stopResizePolling();
|
||||
suggestionWindow = null;
|
||||
});
|
||||
|
||||
return suggestionWindow;
|
||||
}
|
||||
51
surfsense_desktop/src/modules/permissions.ts
Normal file
51
surfsense_desktop/src/modules/permissions.ts
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import { app } from 'electron';
|
||||
|
||||
type PermissionStatus = 'authorized' | 'denied' | 'not determined' | 'restricted' | 'limited';
|
||||
|
||||
export interface PermissionsStatus {
|
||||
accessibility: PermissionStatus;
|
||||
screenRecording: PermissionStatus;
|
||||
}
|
||||
|
||||
function isMac(): boolean {
|
||||
return process.platform === 'darwin';
|
||||
}
|
||||
|
||||
function getNodeMacPermissions() {
|
||||
return require('node-mac-permissions');
|
||||
}
|
||||
|
||||
export function getPermissionsStatus(): PermissionsStatus {
|
||||
if (!isMac()) {
|
||||
return { accessibility: 'authorized', screenRecording: 'authorized' };
|
||||
}
|
||||
|
||||
const perms = getNodeMacPermissions();
|
||||
return {
|
||||
accessibility: perms.getAuthStatus('accessibility'),
|
||||
screenRecording: perms.getAuthStatus('screen'),
|
||||
};
|
||||
}
|
||||
|
||||
export function requestAccessibility(): void {
|
||||
if (!isMac()) return;
|
||||
const perms = getNodeMacPermissions();
|
||||
perms.askForAccessibilityAccess();
|
||||
}
|
||||
|
||||
export function hasScreenRecordingPermission(): boolean {
|
||||
if (!isMac()) return true;
|
||||
const perms = getNodeMacPermissions();
|
||||
return perms.getAuthStatus('screen') === 'authorized';
|
||||
}
|
||||
|
||||
export function requestScreenRecording(): void {
|
||||
if (!isMac()) return;
|
||||
const perms = getNodeMacPermissions();
|
||||
perms.askForScreenCaptureAccess();
|
||||
}
|
||||
|
||||
export function restartApp(): void {
|
||||
app.relaunch();
|
||||
app.exit(0);
|
||||
}
|
||||
|
|
@ -19,28 +19,6 @@ export function getFrontmostApp(): string {
|
|||
return '';
|
||||
}
|
||||
|
||||
export function getSelectedText(): string {
|
||||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
return execSync(
|
||||
'osascript -e \'tell application "System Events" to get value of attribute "AXSelectedText" of focused UI element of first application process whose frontmost is true\''
|
||||
).toString().trim();
|
||||
}
|
||||
// Windows: no reliable accessibility API for selected text across apps
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
export function simulateCopy(): void {
|
||||
if (process.platform === 'darwin') {
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "c" using command down\'');
|
||||
} else if (process.platform === 'win32') {
|
||||
execSync('powershell -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^c\')"');
|
||||
}
|
||||
}
|
||||
|
||||
export function simulatePaste(): void {
|
||||
if (process.platform === 'darwin') {
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "v" using command down\'');
|
||||
|
|
@ -53,3 +31,26 @@ export function checkAccessibilityPermission(): boolean {
|
|||
if (process.platform !== 'darwin') return true;
|
||||
return systemPreferences.isTrustedAccessibilityClient(true);
|
||||
}
|
||||
|
||||
export function getWindowTitle(): string {
|
||||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
return execSync(
|
||||
'osascript -e \'tell application "System Events" to get title of front window of first application process whose frontmost is true\''
|
||||
).toString().trim();
|
||||
}
|
||||
if (process.platform === 'win32') {
|
||||
return execSync(
|
||||
'powershell -command "(Get-Process | Where-Object { $_.MainWindowHandle -eq (Add-Type -MemberDefinition \'[DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow();\' -Name W -PassThru)::GetForegroundWindow() }).MainWindowTitle"'
|
||||
).toString().trim();
|
||||
}
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
export function hasAccessibilityPermission(): boolean {
|
||||
if (process.platform !== 'darwin') return true;
|
||||
return systemPreferences.isTrustedAccessibilityClient(false);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ export function getMainWindow(): BrowserWindow | null {
|
|||
return mainWindow;
|
||||
}
|
||||
|
||||
export function createMainWindow(): BrowserWindow {
|
||||
export function createMainWindow(initialPath = '/dashboard'): BrowserWindow {
|
||||
mainWindow = new BrowserWindow({
|
||||
width: 1280,
|
||||
height: 800,
|
||||
|
|
@ -33,7 +33,7 @@ export function createMainWindow(): BrowserWindow {
|
|||
mainWindow?.show();
|
||||
});
|
||||
|
||||
mainWindow.loadURL(`http://localhost:${getServerPort()}/dashboard`);
|
||||
mainWindow.loadURL(`http://localhost:${getServerPort()}${initialPath}`);
|
||||
|
||||
mainWindow.webContents.setWindowOpenHandler(({ url }) => {
|
||||
if (url.startsWith('http://localhost')) {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,23 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
setQuickAskMode: (mode: string) => ipcRenderer.invoke(IPC_CHANNELS.SET_QUICK_ASK_MODE, mode),
|
||||
getQuickAskMode: () => ipcRenderer.invoke(IPC_CHANNELS.GET_QUICK_ASK_MODE),
|
||||
replaceText: (text: string) => ipcRenderer.invoke(IPC_CHANNELS.REPLACE_TEXT, text),
|
||||
// Permissions
|
||||
getPermissionsStatus: () => ipcRenderer.invoke(IPC_CHANNELS.GET_PERMISSIONS_STATUS),
|
||||
requestAccessibility: () => ipcRenderer.invoke(IPC_CHANNELS.REQUEST_ACCESSIBILITY),
|
||||
requestScreenRecording: () => ipcRenderer.invoke(IPC_CHANNELS.REQUEST_SCREEN_RECORDING),
|
||||
restartApp: () => ipcRenderer.invoke(IPC_CHANNELS.RESTART_APP),
|
||||
// Autocomplete
|
||||
onAutocompleteContext: (callback: (data: { screenshot: string; searchSpaceId?: string; appName?: string; windowTitle?: string }) => void) => {
|
||||
const listener = (_event: unknown, data: { screenshot: string; searchSpaceId?: string; appName?: string; windowTitle?: string }) => callback(data);
|
||||
ipcRenderer.on(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, listener);
|
||||
return () => {
|
||||
ipcRenderer.removeListener(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, listener);
|
||||
};
|
||||
},
|
||||
acceptSuggestion: (text: string) => ipcRenderer.invoke(IPC_CHANNELS.ACCEPT_SUGGESTION, text),
|
||||
dismissSuggestion: () => ipcRenderer.invoke(IPC_CHANNELS.DISMISS_SUGGESTION),
|
||||
setAutocompleteEnabled: (enabled: boolean) => ipcRenderer.invoke(IPC_CHANNELS.SET_AUTOCOMPLETE_ENABLED, enabled),
|
||||
getAutocompleteEnabled: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTOCOMPLETE_ENABLED),
|
||||
|
||||
// Folder sync
|
||||
selectFolder: () => ipcRenderer.invoke(IPC_CHANNELS.FOLDER_SYNC_SELECT_FOLDER),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
export function DesktopContent() {
|
||||
const [isElectron, setIsElectron] = useState(false);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [enabled, setEnabled] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI) {
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
setIsElectron(true);
|
||||
|
||||
window.electronAPI.getAutocompleteEnabled().then((val) => {
|
||||
setEnabled(val);
|
||||
setLoading(false);
|
||||
});
|
||||
}, []);
|
||||
|
||||
if (!isElectron) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12 text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Desktop settings are only available in the SurfSense desktop app.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spinner size="md" className="text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const handleToggle = async (checked: boolean) => {
|
||||
setEnabled(checked);
|
||||
await window.electronAPI!.setAutocompleteEnabled(checked);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg">Autocomplete</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Get inline writing suggestions powered by your knowledge base as you type in any app.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
|
||||
<div className="flex items-center justify-between rounded-lg border p-4">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="autocomplete-toggle" className="text-sm font-medium cursor-pointer">
|
||||
Enable autocomplete
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Show suggestions while typing in other applications.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="autocomplete-toggle"
|
||||
checked={enabled}
|
||||
onCheckedChange={handleToggle}
|
||||
/>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
215
surfsense_web/app/desktop/permissions/page.tsx
Normal file
215
surfsense_web/app/desktop/permissions/page.tsx
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
type PermissionStatus = "authorized" | "denied" | "not determined" | "restricted" | "limited";
|
||||
|
||||
interface PermissionsStatus {
|
||||
accessibility: PermissionStatus;
|
||||
screenRecording: PermissionStatus;
|
||||
}
|
||||
|
||||
const STEPS = [
|
||||
{
|
||||
id: "screen-recording",
|
||||
title: "Screen Recording",
|
||||
description: "Lets SurfSense capture your screen to understand context and provide smart writing suggestions.",
|
||||
action: "requestScreenRecording",
|
||||
field: "screenRecording" as const,
|
||||
},
|
||||
{
|
||||
id: "accessibility",
|
||||
title: "Accessibility",
|
||||
description: "Lets SurfSense insert suggestions seamlessly, right where you\u2019re typing.",
|
||||
action: "requestAccessibility",
|
||||
field: "accessibility" as const,
|
||||
},
|
||||
];
|
||||
|
||||
function StatusBadge({ status }: { status: PermissionStatus }) {
|
||||
if (status === "authorized") {
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 text-xs font-medium text-green-700 dark:text-green-400">
|
||||
<span className="h-2 w-2 rounded-full bg-green-500" />
|
||||
Granted
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (status === "denied") {
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 text-xs font-medium text-amber-700 dark:text-amber-400">
|
||||
<span className="h-2 w-2 rounded-full bg-amber-500" />
|
||||
Denied
|
||||
</span>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1.5 text-xs font-medium text-muted-foreground">
|
||||
<span className="h-2 w-2 rounded-full bg-muted-foreground/40" />
|
||||
Pending
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export default function DesktopPermissionsPage() {
|
||||
const router = useRouter();
|
||||
const [permissions, setPermissions] = useState<PermissionsStatus | null>(null);
|
||||
const [isElectron, setIsElectron] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI) return;
|
||||
setIsElectron(true);
|
||||
|
||||
let interval: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
const isResolved = (s: string) => s === "authorized" || s === "restricted";
|
||||
|
||||
const poll = async () => {
|
||||
const status = await window.electronAPI!.getPermissionsStatus();
|
||||
setPermissions(status);
|
||||
|
||||
if (isResolved(status.accessibility) && isResolved(status.screenRecording)) {
|
||||
if (interval) clearInterval(interval);
|
||||
}
|
||||
};
|
||||
|
||||
poll();
|
||||
interval = setInterval(poll, 2000);
|
||||
return () => { if (interval) clearInterval(interval); };
|
||||
}, []);
|
||||
|
||||
if (!isElectron) {
|
||||
return (
|
||||
<div className="h-screen flex items-center justify-center bg-background">
|
||||
<p className="text-muted-foreground">This page is only available in the desktop app.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!permissions) {
|
||||
return (
|
||||
<div className="h-screen flex items-center justify-center bg-background">
|
||||
<Spinner size="lg" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const allGranted = permissions.accessibility === "authorized" && permissions.screenRecording === "authorized";
|
||||
|
||||
const handleRequest = async (action: string) => {
|
||||
if (action === "requestScreenRecording") {
|
||||
await window.electronAPI!.requestScreenRecording();
|
||||
} else if (action === "requestAccessibility") {
|
||||
await window.electronAPI!.requestAccessibility();
|
||||
}
|
||||
};
|
||||
|
||||
const handleContinue = () => {
|
||||
if (allGranted) {
|
||||
window.electronAPI!.restartApp();
|
||||
}
|
||||
};
|
||||
|
||||
const handleSkip = () => {
|
||||
router.push("/dashboard");
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="h-screen flex flex-col items-center p-4 bg-background dark:bg-neutral-900 select-none overflow-hidden">
|
||||
<div className="w-full max-w-lg flex flex-col min-h-0 h-full gap-6 py-8">
|
||||
{/* Header */}
|
||||
<div className="text-center space-y-3 shrink-0">
|
||||
<Logo className="w-12 h-12 mx-auto" />
|
||||
<div className="space-y-1">
|
||||
<h1 className="text-2xl font-semibold tracking-tight">System Permissions</h1>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
SurfSense needs two macOS permissions to provide context-aware writing suggestions.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Steps */}
|
||||
<div className="rounded-xl border bg-background dark:bg-neutral-900 flex-1 min-h-0 overflow-y-auto px-6 py-6 space-y-6">
|
||||
{STEPS.map((step, index) => {
|
||||
const status = permissions[step.field];
|
||||
const isGranted = status === "authorized";
|
||||
|
||||
return (
|
||||
<div
|
||||
key={step.id}
|
||||
className={`rounded-lg border p-4 transition-colors ${
|
||||
isGranted
|
||||
? "border-green-200 bg-green-50/50 dark:border-green-900 dark:bg-green-950/20"
|
||||
: "border-border"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="flex items-start gap-3">
|
||||
<span className="flex h-7 w-7 shrink-0 items-center justify-center rounded-full bg-primary/10 text-sm font-medium text-primary">
|
||||
{isGranted ? "\u2713" : index + 1}
|
||||
</span>
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-sm font-medium">{step.title}</h3>
|
||||
<p className="text-xs text-muted-foreground">{step.description}</p>
|
||||
</div>
|
||||
</div>
|
||||
<StatusBadge status={status} />
|
||||
</div>
|
||||
{!isGranted && (
|
||||
<div className="mt-3 pl-10 space-y-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => handleRequest(step.action)}
|
||||
className="text-xs"
|
||||
>
|
||||
Open System Settings
|
||||
</Button>
|
||||
{status === "denied" && (
|
||||
<p className="text-xs text-amber-700 dark:text-amber-400">
|
||||
Toggle SurfSense on in System Settings to continue.
|
||||
</p>
|
||||
)}
|
||||
<p className="text-xs text-muted-foreground">
|
||||
If SurfSense doesn't appear in the list, click <strong>+</strong> and select it from Applications.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<div className="text-center space-y-3 shrink-0">
|
||||
{allGranted ? (
|
||||
<>
|
||||
<Button onClick={handleContinue} className="text-sm h-9 min-w-[180px]">
|
||||
Restart & Get Started
|
||||
</Button>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
A restart is needed for permissions to take effect.
|
||||
</p>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Button disabled className="text-sm h-9 min-w-[180px]">
|
||||
Grant permissions to continue
|
||||
</Button>
|
||||
<button
|
||||
onClick={handleSkip}
|
||||
className="block mx-auto text-xs text-muted-foreground hover:text-foreground transition-colors"
|
||||
>
|
||||
Skip for now
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
13
surfsense_web/app/desktop/suggestion/layout.tsx
Normal file
13
surfsense_web/app/desktop/suggestion/layout.tsx
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import "./suggestion.css";
|
||||
|
||||
export const metadata = {
|
||||
title: "SurfSense Suggestion",
|
||||
};
|
||||
|
||||
export default function SuggestionLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return <div className="suggestion-body">{children}</div>;
|
||||
}
|
||||
219
surfsense_web/app/desktop/suggestion/page.tsx
Normal file
219
surfsense_web/app/desktop/suggestion/page.tsx
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
|
||||
type SSEEvent =
|
||||
| { type: "text-delta"; id: string; delta: string }
|
||||
| { type: "text-start"; id: string }
|
||||
| { type: "text-end"; id: string }
|
||||
| { type: "start"; messageId: string }
|
||||
| { type: "finish" }
|
||||
| { type: "error"; errorText: string };
|
||||
|
||||
function friendlyError(raw: string | number): string {
|
||||
if (typeof raw === "number") {
|
||||
if (raw === 401) return "Please sign in to use suggestions.";
|
||||
if (raw === 403) return "You don\u2019t have permission for this.";
|
||||
if (raw === 404) return "Suggestion service not found. Is the backend running?";
|
||||
if (raw >= 500) return "Something went wrong on the server. Try again.";
|
||||
return "Something went wrong. Try again.";
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
if (lower.includes("not authenticated") || lower.includes("unauthorized"))
|
||||
return "Please sign in to use suggestions.";
|
||||
if (lower.includes("no vision llm configured") || lower.includes("no llm configured"))
|
||||
return "No Vision LLM configured. Set one in search space settings.";
|
||||
if (lower.includes("does not support vision"))
|
||||
return "Selected model doesn\u2019t support vision. Set a vision-capable model in settings.";
|
||||
if (lower.includes("fetch") || lower.includes("network") || lower.includes("econnrefused"))
|
||||
return "Can\u2019t reach the server. Check your connection.";
|
||||
return "Something went wrong. Try again.";
|
||||
}
|
||||
|
||||
const AUTO_DISMISS_MS = 3000;
|
||||
|
||||
export default function SuggestionPage() {
|
||||
const [suggestion, setSuggestion] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isDesktop, setIsDesktop] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI?.onAutocompleteContext) {
|
||||
setIsDesktop(false);
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!error) return;
|
||||
const timer = setTimeout(() => {
|
||||
window.electronAPI?.dismissSuggestion?.();
|
||||
}, AUTO_DISMISS_MS);
|
||||
return () => clearTimeout(timer);
|
||||
}, [error]);
|
||||
|
||||
const fetchSuggestion = useCallback(
|
||||
async (screenshot: string, searchSpaceId: string, appName?: string, windowTitle?: string) => {
|
||||
abortRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
|
||||
setIsLoading(true);
|
||||
setSuggestion("");
|
||||
setError(null);
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
setError(friendlyError("not authenticated"));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const backendUrl =
|
||||
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${backendUrl}/api/v1/autocomplete/vision/stream`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
screenshot,
|
||||
search_space_id: parseInt(searchSpaceId, 10),
|
||||
app_name: appName || "",
|
||||
window_title: windowTitle || "",
|
||||
}),
|
||||
signal: controller.signal,
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
setError(friendlyError(response.status));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
setError(friendlyError("network error"));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const events = buffer.split(/\r?\n\r?\n/);
|
||||
buffer = events.pop() || "";
|
||||
|
||||
for (const event of events) {
|
||||
const lines = event.split(/\r?\n/);
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data: ")) continue;
|
||||
const data = line.slice(6).trim();
|
||||
if (!data || data === "[DONE]") continue;
|
||||
|
||||
try {
|
||||
const parsed: SSEEvent = JSON.parse(data);
|
||||
if (parsed.type === "text-delta") {
|
||||
setSuggestion((prev) => prev + parsed.delta);
|
||||
} else if (parsed.type === "error") {
|
||||
setError(friendlyError(parsed.errorText));
|
||||
}
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") return;
|
||||
setError(friendlyError("network error"));
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI?.onAutocompleteContext) return;
|
||||
|
||||
const cleanup = window.electronAPI.onAutocompleteContext((data) => {
|
||||
const searchSpaceId = data.searchSpaceId || "1";
|
||||
if (data.screenshot) {
|
||||
fetchSuggestion(data.screenshot, searchSpaceId, data.appName, data.windowTitle);
|
||||
}
|
||||
});
|
||||
|
||||
return cleanup;
|
||||
}, [fetchSuggestion]);
|
||||
|
||||
if (!isDesktop) {
|
||||
return (
|
||||
<div className="suggestion-tooltip">
|
||||
<span className="suggestion-error-text">
|
||||
This page is only available in the SurfSense desktop app.
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="suggestion-tooltip suggestion-error">
|
||||
<span className="suggestion-error-text">{error}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading && !suggestion) {
|
||||
return (
|
||||
<div className="suggestion-tooltip">
|
||||
<div className="suggestion-loading">
|
||||
<span className="suggestion-dot" />
|
||||
<span className="suggestion-dot" />
|
||||
<span className="suggestion-dot" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const handleAccept = () => {
|
||||
if (suggestion) {
|
||||
window.electronAPI?.acceptSuggestion?.(suggestion);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDismiss = () => {
|
||||
window.electronAPI?.dismissSuggestion?.();
|
||||
};
|
||||
|
||||
if (!suggestion) return null;
|
||||
|
||||
return (
|
||||
<div className="suggestion-tooltip">
|
||||
<p className="suggestion-text">{suggestion}</p>
|
||||
<div className="suggestion-actions">
|
||||
<button className="suggestion-btn suggestion-btn-accept" onClick={handleAccept}>
|
||||
Accept
|
||||
</button>
|
||||
<button className="suggestion-btn suggestion-btn-dismiss" onClick={handleDismiss}>
|
||||
Dismiss
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
121
surfsense_web/app/desktop/suggestion/suggestion.css
Normal file
121
surfsense_web/app/desktop/suggestion/suggestion.css
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
html:has(.suggestion-body),
|
||||
body:has(.suggestion-body) {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
background: transparent !important;
|
||||
overflow: hidden !important;
|
||||
height: auto !important;
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.suggestion-body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
background: transparent;
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
user-select: none;
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.suggestion-tooltip {
|
||||
background: #1e1e1e;
|
||||
border: 1px solid #3c3c3c;
|
||||
border-radius: 8px;
|
||||
padding: 8px 12px;
|
||||
margin: 4px;
|
||||
max-width: 400px;
|
||||
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
.suggestion-text {
|
||||
color: #d4d4d4;
|
||||
font-size: 13px;
|
||||
line-height: 1.45;
|
||||
margin: 0 0 6px 0;
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.suggestion-actions {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
gap: 4px;
|
||||
border-top: 1px solid #2a2a2a;
|
||||
padding-top: 6px;
|
||||
}
|
||||
|
||||
.suggestion-btn {
|
||||
padding: 2px 8px;
|
||||
border-radius: 3px;
|
||||
border: 1px solid #3c3c3c;
|
||||
font-family: inherit;
|
||||
font-size: 10px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
line-height: 16px;
|
||||
transition: background 0.15s, border-color 0.15s;
|
||||
}
|
||||
|
||||
.suggestion-btn-accept {
|
||||
background: #2563eb;
|
||||
border-color: #3b82f6;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
.suggestion-btn-accept:hover {
|
||||
background: #1d4ed8;
|
||||
}
|
||||
|
||||
.suggestion-btn-dismiss {
|
||||
background: #2a2a2a;
|
||||
color: #999;
|
||||
}
|
||||
|
||||
.suggestion-btn-dismiss:hover {
|
||||
background: #333;
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
.suggestion-error {
|
||||
border-color: #5c2626;
|
||||
}
|
||||
|
||||
.suggestion-error-text {
|
||||
color: #f48771;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.suggestion-loading {
|
||||
display: flex;
|
||||
gap: 5px;
|
||||
padding: 2px 0;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.suggestion-dot {
|
||||
width: 4px;
|
||||
height: 4px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
animation: suggestion-pulse 1.2s infinite ease-in-out;
|
||||
}
|
||||
|
||||
.suggestion-dot:nth-child(2) {
|
||||
animation-delay: 0.15s;
|
||||
}
|
||||
|
||||
.suggestion-dot:nth-child(3) {
|
||||
animation-delay: 0.3s;
|
||||
}
|
||||
|
||||
@keyframes suggestion-pulse {
|
||||
0%, 80%, 100% {
|
||||
opacity: 0.3;
|
||||
transform: scale(0.8);
|
||||
}
|
||||
40% {
|
||||
opacity: 1;
|
||||
transform: scale(1.1);
|
||||
}
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@ import {
|
|||
TrashIcon,
|
||||
} from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { memo, useCallback, useEffect, useState } from "react";
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
|
|
@ -224,6 +224,11 @@ const ThreadListItemComponent = memo(function ThreadListItemComponent({
|
|||
onUnarchive,
|
||||
onDelete,
|
||||
}: ThreadListItemComponentProps) {
|
||||
const relativeTime = useMemo(
|
||||
() => formatRelativeTime(new Date(thread.updatedAt)),
|
||||
[thread.updatedAt]
|
||||
);
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
|
|
@ -237,7 +242,7 @@ const ThreadListItemComponent = memo(function ThreadListItemComponent({
|
|||
<div className="flex-1 min-w-0">
|
||||
<p className="truncate text-sm font-medium">{thread.title || "New Chat"}</p>
|
||||
<p className="truncate text-xs text-muted-foreground">
|
||||
{formatRelativeTime(new Date(thread.updatedAt))}
|
||||
{relativeTime}
|
||||
</p>
|
||||
</div>
|
||||
<DropdownMenu>
|
||||
|
|
|
|||
|
|
@ -96,9 +96,6 @@ export function EditorPanelContent({
|
|||
}
|
||||
|
||||
try {
|
||||
const response = await authenticatedFetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`,
|
||||
{ method: "GET", signal: controller.signal }
|
||||
const url = new URL(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`
|
||||
);
|
||||
|
|
|
|||
|
|
@ -158,14 +158,16 @@ export function PlateEditor({
|
|||
// When not forced read-only, the user can toggle between editing/viewing.
|
||||
const canToggleMode = !readOnly;
|
||||
|
||||
const contextProviderValue = useMemo(()=> ({
|
||||
onSave,
|
||||
hasUnsavedChanges,
|
||||
isSaving,
|
||||
canToggleMode,
|
||||
}), [onSave, hasUnsavedChanges, isSaving, canToggleMode]);
|
||||
|
||||
return (
|
||||
<EditorSaveContext.Provider
|
||||
value={{
|
||||
onSave,
|
||||
hasUnsavedChanges,
|
||||
isSaving,
|
||||
canToggleMode,
|
||||
}}
|
||||
value={contextProviderValue}
|
||||
>
|
||||
<Plate
|
||||
editor={editor}
|
||||
|
|
|
|||
|
|
@ -408,6 +408,7 @@ const AudioCommentIllustration = () => (
|
|||
src="/homepage/comments-audio.webp"
|
||||
alt="Audio Comment Illustration"
|
||||
fill
|
||||
sizes="(max-width: 768px) 100vw, (max-width: 1024px) 50vw, 33vw"
|
||||
className="object-cover"
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ export const Navbar = ({ scrolledBgClassName }: NavbarProps = {}) => {
|
|||
};
|
||||
|
||||
handleScroll();
|
||||
window.addEventListener("scroll", handleScroll);
|
||||
window.addEventListener("scroll", handleScroll, { passive: true });
|
||||
return () => window.removeEventListener("scroll", handleScroll);
|
||||
}, []);
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ const MobileNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
|||
};
|
||||
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
document.addEventListener("touchstart", handleClickOutside);
|
||||
document.addEventListener("touchstart", handleClickOutside, { passive: true });
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
document.removeEventListener("touchstart", handleClickOutside);
|
||||
|
|
@ -143,7 +143,6 @@ const MobileNav = ({ navItems, isScrolled, scrolledBgClassName }: any) => {
|
|||
<motion.div
|
||||
ref={navRef}
|
||||
animate={{ borderRadius: open ? "4px" : "2rem" }}
|
||||
key={String(open)}
|
||||
className={cn(
|
||||
"relative mx-auto flex w-full max-w-[calc(100vw-2rem)] flex-col items-center justify-between px-4 py-2 lg:hidden transition-all duration-300",
|
||||
isScrolled
|
||||
|
|
|
|||
|
|
@ -81,9 +81,6 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen
|
|||
}
|
||||
|
||||
try {
|
||||
const response = await authenticatedFetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`,
|
||||
{ method: "GET", signal: controller.signal }
|
||||
const url = new URL(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`
|
||||
);
|
||||
|
|
|
|||
|
|
@ -602,11 +602,11 @@ export function OnboardingTour() {
|
|||
};
|
||||
|
||||
window.addEventListener("resize", handleUpdate);
|
||||
window.addEventListener("scroll", handleUpdate, true);
|
||||
window.addEventListener("scroll", handleUpdate, { capture: true, passive: true });
|
||||
|
||||
return () => {
|
||||
window.removeEventListener("resize", handleUpdate);
|
||||
window.removeEventListener("scroll", handleUpdate, true);
|
||||
window.removeEventListener("scroll", handleUpdate, { capture: true });
|
||||
};
|
||||
}, [isActive, targetEl, currentStep?.placement]);
|
||||
|
||||
|
|
|
|||
|
|
@ -123,6 +123,13 @@ export function ReportPanelContent({
|
|||
const [copied, setCopied] = useState(false);
|
||||
const [exporting, setExporting] = useState<string | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const copyTimerRef = useRef<ReturnType<typeof setTimeout> | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (copyTimerRef.current) clearTimeout(copyTimerRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Editor state — tracks the latest markdown from the Plate editor
|
||||
const [editedMarkdown, setEditedMarkdown] = useState<string | null>(null);
|
||||
|
|
@ -197,7 +204,8 @@ export function ReportPanelContent({
|
|||
try {
|
||||
await navigator.clipboard.writeText(currentMarkdown);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
if (copyTimerRef.current) clearTimeout(copyTimerRef.current);
|
||||
copyTimerRef.current = setTimeout(() => setCopied(false), 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import {
|
|||
Bot,
|
||||
CircleCheck,
|
||||
CircleDashed,
|
||||
Eye,
|
||||
FileText,
|
||||
ImageIcon,
|
||||
RefreshCw,
|
||||
|
|
@ -70,6 +71,15 @@ const ROLE_DESCRIPTIONS = {
|
|||
prefKey: "image_generation_config_id" as const,
|
||||
configType: "image" as const,
|
||||
},
|
||||
vision: {
|
||||
icon: Eye,
|
||||
title: "Vision LLM",
|
||||
description: "Vision-capable model for screenshot analysis and context extraction",
|
||||
color: "text-amber-600 dark:text-amber-400",
|
||||
bgColor: "bg-amber-500/10",
|
||||
prefKey: "vision_llm_id" as const,
|
||||
configType: "llm" as const,
|
||||
},
|
||||
};
|
||||
|
||||
interface LLMRoleManagerProps {
|
||||
|
|
@ -115,6 +125,7 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||
image_generation_config_id: preferences.image_generation_config_id ?? "",
|
||||
vision_llm_id: preferences.vision_llm_id ?? "",
|
||||
}));
|
||||
|
||||
const [savingRole, setSavingRole] = useState<string | null>(null);
|
||||
|
|
@ -126,12 +137,14 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
|||
agent_llm_id: preferences.agent_llm_id ?? "",
|
||||
document_summary_llm_id: preferences.document_summary_llm_id ?? "",
|
||||
image_generation_config_id: preferences.image_generation_config_id ?? "",
|
||||
vision_llm_id: preferences.vision_llm_id ?? "",
|
||||
});
|
||||
}
|
||||
}, [
|
||||
preferences?.agent_llm_id,
|
||||
preferences?.document_summary_llm_id,
|
||||
preferences?.image_generation_config_id,
|
||||
preferences?.vision_llm_id,
|
||||
]);
|
||||
|
||||
const handleRoleAssignment = useCallback(
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom } from "jotai";
|
||||
import { Globe, KeyRound, Receipt, Sparkles, User } from "lucide-react";
|
||||
import { Globe, KeyRound, Monitor, Receipt, Sparkles, User } from "lucide-react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { ApiKeyContent } from "@/app/dashboard/[search_space_id]/user-settings/components/ApiKeyContent";
|
||||
import { CommunityPromptsContent } from "@/app/dashboard/[search_space_id]/user-settings/components/CommunityPromptsContent";
|
||||
import { ProfileContent } from "@/app/dashboard/[search_space_id]/user-settings/components/ProfileContent";
|
||||
import { PromptsContent } from "@/app/dashboard/[search_space_id]/user-settings/components/PromptsContent";
|
||||
import { PurchaseHistoryContent } from "@/app/dashboard/[search_space_id]/user-settings/components/PurchaseHistoryContent";
|
||||
import { DesktopContent } from "@/app/dashboard/[search_space_id]/user-settings/components/DesktopContent";
|
||||
import { userSettingsDialogAtom } from "@/atoms/settings/settings-dialog.atoms";
|
||||
import { SettingsDialog } from "@/components/settings/settings-dialog";
|
||||
|
||||
|
|
@ -37,6 +38,9 @@ export function UserSettingsDialog() {
|
|||
label: "Purchase History",
|
||||
icon: <Receipt className="h-4 w-4" />,
|
||||
},
|
||||
...(typeof window !== "undefined" && window.electronAPI
|
||||
? [{ value: "desktop", label: "Desktop", icon: <Monitor className="h-4 w-4" /> }]
|
||||
: []),
|
||||
];
|
||||
|
||||
return (
|
||||
|
|
@ -54,6 +58,7 @@ export function UserSettingsDialog() {
|
|||
{state.initialTab === "prompts" && <PromptsContent />}
|
||||
{state.initialTab === "community-prompts" && <CommunityPromptsContent />}
|
||||
{state.initialTab === "purchases" && <PurchaseHistoryContent />}
|
||||
{state.initialTab === "desktop" && <DesktopContent />}
|
||||
</div>
|
||||
</SettingsDialog>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -307,6 +307,7 @@ export function Image({
|
|||
src={src}
|
||||
alt={alt}
|
||||
fill
|
||||
sizes="(max-width: 512px) 100vw, 512px"
|
||||
className={cn(
|
||||
"transition-transform duration-300",
|
||||
fit === "cover" ? "object-cover" : "object-contain",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import React, {
|
|||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
|
@ -201,9 +202,9 @@ const Tabs = forwardRef<
|
|||
},
|
||||
[onValueChange, value]
|
||||
);
|
||||
|
||||
const contextValue = useMemo(() => ({ activeValue, onValueChange: handleValueChange }), [activeValue, handleValueChange]);
|
||||
return (
|
||||
<TabsContext.Provider value={{ activeValue, onValueChange: handleValueChange }}>
|
||||
<TabsContext.Provider value={contextValue}>
|
||||
<div ref={ref} className={cn("tabs-container", className)} {...props}>
|
||||
{children}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import type { VariantProps } from "class-variance-authority";
|
|||
import * as React from "react";
|
||||
import { toggleVariants } from "@/components/ui/toggle";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useMemo } from "react";
|
||||
|
||||
const ToggleGroupContext = React.createContext<
|
||||
VariantProps<typeof toggleVariants> & {
|
||||
|
|
@ -27,6 +28,8 @@ function ToggleGroup({
|
|||
VariantProps<typeof toggleVariants> & {
|
||||
spacing?: number;
|
||||
}) {
|
||||
const contextValue = useMemo(() => ({variant, size, spacing }), [variant, size, spacing]);
|
||||
|
||||
return (
|
||||
<ToggleGroupPrimitive.Root
|
||||
data-slot="toggle-group"
|
||||
|
|
@ -40,7 +43,7 @@ function ToggleGroup({
|
|||
)}
|
||||
{...props}
|
||||
>
|
||||
<ToggleGroupContext.Provider value={{ variant, size, spacing }}>
|
||||
<ToggleGroupContext.Provider value={contextValue}>
|
||||
{children}
|
||||
</ToggleGroupContext.Provider>
|
||||
</ToggleGroupPrimitive.Root>
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
"use client";
|
||||
|
||||
import type React from "react";
|
||||
import { createContext, useContext, useEffect, useState } from "react";
|
||||
import { createContext, useCallback, useContext, useEffect, useMemo, useState } from "react";
|
||||
import enMessages from "../messages/en.json";
|
||||
import esMessages from "../messages/es.json";
|
||||
import hiMessages from "../messages/hi.json";
|
||||
import ptMessages from "../messages/pt.json";
|
||||
import zhMessages from "../messages/zh.json";
|
||||
import { set } from "zod";
|
||||
|
||||
type Locale = "en" | "es" | "pt" | "hi" | "zh";
|
||||
|
||||
|
|
@ -49,14 +50,14 @@ export function LocaleProvider({ children }: { children: React.ReactNode }) {
|
|||
}, []);
|
||||
|
||||
// Update locale and persist to localStorage
|
||||
const setLocale = (newLocale: Locale) => {
|
||||
const setLocale = useCallback((newLocale: Locale) => {
|
||||
setLocaleState(newLocale);
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.setItem(LOCALE_STORAGE_KEY, newLocale);
|
||||
// Update HTML lang attribute
|
||||
document.documentElement.lang = newLocale;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Set HTML lang attribute when locale changes
|
||||
useEffect(() => {
|
||||
|
|
@ -65,8 +66,10 @@ export function LocaleProvider({ children }: { children: React.ReactNode }) {
|
|||
}
|
||||
}, [locale, mounted]);
|
||||
|
||||
const contextValue = useMemo(() => ({ locale, messages, setLocale }), [locale, messages, setLocale]);
|
||||
|
||||
return (
|
||||
<LocaleContext.Provider value={{ locale, messages, setLocale }}>
|
||||
<LocaleContext.Provider value={contextValue}>
|
||||
{children}
|
||||
</LocaleContext.Provider>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -264,9 +264,11 @@ export const llmPreferences = z.object({
|
|||
agent_llm_id: z.union([z.number(), z.null()]).optional(),
|
||||
document_summary_llm_id: z.union([z.number(), z.null()]).optional(),
|
||||
image_generation_config_id: z.union([z.number(), z.null()]).optional(),
|
||||
vision_llm_id: z.union([z.number(), z.null()]).optional(),
|
||||
agent_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
document_summary_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
image_generation_config: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
vision_llm: z.union([z.record(z.string(), z.unknown()), z.null()]).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
@ -287,6 +289,7 @@ export const updateLLMPreferencesRequest = z.object({
|
|||
agent_llm_id: true,
|
||||
document_summary_llm_id: true,
|
||||
image_generation_config_id: true,
|
||||
vision_llm_id: true,
|
||||
}),
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { copyToClipboard as copyToClipboardUtil } from "@/lib/utils";
|
||||
|
|
@ -14,6 +14,13 @@ export function useApiKey(): UseApiKeyReturn {
|
|||
const [apiKey, setApiKey] = useState<string | null>(null);
|
||||
const [copied, setCopied] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const copyTimerRef = useRef<ReturnType<typeof setTimeout> | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (copyTimerRef.current) clearTimeout(copyTimerRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// Load API key from localStorage
|
||||
|
|
@ -41,7 +48,8 @@ export function useApiKey(): UseApiKeyReturn {
|
|||
if (success) {
|
||||
setCopied(true);
|
||||
toast.success("API key copied to clipboard");
|
||||
setTimeout(() => {
|
||||
if (copyTimerRef.current) clearTimeout(copyTimerRef.current);
|
||||
copyTimerRef.current = setTimeout(() => {
|
||||
setCopied(false);
|
||||
}, 2000);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ export function getMDXComponents(components?: MDXComponents): MDXComponents {
|
|||
img: ({ className, alt, ...props }: React.ComponentProps<"img">) => (
|
||||
<Image
|
||||
{...(props as ImageProps)}
|
||||
className={cn("rounded-md border", className)}
|
||||
alt={alt ?? ""}
|
||||
sizes="(max-width: 768px) 100vw, 896px"
|
||||
className={cn("rounded-md border", className)}
|
||||
/>
|
||||
),
|
||||
Video: ({ className, ...props }: React.ComponentProps<"video">) => (
|
||||
|
|
|
|||
14
surfsense_web/types/window.d.ts
vendored
14
surfsense_web/types/window.d.ts
vendored
|
|
@ -48,6 +48,20 @@ interface ElectronAPI {
|
|||
setQuickAskMode: (mode: string) => Promise<void>;
|
||||
getQuickAskMode: () => Promise<string>;
|
||||
replaceText: (text: string) => Promise<void>;
|
||||
// Permissions
|
||||
getPermissionsStatus: () => Promise<{
|
||||
accessibility: 'authorized' | 'denied' | 'not determined' | 'restricted' | 'limited';
|
||||
screenRecording: 'authorized' | 'denied' | 'not determined' | 'restricted' | 'limited';
|
||||
}>;
|
||||
requestAccessibility: () => Promise<void>;
|
||||
requestScreenRecording: () => Promise<void>;
|
||||
restartApp: () => Promise<void>;
|
||||
// Autocomplete
|
||||
onAutocompleteContext: (callback: (data: { screenshot: string; searchSpaceId?: string; appName?: string; windowTitle?: string }) => void) => () => void;
|
||||
acceptSuggestion: (text: string) => Promise<void>;
|
||||
dismissSuggestion: () => Promise<void>;
|
||||
setAutocompleteEnabled: (enabled: boolean) => Promise<void>;
|
||||
getAutocompleteEnabled: () => Promise<boolean>;
|
||||
// Folder sync
|
||||
selectFolder: () => Promise<string | null>;
|
||||
addWatchedFolder: (config: WatchedFolderConfig) => Promise<WatchedFolderConfig[]>;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue