mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
feat: implement new chat feature with message persistence and UI integration
This commit is contained in:
parent
f115980d2b
commit
0c3574d049
16 changed files with 1814 additions and 397 deletions
|
|
@ -18,59 +18,59 @@ _checkpointer_initialized: bool = False
|
||||||
def get_postgres_connection_string() -> str:
|
def get_postgres_connection_string() -> str:
|
||||||
"""
|
"""
|
||||||
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
|
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
|
||||||
|
|
||||||
The DATABASE_URL is typically in format:
|
The DATABASE_URL is typically in format:
|
||||||
postgresql+asyncpg://user:pass@host:port/dbname
|
postgresql+asyncpg://user:pass@host:port/dbname
|
||||||
|
|
||||||
We need to convert it to:
|
We need to convert it to:
|
||||||
postgresql://user:pass@host:port/dbname
|
postgresql://user:pass@host:port/dbname
|
||||||
"""
|
"""
|
||||||
db_url = config.DATABASE_URL
|
db_url = config.DATABASE_URL
|
||||||
|
|
||||||
# Handle asyncpg driver prefix
|
# Handle asyncpg driver prefix
|
||||||
if db_url.startswith("postgresql+asyncpg://"):
|
if db_url.startswith("postgresql+asyncpg://"):
|
||||||
return db_url.replace("postgresql+asyncpg://", "postgresql://")
|
return db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
|
|
||||||
# Handle other async prefixes
|
# Handle other async prefixes
|
||||||
if "+asyncpg" in db_url:
|
if "+asyncpg" in db_url:
|
||||||
return db_url.replace("+asyncpg", "")
|
return db_url.replace("+asyncpg", "")
|
||||||
|
|
||||||
return db_url
|
return db_url
|
||||||
|
|
||||||
|
|
||||||
async def get_checkpointer() -> AsyncPostgresSaver:
|
async def get_checkpointer() -> AsyncPostgresSaver:
|
||||||
"""
|
"""
|
||||||
Get or create the global AsyncPostgresSaver instance.
|
Get or create the global AsyncPostgresSaver instance.
|
||||||
|
|
||||||
This function:
|
This function:
|
||||||
1. Creates the checkpointer if it doesn't exist
|
1. Creates the checkpointer if it doesn't exist
|
||||||
2. Sets up the required database tables on first call
|
2. Sets up the required database tables on first call
|
||||||
3. Returns the cached instance on subsequent calls
|
3. Returns the cached instance on subsequent calls
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncPostgresSaver: The configured checkpointer instance
|
AsyncPostgresSaver: The configured checkpointer instance
|
||||||
"""
|
"""
|
||||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||||
|
|
||||||
if _checkpointer is None:
|
if _checkpointer is None:
|
||||||
conn_string = get_postgres_connection_string()
|
conn_string = get_postgres_connection_string()
|
||||||
# from_conn_string returns an async context manager
|
# from_conn_string returns an async context manager
|
||||||
# We need to enter the context to get the actual checkpointer
|
# We need to enter the context to get the actual checkpointer
|
||||||
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
|
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
|
||||||
_checkpointer = await _checkpointer_context.__aenter__()
|
_checkpointer = await _checkpointer_context.__aenter__()
|
||||||
|
|
||||||
# Setup tables on first call (idempotent)
|
# Setup tables on first call (idempotent)
|
||||||
if not _checkpointer_initialized:
|
if not _checkpointer_initialized:
|
||||||
await _checkpointer.setup()
|
await _checkpointer.setup()
|
||||||
_checkpointer_initialized = True
|
_checkpointer_initialized = True
|
||||||
|
|
||||||
return _checkpointer
|
return _checkpointer
|
||||||
|
|
||||||
|
|
||||||
async def setup_checkpointer_tables() -> None:
|
async def setup_checkpointer_tables() -> None:
|
||||||
"""
|
"""
|
||||||
Explicitly setup the checkpointer tables.
|
Explicitly setup the checkpointer tables.
|
||||||
|
|
||||||
This can be called during application startup to ensure
|
This can be called during application startup to ensure
|
||||||
tables exist before any agent calls.
|
tables exist before any agent calls.
|
||||||
"""
|
"""
|
||||||
|
|
@ -81,15 +81,14 @@ async def setup_checkpointer_tables() -> None:
|
||||||
async def close_checkpointer() -> None:
|
async def close_checkpointer() -> None:
|
||||||
"""
|
"""
|
||||||
Close the checkpointer connection.
|
Close the checkpointer connection.
|
||||||
|
|
||||||
This should be called during application shutdown.
|
This should be called during application shutdown.
|
||||||
"""
|
"""
|
||||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||||
|
|
||||||
if _checkpointer_context is not None:
|
if _checkpointer_context is not None:
|
||||||
await _checkpointer_context.__aexit__(None, None, None)
|
await _checkpointer_context.__aexit__(None, None, None)
|
||||||
_checkpointer = None
|
_checkpointer = None
|
||||||
_checkpointer_context = None
|
_checkpointer_context = None
|
||||||
_checkpointer_initialized = False
|
_checkpointer_initialized = False
|
||||||
print("[Checkpointer] PostgreSQL connection closed")
|
print("[Checkpointer] PostgreSQL connection closed")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,4 +81,4 @@ async def run_test():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(run_test())
|
asyncio.run(run_test())
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,9 @@ def create_generate_podcast_tool(
|
||||||
# Check if a podcast is already being generated for this search space
|
# Check if a podcast is already being generated for this search space
|
||||||
active_task_id = get_active_podcast_task(search_space_id)
|
active_task_id = get_active_podcast_task(search_space_id)
|
||||||
if active_task_id:
|
if active_task_id:
|
||||||
print(f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}")
|
print(
|
||||||
|
f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}"
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"status": "already_generating",
|
"status": "already_generating",
|
||||||
"task_id": active_task_id,
|
"task_id": active_task_id,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||||
|
|
||||||
from app.agents.new_chat.checkpointer import close_checkpointer, setup_checkpointer_tables
|
from app.agents.new_chat.checkpointer import (
|
||||||
|
close_checkpointer,
|
||||||
|
setup_checkpointer_tables,
|
||||||
|
)
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
from app.routes import router as crud_router
|
from app.routes import router as crud_router
|
||||||
|
|
|
||||||
|
|
@ -332,6 +332,75 @@ class Chat(BaseModel, TimestampMixin):
|
||||||
search_space = relationship("SearchSpace", back_populates="chats")
|
search_space = relationship("SearchSpace", back_populates="chats")
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatMessageRole(str, Enum):
|
||||||
|
"""Role enum for new chat messages."""
|
||||||
|
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatThread(BaseModel, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Thread model for the new chat feature using assistant-ui.
|
||||||
|
Each thread represents a conversation with message history.
|
||||||
|
LangGraph checkpointer uses thread_id for state persistence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "new_chat_threads"
|
||||||
|
|
||||||
|
title = Column(String(500), nullable=False, default="New Chat", index=True)
|
||||||
|
archived = Column(Boolean, nullable=False, default=False)
|
||||||
|
updated_at = Column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
default=lambda: datetime.now(UTC),
|
||||||
|
onupdate=lambda: datetime.now(UTC),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Foreign keys
|
||||||
|
search_space_id = Column(
|
||||||
|
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||||
|
messages = relationship(
|
||||||
|
"NewChatMessage",
|
||||||
|
back_populates="thread",
|
||||||
|
order_by="NewChatMessage.created_at",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatMessage(BaseModel, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Message model for the new chat feature.
|
||||||
|
Stores individual messages in assistant-ui format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "new_chat_messages"
|
||||||
|
|
||||||
|
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
||||||
|
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
||||||
|
content = Column(JSONB, nullable=False)
|
||||||
|
|
||||||
|
# Foreign key to thread
|
||||||
|
thread_id = Column(
|
||||||
|
Integer,
|
||||||
|
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
thread = relationship("NewChatThread", back_populates="messages")
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel, TimestampMixin):
|
class Document(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "documents"
|
__tablename__ = "documents"
|
||||||
|
|
||||||
|
|
@ -435,6 +504,12 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
order_by="Chat.id",
|
order_by="Chat.id",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
new_chat_threads = relationship(
|
||||||
|
"NewChatThread",
|
||||||
|
back_populates="search_space",
|
||||||
|
order_by="NewChatThread.updated_at.desc()",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
logs = relationship(
|
logs = relationship(
|
||||||
"Log",
|
"Log",
|
||||||
back_populates="search_space",
|
back_populates="search_space",
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from .google_gmail_add_connector_route import (
|
||||||
from .llm_config_routes import router as llm_config_router
|
from .llm_config_routes import router as llm_config_router
|
||||||
from .logs_routes import router as logs_router
|
from .logs_routes import router as logs_router
|
||||||
from .luma_add_connector_route import router as luma_add_connector_router
|
from .luma_add_connector_route import router as luma_add_connector_router
|
||||||
|
from .new_chat_routes import router as new_chat_router
|
||||||
from .notes_routes import router as notes_router
|
from .notes_routes import router as notes_router
|
||||||
from .podcasts_routes import router as podcasts_router
|
from .podcasts_routes import router as podcasts_router
|
||||||
from .rbac_routes import router as rbac_router
|
from .rbac_routes import router as rbac_router
|
||||||
|
|
@ -30,6 +31,7 @@ router.include_router(documents_router)
|
||||||
router.include_router(notes_router)
|
router.include_router(notes_router)
|
||||||
router.include_router(podcasts_router)
|
router.include_router(podcasts_router)
|
||||||
router.include_router(chats_router)
|
router.include_router(chats_router)
|
||||||
|
router.include_router(new_chat_router) # New chat with assistant-ui persistence
|
||||||
router.include_router(search_source_connectors_router)
|
router.include_router(search_source_connectors_router)
|
||||||
router.include_router(google_calendar_add_connector_router)
|
router.include_router(google_calendar_add_connector_router)
|
||||||
router.include_router(google_gmail_add_connector_router)
|
router.include_router(google_gmail_add_connector_router)
|
||||||
|
|
|
||||||
597
surfsense_backend/app/routes/new_chat_routes.py
Normal file
597
surfsense_backend/app/routes/new_chat_routes.py
Normal file
|
|
@ -0,0 +1,597 @@
|
||||||
|
"""
|
||||||
|
Routes for the new chat feature with assistant-ui integration.
|
||||||
|
|
||||||
|
These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
||||||
|
- GET /threads - List threads for sidebar (ThreadListPrimitive)
|
||||||
|
- POST /threads - Create a new thread
|
||||||
|
- GET /threads/{thread_id} - Get thread with messages (load)
|
||||||
|
- PUT /threads/{thread_id} - Update thread (rename, archive)
|
||||||
|
- DELETE /threads/{thread_id} - Delete thread
|
||||||
|
- POST /threads/{thread_id}/messages - Append message
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
NewChatMessage,
|
||||||
|
NewChatMessageRole,
|
||||||
|
NewChatThread,
|
||||||
|
Permission,
|
||||||
|
User,
|
||||||
|
get_async_session,
|
||||||
|
)
|
||||||
|
from app.schemas.new_chat import (
|
||||||
|
NewChatMessageAppend,
|
||||||
|
NewChatMessageRead,
|
||||||
|
NewChatThreadCreate,
|
||||||
|
NewChatThreadRead,
|
||||||
|
NewChatThreadUpdate,
|
||||||
|
NewChatThreadWithMessages,
|
||||||
|
ThreadHistoryLoadResponse,
|
||||||
|
ThreadListItem,
|
||||||
|
ThreadListResponse,
|
||||||
|
)
|
||||||
|
from app.users import current_active_user
|
||||||
|
from app.utils.rbac import check_permission
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Thread Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads", response_model=ThreadListResponse)
|
||||||
|
async def list_threads(
|
||||||
|
search_space_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all threads for the current user in a search space.
|
||||||
|
Returns threads and archived_threads for ThreadListPrimitive.
|
||||||
|
|
||||||
|
Requires CHATS_READ permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to read chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all threads for this user in this search space
|
||||||
|
query = (
|
||||||
|
select(NewChatThread)
|
||||||
|
.filter(
|
||||||
|
NewChatThread.search_space_id == search_space_id,
|
||||||
|
NewChatThread.user_id == user.id,
|
||||||
|
)
|
||||||
|
.order_by(NewChatThread.updated_at.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
all_threads = result.scalars().all()
|
||||||
|
|
||||||
|
# Separate active and archived threads
|
||||||
|
threads = []
|
||||||
|
archived_threads = []
|
||||||
|
|
||||||
|
for thread in all_threads:
|
||||||
|
item = ThreadListItem(
|
||||||
|
id=thread.id,
|
||||||
|
title=thread.title,
|
||||||
|
archived=thread.archived,
|
||||||
|
createdAt=thread.created_at,
|
||||||
|
updatedAt=thread.updated_at,
|
||||||
|
)
|
||||||
|
if thread.archived:
|
||||||
|
archived_threads.append(item)
|
||||||
|
else:
|
||||||
|
threads.append(item)
|
||||||
|
|
||||||
|
return ThreadListResponse(threads=threads, archived_threads=archived_threads)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except OperationalError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while fetching threads: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads", response_model=NewChatThreadRead)
|
||||||
|
async def create_thread(
|
||||||
|
thread: NewChatThreadCreate,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new chat thread.
|
||||||
|
|
||||||
|
Requires CHATS_CREATE permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_CREATE.value,
|
||||||
|
"You don't have permission to create chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
db_thread = NewChatThread(
|
||||||
|
title=thread.title,
|
||||||
|
archived=thread.archived,
|
||||||
|
search_space_id=thread.search_space_id,
|
||||||
|
user_id=user.id,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
session.add(db_thread)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_thread)
|
||||||
|
return db_thread
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except IntegrityError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
|
except OperationalError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while creating the thread: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse)
|
||||||
|
async def get_thread_messages(
|
||||||
|
thread_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a thread with all its messages.
|
||||||
|
This is used by ThreadHistoryAdapter.load() to restore conversation.
|
||||||
|
|
||||||
|
Requires CHATS_READ permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get thread with messages
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread)
|
||||||
|
.options(selectinload(NewChatThread.messages))
|
||||||
|
.filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
# Check permission and ownership
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to read chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure user owns this thread
|
||||||
|
if thread.user_id != user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="You don't have access to this thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return messages in the format expected by assistant-ui
|
||||||
|
messages = [
|
||||||
|
NewChatMessageRead(
|
||||||
|
id=msg.id,
|
||||||
|
thread_id=msg.thread_id,
|
||||||
|
role=msg.role,
|
||||||
|
content=msg.content,
|
||||||
|
created_at=msg.created_at,
|
||||||
|
)
|
||||||
|
for msg in thread.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
return ThreadHistoryLoadResponse(messages=messages)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except OperationalError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages)
|
||||||
|
async def get_thread_full(
|
||||||
|
thread_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get full thread details with all messages.
|
||||||
|
|
||||||
|
Requires CHATS_READ permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread)
|
||||||
|
.options(selectinload(NewChatThread.messages))
|
||||||
|
.filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to read chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
if thread.user_id != user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="You don't have access to this thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
return thread
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except OperationalError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/threads/{thread_id}", response_model=NewChatThreadRead)
|
||||||
|
async def update_thread(
|
||||||
|
thread_id: int,
|
||||||
|
thread_update: NewChatThreadUpdate,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a thread (title, archived status).
|
||||||
|
Used for renaming and archiving threads.
|
||||||
|
|
||||||
|
Requires CHATS_UPDATE permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
db_thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not db_thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
db_thread.search_space_id,
|
||||||
|
Permission.CHATS_UPDATE.value,
|
||||||
|
"You don't have permission to update chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_thread.user_id != user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="You don't have access to this thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
update_data = thread_update.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(db_thread, key, value)
|
||||||
|
|
||||||
|
db_thread.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_thread)
|
||||||
|
return db_thread
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except IntegrityError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
|
except OperationalError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while updating the thread: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/threads/{thread_id}", response_model=dict)
|
||||||
|
async def delete_thread(
|
||||||
|
thread_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a thread and all its messages.
|
||||||
|
|
||||||
|
Requires CHATS_DELETE permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
db_thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not db_thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
db_thread.search_space_id,
|
||||||
|
Permission.CHATS_DELETE.value,
|
||||||
|
"You don't have permission to delete chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_thread.user_id != user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="You don't have access to this thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.delete(db_thread)
|
||||||
|
await session.commit()
|
||||||
|
return {"message": "Thread deleted successfully"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except IntegrityError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Cannot delete thread due to existing dependencies."
|
||||||
|
) from None
|
||||||
|
except OperationalError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while deleting the thread: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Message Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead)
|
||||||
|
async def append_message(
|
||||||
|
thread_id: int,
|
||||||
|
request: Request,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Append a message to a thread.
|
||||||
|
This is used by ThreadHistoryAdapter.append() to persist messages.
|
||||||
|
|
||||||
|
Requires CHATS_UPDATE permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse raw body - extract only role and content, ignoring extra fields
|
||||||
|
raw_body = await request.json()
|
||||||
|
role = raw_body.get("role")
|
||||||
|
content = raw_body.get("content")
|
||||||
|
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing required field: role")
|
||||||
|
if content is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing required field: content"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create message object manually
|
||||||
|
message = NewChatMessageAppend(role=role, content=content)
|
||||||
|
# Get thread
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_UPDATE.value,
|
||||||
|
"You don't have permission to update chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
if thread.user_id != user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="You don't have access to this thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert string role to enum
|
||||||
|
role_str = (
|
||||||
|
message.role.lower() if isinstance(message.role, str) else message.role
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
message_role = NewChatMessageRole(role_str)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid role: {message.role}. Must be 'user', 'assistant', or 'system'.",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
# Create message
|
||||||
|
db_message = NewChatMessage(
|
||||||
|
thread_id=thread_id,
|
||||||
|
role=message_role,
|
||||||
|
content=message.content,
|
||||||
|
)
|
||||||
|
session.add(db_message)
|
||||||
|
|
||||||
|
# Update thread's updated_at timestamp
|
||||||
|
thread.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Auto-generate title from first user message if title is still default
|
||||||
|
if thread.title == "New Chat" and role_str == "user":
|
||||||
|
# Extract text content for title
|
||||||
|
content = message.content
|
||||||
|
if isinstance(content, str):
|
||||||
|
title_text = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Find first text content
|
||||||
|
title_text = ""
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
title_text = part.get("text", "")
|
||||||
|
break
|
||||||
|
elif isinstance(part, str):
|
||||||
|
title_text = part
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
title_text = str(content)
|
||||||
|
|
||||||
|
# Truncate title
|
||||||
|
if title_text:
|
||||||
|
thread.title = title_text[:100] + (
|
||||||
|
"..." if len(title_text) > 100 else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_message)
|
||||||
|
return db_message
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except IntegrityError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
|
except OperationalError:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while appending the message: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead])
|
||||||
|
async def list_messages(
|
||||||
|
thread_id: int,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List messages in a thread with pagination.
|
||||||
|
|
||||||
|
Requires CHATS_READ permission.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify thread exists and user has access
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to read chats in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
if thread.user_id != user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="You don't have access to this thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get messages
|
||||||
|
query = (
|
||||||
|
select(NewChatMessage)
|
||||||
|
.filter(NewChatMessage.thread_id == thread_id)
|
||||||
|
.order_by(NewChatMessage.created_at)
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except OperationalError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An unexpected error occurred while fetching messages: {e!s}",
|
||||||
|
) from None
|
||||||
|
|
@ -21,6 +21,18 @@ from .documents import (
|
||||||
)
|
)
|
||||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||||
|
from .new_chat import (
|
||||||
|
NewChatMessageAppend,
|
||||||
|
NewChatMessageCreate,
|
||||||
|
NewChatMessageRead,
|
||||||
|
NewChatThreadCreate,
|
||||||
|
NewChatThreadRead,
|
||||||
|
NewChatThreadUpdate,
|
||||||
|
NewChatThreadWithMessages,
|
||||||
|
ThreadHistoryLoadResponse,
|
||||||
|
ThreadListItem,
|
||||||
|
ThreadListResponse,
|
||||||
|
)
|
||||||
from .podcasts import (
|
from .podcasts import (
|
||||||
PodcastBase,
|
PodcastBase,
|
||||||
PodcastCreate,
|
PodcastCreate,
|
||||||
|
|
@ -98,7 +110,15 @@ __all__ = [
|
||||||
"MembershipRead",
|
"MembershipRead",
|
||||||
"MembershipReadWithUser",
|
"MembershipReadWithUser",
|
||||||
"MembershipUpdate",
|
"MembershipUpdate",
|
||||||
|
# New chat schemas (assistant-ui integration)
|
||||||
|
"NewChatMessageAppend",
|
||||||
|
"NewChatMessageCreate",
|
||||||
|
"NewChatMessageRead",
|
||||||
"NewChatRequest",
|
"NewChatRequest",
|
||||||
|
"NewChatThreadCreate",
|
||||||
|
"NewChatThreadRead",
|
||||||
|
"NewChatThreadUpdate",
|
||||||
|
"NewChatThreadWithMessages",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
"PermissionInfo",
|
"PermissionInfo",
|
||||||
"PermissionsListResponse",
|
"PermissionsListResponse",
|
||||||
|
|
@ -119,6 +139,9 @@ __all__ = [
|
||||||
"SearchSpaceRead",
|
"SearchSpaceRead",
|
||||||
"SearchSpaceUpdate",
|
"SearchSpaceUpdate",
|
||||||
"SearchSpaceWithStats",
|
"SearchSpaceWithStats",
|
||||||
|
"ThreadHistoryLoadResponse",
|
||||||
|
"ThreadListItem",
|
||||||
|
"ThreadListResponse",
|
||||||
"TimestampModel",
|
"TimestampModel",
|
||||||
"UserCreate",
|
"UserCreate",
|
||||||
"UserRead",
|
"UserRead",
|
||||||
|
|
|
||||||
129
surfsense_backend/app/schemas/new_chat.py
Normal file
129
surfsense_backend/app/schemas/new_chat.py
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
"""
|
||||||
|
Pydantic schemas for the new chat feature with assistant-ui integration.
|
||||||
|
|
||||||
|
These schemas follow the assistant-ui ThreadHistoryAdapter pattern:
|
||||||
|
- ThreadRecord: id, title, archived, createdAt, updatedAt
|
||||||
|
- MessageRecord: id, threadId, role, content, createdAt
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from app.db import NewChatMessageRole
|
||||||
|
|
||||||
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Message Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatMessageBase(BaseModel):
|
||||||
|
"""Base schema for new chat messages."""
|
||||||
|
|
||||||
|
role: NewChatMessageRole
|
||||||
|
content: Any # JSONB content - can be text, tool calls, etc.
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatMessageCreate(NewChatMessageBase):
|
||||||
|
"""Schema for creating a new message."""
|
||||||
|
|
||||||
|
thread_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||||
|
"""Schema for reading a message."""
|
||||||
|
|
||||||
|
thread_id: int
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatMessageAppend(BaseModel):
|
||||||
|
"""
|
||||||
|
Schema for appending a message via the history adapter.
|
||||||
|
This is the format assistant-ui sends when calling append().
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: str # Accept string and validate in route handler
|
||||||
|
content: Any
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Thread Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatThreadBase(BaseModel):
|
||||||
|
"""Base schema for new chat threads."""
|
||||||
|
|
||||||
|
title: str = Field(default="New Chat", max_length=500)
|
||||||
|
archived: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatThreadCreate(NewChatThreadBase):
|
||||||
|
"""Schema for creating a new thread."""
|
||||||
|
|
||||||
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatThreadUpdate(BaseModel):
|
||||||
|
"""Schema for updating a thread."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
archived: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatThreadRead(NewChatThreadBase, IDModel):
|
||||||
|
"""
|
||||||
|
Schema for reading a thread (matches assistant-ui ThreadRecord).
|
||||||
|
"""
|
||||||
|
|
||||||
|
search_space_id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class NewChatThreadWithMessages(NewChatThreadRead):
|
||||||
|
"""Schema for reading a thread with its messages."""
|
||||||
|
|
||||||
|
messages: list[NewChatMessageRead] = []
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# History Adapter Response Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadHistoryLoadResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response format for the ThreadHistoryAdapter.load() method.
|
||||||
|
Returns messages array for the current thread.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: list[NewChatMessageRead]
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadListItem(BaseModel):
|
||||||
|
"""
|
||||||
|
Thread list item for sidebar display.
|
||||||
|
Matches assistant-ui ThreadListPrimitive expected format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
title: str
|
||||||
|
archived: bool
|
||||||
|
created_at: datetime = Field(alias="createdAt")
|
||||||
|
updated_at: datetime = Field(alias="updatedAt")
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True, populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadListResponse(BaseModel):
|
||||||
|
"""Response containing list of threads for the sidebar."""
|
||||||
|
|
||||||
|
threads: list[ThreadListItem]
|
||||||
|
archived_threads: list[ThreadListItem]
|
||||||
|
|
@ -7,14 +7,13 @@ import sys
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.pool import NullPool
|
from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
from app.config import config
|
|
||||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
|
||||||
|
|
||||||
# Import for content-based podcast (new-chat)
|
# Import for content-based podcast (new-chat)
|
||||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||||
from app.agents.podcaster.state import State as PodcasterState
|
from app.agents.podcaster.state import State as PodcasterState
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.config import config
|
||||||
from app.db import Podcast
|
from app.db import Podcast
|
||||||
|
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -201,15 +200,16 @@ async def _generate_content_podcast(
|
||||||
serializable_transcript = []
|
serializable_transcript = []
|
||||||
for entry in podcast_transcript:
|
for entry in podcast_transcript:
|
||||||
if hasattr(entry, "speaker_id"):
|
if hasattr(entry, "speaker_id"):
|
||||||
serializable_transcript.append({
|
serializable_transcript.append(
|
||||||
"speaker_id": entry.speaker_id,
|
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
||||||
"dialog": entry.dialog
|
)
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
serializable_transcript.append({
|
serializable_transcript.append(
|
||||||
"speaker_id": entry.get("speaker_id", 0),
|
{
|
||||||
"dialog": entry.get("dialog", "")
|
"speaker_id": entry.get("speaker_id", 0),
|
||||||
})
|
"dialog": entry.get("dialog", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Save podcast to database
|
# Save podcast to database
|
||||||
podcast = Podcast(
|
podcast = Podcast(
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,15 @@ import json
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
from app.agents.new_chat.llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
from app.agents.new_chat.llm_config import (
|
||||||
|
create_chat_litellm_from_config,
|
||||||
|
load_llm_config_from_yaml,
|
||||||
|
)
|
||||||
from app.schemas.chats import ChatMessage
|
from app.schemas.chats import ChatMessage
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
@ -92,7 +95,7 @@ async def stream_new_chat(
|
||||||
|
|
||||||
# Build input with message history from frontend
|
# Build input with message history from frontend
|
||||||
langchain_messages = []
|
langchain_messages = []
|
||||||
|
|
||||||
# if messages:
|
# if messages:
|
||||||
# # Convert frontend messages to LangChain format
|
# # Convert frontend messages to LangChain format
|
||||||
# for msg in messages:
|
# for msg in messages:
|
||||||
|
|
@ -101,9 +104,9 @@ async def stream_new_chat(
|
||||||
# elif msg.role == "assistant":
|
# elif msg.role == "assistant":
|
||||||
# langchain_messages.append(AIMessage(content=msg.content))
|
# langchain_messages.append(AIMessage(content=msg.content))
|
||||||
# else:
|
# else:
|
||||||
# Fallback: just use the current user query
|
# Fallback: just use the current user query
|
||||||
langchain_messages.append(HumanMessage(content=user_query))
|
langchain_messages.append(HumanMessage(content=user_query))
|
||||||
|
|
||||||
input_state = {
|
input_state = {
|
||||||
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
|
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
|
||||||
# We will use this to simulate group chat functionality in the future
|
# We will use this to simulate group chat functionality in the future
|
||||||
|
|
@ -219,7 +222,9 @@ async def stream_new_chat(
|
||||||
elif isinstance(raw_output, dict):
|
elif isinstance(raw_output, dict):
|
||||||
tool_output = raw_output
|
tool_output = raw_output
|
||||||
else:
|
else:
|
||||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
tool_output = {
|
||||||
|
"result": str(raw_output) if raw_output else "completed"
|
||||||
|
}
|
||||||
|
|
||||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||||
|
|
||||||
|
|
@ -228,16 +233,25 @@ async def stream_new_chat(
|
||||||
# Stream the full podcast result so frontend can render the audio player
|
# Stream the full podcast result so frontend can render the audio player
|
||||||
yield streaming_service.format_tool_output_available(
|
yield streaming_service.format_tool_output_available(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_output if isinstance(tool_output, dict) else {"result": tool_output},
|
tool_output
|
||||||
|
if isinstance(tool_output, dict)
|
||||||
|
else {"result": tool_output},
|
||||||
)
|
)
|
||||||
# Send appropriate terminal message based on status
|
# Send appropriate terminal message based on status
|
||||||
if isinstance(tool_output, dict) and tool_output.get("status") == "success":
|
if (
|
||||||
|
isinstance(tool_output, dict)
|
||||||
|
and tool_output.get("status") == "success"
|
||||||
|
):
|
||||||
yield streaming_service.format_terminal_info(
|
yield streaming_service.format_terminal_info(
|
||||||
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
||||||
"success",
|
"success",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
error_msg = tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error"
|
error_msg = (
|
||||||
|
tool_output.get("error", "Unknown error")
|
||||||
|
if isinstance(tool_output, dict)
|
||||||
|
else "Unknown error"
|
||||||
|
)
|
||||||
yield streaming_service.format_terminal_info(
|
yield streaming_service.format_terminal_info(
|
||||||
f"Podcast generation failed: {error_msg}",
|
f"Podcast generation failed: {error_msg}",
|
||||||
"error",
|
"error",
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,9 @@ export function DashboardClientLayout({
|
||||||
// Check if we're on the researcher page
|
// Check if we're on the researcher page
|
||||||
const isResearcherPage = pathname?.includes("/researcher");
|
const isResearcherPage = pathname?.includes("/researcher");
|
||||||
|
|
||||||
|
// Check if we're on the new-chat page (uses separate thread persistence)
|
||||||
|
const isNewChatPage = pathname?.includes("/new-chat");
|
||||||
|
|
||||||
// Show indicator when chat becomes active and panel is closed
|
// Show indicator when chat becomes active and panel is closed
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (activeChatId && !isChatPannelOpen) {
|
if (activeChatId && !isChatPannelOpen) {
|
||||||
|
|
@ -151,6 +154,12 @@ export function DashboardClientLayout({
|
||||||
}, [search_space_id]);
|
}, [search_space_id]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
// Skip setting activeChatIdAtom on new-chat page (uses separate thread persistence)
|
||||||
|
if (isNewChatPage) {
|
||||||
|
setActiveChatIdState(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const activeChatId =
|
const activeChatId =
|
||||||
typeof chat_id === "string"
|
typeof chat_id === "string"
|
||||||
? chat_id
|
? chat_id
|
||||||
|
|
@ -159,7 +168,7 @@ export function DashboardClientLayout({
|
||||||
: "";
|
: "";
|
||||||
if (!activeChatId) return;
|
if (!activeChatId) return;
|
||||||
setActiveChatIdState(activeChatId);
|
setActiveChatIdState(activeChatId);
|
||||||
}, [chat_id, search_space_id]);
|
}, [chat_id, search_space_id, isNewChatPage]);
|
||||||
|
|
||||||
// Show loading screen while checking onboarding status (only on first load)
|
// Show loading screen while checking onboarding status (only on first load)
|
||||||
if (!hasCheckedOnboarding && (loading || accessLoading) && !isOnboardingPage) {
|
if (!hasCheckedOnboarding && (loading || accessLoading) && !isOnboardingPage) {
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,73 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { AssistantRuntimeProvider, useLocalRuntime } from "@assistant-ui/react";
|
import {
|
||||||
import { useParams } from "next/navigation";
|
AssistantRuntimeProvider,
|
||||||
import { useMemo } from "react";
|
useExternalStoreRuntime,
|
||||||
|
type ThreadMessageLike,
|
||||||
|
} from "@assistant-ui/react";
|
||||||
|
import { useParams, useRouter } from "next/navigation";
|
||||||
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { Thread } from "@/components/assistant-ui/thread";
|
import { Thread } from "@/components/assistant-ui/thread";
|
||||||
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
||||||
import { createNewChatAdapter } from "@/lib/chat/new-chat-transport";
|
import {
|
||||||
|
createThread,
|
||||||
|
getThreadMessages,
|
||||||
|
appendMessage,
|
||||||
|
type MessageRecord,
|
||||||
|
} from "@/lib/chat/thread-persistence";
|
||||||
|
import { getBearerToken } from "@/lib/auth-utils";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import {
|
||||||
|
isPodcastGenerating,
|
||||||
|
looksLikePodcastRequest,
|
||||||
|
setActivePodcastTaskId,
|
||||||
|
} from "@/lib/chat/podcast-state";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert backend message to assistant-ui ThreadMessageLike format
|
||||||
|
*/
|
||||||
|
function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
||||||
|
let content: ThreadMessageLike["content"];
|
||||||
|
|
||||||
|
if (typeof msg.content === "string") {
|
||||||
|
content = [{ type: "text", text: msg.content }];
|
||||||
|
} else if (Array.isArray(msg.content)) {
|
||||||
|
content = msg.content as ThreadMessageLike["content"];
|
||||||
|
} else {
|
||||||
|
content = [{ type: "text", text: String(msg.content) }];
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: `msg-${msg.id}`,
|
||||||
|
role: msg.role,
|
||||||
|
content,
|
||||||
|
createdAt: new Date(msg.created_at),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tools that should render custom UI in the chat.
|
||||||
|
*/
|
||||||
|
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
|
||||||
|
|
||||||
export default function NewChatPage() {
|
export default function NewChatPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
|
const router = useRouter();
|
||||||
|
const [isInitializing, setIsInitializing] = useState(true);
|
||||||
|
const [threadId, setThreadId] = useState<number | null>(null);
|
||||||
|
const [messages, setMessages] = useState<ThreadMessageLike[]>([]);
|
||||||
|
const [isRunning, setIsRunning] = useState(false);
|
||||||
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
// Extract search_space_id and chat_id from URL params
|
// Extract search_space_id from URL params
|
||||||
const searchSpaceId = useMemo(() => {
|
const searchSpaceId = useMemo(() => {
|
||||||
const id = params.search_space_id;
|
const id = params.search_space_id;
|
||||||
const parsed = typeof id === "string" ? Number.parseInt(id, 10) : 0;
|
const parsed = typeof id === "string" ? Number.parseInt(id, 10) : 0;
|
||||||
return Number.isNaN(parsed) ? 0 : parsed;
|
return Number.isNaN(parsed) ? 0 : parsed;
|
||||||
}, [params.search_space_id]);
|
}, [params.search_space_id]);
|
||||||
|
|
||||||
const chatId = useMemo(() => {
|
// Extract chat_id from URL params
|
||||||
|
const urlChatId = useMemo(() => {
|
||||||
const id = params.chat_id;
|
const id = params.chat_id;
|
||||||
let parsed = 0;
|
let parsed = 0;
|
||||||
if (Array.isArray(id) && id.length > 0) {
|
if (Array.isArray(id) && id.length > 0) {
|
||||||
|
|
@ -28,18 +78,368 @@ export default function NewChatPage() {
|
||||||
return Number.isNaN(parsed) ? 0 : parsed;
|
return Number.isNaN(parsed) ? 0 : parsed;
|
||||||
}, [params.chat_id]);
|
}, [params.chat_id]);
|
||||||
|
|
||||||
// Create the adapter with the extracted params
|
// Initialize thread and load messages
|
||||||
const adapter = useMemo(
|
const initializeThread = useCallback(async () => {
|
||||||
() => createNewChatAdapter({ searchSpaceId, chatId }),
|
setIsInitializing(true);
|
||||||
[searchSpaceId, chatId]
|
|
||||||
|
try {
|
||||||
|
if (urlChatId > 0) {
|
||||||
|
// Thread exists - load messages
|
||||||
|
setThreadId(urlChatId);
|
||||||
|
const response = await getThreadMessages(urlChatId);
|
||||||
|
if (response.messages && response.messages.length > 0) {
|
||||||
|
const loadedMessages = response.messages.map(convertToThreadMessage);
|
||||||
|
setMessages(loadedMessages);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Create new thread
|
||||||
|
const newThread = await createThread(searchSpaceId, "New Chat");
|
||||||
|
setThreadId(newThread.id);
|
||||||
|
router.replace(`/dashboard/${searchSpaceId}/new-chat/${newThread.id}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[NewChatPage] Failed to initialize thread:", error);
|
||||||
|
setThreadId(Date.now());
|
||||||
|
} finally {
|
||||||
|
setIsInitializing(false);
|
||||||
|
}
|
||||||
|
}, [urlChatId, searchSpaceId, router]);
|
||||||
|
|
||||||
|
// Initialize on mount
|
||||||
|
useEffect(() => {
|
||||||
|
initializeThread();
|
||||||
|
}, [initializeThread]);
|
||||||
|
|
||||||
|
// Cancel ongoing request
|
||||||
|
const cancelRun = useCallback(async () => {
|
||||||
|
if (abortControllerRef.current) {
|
||||||
|
abortControllerRef.current.abort();
|
||||||
|
abortControllerRef.current = null;
|
||||||
|
}
|
||||||
|
setIsRunning(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Handle new message from user
|
||||||
|
const onNew = useCallback(
|
||||||
|
async (message: ThreadMessageLike) => {
|
||||||
|
if (!threadId) return;
|
||||||
|
|
||||||
|
// Extract user query text
|
||||||
|
let userQuery = "";
|
||||||
|
for (const part of message.content) {
|
||||||
|
if (typeof part === "object" && part.type === "text" && "text" in part) {
|
||||||
|
userQuery += part.text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!userQuery.trim()) return;
|
||||||
|
|
||||||
|
// Check if podcast is already generating
|
||||||
|
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
|
||||||
|
toast.warning("A podcast is already being generated.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const token = getBearerToken();
|
||||||
|
if (!token) {
|
||||||
|
toast.error("Not authenticated. Please log in again.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add user message to state
|
||||||
|
const userMsgId = `msg-user-${Date.now()}`;
|
||||||
|
const userMessage: ThreadMessageLike = {
|
||||||
|
id: userMsgId,
|
||||||
|
role: "user",
|
||||||
|
content: message.content,
|
||||||
|
createdAt: new Date(),
|
||||||
|
};
|
||||||
|
setMessages((prev) => [...prev, userMessage]);
|
||||||
|
|
||||||
|
// Persist user message (don't await, fire and forget)
|
||||||
|
appendMessage(threadId, {
|
||||||
|
role: "user",
|
||||||
|
content: message.content,
|
||||||
|
}).catch((err) => console.error("Failed to persist user message:", err));
|
||||||
|
|
||||||
|
// Start streaming response
|
||||||
|
setIsRunning(true);
|
||||||
|
const controller = new AbortController();
|
||||||
|
abortControllerRef.current = controller;
|
||||||
|
|
||||||
|
// Prepare assistant message
|
||||||
|
const assistantMsgId = `msg-assistant-${Date.now()}`;
|
||||||
|
let accumulatedText = "";
|
||||||
|
const toolCalls = new Map<
|
||||||
|
string,
|
||||||
|
{
|
||||||
|
toolCallId: string;
|
||||||
|
toolName: string;
|
||||||
|
args: Record<string, unknown>;
|
||||||
|
result?: unknown;
|
||||||
|
}
|
||||||
|
>();
|
||||||
|
|
||||||
|
// Helper to build content
|
||||||
|
const buildContent = (): ThreadMessageLike["content"] => {
|
||||||
|
const parts: Array<
|
||||||
|
| { type: "text"; text: string }
|
||||||
|
| {
|
||||||
|
type: "tool-call";
|
||||||
|
toolCallId: string;
|
||||||
|
toolName: string;
|
||||||
|
args: Record<string, unknown>;
|
||||||
|
result?: unknown;
|
||||||
|
}
|
||||||
|
> = [];
|
||||||
|
if (accumulatedText) {
|
||||||
|
parts.push({ type: "text", text: accumulatedText });
|
||||||
|
}
|
||||||
|
for (const toolCall of toolCalls.values()) {
|
||||||
|
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
|
||||||
|
parts.push({
|
||||||
|
type: "tool-call",
|
||||||
|
toolCallId: toolCall.toolCallId,
|
||||||
|
toolName: toolCall.toolName,
|
||||||
|
args: toolCall.args,
|
||||||
|
result: toolCall.result,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parts.length > 0
|
||||||
|
? (parts as ThreadMessageLike["content"])
|
||||||
|
: [{ type: "text", text: "" }];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add placeholder assistant message
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: assistantMsgId,
|
||||||
|
role: "assistant",
|
||||||
|
content: [{ type: "text", text: "" }],
|
||||||
|
createdAt: new Date(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const backendUrl =
|
||||||
|
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
|
||||||
|
// Build message history for context
|
||||||
|
const messageHistory = messages
|
||||||
|
.filter((m) => m.role === "user" || m.role === "assistant")
|
||||||
|
.map((m) => {
|
||||||
|
let text = "";
|
||||||
|
for (const part of m.content) {
|
||||||
|
if (
|
||||||
|
typeof part === "object" &&
|
||||||
|
part.type === "text" &&
|
||||||
|
"text" in part
|
||||||
|
) {
|
||||||
|
text += part.text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { role: m.role, content: text };
|
||||||
|
})
|
||||||
|
.filter((m) => m.content.length > 0);
|
||||||
|
|
||||||
|
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: `Bearer ${token}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
chat_id: threadId,
|
||||||
|
user_query: userQuery.trim(),
|
||||||
|
search_space_id: searchSpaceId,
|
||||||
|
messages: messageHistory,
|
||||||
|
}),
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Backend error: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!response.body) {
|
||||||
|
throw new Error("No response body");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE stream
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = "";
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
const events = buffer.split(/\r?\n\r?\n/);
|
||||||
|
buffer = events.pop() || "";
|
||||||
|
|
||||||
|
for (const event of events) {
|
||||||
|
const lines = event.split(/\r?\n/);
|
||||||
|
for (const line of lines) {
|
||||||
|
if (!line.startsWith("data: ")) continue;
|
||||||
|
const data = line.slice(6).trim();
|
||||||
|
if (!data || data === "[DONE]") continue;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(data);
|
||||||
|
|
||||||
|
switch (parsed.type) {
|
||||||
|
case "text-delta":
|
||||||
|
accumulatedText += parsed.delta;
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? { ...m, content: buildContent() }
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tool-input-start":
|
||||||
|
toolCalls.set(parsed.toolCallId, {
|
||||||
|
toolCallId: parsed.toolCallId,
|
||||||
|
toolName: parsed.toolName,
|
||||||
|
args: {},
|
||||||
|
});
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? { ...m, content: buildContent() }
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tool-input-available": {
|
||||||
|
const tc = toolCalls.get(parsed.toolCallId);
|
||||||
|
if (tc) tc.args = parsed.input || {};
|
||||||
|
else
|
||||||
|
toolCalls.set(parsed.toolCallId, {
|
||||||
|
toolCallId: parsed.toolCallId,
|
||||||
|
toolName: parsed.toolName,
|
||||||
|
args: parsed.input || {},
|
||||||
|
});
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? { ...m, content: buildContent() }
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tool-output-available": {
|
||||||
|
const tc = toolCalls.get(parsed.toolCallId);
|
||||||
|
if (tc) {
|
||||||
|
tc.result = parsed.output;
|
||||||
|
if (
|
||||||
|
tc.toolName === "generate_podcast" &&
|
||||||
|
parsed.output?.status === "processing" &&
|
||||||
|
parsed.output?.task_id
|
||||||
|
) {
|
||||||
|
setActivePodcastTaskId(parsed.output.task_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? { ...m, content: buildContent() }
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
throw new Error(parsed.errorText || "Server error");
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
if (e instanceof SyntaxError) continue;
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
reader.releaseLock();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist assistant message
|
||||||
|
const finalContent = buildContent();
|
||||||
|
if (accumulatedText || toolCalls.size > 0) {
|
||||||
|
appendMessage(threadId, {
|
||||||
|
role: "assistant",
|
||||||
|
content: finalContent,
|
||||||
|
}).catch((err) =>
|
||||||
|
console.error("Failed to persist assistant message:", err)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error && error.name === "AbortError") {
|
||||||
|
// Request was cancelled
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
console.error("[NewChatPage] Chat error:", error);
|
||||||
|
toast.error("Failed to get response. Please try again.");
|
||||||
|
// Update assistant message with error
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === assistantMsgId
|
||||||
|
? {
|
||||||
|
...m,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text",
|
||||||
|
text: "Sorry, there was an error. Please try again.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
setIsRunning(false);
|
||||||
|
abortControllerRef.current = null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[threadId, searchSpaceId, messages]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use LocalRuntime with our custom adapter
|
// Convert message (pass through since already in correct format)
|
||||||
const runtime = useLocalRuntime(adapter);
|
const convertMessage = useCallback(
|
||||||
|
(message: ThreadMessageLike): ThreadMessageLike => message,
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create external store runtime
|
||||||
|
const runtime = useExternalStoreRuntime({
|
||||||
|
messages,
|
||||||
|
isRunning,
|
||||||
|
onNew,
|
||||||
|
convertMessage,
|
||||||
|
onCancel: cancelRun,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Show loading state
|
||||||
|
if (isInitializing) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-[calc(100vh-64px)] items-center justify-center">
|
||||||
|
<div className="text-muted-foreground">Loading chat...</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AssistantRuntimeProvider runtime={runtime}>
|
<AssistantRuntimeProvider runtime={runtime}>
|
||||||
{/* Register tool UI components */}
|
|
||||||
<GeneratePodcastToolUI />
|
<GeneratePodcastToolUI />
|
||||||
<div className="h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden">
|
<div className="h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden">
|
||||||
<Thread />
|
<Thread />
|
||||||
|
|
|
||||||
286
surfsense_web/components/assistant-ui/thread-list.tsx
Normal file
286
surfsense_web/components/assistant-ui/thread-list.tsx
Normal file
|
|
@ -0,0 +1,286 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useCallback, useEffect, useState } from "react";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { ArchiveIcon, MessageSquareIcon, PlusIcon, TrashIcon, MoreVerticalIcon, RotateCcwIcon } from "lucide-react";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuSeparator,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
import {
|
||||||
|
type ThreadListItem,
|
||||||
|
createThreadListManager,
|
||||||
|
type ThreadListState,
|
||||||
|
} from "@/lib/chat/thread-persistence";
|
||||||
|
|
||||||
|
interface ThreadListProps {
|
||||||
|
searchSpaceId: number;
|
||||||
|
currentThreadId?: number;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ThreadList({ searchSpaceId, currentThreadId, className }: ThreadListProps) {
|
||||||
|
const router = useRouter();
|
||||||
|
const [state, setState] = useState<ThreadListState>({
|
||||||
|
threads: [],
|
||||||
|
archivedThreads: [],
|
||||||
|
isLoading: true,
|
||||||
|
error: null,
|
||||||
|
});
|
||||||
|
const [showArchived, setShowArchived] = useState(false);
|
||||||
|
|
||||||
|
// Create the thread list manager
|
||||||
|
const manager = useCallback(
|
||||||
|
() =>
|
||||||
|
createThreadListManager({
|
||||||
|
searchSpaceId,
|
||||||
|
currentThreadId: currentThreadId ?? null,
|
||||||
|
onThreadSwitch: (threadId) => {
|
||||||
|
router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`);
|
||||||
|
},
|
||||||
|
onNewThread: (threadId) => {
|
||||||
|
router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
[searchSpaceId, currentThreadId, router]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Load threads on mount and when searchSpaceId changes
|
||||||
|
const loadThreads = useCallback(async () => {
|
||||||
|
setState((prev) => ({ ...prev, isLoading: true }));
|
||||||
|
const newState = await manager().loadThreads();
|
||||||
|
setState(newState);
|
||||||
|
}, [manager]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadThreads();
|
||||||
|
}, [loadThreads]);
|
||||||
|
|
||||||
|
// Handle new thread creation
|
||||||
|
const handleNewThread = async () => {
|
||||||
|
await manager().createNewThread();
|
||||||
|
await loadThreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle thread actions
|
||||||
|
const handleArchive = async (threadId: number) => {
|
||||||
|
const success = await manager().archiveThread(threadId);
|
||||||
|
if (success) await loadThreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleUnarchive = async (threadId: number) => {
|
||||||
|
const success = await manager().unarchiveThread(threadId);
|
||||||
|
if (success) await loadThreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDelete = async (threadId: number) => {
|
||||||
|
const success = await manager().deleteThread(threadId);
|
||||||
|
if (success) {
|
||||||
|
await loadThreads();
|
||||||
|
// If we deleted the current thread, redirect to new chat
|
||||||
|
if (threadId === currentThreadId) {
|
||||||
|
router.push(`/dashboard/${searchSpaceId}/new-chat`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSwitchToThread = (threadId: number) => {
|
||||||
|
manager().switchToThread(threadId);
|
||||||
|
};
|
||||||
|
|
||||||
|
const displayedThreads = showArchived ? state.archivedThreads : state.threads;
|
||||||
|
|
||||||
|
if (state.isLoading) {
|
||||||
|
return (
|
||||||
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
|
<div className="flex items-center justify-center p-4">
|
||||||
|
<span className="text-muted-foreground text-sm">Loading threads...</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.error) {
|
||||||
|
return (
|
||||||
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
|
<div className="p-4 text-center">
|
||||||
|
<span className="text-destructive text-sm">{state.error}</span>
|
||||||
|
<Button variant="ghost" size="sm" className="mt-2" onClick={loadThreads}>
|
||||||
|
Retry
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
|
{/* Header with New Chat button */}
|
||||||
|
<div className="flex items-center justify-between border-b p-3">
|
||||||
|
<h2 className="font-semibold text-sm">Conversations</h2>
|
||||||
|
<Button variant="ghost" size="icon" className="size-8" onClick={handleNewThread} title="New Chat">
|
||||||
|
<PlusIcon className="size-4" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Tab toggle for active/archived */}
|
||||||
|
<div className="flex border-b">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowArchived(false)}
|
||||||
|
className={cn(
|
||||||
|
"flex-1 px-3 py-2 text-center text-xs font-medium transition-colors",
|
||||||
|
!showArchived
|
||||||
|
? "border-b-2 border-primary text-primary"
|
||||||
|
: "text-muted-foreground hover:text-foreground"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
Active ({state.threads.length})
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowArchived(true)}
|
||||||
|
className={cn(
|
||||||
|
"flex-1 px-3 py-2 text-center text-xs font-medium transition-colors",
|
||||||
|
showArchived
|
||||||
|
? "border-b-2 border-primary text-primary"
|
||||||
|
: "text-muted-foreground hover:text-foreground"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
Archived ({state.archivedThreads.length})
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Thread list */}
|
||||||
|
<div className="flex-1 overflow-y-auto">
|
||||||
|
{displayedThreads.length === 0 ? (
|
||||||
|
<div className="flex flex-col items-center justify-center p-6 text-center">
|
||||||
|
<MessageSquareIcon className="mb-2 size-8 text-muted-foreground/50" />
|
||||||
|
<p className="text-muted-foreground text-sm">
|
||||||
|
{showArchived ? "No archived conversations" : "No conversations yet"}
|
||||||
|
</p>
|
||||||
|
{!showArchived && (
|
||||||
|
<Button variant="outline" size="sm" className="mt-3" onClick={handleNewThread}>
|
||||||
|
<PlusIcon className="mr-1 size-3" />
|
||||||
|
Start a conversation
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="space-y-1 p-2">
|
||||||
|
{displayedThreads.map((thread) => (
|
||||||
|
<ThreadListItemComponent
|
||||||
|
key={thread.id}
|
||||||
|
thread={thread}
|
||||||
|
isActive={thread.id === currentThreadId}
|
||||||
|
isArchived={showArchived}
|
||||||
|
onClick={() => handleSwitchToThread(thread.id)}
|
||||||
|
onArchive={() => handleArchive(thread.id)}
|
||||||
|
onUnarchive={() => handleUnarchive(thread.id)}
|
||||||
|
onDelete={() => handleDelete(thread.id)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ThreadListItemComponentProps {
|
||||||
|
thread: ThreadListItem;
|
||||||
|
isActive: boolean;
|
||||||
|
isArchived: boolean;
|
||||||
|
onClick: () => void;
|
||||||
|
onArchive: () => void;
|
||||||
|
onUnarchive: () => void;
|
||||||
|
onDelete: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
function ThreadListItemComponent({
|
||||||
|
thread,
|
||||||
|
isActive,
|
||||||
|
isArchived,
|
||||||
|
onClick,
|
||||||
|
onArchive,
|
||||||
|
onUnarchive,
|
||||||
|
onDelete,
|
||||||
|
}: ThreadListItemComponentProps) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"group flex items-center gap-2 rounded-lg px-3 py-2 transition-colors cursor-pointer",
|
||||||
|
isActive ? "bg-accent text-accent-foreground" : "hover:bg-muted/50"
|
||||||
|
)}
|
||||||
|
onClick={onClick}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === "Enter" || e.key === " ") onClick();
|
||||||
|
}}
|
||||||
|
role="button"
|
||||||
|
tabIndex={0}
|
||||||
|
>
|
||||||
|
<MessageSquareIcon className="size-4 shrink-0 text-muted-foreground" />
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<p className="truncate text-sm font-medium">{thread.title || "New Chat"}</p>
|
||||||
|
<p className="truncate text-xs text-muted-foreground">
|
||||||
|
{formatRelativeTime(new Date(thread.updatedAt))}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<DropdownMenu>
|
||||||
|
<DropdownMenuTrigger asChild>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="size-7 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
>
|
||||||
|
<MoreVerticalIcon className="size-4" />
|
||||||
|
</Button>
|
||||||
|
</DropdownMenuTrigger>
|
||||||
|
<DropdownMenuContent align="end">
|
||||||
|
{isArchived ? (
|
||||||
|
<DropdownMenuItem onClick={onUnarchive}>
|
||||||
|
<RotateCcwIcon className="mr-2 size-4" />
|
||||||
|
Unarchive
|
||||||
|
</DropdownMenuItem>
|
||||||
|
) : (
|
||||||
|
<DropdownMenuItem onClick={onArchive}>
|
||||||
|
<ArchiveIcon className="mr-2 size-4" />
|
||||||
|
Archive
|
||||||
|
</DropdownMenuItem>
|
||||||
|
)}
|
||||||
|
<DropdownMenuSeparator />
|
||||||
|
<DropdownMenuItem onClick={onDelete} className="text-destructive focus:text-destructive">
|
||||||
|
<TrashIcon className="mr-2 size-4" />
|
||||||
|
Delete
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format a date as relative time (e.g., "2 hours ago", "Yesterday")
|
||||||
|
*/
|
||||||
|
function formatRelativeTime(date: Date): string {
|
||||||
|
const now = new Date();
|
||||||
|
const diffMs = now.getTime() - date.getTime();
|
||||||
|
const diffSecs = Math.floor(diffMs / 1000);
|
||||||
|
const diffMins = Math.floor(diffSecs / 60);
|
||||||
|
const diffHours = Math.floor(diffMins / 60);
|
||||||
|
const diffDays = Math.floor(diffHours / 24);
|
||||||
|
|
||||||
|
if (diffSecs < 60) return "Just now";
|
||||||
|
if (diffMins < 60) return `${diffMins} min${diffMins === 1 ? "" : "s"} ago`;
|
||||||
|
if (diffHours < 24) return `${diffHours} hour${diffHours === 1 ? "" : "s"} ago`;
|
||||||
|
if (diffDays === 1) return "Yesterday";
|
||||||
|
if (diffDays < 7) return `${diffDays} days ago`;
|
||||||
|
|
||||||
|
return date.toLocaleDateString();
|
||||||
|
}
|
||||||
|
|
@ -1,345 +0,0 @@
|
||||||
/**
|
|
||||||
* Custom ChatModelAdapter for the new-chat feature using LocalRuntime.
|
|
||||||
* Connects directly to the FastAPI backend using the Vercel AI SDK Data Stream Protocol.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react";
|
|
||||||
import { toast } from "sonner";
|
|
||||||
import { getBearerToken } from "@/lib/auth-utils";
|
|
||||||
import {
|
|
||||||
isPodcastGenerating,
|
|
||||||
looksLikePodcastRequest,
|
|
||||||
setActivePodcastTaskId,
|
|
||||||
} from "@/lib/chat/podcast-state";
|
|
||||||
|
|
||||||
interface NewChatAdapterConfig {
|
|
||||||
searchSpaceId: number;
|
|
||||||
chatId: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ChatMessageForBackend {
|
|
||||||
role: "user" | "assistant";
|
|
||||||
content: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts assistant-ui messages to a simple format for the backend
|
|
||||||
*/
|
|
||||||
function convertMessagesToBackendFormat(
|
|
||||||
messages: ChatModelRunOptions["messages"]
|
|
||||||
): ChatMessageForBackend[] {
|
|
||||||
return messages
|
|
||||||
.filter((m) => m.role === "user" || m.role === "assistant")
|
|
||||||
.map((m) => {
|
|
||||||
// Extract text content from the message parts
|
|
||||||
let content = "";
|
|
||||||
for (const part of m.content) {
|
|
||||||
if (part.type === "text") {
|
|
||||||
content += part.text;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
role: m.role as "user" | "assistant",
|
|
||||||
content: content.trim(),
|
|
||||||
};
|
|
||||||
})
|
|
||||||
.filter((m) => m.content.length > 0); // Filter out empty messages
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents an in-progress or completed tool call
|
|
||||||
*/
|
|
||||||
interface ToolCallState {
|
|
||||||
toolCallId: string;
|
|
||||||
toolName: string;
|
|
||||||
args: Record<string, unknown>;
|
|
||||||
result?: unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tools that should render custom UI in the chat.
|
|
||||||
* Other tools (like search_knowledge_base) will be hidden from the UI.
|
|
||||||
*/
|
|
||||||
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint.
|
|
||||||
*
|
|
||||||
* The backend expects:
|
|
||||||
* - POST /api/v1/new_chat
|
|
||||||
* - Body: { chat_id: number, user_query: string, search_space_id: number, messages: [...] }
|
|
||||||
* - Returns: SSE stream with Vercel AI SDK Data Stream Protocol
|
|
||||||
*/
|
|
||||||
export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAdapter {
|
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
|
||||||
|
|
||||||
return {
|
|
||||||
async *run({ messages, abortSignal }: ChatModelRunOptions) {
|
|
||||||
// Get the last user message
|
|
||||||
const lastUserMessage = messages.filter((m) => m.role === "user").pop();
|
|
||||||
|
|
||||||
if (!lastUserMessage) {
|
|
||||||
throw new Error("No user message found");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract text content from the last user message
|
|
||||||
let userQuery = "";
|
|
||||||
for (const part of lastUserMessage.content) {
|
|
||||||
if (part.type === "text") {
|
|
||||||
userQuery += part.text;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!userQuery.trim()) {
|
|
||||||
throw new Error("User query cannot be empty");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if user is requesting a podcast while one is already generating
|
|
||||||
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
|
|
||||||
toast.warning("A podcast is already being generated. Please wait for it to complete.");
|
|
||||||
// Return a message telling the user to wait
|
|
||||||
yield {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: "text",
|
|
||||||
text: "A podcast is already being generated. Please wait for it to complete before requesting another one.",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const token = getBearerToken();
|
|
||||||
if (!token) {
|
|
||||||
throw new Error("Not authenticated. Please log in again.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert all messages to backend format for chat history
|
|
||||||
const messageHistory = convertMessagesToBackendFormat(messages);
|
|
||||||
|
|
||||||
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
Authorization: `Bearer ${token}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
chat_id: config.chatId,
|
|
||||||
user_query: userQuery.trim(),
|
|
||||||
search_space_id: config.searchSpaceId,
|
|
||||||
messages: messageHistory,
|
|
||||||
}),
|
|
||||||
signal: abortSignal,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const errorText = await response.text().catch(() => "Unknown error");
|
|
||||||
throw new Error(`Backend error (${response.status}): ${errorText}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!response.body) {
|
|
||||||
throw new Error("No response body");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the SSE stream (Vercel AI SDK Data Stream Protocol)
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let buffer = "";
|
|
||||||
let accumulatedText = "";
|
|
||||||
|
|
||||||
// Track tool calls by their ID
|
|
||||||
const toolCalls = new Map<string, ToolCallState>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the content array with text and tool calls.
|
|
||||||
* Only includes tools that have custom UI (defined in TOOLS_WITH_UI).
|
|
||||||
*/
|
|
||||||
function buildContent() {
|
|
||||||
const content: Array<
|
|
||||||
| { type: "text"; text: string }
|
|
||||||
| { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown }
|
|
||||||
> = [];
|
|
||||||
|
|
||||||
// Add text content if any
|
|
||||||
if (accumulatedText) {
|
|
||||||
content.push({ type: "text" as const, text: accumulatedText });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only add tool calls that have custom UI registered
|
|
||||||
// Other tools (like search_knowledge_base) are hidden from the UI
|
|
||||||
for (const toolCall of toolCalls.values()) {
|
|
||||||
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
|
|
||||||
content.push({
|
|
||||||
type: "tool-call" as const,
|
|
||||||
toolCallId: toolCall.toolCallId,
|
|
||||||
toolName: toolCall.toolName,
|
|
||||||
args: toolCall.args,
|
|
||||||
result: toolCall.result,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return content;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
while (true) {
|
|
||||||
const { done, value } = await reader.read();
|
|
||||||
if (done) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
buffer += chunk;
|
|
||||||
|
|
||||||
// Split on double newlines (handle both \n\n and \r\n\r\n)
|
|
||||||
const events = buffer.split(/\r?\n\r?\n/);
|
|
||||||
buffer = events.pop() || "";
|
|
||||||
|
|
||||||
for (const event of events) {
|
|
||||||
// Each event can have multiple lines, find the data line
|
|
||||||
const lines = event.split(/\r?\n/);
|
|
||||||
for (const line of lines) {
|
|
||||||
if (!line.startsWith("data: ")) continue;
|
|
||||||
|
|
||||||
const data = line.slice(6).trim(); // Remove "data: " prefix
|
|
||||||
|
|
||||||
// Handle [DONE] marker
|
|
||||||
if (data === "[DONE]") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!data) continue;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(data);
|
|
||||||
|
|
||||||
// Handle different message types from the Data Stream Protocol
|
|
||||||
switch (parsed.type) {
|
|
||||||
case "text-delta":
|
|
||||||
accumulatedText += parsed.delta;
|
|
||||||
yield { content: buildContent() };
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "tool-input-start": {
|
|
||||||
// Tool call is starting - create a new tool call entry
|
|
||||||
const { toolCallId, toolName } = parsed;
|
|
||||||
toolCalls.set(toolCallId, {
|
|
||||||
toolCallId,
|
|
||||||
toolName,
|
|
||||||
args: {},
|
|
||||||
});
|
|
||||||
// Yield to show tool is starting (running state)
|
|
||||||
yield { content: buildContent() };
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
case "tool-input-available": {
|
|
||||||
// Tool input is complete - update the args
|
|
||||||
const { toolCallId, toolName, input } = parsed;
|
|
||||||
const existing = toolCalls.get(toolCallId);
|
|
||||||
if (existing) {
|
|
||||||
existing.args = input || {};
|
|
||||||
} else {
|
|
||||||
// Create new entry if we missed tool-input-start
|
|
||||||
toolCalls.set(toolCallId, {
|
|
||||||
toolCallId,
|
|
||||||
toolName,
|
|
||||||
args: input || {},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
yield { content: buildContent() };
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
case "tool-output-available": {
|
|
||||||
// Tool execution is complete - add the result
|
|
||||||
const { toolCallId, output } = parsed;
|
|
||||||
const existing = toolCalls.get(toolCallId);
|
|
||||||
if (existing) {
|
|
||||||
existing.result = output;
|
|
||||||
|
|
||||||
// If this is a podcast tool with status="processing", set the state immediately
|
|
||||||
// This ensures subsequent podcast requests are intercepted
|
|
||||||
if (
|
|
||||||
existing.toolName === "generate_podcast" &&
|
|
||||||
output &&
|
|
||||||
typeof output === "object" &&
|
|
||||||
"status" in output &&
|
|
||||||
output.status === "processing" &&
|
|
||||||
"task_id" in output &&
|
|
||||||
typeof output.task_id === "string"
|
|
||||||
) {
|
|
||||||
setActivePodcastTaskId(output.task_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield { content: buildContent() };
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
case "error":
|
|
||||||
throw new Error(parsed.errorText || "Unknown error from server");
|
|
||||||
|
|
||||||
// Other types like text-start, text-end, start-step, finish-step, etc.
|
|
||||||
// are handled implicitly
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
// Skip non-JSON lines
|
|
||||||
if (e instanceof SyntaxError) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle any remaining buffer
|
|
||||||
if (buffer.trim()) {
|
|
||||||
const lines = buffer.split(/\r?\n/);
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith("data: ")) {
|
|
||||||
const data = line.slice(6).trim();
|
|
||||||
if (data && data !== "[DONE]") {
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(data);
|
|
||||||
if (parsed.type === "text-delta") {
|
|
||||||
accumulatedText += parsed.delta;
|
|
||||||
yield { content: buildContent() };
|
|
||||||
} else if (parsed.type === "tool-output-available") {
|
|
||||||
const { toolCallId, output } = parsed;
|
|
||||||
const existing = toolCalls.get(toolCallId);
|
|
||||||
if (existing) {
|
|
||||||
existing.result = output;
|
|
||||||
|
|
||||||
// Set podcast state if processing
|
|
||||||
if (
|
|
||||||
existing.toolName === "generate_podcast" &&
|
|
||||||
output &&
|
|
||||||
typeof output === "object" &&
|
|
||||||
"status" in output &&
|
|
||||||
output.status === "processing" &&
|
|
||||||
"task_id" in output &&
|
|
||||||
typeof output.task_id === "string"
|
|
||||||
) {
|
|
||||||
setActivePodcastTaskId(output.task_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield { content: buildContent() };
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
// Ignore parse errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
reader.releaseLock();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
223
surfsense_web/lib/chat/thread-persistence.ts
Normal file
223
surfsense_web/lib/chat/thread-persistence.ts
Normal file
|
|
@ -0,0 +1,223 @@
|
||||||
|
/**
|
||||||
|
* Thread persistence utilities for the new chat feature.
|
||||||
|
* Provides API functions and thread list management.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Types matching backend schemas
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
export interface ThreadRecord {
|
||||||
|
id: number;
|
||||||
|
title: string;
|
||||||
|
archived: boolean;
|
||||||
|
search_space_id: number;
|
||||||
|
created_at: string;
|
||||||
|
updated_at: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MessageRecord {
|
||||||
|
id: number;
|
||||||
|
thread_id: number;
|
||||||
|
role: "user" | "assistant" | "system";
|
||||||
|
content: unknown;
|
||||||
|
created_at: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ThreadListResponse {
|
||||||
|
threads: ThreadListItem[];
|
||||||
|
archived_threads: ThreadListItem[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ThreadListItem {
|
||||||
|
id: number;
|
||||||
|
title: string;
|
||||||
|
archived: boolean;
|
||||||
|
createdAt: string;
|
||||||
|
updatedAt: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ThreadHistoryLoadResponse {
|
||||||
|
messages: MessageRecord[];
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// API Service Functions
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch list of threads for a search space
|
||||||
|
*/
|
||||||
|
export async function fetchThreads(
|
||||||
|
searchSpaceId: number
|
||||||
|
): Promise<ThreadListResponse> {
|
||||||
|
return baseApiService.get<ThreadListResponse>(
|
||||||
|
`/api/v1/threads?search_space_id=${searchSpaceId}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new thread
|
||||||
|
*/
|
||||||
|
export async function createThread(
|
||||||
|
searchSpaceId: number,
|
||||||
|
title = "New Chat"
|
||||||
|
): Promise<ThreadRecord> {
|
||||||
|
return baseApiService.post<ThreadRecord>("/api/v1/threads", undefined, {
|
||||||
|
body: {
|
||||||
|
title,
|
||||||
|
archived: false,
|
||||||
|
search_space_id: searchSpaceId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get thread messages
|
||||||
|
*/
|
||||||
|
export async function getThreadMessages(
|
||||||
|
threadId: number
|
||||||
|
): Promise<ThreadHistoryLoadResponse> {
|
||||||
|
return baseApiService.get<ThreadHistoryLoadResponse>(
|
||||||
|
`/api/v1/threads/${threadId}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Append a message to a thread
|
||||||
|
*/
|
||||||
|
export async function appendMessage(
|
||||||
|
threadId: number,
|
||||||
|
message: { role: "user" | "assistant" | "system"; content: unknown }
|
||||||
|
): Promise<MessageRecord> {
|
||||||
|
return baseApiService.post<MessageRecord>(
|
||||||
|
`/api/v1/threads/${threadId}/messages`,
|
||||||
|
undefined,
|
||||||
|
{ body: message }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update thread (rename, archive)
|
||||||
|
*/
|
||||||
|
export async function updateThread(
|
||||||
|
threadId: number,
|
||||||
|
updates: { title?: string; archived?: boolean }
|
||||||
|
): Promise<ThreadRecord> {
|
||||||
|
return baseApiService.put<ThreadRecord>(
|
||||||
|
`/api/v1/threads/${threadId}`,
|
||||||
|
undefined,
|
||||||
|
{ body: updates }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete a thread
|
||||||
|
*/
|
||||||
|
export async function deleteThread(threadId: number): Promise<void> {
|
||||||
|
await baseApiService.delete(`/api/v1/threads/${threadId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Thread List Manager (for thread list sidebar)
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
export interface ThreadListAdapterConfig {
|
||||||
|
searchSpaceId: number;
|
||||||
|
currentThreadId: number | null;
|
||||||
|
onThreadSwitch: (threadId: number) => void;
|
||||||
|
onNewThread: (threadId: number) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ThreadListState {
|
||||||
|
threads: ThreadListItem[];
|
||||||
|
archivedThreads: ThreadListItem[];
|
||||||
|
isLoading: boolean;
|
||||||
|
error: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a thread list management object.
|
||||||
|
* This provides methods to manage the thread list for the sidebar.
|
||||||
|
*/
|
||||||
|
export function createThreadListManager(config: ThreadListAdapterConfig) {
|
||||||
|
return {
|
||||||
|
async loadThreads(): Promise<ThreadListState> {
|
||||||
|
try {
|
||||||
|
const response = await fetchThreads(config.searchSpaceId);
|
||||||
|
return {
|
||||||
|
threads: response.threads,
|
||||||
|
archivedThreads: response.archived_threads,
|
||||||
|
isLoading: false,
|
||||||
|
error: null,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[ThreadListManager] Failed to load threads:", error);
|
||||||
|
return {
|
||||||
|
threads: [],
|
||||||
|
archivedThreads: [],
|
||||||
|
isLoading: false,
|
||||||
|
error:
|
||||||
|
error instanceof Error ? error.message : "Failed to load threads",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async createNewThread(title = "New Chat"): Promise<number | null> {
|
||||||
|
try {
|
||||||
|
const thread = await createThread(config.searchSpaceId, title);
|
||||||
|
config.onNewThread(thread.id);
|
||||||
|
return thread.id;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[ThreadListManager] Failed to create thread:", error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
switchToThread(threadId: number) {
|
||||||
|
config.onThreadSwitch(threadId);
|
||||||
|
},
|
||||||
|
|
||||||
|
async renameThread(threadId: number, newTitle: string): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
await updateThread(threadId, { title: newTitle });
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[ThreadListManager] Failed to rename thread:", error);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async archiveThread(threadId: number): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
await updateThread(threadId, { archived: true });
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[ThreadListManager] Failed to archive thread:", error);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async unarchiveThread(threadId: number): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
await updateThread(threadId, { archived: false });
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[ThreadListManager] Failed to unarchive thread:", error);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async deleteThread(threadId: number): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
await deleteThread(threadId);
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[ThreadListManager] Failed to delete thread:", error);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue