mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 08:46:22 +02:00
replace text-based autocomplete with vision-based endpoint
This commit is contained in:
parent
ced7f7562a
commit
aeb3f13f91
6 changed files with 102 additions and 133 deletions
|
|
@ -1,28 +1,29 @@
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import User, get_async_session
|
from app.db import User, get_async_session
|
||||||
from app.services.autocomplete_service import stream_autocomplete
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
from app.services.vision_autocomplete_service import stream_vision_autocomplete
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
|
||||||
router = APIRouter(prefix="/autocomplete", tags=["autocomplete"])
|
router = APIRouter(prefix="/autocomplete", tags=["autocomplete"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream")
|
class VisionAutocompleteRequest(BaseModel):
|
||||||
async def autocomplete_stream(
|
screenshot: str
|
||||||
text: str = Query(..., description="Current text in the input field"),
|
search_space_id: int
|
||||||
cursor_position: int = Query(-1, description="Cursor position in the text (-1 for end)"),
|
|
||||||
search_space_id: int = Query(..., description="Search space ID for KB context and LLM config"),
|
|
||||||
|
@router.post("/vision/stream")
|
||||||
|
async def vision_autocomplete_stream(
|
||||||
|
body: VisionAutocompleteRequest,
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
):
|
):
|
||||||
if cursor_position < 0:
|
|
||||||
cursor_position = len(text)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_autocomplete(text, cursor_position, search_space_id, session),
|
stream_vision_autocomplete(body.screenshot, body.search_space_id, session),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
**VercelStreamingService.get_response_headers(),
|
**VercelStreamingService.get_response_headers(),
|
||||||
|
|
|
||||||
|
|
@ -1,110 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
|
||||||
from app.services.llm_service import get_agent_llm
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """You are an inline text autocomplete engine. Your job is to complete the user's text naturally.
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
- Output ONLY the continuation text. Do NOT repeat what the user already typed.
|
|
||||||
- Keep completions concise: 1-3 sentences maximum.
|
|
||||||
- Match the user's tone, style, and language.
|
|
||||||
- If knowledge base context is provided, use it to make the completion factually accurate and personalized.
|
|
||||||
- Do NOT add quotes, explanations, or meta-commentary.
|
|
||||||
- Do NOT start with a space unless grammatically required.
|
|
||||||
- If you cannot produce a useful completion, output nothing."""
|
|
||||||
|
|
||||||
KB_CONTEXT_TEMPLATE = """
|
|
||||||
Relevant knowledge base context (use this to personalize the completion):
|
|
||||||
---
|
|
||||||
{kb_context}
|
|
||||||
---
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
async def _retrieve_kb_context(
|
|
||||||
session: AsyncSession,
|
|
||||||
text: str,
|
|
||||||
search_space_id: int,
|
|
||||||
) -> str:
|
|
||||||
try:
|
|
||||||
retriever = ChucksHybridSearchRetriever(session)
|
|
||||||
chunks = await retriever.vector_search(
|
|
||||||
query_text=text[-200:],
|
|
||||||
top_k=3,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
)
|
|
||||||
if not chunks:
|
|
||||||
return ""
|
|
||||||
snippets = []
|
|
||||||
for chunk in chunks:
|
|
||||||
content = getattr(chunk, "content", None) or getattr(chunk, "chunk_text", "")
|
|
||||||
if content:
|
|
||||||
snippets.append(content[:300])
|
|
||||||
if not snippets:
|
|
||||||
return ""
|
|
||||||
return KB_CONTEXT_TEMPLATE.format(kb_context="\n\n".join(snippets))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"KB search failed for autocomplete, proceeding without context: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_autocomplete(
|
|
||||||
text: str,
|
|
||||||
cursor_position: int,
|
|
||||||
search_space_id: int,
|
|
||||||
session: AsyncSession,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Build context, call the LLM, and yield SSE-formatted tokens."""
|
|
||||||
streaming = VercelStreamingService()
|
|
||||||
text_before_cursor = text[:cursor_position] if cursor_position >= 0 else text
|
|
||||||
|
|
||||||
if not text_before_cursor.strip():
|
|
||||||
yield streaming.format_message_start()
|
|
||||||
yield streaming.format_finish()
|
|
||||||
yield streaming.format_done()
|
|
||||||
return
|
|
||||||
|
|
||||||
kb_context = await _retrieve_kb_context(session, text_before_cursor, search_space_id)
|
|
||||||
|
|
||||||
llm = await get_agent_llm(session, search_space_id)
|
|
||||||
if not llm:
|
|
||||||
yield streaming.format_message_start()
|
|
||||||
yield streaming.format_error("No LLM configured for this search space")
|
|
||||||
yield streaming.format_done()
|
|
||||||
return
|
|
||||||
|
|
||||||
system_prompt = SYSTEM_PROMPT
|
|
||||||
if kb_context:
|
|
||||||
system_prompt += kb_context
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
HumanMessage(content=f"Complete this text:\n{text_before_cursor}"),
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield streaming.format_message_start()
|
|
||||||
text_id = streaming.generate_text_id()
|
|
||||||
yield streaming.format_text_start(text_id)
|
|
||||||
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
token = chunk.content if hasattr(chunk, "content") else str(chunk)
|
|
||||||
if token:
|
|
||||||
yield streaming.format_text_delta(text_id, token)
|
|
||||||
|
|
||||||
yield streaming.format_text_end(text_id)
|
|
||||||
yield streaming.format_finish()
|
|
||||||
yield streaming.format_done()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Autocomplete streaming error: {e}")
|
|
||||||
yield streaming.format_error(str(e))
|
|
||||||
yield streaming.format_done()
|
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.llm_service import get_vision_llm
|
||||||
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VISION_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||||
|
|
||||||
|
You will receive a screenshot of the user's screen. Your job:
|
||||||
|
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||||
|
2. Identify the text area where the user will type.
|
||||||
|
3. Based on the full visual context, generate the text the user most likely wants to write.
|
||||||
|
|
||||||
|
Key behavior:
|
||||||
|
- If the text area is EMPTY, draft a full response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||||
|
- If the text area already has text, continue it naturally.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
|
||||||
|
- Be concise but complete — a full thought, not a fragment.
|
||||||
|
- Match the tone and formality of the surrounding context.
|
||||||
|
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||||
|
- Do NOT describe the screenshot or explain your reasoning.
|
||||||
|
- If you cannot determine what to write, output nothing."""
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_vision_autocomplete(
|
||||||
|
screenshot_data_url: str,
|
||||||
|
search_space_id: int,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Analyze a screenshot with the vision LLM and stream a text completion."""
|
||||||
|
streaming = VercelStreamingService()
|
||||||
|
|
||||||
|
llm = await get_vision_llm(session, search_space_id)
|
||||||
|
if not llm:
|
||||||
|
yield streaming.format_message_start()
|
||||||
|
yield streaming.format_error("No Vision LLM configured for this search space")
|
||||||
|
yield streaming.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=VISION_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(content=[
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Analyze this screenshot. Understand the full context of what the user is working on, then generate the text they most likely want to write in the active text area.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": screenshot_data_url},
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield streaming.format_message_start()
|
||||||
|
text_id = streaming.generate_text_id()
|
||||||
|
yield streaming.format_text_start(text_id)
|
||||||
|
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||||
|
if token:
|
||||||
|
yield streaming.format_text_delta(text_id, token)
|
||||||
|
|
||||||
|
yield streaming.format_text_end(text_id)
|
||||||
|
yield streaming.format_finish()
|
||||||
|
yield streaming.format_done()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Vision autocomplete streaming error: {e}")
|
||||||
|
yield streaming.format_error(str(e))
|
||||||
|
yield streaming.format_done()
|
||||||
|
|
@ -26,8 +26,8 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||||
requestAccessibility: () => ipcRenderer.invoke(IPC_CHANNELS.REQUEST_ACCESSIBILITY),
|
requestAccessibility: () => ipcRenderer.invoke(IPC_CHANNELS.REQUEST_ACCESSIBILITY),
|
||||||
restartApp: () => ipcRenderer.invoke(IPC_CHANNELS.RESTART_APP),
|
restartApp: () => ipcRenderer.invoke(IPC_CHANNELS.RESTART_APP),
|
||||||
// Autocomplete
|
// Autocomplete
|
||||||
onAutocompleteContext: (callback: (data: { text: string; cursorPosition: number; searchSpaceId?: string }) => void) => {
|
onAutocompleteContext: (callback: (data: { screenshot: string; searchSpaceId?: string }) => void) => {
|
||||||
const listener = (_event: unknown, data: { text: string; cursorPosition: number; searchSpaceId?: string }) => callback(data);
|
const listener = (_event: unknown, data: { screenshot: string; searchSpaceId?: string }) => callback(data);
|
||||||
ipcRenderer.on(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, listener);
|
ipcRenderer.on(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, listener);
|
||||||
return () => {
|
return () => {
|
||||||
ipcRenderer.removeListener(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, listener);
|
ipcRenderer.removeListener(IPC_CHANNELS.AUTOCOMPLETE_CONTEXT, listener);
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ export default function SuggestionPage() {
|
||||||
const abortRef = useRef<AbortController | null>(null);
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
const fetchSuggestion = useCallback(
|
const fetchSuggestion = useCallback(
|
||||||
async (text: string, cursorPosition: number, searchSpaceId: string) => {
|
async (screenshot: string, searchSpaceId: string) => {
|
||||||
abortRef.current?.abort();
|
abortRef.current?.abort();
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
abortRef.current = controller;
|
abortRef.current = controller;
|
||||||
|
|
@ -37,21 +37,19 @@ export default function SuggestionPage() {
|
||||||
const backendUrl =
|
const backendUrl =
|
||||||
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
|
||||||
const params = new URLSearchParams({
|
|
||||||
text,
|
|
||||||
cursor_position: String(cursorPosition),
|
|
||||||
search_space_id: searchSpaceId,
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`${backendUrl}/api/v1/autocomplete/stream?${params}`,
|
`${backendUrl}/api/v1/autocomplete/vision/stream`,
|
||||||
{
|
{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
Authorization: `Bearer ${token}`,
|
Authorization: `Bearer ${token}`,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
screenshot,
|
||||||
|
search_space_id: parseInt(searchSpaceId, 10),
|
||||||
|
}),
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
@ -119,7 +117,9 @@ export default function SuggestionPage() {
|
||||||
|
|
||||||
const cleanup = window.electronAPI.onAutocompleteContext((data) => {
|
const cleanup = window.electronAPI.onAutocompleteContext((data) => {
|
||||||
const searchSpaceId = data.searchSpaceId || "1";
|
const searchSpaceId = data.searchSpaceId || "1";
|
||||||
fetchSuggestion(data.text, data.cursorPosition, searchSpaceId);
|
if (data.screenshot) {
|
||||||
|
fetchSuggestion(data.screenshot, searchSpaceId);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return cleanup;
|
return cleanup;
|
||||||
|
|
|
||||||
2
surfsense_web/types/window.d.ts
vendored
2
surfsense_web/types/window.d.ts
vendored
|
|
@ -21,7 +21,7 @@ interface ElectronAPI {
|
||||||
requestAccessibility: () => Promise<void>;
|
requestAccessibility: () => Promise<void>;
|
||||||
restartApp: () => Promise<void>;
|
restartApp: () => Promise<void>;
|
||||||
// Autocomplete
|
// Autocomplete
|
||||||
onAutocompleteContext: (callback: (data: { text: string; cursorPosition: number; searchSpaceId?: string }) => void) => () => void;
|
onAutocompleteContext: (callback: (data: { screenshot: string; searchSpaceId?: string }) => void) => () => void;
|
||||||
acceptSuggestion: (text: string) => Promise<void>;
|
acceptSuggestion: (text: string) => Promise<void>;
|
||||||
dismissSuggestion: () => Promise<void>;
|
dismissSuggestion: () => Promise<void>;
|
||||||
updateSuggestionText: (text: string) => Promise<void>;
|
updateSuggestionText: (text: string) => Promise<void>;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue