diff --git a/surfsense_backend/app/agents/new_chat/checkpointer.py b/surfsense_backend/app/agents/new_chat/checkpointer.py index 8db3706a7..637b2926f 100644 --- a/surfsense_backend/app/agents/new_chat/checkpointer.py +++ b/surfsense_backend/app/agents/new_chat/checkpointer.py @@ -18,59 +18,59 @@ _checkpointer_initialized: bool = False def get_postgres_connection_string() -> str: """ Convert the async DATABASE_URL to a sync postgres connection string for psycopg3. - + The DATABASE_URL is typically in format: postgresql+asyncpg://user:pass@host:port/dbname - + We need to convert it to: postgresql://user:pass@host:port/dbname """ db_url = config.DATABASE_URL - + # Handle asyncpg driver prefix if db_url.startswith("postgresql+asyncpg://"): return db_url.replace("postgresql+asyncpg://", "postgresql://") - + # Handle other async prefixes if "+asyncpg" in db_url: return db_url.replace("+asyncpg", "") - + return db_url async def get_checkpointer() -> AsyncPostgresSaver: """ Get or create the global AsyncPostgresSaver instance. - + This function: 1. Creates the checkpointer if it doesn't exist 2. Sets up the required database tables on first call 3. Returns the cached instance on subsequent calls - + Returns: AsyncPostgresSaver: The configured checkpointer instance """ global _checkpointer, _checkpointer_context, _checkpointer_initialized - + if _checkpointer is None: conn_string = get_postgres_connection_string() # from_conn_string returns an async context manager # We need to enter the context to get the actual checkpointer _checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string) _checkpointer = await _checkpointer_context.__aenter__() - + # Setup tables on first call (idempotent) if not _checkpointer_initialized: await _checkpointer.setup() _checkpointer_initialized = True - + return _checkpointer async def setup_checkpointer_tables() -> None: """ Explicitly setup the checkpointer tables. - + This can be called during application startup to ensure tables exist before any agent calls. """ @@ -81,15 +81,14 @@ async def setup_checkpointer_tables() -> None: async def close_checkpointer() -> None: """ Close the checkpointer connection. - + This should be called during application shutdown. """ global _checkpointer, _checkpointer_context, _checkpointer_initialized - + if _checkpointer_context is not None: await _checkpointer_context.__aexit__(None, None, None) _checkpointer = None _checkpointer_context = None _checkpointer_initialized = False print("[Checkpointer] PostgreSQL connection closed") - diff --git a/surfsense_backend/app/agents/new_chat/new_chat_test.py b/surfsense_backend/app/agents/new_chat/new_chat_test.py index 6a4e9bd02..857fee6cc 100644 --- a/surfsense_backend/app/agents/new_chat/new_chat_test.py +++ b/surfsense_backend/app/agents/new_chat/new_chat_test.py @@ -81,4 +81,4 @@ async def run_test(): if __name__ == "__main__": - asyncio.run(run_test()) \ No newline at end of file + asyncio.run(run_test()) diff --git a/surfsense_backend/app/agents/new_chat/podcast.py b/surfsense_backend/app/agents/new_chat/podcast.py index d57d0fb21..46974d184 100644 --- a/surfsense_backend/app/agents/new_chat/podcast.py +++ b/surfsense_backend/app/agents/new_chat/podcast.py @@ -126,7 +126,9 @@ def create_generate_podcast_tool( # Check if a podcast is already being generated for this search space active_task_id = get_active_podcast_task(search_space_id) if active_task_id: - print(f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}") + print( + f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}" + ) return { "status": "already_generating", "task_id": active_task_id, diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index b81e2d36d..7d7e88a28 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -5,7 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.ext.asyncio import AsyncSession from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware -from app.agents.new_chat.checkpointer import close_checkpointer, setup_checkpointer_tables +from app.agents.new_chat.checkpointer import ( + close_checkpointer, + setup_checkpointer_tables, +) from app.config import config from app.db import User, create_db_and_tables, get_async_session from app.routes import router as crud_router diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index c338240b3..9b7811022 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -332,6 +332,75 @@ class Chat(BaseModel, TimestampMixin): search_space = relationship("SearchSpace", back_populates="chats") +class NewChatMessageRole(str, Enum): + """Role enum for new chat messages.""" + + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +class NewChatThread(BaseModel, TimestampMixin): + """ + Thread model for the new chat feature using assistant-ui. + Each thread represents a conversation with message history. + LangGraph checkpointer uses thread_id for state persistence. + """ + + __tablename__ = "new_chat_threads" + + title = Column(String(500), nullable=False, default="New Chat", index=True) + archived = Column(Boolean, nullable=False, default=False) + updated_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + index=True, + ) + + # Foreign keys + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) + + # Relationships + search_space = relationship("SearchSpace", back_populates="new_chat_threads") + messages = relationship( + "NewChatMessage", + back_populates="thread", + order_by="NewChatMessage.created_at", + cascade="all, delete-orphan", + ) + + +class NewChatMessage(BaseModel, TimestampMixin): + """ + Message model for the new chat feature. + Stores individual messages in assistant-ui format. + """ + + __tablename__ = "new_chat_messages" + + role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False) + # Content stored as JSONB to support rich content (text, tool calls, etc.) + content = Column(JSONB, nullable=False) + + # Foreign key to thread + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # Relationship + thread = relationship("NewChatThread", back_populates="messages") + + class Document(BaseModel, TimestampMixin): __tablename__ = "documents" @@ -435,6 +504,12 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="Chat.id", cascade="all, delete-orphan", ) + new_chat_threads = relationship( + "NewChatThread", + back_populates="search_space", + order_by="NewChatThread.updated_at.desc()", + cascade="all, delete-orphan", + ) logs = relationship( "Log", back_populates="search_space", diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index c9d70588d..5430e8b1e 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -15,6 +15,7 @@ from .google_gmail_add_connector_route import ( from .llm_config_routes import router as llm_config_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router +from .new_chat_routes import router as new_chat_router from .notes_routes import router as notes_router from .podcasts_routes import router as podcasts_router from .rbac_routes import router as rbac_router @@ -30,6 +31,7 @@ router.include_router(documents_router) router.include_router(notes_router) router.include_router(podcasts_router) router.include_router(chats_router) +router.include_router(new_chat_router) # New chat with assistant-ui persistence router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py new file mode 100644 index 000000000..c3102db67 --- /dev/null +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -0,0 +1,597 @@ +""" +Routes for the new chat feature with assistant-ui integration. + +These endpoints support the ThreadHistoryAdapter pattern from assistant-ui: +- GET /threads - List threads for sidebar (ThreadListPrimitive) +- POST /threads - Create a new thread +- GET /threads/{thread_id} - Get thread with messages (load) +- PUT /threads/{thread_id} - Update thread (rename, archive) +- DELETE /threads/{thread_id} - Delete thread +- POST /threads/{thread_id}/messages - Append message +""" + +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.db import ( + NewChatMessage, + NewChatMessageRole, + NewChatThread, + Permission, + User, + get_async_session, +) +from app.schemas.new_chat import ( + NewChatMessageAppend, + NewChatMessageRead, + NewChatThreadCreate, + NewChatThreadRead, + NewChatThreadUpdate, + NewChatThreadWithMessages, + ThreadHistoryLoadResponse, + ThreadListItem, + ThreadListResponse, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() + + +# ============================================================================= +# Thread Endpoints +# ============================================================================= + + +@router.get("/threads", response_model=ThreadListResponse) +async def list_threads( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all threads for the current user in a search space. + Returns threads and archived_threads for ThreadListPrimitive. + + Requires CHATS_READ permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to read chats in this search space", + ) + + # Get all threads for this user in this search space + query = ( + select(NewChatThread) + .filter( + NewChatThread.search_space_id == search_space_id, + NewChatThread.user_id == user.id, + ) + .order_by(NewChatThread.updated_at.desc()) + ) + + result = await session.execute(query) + all_threads = result.scalars().all() + + # Separate active and archived threads + threads = [] + archived_threads = [] + + for thread in all_threads: + item = ThreadListItem( + id=thread.id, + title=thread.title, + archived=thread.archived, + createdAt=thread.created_at, + updatedAt=thread.updated_at, + ) + if thread.archived: + archived_threads.append(item) + else: + threads.append(item) + + return ThreadListResponse(threads=threads, archived_threads=archived_threads) + + except HTTPException: + raise + except OperationalError: + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while fetching threads: {e!s}", + ) from None + + +@router.post("/threads", response_model=NewChatThreadRead) +async def create_thread( + thread: NewChatThreadCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new chat thread. + + Requires CHATS_CREATE permission. + """ + try: + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_CREATE.value, + "You don't have permission to create chats in this search space", + ) + + now = datetime.now(UTC) + db_thread = NewChatThread( + title=thread.title, + archived=thread.archived, + search_space_id=thread.search_space_id, + user_id=user.id, + updated_at=now, + ) + session.add(db_thread) + await session.commit() + await session.refresh(db_thread) + return db_thread + + except HTTPException: + raise + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + except OperationalError: + await session.rollback() + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while creating the thread: {e!s}", + ) from None + + +@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse) +async def get_thread_messages( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a thread with all its messages. + This is used by ThreadHistoryAdapter.load() to restore conversation. + + Requires CHATS_READ permission. + """ + try: + # Get thread with messages + result = await session.execute( + select(NewChatThread) + .options(selectinload(NewChatThread.messages)) + .filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + # Check permission and ownership + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to read chats in this search space", + ) + + # Ensure user owns this thread + if thread.user_id != user.id: + raise HTTPException( + status_code=403, detail="You don't have access to this thread" + ) + + # Return messages in the format expected by assistant-ui + messages = [ + NewChatMessageRead( + id=msg.id, + thread_id=msg.thread_id, + role=msg.role, + content=msg.content, + created_at=msg.created_at, + ) + for msg in thread.messages + ] + + return ThreadHistoryLoadResponse(messages=messages) + + except HTTPException: + raise + except OperationalError: + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while fetching the thread: {e!s}", + ) from None + + +@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages) +async def get_thread_full( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get full thread details with all messages. + + Requires CHATS_READ permission. + """ + try: + result = await session.execute( + select(NewChatThread) + .options(selectinload(NewChatThread.messages)) + .filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to read chats in this search space", + ) + + if thread.user_id != user.id: + raise HTTPException( + status_code=403, detail="You don't have access to this thread" + ) + + return thread + + except HTTPException: + raise + except OperationalError: + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while fetching the thread: {e!s}", + ) from None + + +@router.put("/threads/{thread_id}", response_model=NewChatThreadRead) +async def update_thread( + thread_id: int, + thread_update: NewChatThreadUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update a thread (title, archived status). + Used for renaming and archiving threads. + + Requires CHATS_UPDATE permission. + """ + try: + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + db_thread = result.scalars().first() + + if not db_thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + db_thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + + if db_thread.user_id != user.id: + raise HTTPException( + status_code=403, detail="You don't have access to this thread" + ) + + # Update fields + update_data = thread_update.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(db_thread, key, value) + + db_thread.updated_at = datetime.now(UTC) + + await session.commit() + await session.refresh(db_thread) + return db_thread + + except HTTPException: + raise + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + except OperationalError: + await session.rollback() + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while updating the thread: {e!s}", + ) from None + + +@router.delete("/threads/{thread_id}", response_model=dict) +async def delete_thread( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete a thread and all its messages. + + Requires CHATS_DELETE permission. + """ + try: + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + db_thread = result.scalars().first() + + if not db_thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + db_thread.search_space_id, + Permission.CHATS_DELETE.value, + "You don't have permission to delete chats in this search space", + ) + + if db_thread.user_id != user.id: + raise HTTPException( + status_code=403, detail="You don't have access to this thread" + ) + + await session.delete(db_thread) + await session.commit() + return {"message": "Thread deleted successfully"} + + except HTTPException: + raise + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=400, detail="Cannot delete thread due to existing dependencies." + ) from None + except OperationalError: + await session.rollback() + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while deleting the thread: {e!s}", + ) from None + + +# ============================================================================= +# Message Endpoints +# ============================================================================= + + +@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead) +async def append_message( + thread_id: int, + request: Request, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Append a message to a thread. + This is used by ThreadHistoryAdapter.append() to persist messages. + + Requires CHATS_UPDATE permission. + """ + try: + # Parse raw body - extract only role and content, ignoring extra fields + raw_body = await request.json() + role = raw_body.get("role") + content = raw_body.get("content") + + if not role: + raise HTTPException(status_code=400, detail="Missing required field: role") + if content is None: + raise HTTPException( + status_code=400, detail="Missing required field: content" + ) + + # Create message object manually + message = NewChatMessageAppend(role=role, content=content) + # Get thread + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + + if thread.user_id != user.id: + raise HTTPException( + status_code=403, detail="You don't have access to this thread" + ) + + # Convert string role to enum + role_str = ( + message.role.lower() if isinstance(message.role, str) else message.role + ) + try: + message_role = NewChatMessageRole(role_str) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid role: {message.role}. Must be 'user', 'assistant', or 'system'.", + ) from None + + # Create message + db_message = NewChatMessage( + thread_id=thread_id, + role=message_role, + content=message.content, + ) + session.add(db_message) + + # Update thread's updated_at timestamp + thread.updated_at = datetime.now(UTC) + + # Auto-generate title from first user message if title is still default + if thread.title == "New Chat" and role_str == "user": + # Extract text content for title + content = message.content + if isinstance(content, str): + title_text = content + elif isinstance(content, list): + # Find first text content + title_text = "" + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + title_text = part.get("text", "") + break + elif isinstance(part, str): + title_text = part + break + else: + title_text = str(content) + + # Truncate title + if title_text: + thread.title = title_text[:100] + ( + "..." if len(title_text) > 100 else "" + ) + + await session.commit() + await session.refresh(db_message) + return db_message + + except HTTPException: + raise + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + except OperationalError: + await session.rollback() + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while appending the message: {e!s}", + ) from None + + +@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead]) +async def list_messages( + thread_id: int, + skip: int = 0, + limit: int = 100, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List messages in a thread with pagination. + + Requires CHATS_READ permission. + """ + try: + # Verify thread exists and user has access + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to read chats in this search space", + ) + + if thread.user_id != user.id: + raise HTTPException( + status_code=403, detail="You don't have access to this thread" + ) + + # Get messages + query = ( + select(NewChatMessage) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at) + .offset(skip) + .limit(limit) + ) + + result = await session.execute(query) + return result.scalars().all() + + except HTTPException: + raise + except OperationalError: + raise HTTPException( + status_code=503, detail="Database operation failed. Please try again later." + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while fetching messages: {e!s}", + ) from None diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index a4308f6a2..92f9cdc78 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -21,6 +21,18 @@ from .documents import ( ) from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate +from .new_chat import ( + NewChatMessageAppend, + NewChatMessageCreate, + NewChatMessageRead, + NewChatThreadCreate, + NewChatThreadRead, + NewChatThreadUpdate, + NewChatThreadWithMessages, + ThreadHistoryLoadResponse, + ThreadListItem, + ThreadListResponse, +) from .podcasts import ( PodcastBase, PodcastCreate, @@ -98,7 +110,15 @@ __all__ = [ "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", + # New chat schemas (assistant-ui integration) + "NewChatMessageAppend", + "NewChatMessageCreate", + "NewChatMessageRead", "NewChatRequest", + "NewChatThreadCreate", + "NewChatThreadRead", + "NewChatThreadUpdate", + "NewChatThreadWithMessages", "PaginatedResponse", "PermissionInfo", "PermissionsListResponse", @@ -119,6 +139,9 @@ __all__ = [ "SearchSpaceRead", "SearchSpaceUpdate", "SearchSpaceWithStats", + "ThreadHistoryLoadResponse", + "ThreadListItem", + "ThreadListResponse", "TimestampModel", "UserCreate", "UserRead", diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py new file mode 100644 index 000000000..e1cf4efb8 --- /dev/null +++ b/surfsense_backend/app/schemas/new_chat.py @@ -0,0 +1,129 @@ +""" +Pydantic schemas for the new chat feature with assistant-ui integration. + +These schemas follow the assistant-ui ThreadHistoryAdapter pattern: +- ThreadRecord: id, title, archived, createdAt, updatedAt +- MessageRecord: id, threadId, role, content, createdAt +""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import NewChatMessageRole + +from .base import IDModel, TimestampModel + +# ============================================================================= +# Message Schemas +# ============================================================================= + + +class NewChatMessageBase(BaseModel): + """Base schema for new chat messages.""" + + role: NewChatMessageRole + content: Any # JSONB content - can be text, tool calls, etc. + + +class NewChatMessageCreate(NewChatMessageBase): + """Schema for creating a new message.""" + + thread_id: int + + +class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): + """Schema for reading a message.""" + + thread_id: int + model_config = ConfigDict(from_attributes=True) + + +class NewChatMessageAppend(BaseModel): + """ + Schema for appending a message via the history adapter. + This is the format assistant-ui sends when calling append(). + """ + + role: str # Accept string and validate in route handler + content: Any + + +# ============================================================================= +# Thread Schemas +# ============================================================================= + + +class NewChatThreadBase(BaseModel): + """Base schema for new chat threads.""" + + title: str = Field(default="New Chat", max_length=500) + archived: bool = False + + +class NewChatThreadCreate(NewChatThreadBase): + """Schema for creating a new thread.""" + + search_space_id: int + + +class NewChatThreadUpdate(BaseModel): + """Schema for updating a thread.""" + + title: str | None = None + archived: bool | None = None + + +class NewChatThreadRead(NewChatThreadBase, IDModel): + """ + Schema for reading a thread (matches assistant-ui ThreadRecord). + """ + + search_space_id: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class NewChatThreadWithMessages(NewChatThreadRead): + """Schema for reading a thread with its messages.""" + + messages: list[NewChatMessageRead] = [] + + +# ============================================================================= +# History Adapter Response Schemas +# ============================================================================= + + +class ThreadHistoryLoadResponse(BaseModel): + """ + Response format for the ThreadHistoryAdapter.load() method. + Returns messages array for the current thread. + """ + + messages: list[NewChatMessageRead] + + +class ThreadListItem(BaseModel): + """ + Thread list item for sidebar display. + Matches assistant-ui ThreadListPrimitive expected format. + """ + + id: int + title: str + archived: bool + created_at: datetime = Field(alias="createdAt") + updated_at: datetime = Field(alias="updatedAt") + + model_config = ConfigDict(from_attributes=True, populate_by_name=True) + + +class ThreadListResponse(BaseModel): + """Response containing list of threads for the sidebar.""" + + threads: list[ThreadListItem] + archived_threads: list[ThreadListItem] diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 1abfba193..aa28259ab 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -7,14 +7,13 @@ import sys from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool -from app.celery_app import celery_app -from app.config import config -from app.tasks.podcast_tasks import generate_chat_podcast - # Import for content-based podcast (new-chat) from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState +from app.celery_app import celery_app +from app.config import config from app.db import Podcast +from app.tasks.podcast_tasks import generate_chat_podcast logger = logging.getLogger(__name__) @@ -201,15 +200,16 @@ async def _generate_content_podcast( serializable_transcript = [] for entry in podcast_transcript: if hasattr(entry, "speaker_id"): - serializable_transcript.append({ - "speaker_id": entry.speaker_id, - "dialog": entry.dialog - }) + serializable_transcript.append( + {"speaker_id": entry.speaker_id, "dialog": entry.dialog} + ) else: - serializable_transcript.append({ - "speaker_id": entry.get("speaker_id", 0), - "dialog": entry.get("dialog", "") - }) + serializable_transcript.append( + { + "speaker_id": entry.get("speaker_id", 0), + "dialog": entry.get("dialog", ""), + } + ) # Save podcast to database podcast = Podcast( diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 159ca9c9a..5ddd097e6 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -9,12 +9,15 @@ import json from collections.abc import AsyncGenerator from uuid import UUID -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import HumanMessage from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer -from app.agents.new_chat.llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml +from app.agents.new_chat.llm_config import ( + create_chat_litellm_from_config, + load_llm_config_from_yaml, +) from app.schemas.chats import ChatMessage from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService @@ -92,7 +95,7 @@ async def stream_new_chat( # Build input with message history from frontend langchain_messages = [] - + # if messages: # # Convert frontend messages to LangChain format # for msg in messages: @@ -101,9 +104,9 @@ async def stream_new_chat( # elif msg.role == "assistant": # langchain_messages.append(AIMessage(content=msg.content)) # else: - # Fallback: just use the current user query + # Fallback: just use the current user query langchain_messages.append(HumanMessage(content=user_query)) - + input_state = { # Lets not pass this message atm because we are using the checkpointer to manage the conversation history # We will use this to simulate group chat functionality in the future @@ -219,7 +222,9 @@ async def stream_new_chat( elif isinstance(raw_output, dict): tool_output = raw_output else: - tool_output = {"result": str(raw_output) if raw_output else "completed"} + tool_output = { + "result": str(raw_output) if raw_output else "completed" + } tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" @@ -228,16 +233,25 @@ async def stream_new_chat( # Stream the full podcast result so frontend can render the audio player yield streaming_service.format_tool_output_available( tool_call_id, - tool_output if isinstance(tool_output, dict) else {"result": tool_output}, + tool_output + if isinstance(tool_output, dict) + else {"result": tool_output}, ) # Send appropriate terminal message based on status - if isinstance(tool_output, dict) and tool_output.get("status") == "success": + if ( + isinstance(tool_output, dict) + and tool_output.get("status") == "success" + ): yield streaming_service.format_terminal_info( f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", "success", ) else: - error_msg = tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error" + error_msg = ( + tool_output.get("error", "Unknown error") + if isinstance(tool_output, dict) + else "Unknown error" + ) yield streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index c0f5bf0b0..5583af64d 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -47,6 +47,9 @@ export function DashboardClientLayout({ // Check if we're on the researcher page const isResearcherPage = pathname?.includes("/researcher"); + // Check if we're on the new-chat page (uses separate thread persistence) + const isNewChatPage = pathname?.includes("/new-chat"); + // Show indicator when chat becomes active and panel is closed useEffect(() => { if (activeChatId && !isChatPannelOpen) { @@ -151,6 +154,12 @@ export function DashboardClientLayout({ }, [search_space_id]); useEffect(() => { + // Skip setting activeChatIdAtom on new-chat page (uses separate thread persistence) + if (isNewChatPage) { + setActiveChatIdState(null); + return; + } + const activeChatId = typeof chat_id === "string" ? chat_id @@ -159,7 +168,7 @@ export function DashboardClientLayout({ : ""; if (!activeChatId) return; setActiveChatIdState(activeChatId); - }, [chat_id, search_space_id]); + }, [chat_id, search_space_id, isNewChatPage]); // Show loading screen while checking onboarding status (only on first load) if (!hasCheckedOnboarding && (loading || accessLoading) && !isOnboardingPage) { diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index cd28a26fa..ab880813e 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1,23 +1,73 @@ "use client"; -import { AssistantRuntimeProvider, useLocalRuntime } from "@assistant-ui/react"; -import { useParams } from "next/navigation"; -import { useMemo } from "react"; +import { + AssistantRuntimeProvider, + useExternalStoreRuntime, + type ThreadMessageLike, +} from "@assistant-ui/react"; +import { useParams, useRouter } from "next/navigation"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Thread } from "@/components/assistant-ui/thread"; import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; -import { createNewChatAdapter } from "@/lib/chat/new-chat-transport"; +import { + createThread, + getThreadMessages, + appendMessage, + type MessageRecord, +} from "@/lib/chat/thread-persistence"; +import { getBearerToken } from "@/lib/auth-utils"; +import { toast } from "sonner"; +import { + isPodcastGenerating, + looksLikePodcastRequest, + setActivePodcastTaskId, +} from "@/lib/chat/podcast-state"; + +/** + * Convert backend message to assistant-ui ThreadMessageLike format + */ +function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { + let content: ThreadMessageLike["content"]; + + if (typeof msg.content === "string") { + content = [{ type: "text", text: msg.content }]; + } else if (Array.isArray(msg.content)) { + content = msg.content as ThreadMessageLike["content"]; + } else { + content = [{ type: "text", text: String(msg.content) }]; + } + + return { + id: `msg-${msg.id}`, + role: msg.role, + content, + createdAt: new Date(msg.created_at), + }; +} + +/** + * Tools that should render custom UI in the chat. + */ +const TOOLS_WITH_UI = new Set(["generate_podcast"]); export default function NewChatPage() { const params = useParams(); + const router = useRouter(); + const [isInitializing, setIsInitializing] = useState(true); + const [threadId, setThreadId] = useState(null); + const [messages, setMessages] = useState([]); + const [isRunning, setIsRunning] = useState(false); + const abortControllerRef = useRef(null); - // Extract search_space_id and chat_id from URL params + // Extract search_space_id from URL params const searchSpaceId = useMemo(() => { const id = params.search_space_id; const parsed = typeof id === "string" ? Number.parseInt(id, 10) : 0; return Number.isNaN(parsed) ? 0 : parsed; }, [params.search_space_id]); - const chatId = useMemo(() => { + // Extract chat_id from URL params + const urlChatId = useMemo(() => { const id = params.chat_id; let parsed = 0; if (Array.isArray(id) && id.length > 0) { @@ -28,18 +78,368 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.chat_id]); - // Create the adapter with the extracted params - const adapter = useMemo( - () => createNewChatAdapter({ searchSpaceId, chatId }), - [searchSpaceId, chatId] + // Initialize thread and load messages + const initializeThread = useCallback(async () => { + setIsInitializing(true); + + try { + if (urlChatId > 0) { + // Thread exists - load messages + setThreadId(urlChatId); + const response = await getThreadMessages(urlChatId); + if (response.messages && response.messages.length > 0) { + const loadedMessages = response.messages.map(convertToThreadMessage); + setMessages(loadedMessages); + } + } else { + // Create new thread + const newThread = await createThread(searchSpaceId, "New Chat"); + setThreadId(newThread.id); + router.replace(`/dashboard/${searchSpaceId}/new-chat/${newThread.id}`); + } + } catch (error) { + console.error("[NewChatPage] Failed to initialize thread:", error); + setThreadId(Date.now()); + } finally { + setIsInitializing(false); + } + }, [urlChatId, searchSpaceId, router]); + + // Initialize on mount + useEffect(() => { + initializeThread(); + }, [initializeThread]); + + // Cancel ongoing request + const cancelRun = useCallback(async () => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + setIsRunning(false); + }, []); + + // Handle new message from user + const onNew = useCallback( + async (message: ThreadMessageLike) => { + if (!threadId) return; + + // Extract user query text + let userQuery = ""; + for (const part of message.content) { + if (typeof part === "object" && part.type === "text" && "text" in part) { + userQuery += part.text; + } + } + + if (!userQuery.trim()) return; + + // Check if podcast is already generating + if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) { + toast.warning("A podcast is already being generated."); + return; + } + + const token = getBearerToken(); + if (!token) { + toast.error("Not authenticated. Please log in again."); + return; + } + + // Add user message to state + const userMsgId = `msg-user-${Date.now()}`; + const userMessage: ThreadMessageLike = { + id: userMsgId, + role: "user", + content: message.content, + createdAt: new Date(), + }; + setMessages((prev) => [...prev, userMessage]); + + // Persist user message (don't await, fire and forget) + appendMessage(threadId, { + role: "user", + content: message.content, + }).catch((err) => console.error("Failed to persist user message:", err)); + + // Start streaming response + setIsRunning(true); + const controller = new AbortController(); + abortControllerRef.current = controller; + + // Prepare assistant message + const assistantMsgId = `msg-assistant-${Date.now()}`; + let accumulatedText = ""; + const toolCalls = new Map< + string, + { + toolCallId: string; + toolName: string; + args: Record; + result?: unknown; + } + >(); + + // Helper to build content + const buildContent = (): ThreadMessageLike["content"] => { + const parts: Array< + | { type: "text"; text: string } + | { + type: "tool-call"; + toolCallId: string; + toolName: string; + args: Record; + result?: unknown; + } + > = []; + if (accumulatedText) { + parts.push({ type: "text", text: accumulatedText }); + } + for (const toolCall of toolCalls.values()) { + if (TOOLS_WITH_UI.has(toolCall.toolName)) { + parts.push({ + type: "tool-call", + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + args: toolCall.args, + result: toolCall.result, + }); + } + } + return parts.length > 0 + ? (parts as ThreadMessageLike["content"]) + : [{ type: "text", text: "" }]; + }; + + // Add placeholder assistant message + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); + + try { + const backendUrl = + process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + + // Build message history for context + const messageHistory = messages + .filter((m) => m.role === "user" || m.role === "assistant") + .map((m) => { + let text = ""; + for (const part of m.content) { + if ( + typeof part === "object" && + part.type === "text" && + "text" in part + ) { + text += part.text; + } + } + return { role: m.role, content: text }; + }) + .filter((m) => m.content.length > 0); + + const response = await fetch(`${backendUrl}/api/v1/new_chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chat_id: threadId, + user_query: userQuery.trim(), + search_space_id: searchSpaceId, + messages: messageHistory, + }), + signal: controller.signal, + }); + + if (!response.ok) { + throw new Error(`Backend error: ${response.status}`); + } + + if (!response.body) { + throw new Error("No response body"); + } + + // Parse SSE stream + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const events = buffer.split(/\r?\n\r?\n/); + buffer = events.pop() || ""; + + for (const event of events) { + const lines = event.split(/\r?\n/); + for (const line of lines) { + if (!line.startsWith("data: ")) continue; + const data = line.slice(6).trim(); + if (!data || data === "[DONE]") continue; + + try { + const parsed = JSON.parse(data); + + switch (parsed.type) { + case "text-delta": + accumulatedText += parsed.delta; + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, content: buildContent() } + : m + ) + ); + break; + + case "tool-input-start": + toolCalls.set(parsed.toolCallId, { + toolCallId: parsed.toolCallId, + toolName: parsed.toolName, + args: {}, + }); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, content: buildContent() } + : m + ) + ); + break; + + case "tool-input-available": { + const tc = toolCalls.get(parsed.toolCallId); + if (tc) tc.args = parsed.input || {}; + else + toolCalls.set(parsed.toolCallId, { + toolCallId: parsed.toolCallId, + toolName: parsed.toolName, + args: parsed.input || {}, + }); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, content: buildContent() } + : m + ) + ); + break; + } + + case "tool-output-available": { + const tc = toolCalls.get(parsed.toolCallId); + if (tc) { + tc.result = parsed.output; + if ( + tc.toolName === "generate_podcast" && + parsed.output?.status === "processing" && + parsed.output?.task_id + ) { + setActivePodcastTaskId(parsed.output.task_id); + } + } + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, content: buildContent() } + : m + ) + ); + break; + } + + case "error": + throw new Error(parsed.errorText || "Server error"); + } + } catch (e) { + if (e instanceof SyntaxError) continue; + throw e; + } + } + } + } + } finally { + reader.releaseLock(); + } + + // Persist assistant message + const finalContent = buildContent(); + if (accumulatedText || toolCalls.size > 0) { + appendMessage(threadId, { + role: "assistant", + content: finalContent, + }).catch((err) => + console.error("Failed to persist assistant message:", err) + ); + } + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + // Request was cancelled + return; + } + console.error("[NewChatPage] Chat error:", error); + toast.error("Failed to get response. Please try again."); + // Update assistant message with error + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [ + { + type: "text", + text: "Sorry, there was an error. Please try again.", + }, + ], + } + : m + ) + ); + } finally { + setIsRunning(false); + abortControllerRef.current = null; + } + }, + [threadId, searchSpaceId, messages] ); - // Use LocalRuntime with our custom adapter - const runtime = useLocalRuntime(adapter); + // Convert message (pass through since already in correct format) + const convertMessage = useCallback( + (message: ThreadMessageLike): ThreadMessageLike => message, + [] + ); + + // Create external store runtime + const runtime = useExternalStoreRuntime({ + messages, + isRunning, + onNew, + convertMessage, + onCancel: cancelRun, + }); + + // Show loading state + if (isInitializing) { + return ( +
+
Loading chat...
+
+ ); + } return ( - {/* Register tool UI components */}
diff --git a/surfsense_web/components/assistant-ui/thread-list.tsx b/surfsense_web/components/assistant-ui/thread-list.tsx new file mode 100644 index 000000000..de479e6b8 --- /dev/null +++ b/surfsense_web/components/assistant-ui/thread-list.tsx @@ -0,0 +1,286 @@ +"use client"; + +import { useCallback, useEffect, useState } from "react"; +import { useRouter } from "next/navigation"; +import { ArchiveIcon, MessageSquareIcon, PlusIcon, TrashIcon, MoreVerticalIcon, RotateCcwIcon } from "lucide-react"; +import { cn } from "@/lib/utils"; +import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { + type ThreadListItem, + createThreadListManager, + type ThreadListState, +} from "@/lib/chat/thread-persistence"; + +interface ThreadListProps { + searchSpaceId: number; + currentThreadId?: number; + className?: string; +} + +export function ThreadList({ searchSpaceId, currentThreadId, className }: ThreadListProps) { + const router = useRouter(); + const [state, setState] = useState({ + threads: [], + archivedThreads: [], + isLoading: true, + error: null, + }); + const [showArchived, setShowArchived] = useState(false); + + // Create the thread list manager + const manager = useCallback( + () => + createThreadListManager({ + searchSpaceId, + currentThreadId: currentThreadId ?? null, + onThreadSwitch: (threadId) => { + router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`); + }, + onNewThread: (threadId) => { + router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`); + }, + }), + [searchSpaceId, currentThreadId, router] + ); + + // Load threads on mount and when searchSpaceId changes + const loadThreads = useCallback(async () => { + setState((prev) => ({ ...prev, isLoading: true })); + const newState = await manager().loadThreads(); + setState(newState); + }, [manager]); + + useEffect(() => { + loadThreads(); + }, [loadThreads]); + + // Handle new thread creation + const handleNewThread = async () => { + await manager().createNewThread(); + await loadThreads(); + }; + + // Handle thread actions + const handleArchive = async (threadId: number) => { + const success = await manager().archiveThread(threadId); + if (success) await loadThreads(); + }; + + const handleUnarchive = async (threadId: number) => { + const success = await manager().unarchiveThread(threadId); + if (success) await loadThreads(); + }; + + const handleDelete = async (threadId: number) => { + const success = await manager().deleteThread(threadId); + if (success) { + await loadThreads(); + // If we deleted the current thread, redirect to new chat + if (threadId === currentThreadId) { + router.push(`/dashboard/${searchSpaceId}/new-chat`); + } + } + }; + + const handleSwitchToThread = (threadId: number) => { + manager().switchToThread(threadId); + }; + + const displayedThreads = showArchived ? state.archivedThreads : state.threads; + + if (state.isLoading) { + return ( +
+
+ Loading threads... +
+
+ ); + } + + if (state.error) { + return ( +
+
+ {state.error} + +
+
+ ); + } + + return ( +
+ {/* Header with New Chat button */} +
+

Conversations

+ +
+ + {/* Tab toggle for active/archived */} +
+ + +
+ + {/* Thread list */} +
+ {displayedThreads.length === 0 ? ( +
+ +

+ {showArchived ? "No archived conversations" : "No conversations yet"} +

+ {!showArchived && ( + + )} +
+ ) : ( +
+ {displayedThreads.map((thread) => ( + handleSwitchToThread(thread.id)} + onArchive={() => handleArchive(thread.id)} + onUnarchive={() => handleUnarchive(thread.id)} + onDelete={() => handleDelete(thread.id)} + /> + ))} +
+ )} +
+
+ ); +} + +interface ThreadListItemComponentProps { + thread: ThreadListItem; + isActive: boolean; + isArchived: boolean; + onClick: () => void; + onArchive: () => void; + onUnarchive: () => void; + onDelete: () => void; +} + +function ThreadListItemComponent({ + thread, + isActive, + isArchived, + onClick, + onArchive, + onUnarchive, + onDelete, +}: ThreadListItemComponentProps) { + return ( +
{ + if (e.key === "Enter" || e.key === " ") onClick(); + }} + role="button" + tabIndex={0} + > + +
+

{thread.title || "New Chat"}

+

+ {formatRelativeTime(new Date(thread.updatedAt))} +

+
+ + + + + + {isArchived ? ( + + + Unarchive + + ) : ( + + + Archive + + )} + + + + Delete + + + +
+ ); +} + +/** + * Format a date as relative time (e.g., "2 hours ago", "Yesterday") + */ +function formatRelativeTime(date: Date): string { + const now = new Date(); + const diffMs = now.getTime() - date.getTime(); + const diffSecs = Math.floor(diffMs / 1000); + const diffMins = Math.floor(diffSecs / 60); + const diffHours = Math.floor(diffMins / 60); + const diffDays = Math.floor(diffHours / 24); + + if (diffSecs < 60) return "Just now"; + if (diffMins < 60) return `${diffMins} min${diffMins === 1 ? "" : "s"} ago`; + if (diffHours < 24) return `${diffHours} hour${diffHours === 1 ? "" : "s"} ago`; + if (diffDays === 1) return "Yesterday"; + if (diffDays < 7) return `${diffDays} days ago`; + + return date.toLocaleDateString(); +} diff --git a/surfsense_web/lib/chat/new-chat-transport.ts b/surfsense_web/lib/chat/new-chat-transport.ts deleted file mode 100644 index 0bde3066d..000000000 --- a/surfsense_web/lib/chat/new-chat-transport.ts +++ /dev/null @@ -1,345 +0,0 @@ -/** - * Custom ChatModelAdapter for the new-chat feature using LocalRuntime. - * Connects directly to the FastAPI backend using the Vercel AI SDK Data Stream Protocol. - */ - -import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react"; -import { toast } from "sonner"; -import { getBearerToken } from "@/lib/auth-utils"; -import { - isPodcastGenerating, - looksLikePodcastRequest, - setActivePodcastTaskId, -} from "@/lib/chat/podcast-state"; - -interface NewChatAdapterConfig { - searchSpaceId: number; - chatId: number; -} - -interface ChatMessageForBackend { - role: "user" | "assistant"; - content: string; -} - -/** - * Converts assistant-ui messages to a simple format for the backend - */ -function convertMessagesToBackendFormat( - messages: ChatModelRunOptions["messages"] -): ChatMessageForBackend[] { - return messages - .filter((m) => m.role === "user" || m.role === "assistant") - .map((m) => { - // Extract text content from the message parts - let content = ""; - for (const part of m.content) { - if (part.type === "text") { - content += part.text; - } - } - return { - role: m.role as "user" | "assistant", - content: content.trim(), - }; - }) - .filter((m) => m.content.length > 0); // Filter out empty messages -} - -/** - * Represents an in-progress or completed tool call - */ -interface ToolCallState { - toolCallId: string; - toolName: string; - args: Record; - result?: unknown; -} - -/** - * Tools that should render custom UI in the chat. - * Other tools (like search_knowledge_base) will be hidden from the UI. - */ -const TOOLS_WITH_UI = new Set(["generate_podcast"]); - -/** - * Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint. - * - * The backend expects: - * - POST /api/v1/new_chat - * - Body: { chat_id: number, user_query: string, search_space_id: number, messages: [...] } - * - Returns: SSE stream with Vercel AI SDK Data Stream Protocol - */ -export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAdapter { - const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - - return { - async *run({ messages, abortSignal }: ChatModelRunOptions) { - // Get the last user message - const lastUserMessage = messages.filter((m) => m.role === "user").pop(); - - if (!lastUserMessage) { - throw new Error("No user message found"); - } - - // Extract text content from the last user message - let userQuery = ""; - for (const part of lastUserMessage.content) { - if (part.type === "text") { - userQuery += part.text; - } - } - - if (!userQuery.trim()) { - throw new Error("User query cannot be empty"); - } - - // Check if user is requesting a podcast while one is already generating - if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) { - toast.warning("A podcast is already being generated. Please wait for it to complete."); - // Return a message telling the user to wait - yield { - content: [ - { - type: "text", - text: "A podcast is already being generated. Please wait for it to complete before requesting another one.", - }, - ], - }; - return; - } - - const token = getBearerToken(); - if (!token) { - throw new Error("Not authenticated. Please log in again."); - } - - // Convert all messages to backend format for chat history - const messageHistory = convertMessagesToBackendFormat(messages); - - const response = await fetch(`${backendUrl}/api/v1/new_chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - chat_id: config.chatId, - user_query: userQuery.trim(), - search_space_id: config.searchSpaceId, - messages: messageHistory, - }), - signal: abortSignal, - }); - - if (!response.ok) { - const errorText = await response.text().catch(() => "Unknown error"); - throw new Error(`Backend error (${response.status}): ${errorText}`); - } - - if (!response.body) { - throw new Error("No response body"); - } - - // Parse the SSE stream (Vercel AI SDK Data Stream Protocol) - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ""; - let accumulatedText = ""; - - // Track tool calls by their ID - const toolCalls = new Map(); - - /** - * Build the content array with text and tool calls. - * Only includes tools that have custom UI (defined in TOOLS_WITH_UI). - */ - function buildContent() { - const content: Array< - | { type: "text"; text: string } - | { type: "tool-call"; toolCallId: string; toolName: string; args: Record; result?: unknown } - > = []; - - // Add text content if any - if (accumulatedText) { - content.push({ type: "text" as const, text: accumulatedText }); - } - - // Only add tool calls that have custom UI registered - // Other tools (like search_knowledge_base) are hidden from the UI - for (const toolCall of toolCalls.values()) { - if (TOOLS_WITH_UI.has(toolCall.toolName)) { - content.push({ - type: "tool-call" as const, - toolCallId: toolCall.toolCallId, - toolName: toolCall.toolName, - args: toolCall.args, - result: toolCall.result, - }); - } - } - - return content; - } - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } - - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - // Split on double newlines (handle both \n\n and \r\n\r\n) - const events = buffer.split(/\r?\n\r?\n/); - buffer = events.pop() || ""; - - for (const event of events) { - // Each event can have multiple lines, find the data line - const lines = event.split(/\r?\n/); - for (const line of lines) { - if (!line.startsWith("data: ")) continue; - - const data = line.slice(6).trim(); // Remove "data: " prefix - - // Handle [DONE] marker - if (data === "[DONE]") { - continue; - } - - if (!data) continue; - - try { - const parsed = JSON.parse(data); - - // Handle different message types from the Data Stream Protocol - switch (parsed.type) { - case "text-delta": - accumulatedText += parsed.delta; - yield { content: buildContent() }; - break; - - case "tool-input-start": { - // Tool call is starting - create a new tool call entry - const { toolCallId, toolName } = parsed; - toolCalls.set(toolCallId, { - toolCallId, - toolName, - args: {}, - }); - // Yield to show tool is starting (running state) - yield { content: buildContent() }; - break; - } - - case "tool-input-available": { - // Tool input is complete - update the args - const { toolCallId, toolName, input } = parsed; - const existing = toolCalls.get(toolCallId); - if (existing) { - existing.args = input || {}; - } else { - // Create new entry if we missed tool-input-start - toolCalls.set(toolCallId, { - toolCallId, - toolName, - args: input || {}, - }); - } - yield { content: buildContent() }; - break; - } - - case "tool-output-available": { - // Tool execution is complete - add the result - const { toolCallId, output } = parsed; - const existing = toolCalls.get(toolCallId); - if (existing) { - existing.result = output; - - // If this is a podcast tool with status="processing", set the state immediately - // This ensures subsequent podcast requests are intercepted - if ( - existing.toolName === "generate_podcast" && - output && - typeof output === "object" && - "status" in output && - output.status === "processing" && - "task_id" in output && - typeof output.task_id === "string" - ) { - setActivePodcastTaskId(output.task_id); - } - } - yield { content: buildContent() }; - break; - } - - case "error": - throw new Error(parsed.errorText || "Unknown error from server"); - - // Other types like text-start, text-end, start-step, finish-step, etc. - // are handled implicitly - default: - break; - } - } catch (e) { - // Skip non-JSON lines - if (e instanceof SyntaxError) { - continue; - } - throw e; - } - } - } - } - - // Handle any remaining buffer - if (buffer.trim()) { - const lines = buffer.split(/\r?\n/); - for (const line of lines) { - if (line.startsWith("data: ")) { - const data = line.slice(6).trim(); - if (data && data !== "[DONE]") { - try { - const parsed = JSON.parse(data); - if (parsed.type === "text-delta") { - accumulatedText += parsed.delta; - yield { content: buildContent() }; - } else if (parsed.type === "tool-output-available") { - const { toolCallId, output } = parsed; - const existing = toolCalls.get(toolCallId); - if (existing) { - existing.result = output; - - // Set podcast state if processing - if ( - existing.toolName === "generate_podcast" && - output && - typeof output === "object" && - "status" in output && - output.status === "processing" && - "task_id" in output && - typeof output.task_id === "string" - ) { - setActivePodcastTaskId(output.task_id); - } - } - yield { content: buildContent() }; - } - } catch { - // Ignore parse errors - } - } - } - } - } - } finally { - reader.releaseLock(); - } - }, - }; -} - diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts new file mode 100644 index 000000000..f25c47c87 --- /dev/null +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -0,0 +1,223 @@ +/** + * Thread persistence utilities for the new chat feature. + * Provides API functions and thread list management. + */ + +import { baseApiService } from "@/lib/apis/base-api.service"; + +// ============================================================================= +// Types matching backend schemas +// ============================================================================= + +export interface ThreadRecord { + id: number; + title: string; + archived: boolean; + search_space_id: number; + created_at: string; + updated_at: string; +} + +export interface MessageRecord { + id: number; + thread_id: number; + role: "user" | "assistant" | "system"; + content: unknown; + created_at: string; +} + +export interface ThreadListResponse { + threads: ThreadListItem[]; + archived_threads: ThreadListItem[]; +} + +export interface ThreadListItem { + id: number; + title: string; + archived: boolean; + createdAt: string; + updatedAt: string; +} + +export interface ThreadHistoryLoadResponse { + messages: MessageRecord[]; +} + +// ============================================================================= +// API Service Functions +// ============================================================================= + +/** + * Fetch list of threads for a search space + */ +export async function fetchThreads( + searchSpaceId: number +): Promise { + return baseApiService.get( + `/api/v1/threads?search_space_id=${searchSpaceId}` + ); +} + +/** + * Create a new thread + */ +export async function createThread( + searchSpaceId: number, + title = "New Chat" +): Promise { + return baseApiService.post("/api/v1/threads", undefined, { + body: { + title, + archived: false, + search_space_id: searchSpaceId, + }, + }); +} + +/** + * Get thread messages + */ +export async function getThreadMessages( + threadId: number +): Promise { + return baseApiService.get( + `/api/v1/threads/${threadId}` + ); +} + +/** + * Append a message to a thread + */ +export async function appendMessage( + threadId: number, + message: { role: "user" | "assistant" | "system"; content: unknown } +): Promise { + return baseApiService.post( + `/api/v1/threads/${threadId}/messages`, + undefined, + { body: message } + ); +} + +/** + * Update thread (rename, archive) + */ +export async function updateThread( + threadId: number, + updates: { title?: string; archived?: boolean } +): Promise { + return baseApiService.put( + `/api/v1/threads/${threadId}`, + undefined, + { body: updates } + ); +} + +/** + * Delete a thread + */ +export async function deleteThread(threadId: number): Promise { + await baseApiService.delete(`/api/v1/threads/${threadId}`); +} + +// ============================================================================= +// Thread List Manager (for thread list sidebar) +// ============================================================================= + +export interface ThreadListAdapterConfig { + searchSpaceId: number; + currentThreadId: number | null; + onThreadSwitch: (threadId: number) => void; + onNewThread: (threadId: number) => void; +} + +export interface ThreadListState { + threads: ThreadListItem[]; + archivedThreads: ThreadListItem[]; + isLoading: boolean; + error: string | null; +} + +/** + * Creates a thread list management object. + * This provides methods to manage the thread list for the sidebar. + */ +export function createThreadListManager(config: ThreadListAdapterConfig) { + return { + async loadThreads(): Promise { + try { + const response = await fetchThreads(config.searchSpaceId); + return { + threads: response.threads, + archivedThreads: response.archived_threads, + isLoading: false, + error: null, + }; + } catch (error) { + console.error("[ThreadListManager] Failed to load threads:", error); + return { + threads: [], + archivedThreads: [], + isLoading: false, + error: + error instanceof Error ? error.message : "Failed to load threads", + }; + } + }, + + async createNewThread(title = "New Chat"): Promise { + try { + const thread = await createThread(config.searchSpaceId, title); + config.onNewThread(thread.id); + return thread.id; + } catch (error) { + console.error("[ThreadListManager] Failed to create thread:", error); + return null; + } + }, + + switchToThread(threadId: number) { + config.onThreadSwitch(threadId); + }, + + async renameThread(threadId: number, newTitle: string): Promise { + try { + await updateThread(threadId, { title: newTitle }); + return true; + } catch (error) { + console.error("[ThreadListManager] Failed to rename thread:", error); + return false; + } + }, + + async archiveThread(threadId: number): Promise { + try { + await updateThread(threadId, { archived: true }); + return true; + } catch (error) { + console.error("[ThreadListManager] Failed to archive thread:", error); + return false; + } + }, + + async unarchiveThread(threadId: number): Promise { + try { + await updateThread(threadId, { archived: false }); + return true; + } catch (error) { + console.error("[ThreadListManager] Failed to unarchive thread:", error); + return false; + } + }, + + async deleteThread(threadId: number): Promise { + try { + await deleteThread(threadId); + return true; + } catch (error) { + console.error("[ThreadListManager] Failed to delete thread:", error); + return false; + } + }, + }; +}