mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
feat: implement new chat feature with message persistence and UI integration
This commit is contained in:
parent
f115980d2b
commit
0c3574d049
16 changed files with 1814 additions and 397 deletions
|
|
@ -92,4 +92,3 @@ async def close_checkpointer() -> None:
|
|||
_checkpointer_context = None
|
||||
_checkpointer_initialized = False
|
||||
print("[Checkpointer] PostgreSQL connection closed")
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,9 @@ def create_generate_podcast_tool(
|
|||
# Check if a podcast is already being generated for this search space
|
||||
active_task_id = get_active_podcast_task(search_space_id)
|
||||
if active_task_id:
|
||||
print(f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}")
|
||||
print(
|
||||
f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}"
|
||||
)
|
||||
return {
|
||||
"status": "already_generating",
|
||||
"task_id": active_task_id,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
|
||||
from app.agents.new_chat.checkpointer import close_checkpointer, setup_checkpointer_tables
|
||||
from app.agents.new_chat.checkpointer import (
|
||||
close_checkpointer,
|
||||
setup_checkpointer_tables,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
|
|
|
|||
|
|
@ -332,6 +332,75 @@ class Chat(BaseModel, TimestampMixin):
|
|||
search_space = relationship("SearchSpace", back_populates="chats")
|
||||
|
||||
|
||||
class NewChatMessageRole(str, Enum):
|
||||
"""Role enum for new chat messages."""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class NewChatThread(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Thread model for the new chat feature using assistant-ui.
|
||||
Each thread represents a conversation with message history.
|
||||
LangGraph checkpointer uses thread_id for state persistence.
|
||||
"""
|
||||
|
||||
__tablename__ = "new_chat_threads"
|
||||
|
||||
title = Column(String(500), nullable=False, default="New Chat", index=True)
|
||||
archived = Column(Boolean, nullable=False, default=False)
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Foreign keys
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||
messages = relationship(
|
||||
"NewChatMessage",
|
||||
back_populates="thread",
|
||||
order_by="NewChatMessage.created_at",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class NewChatMessage(BaseModel, TimestampMixin):
|
||||
"""
|
||||
Message model for the new chat feature.
|
||||
Stores individual messages in assistant-ui format.
|
||||
"""
|
||||
|
||||
__tablename__ = "new_chat_messages"
|
||||
|
||||
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
||||
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
||||
content = Column(JSONB, nullable=False)
|
||||
|
||||
# Foreign key to thread
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship
|
||||
thread = relationship("NewChatThread", back_populates="messages")
|
||||
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
__tablename__ = "documents"
|
||||
|
||||
|
|
@ -435,6 +504,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="Chat.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
new_chat_threads = relationship(
|
||||
"NewChatThread",
|
||||
back_populates="search_space",
|
||||
order_by="NewChatThread.updated_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
logs = relationship(
|
||||
"Log",
|
||||
back_populates="search_space",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .google_gmail_add_connector_route import (
|
|||
from .llm_config_routes import router as llm_config_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
from .notes_routes import router as notes_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
|
|
@ -30,6 +31,7 @@ router.include_router(documents_router)
|
|||
router.include_router(notes_router)
|
||||
router.include_router(podcasts_router)
|
||||
router.include_router(chats_router)
|
||||
router.include_router(new_chat_router) # New chat with assistant-ui persistence
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(google_calendar_add_connector_router)
|
||||
router.include_router(google_gmail_add_connector_router)
|
||||
|
|
|
|||
597
surfsense_backend/app/routes/new_chat_routes.py
Normal file
597
surfsense_backend/app/routes/new_chat_routes.py
Normal file
|
|
@ -0,0 +1,597 @@
|
|||
"""
|
||||
Routes for the new chat feature with assistant-ui integration.
|
||||
|
||||
These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
||||
- GET /threads - List threads for sidebar (ThreadListPrimitive)
|
||||
- POST /threads - Create a new thread
|
||||
- GET /threads/{thread_id} - Get thread with messages (load)
|
||||
- PUT /threads/{thread_id} - Update thread (rename, archive)
|
||||
- DELETE /threads/{thread_id} - Delete thread
|
||||
- POST /threads/{thread_id}/messages - Append message
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import (
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
NewChatMessageAppend,
|
||||
NewChatMessageRead,
|
||||
NewChatThreadCreate,
|
||||
NewChatThreadRead,
|
||||
NewChatThreadUpdate,
|
||||
NewChatThreadWithMessages,
|
||||
ThreadHistoryLoadResponse,
|
||||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Thread Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/threads", response_model=ThreadListResponse)
|
||||
async def list_threads(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List all threads for the current user in a search space.
|
||||
Returns threads and archived_threads for ThreadListPrimitive.
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Get all threads for this user in this search space
|
||||
query = (
|
||||
select(NewChatThread)
|
||||
.filter(
|
||||
NewChatThread.search_space_id == search_space_id,
|
||||
NewChatThread.user_id == user.id,
|
||||
)
|
||||
.order_by(NewChatThread.updated_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
all_threads = result.scalars().all()
|
||||
|
||||
# Separate active and archived threads
|
||||
threads = []
|
||||
archived_threads = []
|
||||
|
||||
for thread in all_threads:
|
||||
item = ThreadListItem(
|
||||
id=thread.id,
|
||||
title=thread.title,
|
||||
archived=thread.archived,
|
||||
createdAt=thread.created_at,
|
||||
updatedAt=thread.updated_at,
|
||||
)
|
||||
if thread.archived:
|
||||
archived_threads.append(item)
|
||||
else:
|
||||
threads.append(item)
|
||||
|
||||
return ThreadListResponse(threads=threads, archived_threads=archived_threads)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching threads: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.post("/threads", response_model=NewChatThreadRead)
|
||||
async def create_thread(
|
||||
thread: NewChatThreadCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new chat thread.
|
||||
|
||||
Requires CHATS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to create chats in this search space",
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
db_thread = NewChatThread(
|
||||
title=thread.title,
|
||||
archived=thread.archived,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(db_thread)
|
||||
await session.commit()
|
||||
await session.refresh(db_thread)
|
||||
return db_thread
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while creating the thread: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse)
|
||||
async def get_thread_messages(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a thread with all its messages.
|
||||
This is used by ThreadHistoryAdapter.load() to restore conversation.
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Get thread with messages
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
# Check permission and ownership
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Ensure user owns this thread
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Return messages in the format expected by assistant-ui
|
||||
messages = [
|
||||
NewChatMessageRead(
|
||||
id=msg.id,
|
||||
thread_id=msg.thread_id,
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
for msg in thread.messages
|
||||
]
|
||||
|
||||
return ThreadHistoryLoadResponse(messages=messages)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages)
|
||||
async def get_thread_full(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get full thread details with all messages.
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
return thread
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/threads/{thread_id}", response_model=NewChatThreadRead)
|
||||
async def update_thread(
|
||||
thread_id: int,
|
||||
thread_update: NewChatThreadUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update a thread (title, archived status).
|
||||
Used for renaming and archiving threads.
|
||||
|
||||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
db_thread = result.scalars().first()
|
||||
|
||||
if not db_thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
if db_thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Update fields
|
||||
update_data = thread_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_thread, key, value)
|
||||
|
||||
db_thread.updated_at = datetime.now(UTC)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_thread)
|
||||
return db_thread
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while updating the thread: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/threads/{thread_id}", response_model=dict)
|
||||
async def delete_thread(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a thread and all its messages.
|
||||
|
||||
Requires CHATS_DELETE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
db_thread = result.scalars().first()
|
||||
|
||||
if not db_thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_DELETE.value,
|
||||
"You don't have permission to delete chats in this search space",
|
||||
)
|
||||
|
||||
if db_thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
await session.delete(db_thread)
|
||||
await session.commit()
|
||||
return {"message": "Thread deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete thread due to existing dependencies."
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while deleting the thread: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Message Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead)
|
||||
async def append_message(
|
||||
thread_id: int,
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Append a message to a thread.
|
||||
This is used by ThreadHistoryAdapter.append() to persist messages.
|
||||
|
||||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Parse raw body - extract only role and content, ignoring extra fields
|
||||
raw_body = await request.json()
|
||||
role = raw_body.get("role")
|
||||
content = raw_body.get("content")
|
||||
|
||||
if not role:
|
||||
raise HTTPException(status_code=400, detail="Missing required field: role")
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field: content"
|
||||
)
|
||||
|
||||
# Create message object manually
|
||||
message = NewChatMessageAppend(role=role, content=content)
|
||||
# Get thread
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Convert string role to enum
|
||||
role_str = (
|
||||
message.role.lower() if isinstance(message.role, str) else message.role
|
||||
)
|
||||
try:
|
||||
message_role = NewChatMessageRole(role_str)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid role: {message.role}. Must be 'user', 'assistant', or 'system'.",
|
||||
) from None
|
||||
|
||||
# Create message
|
||||
db_message = NewChatMessage(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=message.content,
|
||||
)
|
||||
session.add(db_message)
|
||||
|
||||
# Update thread's updated_at timestamp
|
||||
thread.updated_at = datetime.now(UTC)
|
||||
|
||||
# Auto-generate title from first user message if title is still default
|
||||
if thread.title == "New Chat" and role_str == "user":
|
||||
# Extract text content for title
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
title_text = content
|
||||
elif isinstance(content, list):
|
||||
# Find first text content
|
||||
title_text = ""
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
title_text = part.get("text", "")
|
||||
break
|
||||
elif isinstance(part, str):
|
||||
title_text = part
|
||||
break
|
||||
else:
|
||||
title_text = str(content)
|
||||
|
||||
# Truncate title
|
||||
if title_text:
|
||||
thread.title = title_text[:100] + (
|
||||
"..." if len(title_text) > 100 else ""
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_message)
|
||||
return db_message
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while appending the message: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead])
|
||||
async def list_messages(
|
||||
thread_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List messages in a thread with pagination.
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Verify thread exists and user has access
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
if thread.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this thread"
|
||||
)
|
||||
|
||||
# Get messages
|
||||
query = (
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching messages: {e!s}",
|
||||
) from None
|
||||
|
|
@ -21,6 +21,18 @@ from .documents import (
|
|||
)
|
||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .new_chat import (
|
||||
NewChatMessageAppend,
|
||||
NewChatMessageCreate,
|
||||
NewChatMessageRead,
|
||||
NewChatThreadCreate,
|
||||
NewChatThreadRead,
|
||||
NewChatThreadUpdate,
|
||||
NewChatThreadWithMessages,
|
||||
ThreadHistoryLoadResponse,
|
||||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from .podcasts import (
|
||||
PodcastBase,
|
||||
PodcastCreate,
|
||||
|
|
@ -98,7 +110,15 @@ __all__ = [
|
|||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
# New chat schemas (assistant-ui integration)
|
||||
"NewChatMessageAppend",
|
||||
"NewChatMessageCreate",
|
||||
"NewChatMessageRead",
|
||||
"NewChatRequest",
|
||||
"NewChatThreadCreate",
|
||||
"NewChatThreadRead",
|
||||
"NewChatThreadUpdate",
|
||||
"NewChatThreadWithMessages",
|
||||
"PaginatedResponse",
|
||||
"PermissionInfo",
|
||||
"PermissionsListResponse",
|
||||
|
|
@ -119,6 +139,9 @@ __all__ = [
|
|||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceWithStats",
|
||||
"ThreadHistoryLoadResponse",
|
||||
"ThreadListItem",
|
||||
"ThreadListResponse",
|
||||
"TimestampModel",
|
||||
"UserCreate",
|
||||
"UserRead",
|
||||
|
|
|
|||
129
surfsense_backend/app/schemas/new_chat.py
Normal file
129
surfsense_backend/app/schemas/new_chat.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
Pydantic schemas for the new chat feature with assistant-ui integration.
|
||||
|
||||
These schemas follow the assistant-ui ThreadHistoryAdapter pattern:
|
||||
- ThreadRecord: id, title, archived, createdAt, updatedAt
|
||||
- MessageRecord: id, threadId, role, content, createdAt
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.db import NewChatMessageRole
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
# =============================================================================
|
||||
# Message Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NewChatMessageBase(BaseModel):
|
||||
"""Base schema for new chat messages."""
|
||||
|
||||
role: NewChatMessageRole
|
||||
content: Any # JSONB content - can be text, tool calls, etc.
|
||||
|
||||
|
||||
class NewChatMessageCreate(NewChatMessageBase):
|
||||
"""Schema for creating a new message."""
|
||||
|
||||
thread_id: int
|
||||
|
||||
|
||||
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||
"""Schema for reading a message."""
|
||||
|
||||
thread_id: int
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class NewChatMessageAppend(BaseModel):
|
||||
"""
|
||||
Schema for appending a message via the history adapter.
|
||||
This is the format assistant-ui sends when calling append().
|
||||
"""
|
||||
|
||||
role: str # Accept string and validate in route handler
|
||||
content: Any
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Thread Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NewChatThreadBase(BaseModel):
|
||||
"""Base schema for new chat threads."""
|
||||
|
||||
title: str = Field(default="New Chat", max_length=500)
|
||||
archived: bool = False
|
||||
|
||||
|
||||
class NewChatThreadCreate(NewChatThreadBase):
|
||||
"""Schema for creating a new thread."""
|
||||
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class NewChatThreadUpdate(BaseModel):
|
||||
"""Schema for updating a thread."""
|
||||
|
||||
title: str | None = None
|
||||
archived: bool | None = None
|
||||
|
||||
|
||||
class NewChatThreadRead(NewChatThreadBase, IDModel):
|
||||
"""
|
||||
Schema for reading a thread (matches assistant-ui ThreadRecord).
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class NewChatThreadWithMessages(NewChatThreadRead):
|
||||
"""Schema for reading a thread with its messages."""
|
||||
|
||||
messages: list[NewChatMessageRead] = []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# History Adapter Response Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ThreadHistoryLoadResponse(BaseModel):
|
||||
"""
|
||||
Response format for the ThreadHistoryAdapter.load() method.
|
||||
Returns messages array for the current thread.
|
||||
"""
|
||||
|
||||
messages: list[NewChatMessageRead]
|
||||
|
||||
|
||||
class ThreadListItem(BaseModel):
|
||||
"""
|
||||
Thread list item for sidebar display.
|
||||
Matches assistant-ui ThreadListPrimitive expected format.
|
||||
"""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
archived: bool
|
||||
created_at: datetime = Field(alias="createdAt")
|
||||
updated_at: datetime = Field(alias="updatedAt")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True, populate_by_name=True)
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel):
|
||||
"""Response containing list of threads for the sidebar."""
|
||||
|
||||
threads: list[ThreadListItem]
|
||||
archived_threads: list[ThreadListItem]
|
||||
|
|
@ -7,14 +7,13 @@ import sys
|
|||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
|
||||
# Import for content-based podcast (new-chat)
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Podcast
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -201,15 +200,16 @@ async def _generate_content_podcast(
|
|||
serializable_transcript = []
|
||||
for entry in podcast_transcript:
|
||||
if hasattr(entry, "speaker_id"):
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.speaker_id,
|
||||
"dialog": entry.dialog
|
||||
})
|
||||
serializable_transcript.append(
|
||||
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
||||
)
|
||||
else:
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.get("speaker_id", 0),
|
||||
"dialog": entry.get("dialog", "")
|
||||
})
|
||||
serializable_transcript.append(
|
||||
{
|
||||
"speaker_id": entry.get("speaker_id", 0),
|
||||
"dialog": entry.get("dialog", ""),
|
||||
}
|
||||
)
|
||||
|
||||
# Save podcast to database
|
||||
podcast = Podcast(
|
||||
|
|
|
|||
|
|
@ -9,12 +9,15 @@ import json
|
|||
from collections.abc import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
||||
from app.agents.new_chat.llm_config import (
|
||||
create_chat_litellm_from_config,
|
||||
load_llm_config_from_yaml,
|
||||
)
|
||||
from app.schemas.chats import ChatMessage
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
|
@ -101,7 +104,7 @@ async def stream_new_chat(
|
|||
# elif msg.role == "assistant":
|
||||
# langchain_messages.append(AIMessage(content=msg.content))
|
||||
# else:
|
||||
# Fallback: just use the current user query
|
||||
# Fallback: just use the current user query
|
||||
langchain_messages.append(HumanMessage(content=user_query))
|
||||
|
||||
input_state = {
|
||||
|
|
@ -219,7 +222,9 @@ async def stream_new_chat(
|
|||
elif isinstance(raw_output, dict):
|
||||
tool_output = raw_output
|
||||
else:
|
||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
||||
tool_output = {
|
||||
"result": str(raw_output) if raw_output else "completed"
|
||||
}
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
|
||||
|
|
@ -228,16 +233,25 @@ async def stream_new_chat(
|
|||
# Stream the full podcast result so frontend can render the audio player
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
tool_output if isinstance(tool_output, dict) else {"result": tool_output},
|
||||
tool_output
|
||||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
# Send appropriate terminal message based on status
|
||||
if isinstance(tool_output, dict) and tool_output.get("status") == "success":
|
||||
if (
|
||||
isinstance(tool_output, dict)
|
||||
and tool_output.get("status") == "success"
|
||||
):
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
||||
"success",
|
||||
)
|
||||
else:
|
||||
error_msg = tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error"
|
||||
error_msg = (
|
||||
tool_output.get("error", "Unknown error")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast generation failed: {error_msg}",
|
||||
"error",
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ export function DashboardClientLayout({
|
|||
// Check if we're on the researcher page
|
||||
const isResearcherPage = pathname?.includes("/researcher");
|
||||
|
||||
// Check if we're on the new-chat page (uses separate thread persistence)
|
||||
const isNewChatPage = pathname?.includes("/new-chat");
|
||||
|
||||
// Show indicator when chat becomes active and panel is closed
|
||||
useEffect(() => {
|
||||
if (activeChatId && !isChatPannelOpen) {
|
||||
|
|
@ -151,6 +154,12 @@ export function DashboardClientLayout({
|
|||
}, [search_space_id]);
|
||||
|
||||
useEffect(() => {
|
||||
// Skip setting activeChatIdAtom on new-chat page (uses separate thread persistence)
|
||||
if (isNewChatPage) {
|
||||
setActiveChatIdState(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const activeChatId =
|
||||
typeof chat_id === "string"
|
||||
? chat_id
|
||||
|
|
@ -159,7 +168,7 @@ export function DashboardClientLayout({
|
|||
: "";
|
||||
if (!activeChatId) return;
|
||||
setActiveChatIdState(activeChatId);
|
||||
}, [chat_id, search_space_id]);
|
||||
}, [chat_id, search_space_id, isNewChatPage]);
|
||||
|
||||
// Show loading screen while checking onboarding status (only on first load)
|
||||
if (!hasCheckedOnboarding && (loading || accessLoading) && !isOnboardingPage) {
|
||||
|
|
|
|||
|
|
@ -1,23 +1,73 @@
|
|||
"use client";
|
||||
|
||||
import { AssistantRuntimeProvider, useLocalRuntime } from "@assistant-ui/react";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useMemo } from "react";
|
||||
import {
|
||||
AssistantRuntimeProvider,
|
||||
useExternalStoreRuntime,
|
||||
type ThreadMessageLike,
|
||||
} from "@assistant-ui/react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { Thread } from "@/components/assistant-ui/thread";
|
||||
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
||||
import { createNewChatAdapter } from "@/lib/chat/new-chat-transport";
|
||||
import {
|
||||
createThread,
|
||||
getThreadMessages,
|
||||
appendMessage,
|
||||
type MessageRecord,
|
||||
} from "@/lib/chat/thread-persistence";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
isPodcastGenerating,
|
||||
looksLikePodcastRequest,
|
||||
setActivePodcastTaskId,
|
||||
} from "@/lib/chat/podcast-state";
|
||||
|
||||
/**
|
||||
* Convert backend message to assistant-ui ThreadMessageLike format
|
||||
*/
|
||||
function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
||||
let content: ThreadMessageLike["content"];
|
||||
|
||||
if (typeof msg.content === "string") {
|
||||
content = [{ type: "text", text: msg.content }];
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
content = msg.content as ThreadMessageLike["content"];
|
||||
} else {
|
||||
content = [{ type: "text", text: String(msg.content) }];
|
||||
}
|
||||
|
||||
return {
|
||||
id: `msg-${msg.id}`,
|
||||
role: msg.role,
|
||||
content,
|
||||
createdAt: new Date(msg.created_at),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Tools that should render custom UI in the chat.
|
||||
*/
|
||||
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
|
||||
|
||||
export default function NewChatPage() {
|
||||
const params = useParams();
|
||||
const router = useRouter();
|
||||
const [isInitializing, setIsInitializing] = useState(true);
|
||||
const [threadId, setThreadId] = useState<number | null>(null);
|
||||
const [messages, setMessages] = useState<ThreadMessageLike[]>([]);
|
||||
const [isRunning, setIsRunning] = useState(false);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// Extract search_space_id and chat_id from URL params
|
||||
// Extract search_space_id from URL params
|
||||
const searchSpaceId = useMemo(() => {
|
||||
const id = params.search_space_id;
|
||||
const parsed = typeof id === "string" ? Number.parseInt(id, 10) : 0;
|
||||
return Number.isNaN(parsed) ? 0 : parsed;
|
||||
}, [params.search_space_id]);
|
||||
|
||||
const chatId = useMemo(() => {
|
||||
// Extract chat_id from URL params
|
||||
const urlChatId = useMemo(() => {
|
||||
const id = params.chat_id;
|
||||
let parsed = 0;
|
||||
if (Array.isArray(id) && id.length > 0) {
|
||||
|
|
@ -28,18 +78,368 @@ export default function NewChatPage() {
|
|||
return Number.isNaN(parsed) ? 0 : parsed;
|
||||
}, [params.chat_id]);
|
||||
|
||||
// Create the adapter with the extracted params
|
||||
const adapter = useMemo(
|
||||
() => createNewChatAdapter({ searchSpaceId, chatId }),
|
||||
[searchSpaceId, chatId]
|
||||
// Initialize thread and load messages
|
||||
const initializeThread = useCallback(async () => {
|
||||
setIsInitializing(true);
|
||||
|
||||
try {
|
||||
if (urlChatId > 0) {
|
||||
// Thread exists - load messages
|
||||
setThreadId(urlChatId);
|
||||
const response = await getThreadMessages(urlChatId);
|
||||
if (response.messages && response.messages.length > 0) {
|
||||
const loadedMessages = response.messages.map(convertToThreadMessage);
|
||||
setMessages(loadedMessages);
|
||||
}
|
||||
} else {
|
||||
// Create new thread
|
||||
const newThread = await createThread(searchSpaceId, "New Chat");
|
||||
setThreadId(newThread.id);
|
||||
router.replace(`/dashboard/${searchSpaceId}/new-chat/${newThread.id}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("[NewChatPage] Failed to initialize thread:", error);
|
||||
setThreadId(Date.now());
|
||||
} finally {
|
||||
setIsInitializing(false);
|
||||
}
|
||||
}, [urlChatId, searchSpaceId, router]);
|
||||
|
||||
// Initialize on mount
|
||||
useEffect(() => {
|
||||
initializeThread();
|
||||
}, [initializeThread]);
|
||||
|
||||
// Cancel ongoing request
|
||||
const cancelRun = useCallback(async () => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
setIsRunning(false);
|
||||
}, []);
|
||||
|
||||
// Handle new message from user
|
||||
const onNew = useCallback(
|
||||
async (message: ThreadMessageLike) => {
|
||||
if (!threadId) return;
|
||||
|
||||
// Extract user query text
|
||||
let userQuery = "";
|
||||
for (const part of message.content) {
|
||||
if (typeof part === "object" && part.type === "text" && "text" in part) {
|
||||
userQuery += part.text;
|
||||
}
|
||||
}
|
||||
|
||||
if (!userQuery.trim()) return;
|
||||
|
||||
// Check if podcast is already generating
|
||||
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
|
||||
toast.warning("A podcast is already being generated.");
|
||||
return;
|
||||
}
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
toast.error("Not authenticated. Please log in again.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Add user message to state
|
||||
const userMsgId = `msg-user-${Date.now()}`;
|
||||
const userMessage: ThreadMessageLike = {
|
||||
id: userMsgId,
|
||||
role: "user",
|
||||
content: message.content,
|
||||
createdAt: new Date(),
|
||||
};
|
||||
setMessages((prev) => [...prev, userMessage]);
|
||||
|
||||
// Persist user message (don't await, fire and forget)
|
||||
appendMessage(threadId, {
|
||||
role: "user",
|
||||
content: message.content,
|
||||
}).catch((err) => console.error("Failed to persist user message:", err));
|
||||
|
||||
// Start streaming response
|
||||
setIsRunning(true);
|
||||
const controller = new AbortController();
|
||||
abortControllerRef.current = controller;
|
||||
|
||||
// Prepare assistant message
|
||||
const assistantMsgId = `msg-assistant-${Date.now()}`;
|
||||
let accumulatedText = "";
|
||||
const toolCalls = new Map<
|
||||
string,
|
||||
{
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: unknown;
|
||||
}
|
||||
>();
|
||||
|
||||
// Helper to build content
|
||||
const buildContent = (): ThreadMessageLike["content"] => {
|
||||
const parts: Array<
|
||||
| { type: "text"; text: string }
|
||||
| {
|
||||
type: "tool-call";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: unknown;
|
||||
}
|
||||
> = [];
|
||||
if (accumulatedText) {
|
||||
parts.push({ type: "text", text: accumulatedText });
|
||||
}
|
||||
for (const toolCall of toolCalls.values()) {
|
||||
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
|
||||
parts.push({
|
||||
type: "tool-call",
|
||||
toolCallId: toolCall.toolCallId,
|
||||
toolName: toolCall.toolName,
|
||||
args: toolCall.args,
|
||||
result: toolCall.result,
|
||||
});
|
||||
}
|
||||
}
|
||||
return parts.length > 0
|
||||
? (parts as ThreadMessageLike["content"])
|
||||
: [{ type: "text", text: "" }];
|
||||
};
|
||||
|
||||
// Add placeholder assistant message
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: assistantMsgId,
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
createdAt: new Date(),
|
||||
},
|
||||
]);
|
||||
|
||||
try {
|
||||
const backendUrl =
|
||||
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
|
||||
// Build message history for context
|
||||
const messageHistory = messages
|
||||
.filter((m) => m.role === "user" || m.role === "assistant")
|
||||
.map((m) => {
|
||||
let text = "";
|
||||
for (const part of m.content) {
|
||||
if (
|
||||
typeof part === "object" &&
|
||||
part.type === "text" &&
|
||||
"text" in part
|
||||
) {
|
||||
text += part.text;
|
||||
}
|
||||
}
|
||||
return { role: m.role, content: text };
|
||||
})
|
||||
.filter((m) => m.content.length > 0);
|
||||
|
||||
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
chat_id: threadId,
|
||||
user_query: userQuery.trim(),
|
||||
search_space_id: searchSpaceId,
|
||||
messages: messageHistory,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Backend error: ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const events = buffer.split(/\r?\n\r?\n/);
|
||||
buffer = events.pop() || "";
|
||||
|
||||
for (const event of events) {
|
||||
const lines = event.split(/\r?\n/);
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data: ")) continue;
|
||||
const data = line.slice(6).trim();
|
||||
if (!data || data === "[DONE]") continue;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
accumulatedText += parsed.delta;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContent() }
|
||||
: m
|
||||
)
|
||||
);
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
toolCalls.set(parsed.toolCallId, {
|
||||
toolCallId: parsed.toolCallId,
|
||||
toolName: parsed.toolName,
|
||||
args: {},
|
||||
});
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContent() }
|
||||
: m
|
||||
)
|
||||
);
|
||||
break;
|
||||
|
||||
case "tool-input-available": {
|
||||
const tc = toolCalls.get(parsed.toolCallId);
|
||||
if (tc) tc.args = parsed.input || {};
|
||||
else
|
||||
toolCalls.set(parsed.toolCallId, {
|
||||
toolCallId: parsed.toolCallId,
|
||||
toolName: parsed.toolName,
|
||||
args: parsed.input || {},
|
||||
});
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContent() }
|
||||
: m
|
||||
)
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-output-available": {
|
||||
const tc = toolCalls.get(parsed.toolCallId);
|
||||
if (tc) {
|
||||
tc.result = parsed.output;
|
||||
if (
|
||||
tc.toolName === "generate_podcast" &&
|
||||
parsed.output?.status === "processing" &&
|
||||
parsed.output?.task_id
|
||||
) {
|
||||
setActivePodcastTaskId(parsed.output.task_id);
|
||||
}
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? { ...m, content: buildContent() }
|
||||
: m
|
||||
)
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
case "error":
|
||||
throw new Error(parsed.errorText || "Server error");
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof SyntaxError) continue;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Persist assistant message
|
||||
const finalContent = buildContent();
|
||||
if (accumulatedText || toolCalls.size > 0) {
|
||||
appendMessage(threadId, {
|
||||
role: "assistant",
|
||||
content: finalContent,
|
||||
}).catch((err) =>
|
||||
console.error("Failed to persist assistant message:", err)
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
// Request was cancelled
|
||||
return;
|
||||
}
|
||||
console.error("[NewChatPage] Chat error:", error);
|
||||
toast.error("Failed to get response. Please try again.");
|
||||
// Update assistant message with error
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? {
|
||||
...m,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Sorry, there was an error. Please try again.",
|
||||
},
|
||||
],
|
||||
}
|
||||
: m
|
||||
)
|
||||
);
|
||||
} finally {
|
||||
setIsRunning(false);
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
},
|
||||
[threadId, searchSpaceId, messages]
|
||||
);
|
||||
|
||||
// Use LocalRuntime with our custom adapter
|
||||
const runtime = useLocalRuntime(adapter);
|
||||
// Convert message (pass through since already in correct format)
|
||||
const convertMessage = useCallback(
|
||||
(message: ThreadMessageLike): ThreadMessageLike => message,
|
||||
[]
|
||||
);
|
||||
|
||||
// Create external store runtime
|
||||
const runtime = useExternalStoreRuntime({
|
||||
messages,
|
||||
isRunning,
|
||||
onNew,
|
||||
convertMessage,
|
||||
onCancel: cancelRun,
|
||||
});
|
||||
|
||||
// Show loading state
|
||||
if (isInitializing) {
|
||||
return (
|
||||
<div className="flex h-[calc(100vh-64px)] items-center justify-center">
|
||||
<div className="text-muted-foreground">Loading chat...</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AssistantRuntimeProvider runtime={runtime}>
|
||||
{/* Register tool UI components */}
|
||||
<GeneratePodcastToolUI />
|
||||
<div className="h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden">
|
||||
<Thread />
|
||||
|
|
|
|||
286
surfsense_web/components/assistant-ui/thread-list.tsx
Normal file
286
surfsense_web/components/assistant-ui/thread-list.tsx
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { ArchiveIcon, MessageSquareIcon, PlusIcon, TrashIcon, MoreVerticalIcon, RotateCcwIcon } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import {
|
||||
type ThreadListItem,
|
||||
createThreadListManager,
|
||||
type ThreadListState,
|
||||
} from "@/lib/chat/thread-persistence";
|
||||
|
||||
interface ThreadListProps {
|
||||
searchSpaceId: number;
|
||||
currentThreadId?: number;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ThreadList({ searchSpaceId, currentThreadId, className }: ThreadListProps) {
|
||||
const router = useRouter();
|
||||
const [state, setState] = useState<ThreadListState>({
|
||||
threads: [],
|
||||
archivedThreads: [],
|
||||
isLoading: true,
|
||||
error: null,
|
||||
});
|
||||
const [showArchived, setShowArchived] = useState(false);
|
||||
|
||||
// Create the thread list manager
|
||||
const manager = useCallback(
|
||||
() =>
|
||||
createThreadListManager({
|
||||
searchSpaceId,
|
||||
currentThreadId: currentThreadId ?? null,
|
||||
onThreadSwitch: (threadId) => {
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`);
|
||||
},
|
||||
onNewThread: (threadId) => {
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`);
|
||||
},
|
||||
}),
|
||||
[searchSpaceId, currentThreadId, router]
|
||||
);
|
||||
|
||||
// Load threads on mount and when searchSpaceId changes
|
||||
const loadThreads = useCallback(async () => {
|
||||
setState((prev) => ({ ...prev, isLoading: true }));
|
||||
const newState = await manager().loadThreads();
|
||||
setState(newState);
|
||||
}, [manager]);
|
||||
|
||||
useEffect(() => {
|
||||
loadThreads();
|
||||
}, [loadThreads]);
|
||||
|
||||
// Handle new thread creation
|
||||
const handleNewThread = async () => {
|
||||
await manager().createNewThread();
|
||||
await loadThreads();
|
||||
};
|
||||
|
||||
// Handle thread actions
|
||||
const handleArchive = async (threadId: number) => {
|
||||
const success = await manager().archiveThread(threadId);
|
||||
if (success) await loadThreads();
|
||||
};
|
||||
|
||||
const handleUnarchive = async (threadId: number) => {
|
||||
const success = await manager().unarchiveThread(threadId);
|
||||
if (success) await loadThreads();
|
||||
};
|
||||
|
||||
const handleDelete = async (threadId: number) => {
|
||||
const success = await manager().deleteThread(threadId);
|
||||
if (success) {
|
||||
await loadThreads();
|
||||
// If we deleted the current thread, redirect to new chat
|
||||
if (threadId === currentThreadId) {
|
||||
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleSwitchToThread = (threadId: number) => {
|
||||
manager().switchToThread(threadId);
|
||||
};
|
||||
|
||||
const displayedThreads = showArchived ? state.archivedThreads : state.threads;
|
||||
|
||||
if (state.isLoading) {
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
<div className="flex items-center justify-center p-4">
|
||||
<span className="text-muted-foreground text-sm">Loading threads...</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (state.error) {
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
<div className="p-4 text-center">
|
||||
<span className="text-destructive text-sm">{state.error}</span>
|
||||
<Button variant="ghost" size="sm" className="mt-2" onClick={loadThreads}>
|
||||
Retry
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Header with New Chat button */}
|
||||
<div className="flex items-center justify-between border-b p-3">
|
||||
<h2 className="font-semibold text-sm">Conversations</h2>
|
||||
<Button variant="ghost" size="icon" className="size-8" onClick={handleNewThread} title="New Chat">
|
||||
<PlusIcon className="size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Tab toggle for active/archived */}
|
||||
<div className="flex border-b">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowArchived(false)}
|
||||
className={cn(
|
||||
"flex-1 px-3 py-2 text-center text-xs font-medium transition-colors",
|
||||
!showArchived
|
||||
? "border-b-2 border-primary text-primary"
|
||||
: "text-muted-foreground hover:text-foreground"
|
||||
)}
|
||||
>
|
||||
Active ({state.threads.length})
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowArchived(true)}
|
||||
className={cn(
|
||||
"flex-1 px-3 py-2 text-center text-xs font-medium transition-colors",
|
||||
showArchived
|
||||
? "border-b-2 border-primary text-primary"
|
||||
: "text-muted-foreground hover:text-foreground"
|
||||
)}
|
||||
>
|
||||
Archived ({state.archivedThreads.length})
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Thread list */}
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
{displayedThreads.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center p-6 text-center">
|
||||
<MessageSquareIcon className="mb-2 size-8 text-muted-foreground/50" />
|
||||
<p className="text-muted-foreground text-sm">
|
||||
{showArchived ? "No archived conversations" : "No conversations yet"}
|
||||
</p>
|
||||
{!showArchived && (
|
||||
<Button variant="outline" size="sm" className="mt-3" onClick={handleNewThread}>
|
||||
<PlusIcon className="mr-1 size-3" />
|
||||
Start a conversation
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-1 p-2">
|
||||
{displayedThreads.map((thread) => (
|
||||
<ThreadListItemComponent
|
||||
key={thread.id}
|
||||
thread={thread}
|
||||
isActive={thread.id === currentThreadId}
|
||||
isArchived={showArchived}
|
||||
onClick={() => handleSwitchToThread(thread.id)}
|
||||
onArchive={() => handleArchive(thread.id)}
|
||||
onUnarchive={() => handleUnarchive(thread.id)}
|
||||
onDelete={() => handleDelete(thread.id)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ThreadListItemComponentProps {
|
||||
thread: ThreadListItem;
|
||||
isActive: boolean;
|
||||
isArchived: boolean;
|
||||
onClick: () => void;
|
||||
onArchive: () => void;
|
||||
onUnarchive: () => void;
|
||||
onDelete: () => void;
|
||||
}
|
||||
|
||||
function ThreadListItemComponent({
|
||||
thread,
|
||||
isActive,
|
||||
isArchived,
|
||||
onClick,
|
||||
onArchive,
|
||||
onUnarchive,
|
||||
onDelete,
|
||||
}: ThreadListItemComponentProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group flex items-center gap-2 rounded-lg px-3 py-2 transition-colors cursor-pointer",
|
||||
isActive ? "bg-accent text-accent-foreground" : "hover:bg-muted/50"
|
||||
)}
|
||||
onClick={onClick}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") onClick();
|
||||
}}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<MessageSquareIcon className="size-4 shrink-0 text-muted-foreground" />
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="truncate text-sm font-medium">{thread.title || "New Chat"}</p>
|
||||
<p className="truncate text-xs text-muted-foreground">
|
||||
{formatRelativeTime(new Date(thread.updatedAt))}
|
||||
</p>
|
||||
</div>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="size-7 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<MoreVerticalIcon className="size-4" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
{isArchived ? (
|
||||
<DropdownMenuItem onClick={onUnarchive}>
|
||||
<RotateCcwIcon className="mr-2 size-4" />
|
||||
Unarchive
|
||||
</DropdownMenuItem>
|
||||
) : (
|
||||
<DropdownMenuItem onClick={onArchive}>
|
||||
<ArchiveIcon className="mr-2 size-4" />
|
||||
Archive
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem onClick={onDelete} className="text-destructive focus:text-destructive">
|
||||
<TrashIcon className="mr-2 size-4" />
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a date as relative time (e.g., "2 hours ago", "Yesterday")
|
||||
*/
|
||||
function formatRelativeTime(date: Date): string {
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffSecs = Math.floor(diffMs / 1000);
|
||||
const diffMins = Math.floor(diffSecs / 60);
|
||||
const diffHours = Math.floor(diffMins / 60);
|
||||
const diffDays = Math.floor(diffHours / 24);
|
||||
|
||||
if (diffSecs < 60) return "Just now";
|
||||
if (diffMins < 60) return `${diffMins} min${diffMins === 1 ? "" : "s"} ago`;
|
||||
if (diffHours < 24) return `${diffHours} hour${diffHours === 1 ? "" : "s"} ago`;
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
return date.toLocaleDateString();
|
||||
}
|
||||
|
|
@ -1,345 +0,0 @@
|
|||
/**
|
||||
* Custom ChatModelAdapter for the new-chat feature using LocalRuntime.
|
||||
* Connects directly to the FastAPI backend using the Vercel AI SDK Data Stream Protocol.
|
||||
*/
|
||||
|
||||
import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react";
|
||||
import { toast } from "sonner";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import {
|
||||
isPodcastGenerating,
|
||||
looksLikePodcastRequest,
|
||||
setActivePodcastTaskId,
|
||||
} from "@/lib/chat/podcast-state";
|
||||
|
||||
interface NewChatAdapterConfig {
|
||||
searchSpaceId: number;
|
||||
chatId: number;
|
||||
}
|
||||
|
||||
interface ChatMessageForBackend {
|
||||
role: "user" | "assistant";
|
||||
content: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts assistant-ui messages to a simple format for the backend
|
||||
*/
|
||||
function convertMessagesToBackendFormat(
|
||||
messages: ChatModelRunOptions["messages"]
|
||||
): ChatMessageForBackend[] {
|
||||
return messages
|
||||
.filter((m) => m.role === "user" || m.role === "assistant")
|
||||
.map((m) => {
|
||||
// Extract text content from the message parts
|
||||
let content = "";
|
||||
for (const part of m.content) {
|
||||
if (part.type === "text") {
|
||||
content += part.text;
|
||||
}
|
||||
}
|
||||
return {
|
||||
role: m.role as "user" | "assistant",
|
||||
content: content.trim(),
|
||||
};
|
||||
})
|
||||
.filter((m) => m.content.length > 0); // Filter out empty messages
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents an in-progress or completed tool call
|
||||
*/
|
||||
interface ToolCallState {
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tools that should render custom UI in the chat.
|
||||
* Other tools (like search_knowledge_base) will be hidden from the UI.
|
||||
*/
|
||||
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
|
||||
|
||||
/**
|
||||
* Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint.
|
||||
*
|
||||
* The backend expects:
|
||||
* - POST /api/v1/new_chat
|
||||
* - Body: { chat_id: number, user_query: string, search_space_id: number, messages: [...] }
|
||||
* - Returns: SSE stream with Vercel AI SDK Data Stream Protocol
|
||||
*/
|
||||
export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAdapter {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
|
||||
return {
|
||||
async *run({ messages, abortSignal }: ChatModelRunOptions) {
|
||||
// Get the last user message
|
||||
const lastUserMessage = messages.filter((m) => m.role === "user").pop();
|
||||
|
||||
if (!lastUserMessage) {
|
||||
throw new Error("No user message found");
|
||||
}
|
||||
|
||||
// Extract text content from the last user message
|
||||
let userQuery = "";
|
||||
for (const part of lastUserMessage.content) {
|
||||
if (part.type === "text") {
|
||||
userQuery += part.text;
|
||||
}
|
||||
}
|
||||
|
||||
if (!userQuery.trim()) {
|
||||
throw new Error("User query cannot be empty");
|
||||
}
|
||||
|
||||
// Check if user is requesting a podcast while one is already generating
|
||||
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
|
||||
toast.warning("A podcast is already being generated. Please wait for it to complete.");
|
||||
// Return a message telling the user to wait
|
||||
yield {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "A podcast is already being generated. Please wait for it to complete before requesting another one.",
|
||||
},
|
||||
],
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
throw new Error("Not authenticated. Please log in again.");
|
||||
}
|
||||
|
||||
// Convert all messages to backend format for chat history
|
||||
const messageHistory = convertMessagesToBackendFormat(messages);
|
||||
|
||||
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
chat_id: config.chatId,
|
||||
user_query: userQuery.trim(),
|
||||
search_space_id: config.searchSpaceId,
|
||||
messages: messageHistory,
|
||||
}),
|
||||
signal: abortSignal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => "Unknown error");
|
||||
throw new Error(`Backend error (${response.status}): ${errorText}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
// Parse the SSE stream (Vercel AI SDK Data Stream Protocol)
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let accumulatedText = "";
|
||||
|
||||
// Track tool calls by their ID
|
||||
const toolCalls = new Map<string, ToolCallState>();
|
||||
|
||||
/**
|
||||
* Build the content array with text and tool calls.
|
||||
* Only includes tools that have custom UI (defined in TOOLS_WITH_UI).
|
||||
*/
|
||||
function buildContent() {
|
||||
const content: Array<
|
||||
| { type: "text"; text: string }
|
||||
| { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown }
|
||||
> = [];
|
||||
|
||||
// Add text content if any
|
||||
if (accumulatedText) {
|
||||
content.push({ type: "text" as const, text: accumulatedText });
|
||||
}
|
||||
|
||||
// Only add tool calls that have custom UI registered
|
||||
// Other tools (like search_knowledge_base) are hidden from the UI
|
||||
for (const toolCall of toolCalls.values()) {
|
||||
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
|
||||
content.push({
|
||||
type: "tool-call" as const,
|
||||
toolCallId: toolCall.toolCallId,
|
||||
toolName: toolCall.toolName,
|
||||
args: toolCall.args,
|
||||
result: toolCall.result,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
buffer += chunk;
|
||||
|
||||
// Split on double newlines (handle both \n\n and \r\n\r\n)
|
||||
const events = buffer.split(/\r?\n\r?\n/);
|
||||
buffer = events.pop() || "";
|
||||
|
||||
for (const event of events) {
|
||||
// Each event can have multiple lines, find the data line
|
||||
const lines = event.split(/\r?\n/);
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data: ")) continue;
|
||||
|
||||
const data = line.slice(6).trim(); // Remove "data: " prefix
|
||||
|
||||
// Handle [DONE] marker
|
||||
if (data === "[DONE]") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!data) continue;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
|
||||
// Handle different message types from the Data Stream Protocol
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
accumulatedText += parsed.delta;
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
|
||||
case "tool-input-start": {
|
||||
// Tool call is starting - create a new tool call entry
|
||||
const { toolCallId, toolName } = parsed;
|
||||
toolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args: {},
|
||||
});
|
||||
// Yield to show tool is starting (running state)
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-input-available": {
|
||||
// Tool input is complete - update the args
|
||||
const { toolCallId, toolName, input } = parsed;
|
||||
const existing = toolCalls.get(toolCallId);
|
||||
if (existing) {
|
||||
existing.args = input || {};
|
||||
} else {
|
||||
// Create new entry if we missed tool-input-start
|
||||
toolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args: input || {},
|
||||
});
|
||||
}
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-output-available": {
|
||||
// Tool execution is complete - add the result
|
||||
const { toolCallId, output } = parsed;
|
||||
const existing = toolCalls.get(toolCallId);
|
||||
if (existing) {
|
||||
existing.result = output;
|
||||
|
||||
// If this is a podcast tool with status="processing", set the state immediately
|
||||
// This ensures subsequent podcast requests are intercepted
|
||||
if (
|
||||
existing.toolName === "generate_podcast" &&
|
||||
output &&
|
||||
typeof output === "object" &&
|
||||
"status" in output &&
|
||||
output.status === "processing" &&
|
||||
"task_id" in output &&
|
||||
typeof output.task_id === "string"
|
||||
) {
|
||||
setActivePodcastTaskId(output.task_id);
|
||||
}
|
||||
}
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
}
|
||||
|
||||
case "error":
|
||||
throw new Error(parsed.errorText || "Unknown error from server");
|
||||
|
||||
// Other types like text-start, text-end, start-step, finish-step, etc.
|
||||
// are handled implicitly
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
// Skip non-JSON lines
|
||||
if (e instanceof SyntaxError) {
|
||||
continue;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle any remaining buffer
|
||||
if (buffer.trim()) {
|
||||
const lines = buffer.split(/\r?\n/);
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data && data !== "[DONE]") {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed.type === "text-delta") {
|
||||
accumulatedText += parsed.delta;
|
||||
yield { content: buildContent() };
|
||||
} else if (parsed.type === "tool-output-available") {
|
||||
const { toolCallId, output } = parsed;
|
||||
const existing = toolCalls.get(toolCallId);
|
||||
if (existing) {
|
||||
existing.result = output;
|
||||
|
||||
// Set podcast state if processing
|
||||
if (
|
||||
existing.toolName === "generate_podcast" &&
|
||||
output &&
|
||||
typeof output === "object" &&
|
||||
"status" in output &&
|
||||
output.status === "processing" &&
|
||||
"task_id" in output &&
|
||||
typeof output.task_id === "string"
|
||||
) {
|
||||
setActivePodcastTaskId(output.task_id);
|
||||
}
|
||||
}
|
||||
yield { content: buildContent() };
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
223
surfsense_web/lib/chat/thread-persistence.ts
Normal file
223
surfsense_web/lib/chat/thread-persistence.ts
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
/**
|
||||
* Thread persistence utilities for the new chat feature.
|
||||
* Provides API functions and thread list management.
|
||||
*/
|
||||
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
|
||||
// =============================================================================
|
||||
// Types matching backend schemas
|
||||
// =============================================================================
|
||||
|
||||
export interface ThreadRecord {
|
||||
id: number;
|
||||
title: string;
|
||||
archived: boolean;
|
||||
search_space_id: number;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export interface MessageRecord {
|
||||
id: number;
|
||||
thread_id: number;
|
||||
role: "user" | "assistant" | "system";
|
||||
content: unknown;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface ThreadListResponse {
|
||||
threads: ThreadListItem[];
|
||||
archived_threads: ThreadListItem[];
|
||||
}
|
||||
|
||||
export interface ThreadListItem {
|
||||
id: number;
|
||||
title: string;
|
||||
archived: boolean;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
}
|
||||
|
||||
export interface ThreadHistoryLoadResponse {
|
||||
messages: MessageRecord[];
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// API Service Functions
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Fetch list of threads for a search space
|
||||
*/
|
||||
export async function fetchThreads(
|
||||
searchSpaceId: number
|
||||
): Promise<ThreadListResponse> {
|
||||
return baseApiService.get<ThreadListResponse>(
|
||||
`/api/v1/threads?search_space_id=${searchSpaceId}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new thread
|
||||
*/
|
||||
export async function createThread(
|
||||
searchSpaceId: number,
|
||||
title = "New Chat"
|
||||
): Promise<ThreadRecord> {
|
||||
return baseApiService.post<ThreadRecord>("/api/v1/threads", undefined, {
|
||||
body: {
|
||||
title,
|
||||
archived: false,
|
||||
search_space_id: searchSpaceId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get thread messages
|
||||
*/
|
||||
export async function getThreadMessages(
|
||||
threadId: number
|
||||
): Promise<ThreadHistoryLoadResponse> {
|
||||
return baseApiService.get<ThreadHistoryLoadResponse>(
|
||||
`/api/v1/threads/${threadId}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Append a message to a thread
|
||||
*/
|
||||
export async function appendMessage(
|
||||
threadId: number,
|
||||
message: { role: "user" | "assistant" | "system"; content: unknown }
|
||||
): Promise<MessageRecord> {
|
||||
return baseApiService.post<MessageRecord>(
|
||||
`/api/v1/threads/${threadId}/messages`,
|
||||
undefined,
|
||||
{ body: message }
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Update thread (rename, archive)
|
||||
*/
|
||||
export async function updateThread(
|
||||
threadId: number,
|
||||
updates: { title?: string; archived?: boolean }
|
||||
): Promise<ThreadRecord> {
|
||||
return baseApiService.put<ThreadRecord>(
|
||||
`/api/v1/threads/${threadId}`,
|
||||
undefined,
|
||||
{ body: updates }
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a thread
|
||||
*/
|
||||
export async function deleteThread(threadId: number): Promise<void> {
|
||||
await baseApiService.delete(`/api/v1/threads/${threadId}`);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Thread List Manager (for thread list sidebar)
|
||||
// =============================================================================
|
||||
|
||||
export interface ThreadListAdapterConfig {
|
||||
searchSpaceId: number;
|
||||
currentThreadId: number | null;
|
||||
onThreadSwitch: (threadId: number) => void;
|
||||
onNewThread: (threadId: number) => void;
|
||||
}
|
||||
|
||||
export interface ThreadListState {
|
||||
threads: ThreadListItem[];
|
||||
archivedThreads: ThreadListItem[];
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a thread list management object.
|
||||
* This provides methods to manage the thread list for the sidebar.
|
||||
*/
|
||||
export function createThreadListManager(config: ThreadListAdapterConfig) {
|
||||
return {
|
||||
async loadThreads(): Promise<ThreadListState> {
|
||||
try {
|
||||
const response = await fetchThreads(config.searchSpaceId);
|
||||
return {
|
||||
threads: response.threads,
|
||||
archivedThreads: response.archived_threads,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("[ThreadListManager] Failed to load threads:", error);
|
||||
return {
|
||||
threads: [],
|
||||
archivedThreads: [],
|
||||
isLoading: false,
|
||||
error:
|
||||
error instanceof Error ? error.message : "Failed to load threads",
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
async createNewThread(title = "New Chat"): Promise<number | null> {
|
||||
try {
|
||||
const thread = await createThread(config.searchSpaceId, title);
|
||||
config.onNewThread(thread.id);
|
||||
return thread.id;
|
||||
} catch (error) {
|
||||
console.error("[ThreadListManager] Failed to create thread:", error);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
switchToThread(threadId: number) {
|
||||
config.onThreadSwitch(threadId);
|
||||
},
|
||||
|
||||
async renameThread(threadId: number, newTitle: string): Promise<boolean> {
|
||||
try {
|
||||
await updateThread(threadId, { title: newTitle });
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error("[ThreadListManager] Failed to rename thread:", error);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
|
||||
async archiveThread(threadId: number): Promise<boolean> {
|
||||
try {
|
||||
await updateThread(threadId, { archived: true });
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error("[ThreadListManager] Failed to archive thread:", error);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
|
||||
async unarchiveThread(threadId: number): Promise<boolean> {
|
||||
try {
|
||||
await updateThread(threadId, { archived: false });
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error("[ThreadListManager] Failed to unarchive thread:", error);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
|
||||
async deleteThread(threadId: number): Promise<boolean> {
|
||||
try {
|
||||
await deleteThread(threadId);
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error("[ThreadListManager] Failed to delete thread:", error);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue