mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Merge pull request #608 from AnishSarkar22/feature/podcast-agent
Added Podcast agent within chat
This commit is contained in:
commit
f115980d2b
14 changed files with 1488 additions and 23 deletions
|
|
@ -2,7 +2,7 @@
|
|||
SurfSense deep agent implementation.
|
||||
|
||||
This module provides the factory function for creating SurfSense deep agents
|
||||
with knowledge base search capability.
|
||||
with knowledge base search and podcast generation capabilities.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.knowledge_base import create_search_knowledge_base_tool
|
||||
from app.agents.new_chat.podcast import create_generate_podcast_tool
|
||||
from app.agents.new_chat.system_prompt import build_surfsense_system_prompt
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
|
|
@ -29,12 +30,14 @@ def create_surfsense_deep_agent(
|
|||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
checkpointer: Checkpointer,
|
||||
user_id: str | None = None,
|
||||
user_instructions: str | None = None,
|
||||
enable_citations: bool = True,
|
||||
enable_podcast: bool = True,
|
||||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with knowledge base search capability.
|
||||
Create a SurfSense deep agent with knowledge base search and podcast generation capabilities.
|
||||
|
||||
Args:
|
||||
llm: ChatLiteLLM instance
|
||||
|
|
@ -43,10 +46,13 @@ def create_surfsense_deep_agent(
|
|||
connector_service: Initialized connector service
|
||||
checkpointer: LangGraph checkpointer for conversation state persistence.
|
||||
Use AsyncPostgresSaver for production or MemorySaver for testing.
|
||||
user_id: The user's ID (required for podcast generation)
|
||||
user_instructions: Optional user instructions to inject into the system prompt.
|
||||
These will be added to the system prompt to customize agent behavior.
|
||||
enable_citations: Whether to include citation instructions in the system prompt (default: True).
|
||||
When False, the agent will not be instructed to add citations to responses.
|
||||
enable_podcast: Whether to include the podcast generation tool (default: True).
|
||||
When True and user_id is provided, the agent can generate podcasts.
|
||||
additional_tools: Optional sequence of additional tools to inject into the agent.
|
||||
The search_knowledge_base tool will always be included.
|
||||
|
||||
|
|
@ -62,6 +68,16 @@ def create_surfsense_deep_agent(
|
|||
|
||||
# Combine search tool with any additional tools
|
||||
tools = [search_tool]
|
||||
|
||||
# Add podcast tool if enabled and user_id is provided
|
||||
if enable_podcast and user_id:
|
||||
podcast_tool = create_generate_podcast_tool(
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
user_id=str(user_id),
|
||||
)
|
||||
tools.append(podcast_tool)
|
||||
|
||||
if additional_tools:
|
||||
tools.extend(additional_tools)
|
||||
|
||||
|
|
|
|||
174
surfsense_backend/app/agents/new_chat/podcast.py
Normal file
174
surfsense_backend/app/agents/new_chat/podcast.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Podcast generation tool for the new chat agent.
|
||||
|
||||
This module provides a factory function for creating the generate_podcast tool
|
||||
that submits a Celery task for background podcast generation. The frontend
|
||||
polls for completion and auto-updates when the podcast is ready.
|
||||
|
||||
Duplicate request prevention:
|
||||
- Only one podcast can be generated at a time per search space
|
||||
- Uses Redis to track active podcast tasks
|
||||
- Returns a friendly message if a podcast is already being generated
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# Redis connection for tracking active podcast tasks
|
||||
# Uses the same Redis instance as Celery
|
||||
REDIS_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for podcast task tracking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def get_active_podcast_key(search_space_id: int) -> str:
|
||||
"""Generate Redis key for tracking active podcast task."""
|
||||
return f"podcast:active:{search_space_id}"
|
||||
|
||||
|
||||
def get_active_podcast_task(search_space_id: int) -> str | None:
|
||||
"""Check if there's an active podcast task for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
return client.get(get_active_podcast_key(search_space_id))
|
||||
except Exception:
|
||||
# If Redis is unavailable, allow the request (fail open)
|
||||
return None
|
||||
|
||||
|
||||
def set_active_podcast_task(search_space_id: int, task_id: str) -> None:
|
||||
"""Mark a podcast task as active for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
# Set with 30-minute expiry as safety net (podcast should complete before this)
|
||||
client.setex(get_active_podcast_key(search_space_id), 1800, task_id)
|
||||
except Exception as e:
|
||||
print(f"[generate_podcast] Warning: Could not set active task in Redis: {e}")
|
||||
|
||||
|
||||
def clear_active_podcast_task(search_space_id: int) -> None:
|
||||
"""Clear the active podcast task for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.delete(get_active_podcast_key(search_space_id))
|
||||
except Exception as e:
|
||||
print(f"[generate_podcast] Warning: Could not clear active task in Redis: {e}")
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_podcast tool with injected dependencies.
|
||||
|
||||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session (not used - Celery creates its own)
|
||||
user_id: The user's ID (as string)
|
||||
|
||||
Returns:
|
||||
A configured tool function for generating podcasts
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_podcast(
|
||||
source_content: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a podcast from the provided content.
|
||||
|
||||
Use this tool when the user asks to create, generate, or make a podcast.
|
||||
Common triggers include phrases like:
|
||||
- "Give me a podcast about this"
|
||||
- "Create a podcast from this conversation"
|
||||
- "Generate a podcast summary"
|
||||
- "Make a podcast about..."
|
||||
- "Turn this into a podcast"
|
||||
|
||||
The tool will start generating a podcast in the background.
|
||||
The podcast will be available once generation completes.
|
||||
|
||||
IMPORTANT: Only one podcast can be generated at a time. If a podcast
|
||||
is already being generated, this tool will return a message asking
|
||||
the user to wait.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast.
|
||||
This can be a summary, research findings, or any text
|
||||
the user wants transformed into an audio podcast.
|
||||
podcast_title: Title for the podcast (default: "SurfSense Podcast")
|
||||
user_prompt: Optional instructions for podcast style, tone, or format.
|
||||
For example: "Make it casual and fun" or "Focus on the key insights"
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- status: "processing" (task submitted), "already_generating", or "error"
|
||||
- task_id: The Celery task ID for polling status (if processing)
|
||||
- title: The podcast title
|
||||
- message: Status message for the user
|
||||
"""
|
||||
try:
|
||||
# Check if a podcast is already being generated for this search space
|
||||
active_task_id = get_active_podcast_task(search_space_id)
|
||||
if active_task_id:
|
||||
print(f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}")
|
||||
return {
|
||||
"status": "already_generating",
|
||||
"task_id": active_task_id,
|
||||
"title": podcast_title,
|
||||
"message": "A podcast is already being generated. Please wait for it to complete before requesting another one.",
|
||||
}
|
||||
|
||||
# Import Celery task here to avoid circular imports
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
# Submit Celery task for background processing
|
||||
task = generate_content_podcast_task.delay(
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_id=str(user_id),
|
||||
podcast_title=podcast_title,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
# Mark this task as active
|
||||
set_active_podcast_task(search_space_id, task.id)
|
||||
|
||||
print(f"[generate_podcast] Submitted Celery task: {task.id}")
|
||||
|
||||
# Return immediately with task_id for polling
|
||||
return {
|
||||
"status": "processing",
|
||||
"task_id": task.id,
|
||||
"title": podcast_title,
|
||||
"message": "Podcast generation started. This may take a few minutes.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[generate_podcast] Error submitting task: {error_message}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": error_message,
|
||||
"title": podcast_title,
|
||||
"task_id": None,
|
||||
}
|
||||
|
||||
return generate_podcast
|
||||
|
|
@ -121,7 +121,8 @@ Today's date (UTC): {resolved_today}
|
|||
</system_instruction>{user_section}
|
||||
<tools>
|
||||
You have access to the following tools:
|
||||
- search_knowledge_base: Search the user's personal knowledge base for relevant information.
|
||||
|
||||
1. search_knowledge_base: Search the user's personal knowledge base for relevant information.
|
||||
- Args:
|
||||
- query: The search query - be specific and include key terms
|
||||
- top_k: Number of results to retrieve (default: 10)
|
||||
|
|
@ -129,6 +130,21 @@ You have access to the following tools:
|
|||
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
|
||||
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
|
||||
- Returns: Formatted string with relevant documents and their content
|
||||
|
||||
2. generate_podcast: Generate an audio podcast from provided content.
|
||||
- Use this when the user asks to create, generate, or make a podcast.
|
||||
- Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast"
|
||||
- Args:
|
||||
- source_content: The text content to convert into a podcast. This MUST be comprehensive and include:
|
||||
* If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses)
|
||||
* If based on knowledge base search: Include the key findings and insights from the search results
|
||||
* You can combine both: conversation context + search results for richer podcasts
|
||||
* The more detailed the source_content, the better the podcast quality
|
||||
- podcast_title: Optional title for the podcast (default: "SurfSense Podcast")
|
||||
- user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun")
|
||||
- Returns: A task_id for tracking. The podcast will be generated in the background.
|
||||
- IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating".
|
||||
- After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes).
|
||||
</tools>
|
||||
<tool_call_examples>
|
||||
- User: "Fetch all my notes and what's in them?"
|
||||
|
|
@ -136,6 +152,16 @@ You have access to the following tools:
|
|||
|
||||
- User: "What did I discuss on Slack last week about the React migration?"
|
||||
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
|
||||
|
||||
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||
|
||||
- User: "Create a podcast summary of this conversation"
|
||||
- Call: `generate_podcast(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
|
||||
|
||||
- User: "Make a podcast about quantum computing"
|
||||
- First search: `search_knowledge_base(query="quantum computing")`
|
||||
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\n\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
|
||||
</tool_call_examples>{citation_section}
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -444,3 +444,66 @@ async def get_podcast_by_chat_id(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Error fetching podcast: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/task/{task_id}/status")
|
||||
async def get_podcast_task_status(
|
||||
task_id: str,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get the status of a podcast generation task.
|
||||
Used by new-chat frontend to poll for completion.
|
||||
|
||||
Returns:
|
||||
- status: "processing" | "success" | "error"
|
||||
- podcast_id: (only if status == "success")
|
||||
- title: (only if status == "success")
|
||||
- error: (only if status == "error")
|
||||
"""
|
||||
try:
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
result = AsyncResult(task_id, app=celery_app)
|
||||
|
||||
if result.ready():
|
||||
# Task completed
|
||||
if result.successful():
|
||||
task_result = result.result
|
||||
if isinstance(task_result, dict):
|
||||
if task_result.get("status") == "success":
|
||||
return {
|
||||
"status": "success",
|
||||
"podcast_id": task_result.get("podcast_id"),
|
||||
"title": task_result.get("title"),
|
||||
"transcript_entries": task_result.get("transcript_entries"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": task_result.get("error", "Unknown error"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Unexpected task result format",
|
||||
}
|
||||
else:
|
||||
# Task failed
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(result.result) if result.result else "Task failed",
|
||||
}
|
||||
else:
|
||||
# Task still processing
|
||||
return {
|
||||
"status": "processing",
|
||||
"state": result.state,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error checking task status: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ from app.celery_app import celery_app
|
|||
from app.config import config
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
|
||||
# Import for content-based podcast (new-chat)
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.db import Podcast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
|
|
@ -86,3 +91,149 @@ async def _generate_chat_podcast(
|
|||
except Exception as e:
|
||||
logger.error(f"Error generating podcast from chat: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content-based podcast generation (for new-chat)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
|
||||
"""Clear the active podcast task key from Redis when task completes."""
|
||||
import os
|
||||
|
||||
import redis
|
||||
|
||||
try:
|
||||
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
client = redis.from_url(redis_url, decode_responses=True)
|
||||
key = f"podcast:active:{search_space_id}"
|
||||
client.delete(key)
|
||||
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear active podcast key: {e}")
|
||||
|
||||
|
||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||
def generate_content_podcast_task(
|
||||
self,
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Celery task to generate podcast from source content (for new-chat).
|
||||
|
||||
Unlike generate_chat_podcast which requires a chat_id, this task
|
||||
generates a podcast directly from provided content.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user (as string)
|
||||
podcast_title: Title for the podcast
|
||||
user_prompt: Optional instructions for podcast style/tone
|
||||
|
||||
Returns:
|
||||
dict with podcast_id on success, or error info on failure
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_content_podcast(
|
||||
source_content,
|
||||
search_space_id,
|
||||
user_id,
|
||||
podcast_title,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating content podcast: {e!s}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
finally:
|
||||
# Always clear the active podcast key when task completes (success or failure)
|
||||
_clear_active_podcast_redis_key(search_space_id)
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _generate_content_podcast(
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict:
|
||||
"""Generate content-based podcast with new session."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
try:
|
||||
# Configure the podcaster graph
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast_title,
|
||||
"user_id": str(user_id),
|
||||
"search_space_id": search_space_id,
|
||||
"user_prompt": user_prompt,
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize the podcaster state with the source content
|
||||
initial_state = PodcasterState(
|
||||
source_content=source_content,
|
||||
db_session=session,
|
||||
)
|
||||
|
||||
# Run the podcaster graph
|
||||
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
|
||||
|
||||
# Extract results
|
||||
podcast_transcript = result.get("podcast_transcript", [])
|
||||
file_path = result.get("final_podcast_file_path", "")
|
||||
|
||||
# Convert transcript to serializable format
|
||||
serializable_transcript = []
|
||||
for entry in podcast_transcript:
|
||||
if hasattr(entry, "speaker_id"):
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.speaker_id,
|
||||
"dialog": entry.dialog
|
||||
})
|
||||
else:
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.get("speaker_id", 0),
|
||||
"dialog": entry.get("dialog", "")
|
||||
})
|
||||
|
||||
# Save podcast to database
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
podcast_transcript=serializable_transcript,
|
||||
file_location=file_path,
|
||||
search_space_id=search_space_id,
|
||||
chat_id=None, # No chat_id for new-chat podcasts
|
||||
chat_state_version=None,
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
|
||||
logger.info(f"Successfully generated content podcast: {podcast.id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast_title,
|
||||
"transcript_entries": len(serializable_transcript),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _generate_content_podcast: {e!s}")
|
||||
await session.rollback()
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ This module streams responses from the deep agent using the Vercel AI SDK
|
|||
Data Stream Protocol (SSE format).
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -78,13 +79,15 @@ async def stream_new_chat(
|
|||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
# Create the deep agent with checkpointer
|
||||
# Create the deep agent with checkpointer with podcast capability
|
||||
agent = create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=str(user_id),
|
||||
enable_podcast=True,
|
||||
)
|
||||
|
||||
# Build input with message history from frontend
|
||||
|
|
@ -182,22 +185,72 @@ async def stream_new_chat(
|
|||
f"Searching knowledge base: {query[:100]}{'...' if len(query) > 100 else ''}",
|
||||
"info",
|
||||
)
|
||||
elif tool_name == "generate_podcast":
|
||||
title = (
|
||||
tool_input.get("podcast_title", "SurfSense Podcast")
|
||||
if isinstance(tool_input, dict)
|
||||
else "SurfSense Podcast"
|
||||
)
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Generating podcast: {title}",
|
||||
"info",
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
run_id = event.get("run_id", "")
|
||||
tool_output = event.get("data", {}).get("output", "")
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
raw_output = event.get("data", {}).get("output", "")
|
||||
|
||||
# Extract content from ToolMessage if needed
|
||||
# LangGraph may return a ToolMessage object instead of raw dict
|
||||
if hasattr(raw_output, "content"):
|
||||
# It's a ToolMessage object - extract the content
|
||||
content = raw_output.content
|
||||
# If content is a string that looks like JSON, try to parse it
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
tool_output = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tool_output = {"result": content}
|
||||
elif isinstance(content, dict):
|
||||
tool_output = content
|
||||
else:
|
||||
tool_output = {"result": str(content)}
|
||||
elif isinstance(raw_output, dict):
|
||||
tool_output = raw_output
|
||||
else:
|
||||
tool_output = {"result": str(raw_output) if raw_output else "completed"}
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
|
||||
# Don't stream the full output (can be very large), just acknowledge
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{"status": "completed", "result_length": len(str(tool_output))},
|
||||
)
|
||||
|
||||
yield streaming_service.format_terminal_info(
|
||||
"Knowledge base search completed", "success"
|
||||
)
|
||||
# Handle different tool outputs
|
||||
if tool_name == "generate_podcast":
|
||||
# Stream the full podcast result so frontend can render the audio player
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
tool_output if isinstance(tool_output, dict) else {"result": tool_output},
|
||||
)
|
||||
# Send appropriate terminal message based on status
|
||||
if isinstance(tool_output, dict) and tool_output.get("status") == "success":
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
||||
"success",
|
||||
)
|
||||
else:
|
||||
error_msg = tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error"
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast generation failed: {error_msg}",
|
||||
"error",
|
||||
)
|
||||
else:
|
||||
# Don't stream the full output for other tools (can be very large), just acknowledge
|
||||
yield streaming_service.format_tool_output_available(
|
||||
tool_call_id,
|
||||
{"status": "completed", "result_length": len(str(tool_output))},
|
||||
)
|
||||
yield streaming_service.format_terminal_info(
|
||||
"Knowledge base search completed", "success"
|
||||
)
|
||||
|
||||
# Handle chain/agent end to close any open text blocks
|
||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import { AssistantRuntimeProvider, useLocalRuntime } from "@assistant-ui/react";
|
|||
import { useParams } from "next/navigation";
|
||||
import { useMemo } from "react";
|
||||
import { Thread } from "@/components/assistant-ui/thread";
|
||||
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
|
||||
import { createNewChatAdapter } from "@/lib/chat/new-chat-transport";
|
||||
|
||||
export default function NewChatPage() {
|
||||
|
|
@ -38,6 +39,8 @@ export default function NewChatPage() {
|
|||
|
||||
return (
|
||||
<AssistantRuntimeProvider runtime={runtime}>
|
||||
{/* Register tool UI components */}
|
||||
<GeneratePodcastToolUI />
|
||||
<div className="h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden">
|
||||
<Thread />
|
||||
</div>
|
||||
|
|
|
|||
310
surfsense_web/components/tool-ui/audio.tsx
Normal file
310
surfsense_web/components/tool-ui/audio.tsx
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"use client";
|
||||
|
||||
import { DownloadIcon, PauseIcon, PlayIcon, Volume2Icon, VolumeXIcon } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface AudioProps {
|
||||
id: string;
|
||||
assetId?: string;
|
||||
src: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
artwork?: string;
|
||||
durationMs?: number;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function formatTime(seconds: number): string {
|
||||
if (!Number.isFinite(seconds) || seconds < 0) return "0:00";
|
||||
const mins = Math.floor(seconds / 60);
|
||||
const secs = Math.floor(seconds % 60);
|
||||
return `${mins}:${secs.toString().padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
export function Audio({
|
||||
id,
|
||||
src,
|
||||
title,
|
||||
description,
|
||||
artwork,
|
||||
durationMs,
|
||||
className,
|
||||
}: AudioProps) {
|
||||
const audioRef = useRef<HTMLAudioElement>(null);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [currentTime, setCurrentTime] = useState(0);
|
||||
const [duration, setDuration] = useState(durationMs ? durationMs / 1000 : 0);
|
||||
const [volume, setVolume] = useState(1);
|
||||
const [isMuted, setIsMuted] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Handle play/pause
|
||||
const togglePlayPause = useCallback(() => {
|
||||
const audio = audioRef.current;
|
||||
if (!audio) return;
|
||||
|
||||
if (isPlaying) {
|
||||
audio.pause();
|
||||
} else {
|
||||
audio.play().catch((err) => {
|
||||
console.error("Error playing audio:", err);
|
||||
setError("Failed to play audio");
|
||||
});
|
||||
}
|
||||
}, [isPlaying]);
|
||||
|
||||
// Handle seek
|
||||
const handleSeek = useCallback((value: number[]) => {
|
||||
const audio = audioRef.current;
|
||||
if (!audio || !Number.isFinite(value[0])) return;
|
||||
audio.currentTime = value[0];
|
||||
setCurrentTime(value[0]);
|
||||
}, []);
|
||||
|
||||
// Handle volume change
|
||||
const handleVolumeChange = useCallback((value: number[]) => {
|
||||
const audio = audioRef.current;
|
||||
if (!audio || !Number.isFinite(value[0])) return;
|
||||
const newVolume = value[0];
|
||||
audio.volume = newVolume;
|
||||
setVolume(newVolume);
|
||||
setIsMuted(newVolume === 0);
|
||||
}, []);
|
||||
|
||||
// Toggle mute
|
||||
const toggleMute = useCallback(() => {
|
||||
const audio = audioRef.current;
|
||||
if (!audio) return;
|
||||
|
||||
if (isMuted) {
|
||||
audio.volume = volume || 1;
|
||||
setIsMuted(false);
|
||||
} else {
|
||||
audio.volume = 0;
|
||||
setIsMuted(true);
|
||||
}
|
||||
}, [isMuted, volume]);
|
||||
|
||||
// Handle download
|
||||
const handleDownload = useCallback(async () => {
|
||||
try {
|
||||
const response = await fetch(src);
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `${title.replace(/[^a-zA-Z0-9]/g, "_")}.mp3`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
window.URL.revokeObjectURL(url);
|
||||
} catch (err) {
|
||||
console.error("Error downloading audio:", err);
|
||||
}
|
||||
}, [src, title]);
|
||||
|
||||
// Set up audio event listeners
|
||||
useEffect(() => {
|
||||
const audio = audioRef.current;
|
||||
if (!audio) return;
|
||||
|
||||
const handleLoadedMetadata = () => {
|
||||
setDuration(audio.duration);
|
||||
setIsLoading(false);
|
||||
};
|
||||
|
||||
const handleTimeUpdate = () => {
|
||||
setCurrentTime(audio.currentTime);
|
||||
};
|
||||
|
||||
const handlePlay = () => setIsPlaying(true);
|
||||
const handlePause = () => setIsPlaying(false);
|
||||
const handleEnded = () => {
|
||||
setIsPlaying(false);
|
||||
setCurrentTime(0);
|
||||
};
|
||||
const handleError = () => {
|
||||
setError("Failed to load audio");
|
||||
setIsLoading(false);
|
||||
};
|
||||
const handleCanPlay = () => setIsLoading(false);
|
||||
|
||||
audio.addEventListener("loadedmetadata", handleLoadedMetadata);
|
||||
audio.addEventListener("timeupdate", handleTimeUpdate);
|
||||
audio.addEventListener("play", handlePlay);
|
||||
audio.addEventListener("pause", handlePause);
|
||||
audio.addEventListener("ended", handleEnded);
|
||||
audio.addEventListener("error", handleError);
|
||||
audio.addEventListener("canplay", handleCanPlay);
|
||||
|
||||
return () => {
|
||||
audio.removeEventListener("loadedmetadata", handleLoadedMetadata);
|
||||
audio.removeEventListener("timeupdate", handleTimeUpdate);
|
||||
audio.removeEventListener("play", handlePlay);
|
||||
audio.removeEventListener("pause", handlePause);
|
||||
audio.removeEventListener("ended", handleEnded);
|
||||
audio.removeEventListener("error", handleError);
|
||||
audio.removeEventListener("canplay", handleCanPlay);
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-4 rounded-xl border border-destructive/20 bg-destructive/5 p-4",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex size-16 items-center justify-center rounded-lg bg-destructive/10">
|
||||
<Volume2Icon className="size-8 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<p className="font-medium text-destructive">{title}</p>
|
||||
<p className="text-destructive/70 text-sm">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
id={id}
|
||||
className={cn(
|
||||
"group relative overflow-hidden rounded-xl border bg-gradient-to-br from-background to-muted/30 p-4 shadow-sm transition-all hover:shadow-md",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{/* Hidden audio element */}
|
||||
<audio ref={audioRef} src={src} preload="metadata">
|
||||
<track kind="captions" srcLang="en" label="English captions" default />
|
||||
</audio>
|
||||
|
||||
<div className="flex gap-4">
|
||||
{/* Artwork */}
|
||||
<div className="relative shrink-0">
|
||||
<div className="relative size-20 overflow-hidden rounded-lg bg-gradient-to-br from-primary/20 to-primary/5 shadow-inner">
|
||||
{artwork ? (
|
||||
<Image
|
||||
src={artwork}
|
||||
alt={title}
|
||||
fill
|
||||
className="object-cover"
|
||||
unoptimized
|
||||
/>
|
||||
) : (
|
||||
<div className="flex size-full items-center justify-center">
|
||||
<Volume2Icon className="size-8 text-primary/50" />
|
||||
</div>
|
||||
)}
|
||||
{/* Play overlay on artwork */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={togglePlayPause}
|
||||
className="absolute inset-0 flex items-center justify-center bg-black/0 opacity-0 transition-all group-hover:bg-black/30 group-hover:opacity-100"
|
||||
aria-label={isPlaying ? "Pause" : "Play"}
|
||||
>
|
||||
{isPlaying ? (
|
||||
<PauseIcon className="size-8 text-white drop-shadow-lg" />
|
||||
) : (
|
||||
<PlayIcon className="size-8 text-white drop-shadow-lg" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex min-w-0 flex-1 flex-col justify-between">
|
||||
{/* Title and description */}
|
||||
<div className="min-w-0">
|
||||
<h3 className="truncate font-semibold text-foreground">{title}</h3>
|
||||
{description && (
|
||||
<p className="mt-0.5 line-clamp-1 text-muted-foreground text-sm">
|
||||
{description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Progress bar */}
|
||||
<div className="mt-2 space-y-1">
|
||||
<Slider
|
||||
value={[currentTime]}
|
||||
max={duration || 100}
|
||||
step={0.1}
|
||||
onValueChange={handleSeek}
|
||||
className="cursor-pointer"
|
||||
disabled={isLoading}
|
||||
/>
|
||||
<div className="flex justify-between text-muted-foreground text-xs">
|
||||
<span>{formatTime(currentTime)}</span>
|
||||
<span>{formatTime(duration)}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Controls */}
|
||||
<div className="mt-3 flex items-center justify-between border-t pt-3">
|
||||
<div className="flex items-center gap-2">
|
||||
{/* Play/Pause button */}
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
onClick={togglePlayPause}
|
||||
disabled={isLoading}
|
||||
className="gap-2"
|
||||
>
|
||||
{isLoading ? (
|
||||
<div className="size-4 animate-spin rounded-full border-2 border-current border-t-transparent" />
|
||||
) : isPlaying ? (
|
||||
<PauseIcon className="size-4" />
|
||||
) : (
|
||||
<PlayIcon className="size-4" />
|
||||
)}
|
||||
{isPlaying ? "Pause" : "Play"}
|
||||
</Button>
|
||||
|
||||
{/* Volume control */}
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={toggleMute}
|
||||
className="size-8"
|
||||
>
|
||||
{isMuted ? (
|
||||
<VolumeXIcon className="size-4" />
|
||||
) : (
|
||||
<Volume2Icon className="size-4" />
|
||||
)}
|
||||
</Button>
|
||||
<Slider
|
||||
value={[isMuted ? 0 : volume]}
|
||||
max={1}
|
||||
step={0.01}
|
||||
onValueChange={handleVolumeChange}
|
||||
className="w-20"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Download button */}
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleDownload}
|
||||
className="gap-2"
|
||||
>
|
||||
<DownloadIcon className="size-4" />
|
||||
Download
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
427
surfsense_web/components/tool-ui/generate-podcast.tsx
Normal file
427
surfsense_web/components/tool-ui/generate-podcast.tsx
Normal file
|
|
@ -0,0 +1,427 @@
|
|||
"use client";
|
||||
|
||||
import { makeAssistantToolUI } from "@assistant-ui/react";
|
||||
import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Audio } from "@/components/tool-ui/audio";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
|
||||
import {
|
||||
clearActivePodcastTaskId,
|
||||
setActivePodcastTaskId,
|
||||
} from "@/lib/chat/podcast-state";
|
||||
import type { PodcastTranscriptEntry } from "@/contracts/types/podcast.types";
|
||||
|
||||
/**
|
||||
* Type definitions for the generate_podcast tool
|
||||
*/
|
||||
interface GeneratePodcastArgs {
|
||||
source_content: string;
|
||||
podcast_title?: string;
|
||||
user_prompt?: string;
|
||||
}
|
||||
|
||||
interface GeneratePodcastResult {
|
||||
status: "processing" | "already_generating" | "success" | "error";
|
||||
task_id?: string;
|
||||
podcast_id?: number;
|
||||
title?: string;
|
||||
transcript_entries?: number;
|
||||
message?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface TaskStatusResponse {
|
||||
status: "processing" | "success" | "error";
|
||||
podcast_id?: number;
|
||||
title?: string;
|
||||
transcript_entries?: number;
|
||||
state?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loading state component shown while podcast is being generated
|
||||
*/
|
||||
function PodcastGeneratingState({ title }: { title: string }) {
|
||||
return (
|
||||
<div className="my-4 overflow-hidden rounded-xl border border-primary/20 bg-gradient-to-br from-primary/5 to-primary/10 p-6">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="relative">
|
||||
<div className="flex size-16 items-center justify-center rounded-full bg-primary/20">
|
||||
<MicIcon className="size-8 text-primary" />
|
||||
</div>
|
||||
{/* Animated rings */}
|
||||
<div className="absolute inset-1 animate-ping rounded-full bg-primary/20" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<h3 className="font-semibold text-foreground text-lg">{title}</h3>
|
||||
<div className="mt-2 flex items-center gap-2 text-muted-foreground">
|
||||
<Loader2Icon className="size-4 animate-spin" />
|
||||
<span className="text-sm">Generating podcast... This may take a few minutes</span>
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<div className="h-1.5 w-full overflow-hidden rounded-full bg-primary/10">
|
||||
<div className="h-full w-1/3 animate-pulse rounded-full bg-primary" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Error state component shown when podcast generation fails
|
||||
*/
|
||||
function PodcastErrorState({ title, error }: { title: string; error: string }) {
|
||||
return (
|
||||
<div className="my-4 overflow-hidden rounded-xl border border-destructive/20 bg-destructive/5 p-6">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex size-16 shrink-0 items-center justify-center rounded-full bg-destructive/10">
|
||||
<AlertCircleIcon className="size-8 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<h3 className="font-semibold text-foreground">{title}</h3>
|
||||
<p className="mt-1 text-destructive text-sm">Failed to generate podcast</p>
|
||||
<p className="mt-2 text-muted-foreground text-sm">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Audio loading state component
|
||||
*/
|
||||
function AudioLoadingState({ title }: { title: string }) {
|
||||
return (
|
||||
<div className="my-4 overflow-hidden rounded-xl border bg-muted/30 p-6">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex size-16 items-center justify-center rounded-full bg-primary/10">
|
||||
<MicIcon className="size-8 text-primary/50" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<h3 className="font-semibold text-foreground">{title}</h3>
|
||||
<div className="mt-2 flex items-center gap-2 text-muted-foreground">
|
||||
<Loader2Icon className="size-4 animate-spin" />
|
||||
<span className="text-sm">Loading audio...</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Podcast Player Component - Fetches audio and transcript with authentication
|
||||
*/
|
||||
function PodcastPlayer({
|
||||
podcastId,
|
||||
title,
|
||||
description,
|
||||
durationMs,
|
||||
}: {
|
||||
podcastId: number;
|
||||
title: string;
|
||||
description: string;
|
||||
durationMs?: number;
|
||||
}) {
|
||||
const [audioSrc, setAudioSrc] = useState<string | null>(null);
|
||||
const [transcript, setTranscript] = useState<PodcastTranscriptEntry[] | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const objectUrlRef = useRef<string | null>(null);
|
||||
|
||||
// Cleanup object URL on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (objectUrlRef.current) {
|
||||
URL.revokeObjectURL(objectUrlRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Fetch audio and podcast details (including transcript)
|
||||
const loadPodcast = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
// Revoke previous object URL if exists
|
||||
if (objectUrlRef.current) {
|
||||
URL.revokeObjectURL(objectUrlRef.current);
|
||||
objectUrlRef.current = null;
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 60000); // 60s timeout
|
||||
|
||||
try {
|
||||
// Fetch audio blob and podcast details in parallel
|
||||
const [audioBlob, podcastDetails] = await Promise.all([
|
||||
podcastsApiService.loadPodcast({
|
||||
request: { id: podcastId },
|
||||
controller,
|
||||
}),
|
||||
podcastsApiService.getPodcastById(podcastId),
|
||||
]);
|
||||
|
||||
// Create object URL from blob
|
||||
const objectUrl = URL.createObjectURL(audioBlob);
|
||||
objectUrlRef.current = objectUrl;
|
||||
setAudioSrc(objectUrl);
|
||||
|
||||
// Set transcript from podcast details
|
||||
if (podcastDetails?.podcast_transcript) {
|
||||
setTranscript(podcastDetails.podcast_transcript);
|
||||
}
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error loading podcast:", err);
|
||||
if (err instanceof DOMException && err.name === "AbortError") {
|
||||
setError("Request timed out. Please try again.");
|
||||
} else {
|
||||
setError(err instanceof Error ? err.message : "Failed to load podcast");
|
||||
}
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [podcastId]);
|
||||
|
||||
// Load podcast when component mounts
|
||||
useEffect(() => {
|
||||
loadPodcast();
|
||||
}, [loadPodcast]);
|
||||
|
||||
if (isLoading) {
|
||||
return <AudioLoadingState title={title} />;
|
||||
}
|
||||
|
||||
if (error || !audioSrc) {
|
||||
return <PodcastErrorState title={title} error={error || "Failed to load audio"} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="my-4">
|
||||
<Audio
|
||||
id={`podcast-${podcastId}`}
|
||||
src={audioSrc}
|
||||
title={title}
|
||||
description={description}
|
||||
durationMs={durationMs}
|
||||
className="w-full"
|
||||
/>
|
||||
{/* Transcript section */}
|
||||
{transcript && transcript.length > 0 && (
|
||||
<details className="mt-3 rounded-lg border bg-muted/30 p-3">
|
||||
<summary className="cursor-pointer font-medium text-muted-foreground text-sm hover:text-foreground">
|
||||
View transcript ({transcript.length} entries)
|
||||
</summary>
|
||||
<div className="mt-3 space-y-3 max-h-96 overflow-y-auto">
|
||||
{transcript.map((entry, idx) => (
|
||||
<div key={`${idx}-${entry.speaker_id}`} className="text-sm">
|
||||
<span className="font-medium text-primary">
|
||||
Speaker {entry.speaker_id + 1}:
|
||||
</span>{" "}
|
||||
<span className="text-muted-foreground">{entry.dialog}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Polling component that checks task status and shows player when complete
|
||||
*/
|
||||
function PodcastTaskPoller({
|
||||
taskId,
|
||||
title,
|
||||
}: {
|
||||
taskId: string;
|
||||
title: string;
|
||||
}) {
|
||||
const [taskStatus, setTaskStatus] = useState<TaskStatusResponse>({ status: "processing" });
|
||||
const pollingRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
// Set active podcast state when this component mounts
|
||||
useEffect(() => {
|
||||
setActivePodcastTaskId(taskId);
|
||||
|
||||
// Clear when component unmounts
|
||||
return () => {
|
||||
// Only clear if this task is still the active one
|
||||
clearActivePodcastTaskId();
|
||||
};
|
||||
}, [taskId]);
|
||||
|
||||
// Poll for task status
|
||||
useEffect(() => {
|
||||
const pollStatus = async () => {
|
||||
try {
|
||||
const response = await baseApiService.get<TaskStatusResponse>(
|
||||
`/api/v1/podcasts/task/${taskId}/status`
|
||||
);
|
||||
setTaskStatus(response);
|
||||
|
||||
// Stop polling if task is complete or errored
|
||||
if (response.status !== "processing") {
|
||||
if (pollingRef.current) {
|
||||
clearInterval(pollingRef.current);
|
||||
pollingRef.current = null;
|
||||
}
|
||||
// Clear the active podcast state when task completes
|
||||
clearActivePodcastTaskId();
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error polling task status:", err);
|
||||
// Don't stop polling on network errors, continue polling
|
||||
}
|
||||
};
|
||||
|
||||
// Initial poll
|
||||
pollStatus();
|
||||
|
||||
// Poll every 5 seconds
|
||||
pollingRef.current = setInterval(pollStatus, 5000);
|
||||
|
||||
return () => {
|
||||
if (pollingRef.current) {
|
||||
clearInterval(pollingRef.current);
|
||||
}
|
||||
};
|
||||
}, [taskId]);
|
||||
|
||||
// Show loading state while processing
|
||||
if (taskStatus.status === "processing") {
|
||||
return <PodcastGeneratingState title={title} />;
|
||||
}
|
||||
|
||||
// Show error state
|
||||
if (taskStatus.status === "error") {
|
||||
return <PodcastErrorState title={title} error={taskStatus.error || "Generation failed"} />;
|
||||
}
|
||||
|
||||
// Show player when complete
|
||||
if (taskStatus.status === "success" && taskStatus.podcast_id) {
|
||||
return (
|
||||
<PodcastPlayer
|
||||
podcastId={taskStatus.podcast_id}
|
||||
title={taskStatus.title || title}
|
||||
description={
|
||||
taskStatus.transcript_entries
|
||||
? `${taskStatus.transcript_entries} dialogue entries`
|
||||
: "SurfSense AI-generated podcast"
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback
|
||||
return <PodcastErrorState title={title} error="Unexpected state" />;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate Podcast Tool UI Component
|
||||
*
|
||||
* This component is registered with assistant-ui to render custom UI
|
||||
* when the generate_podcast tool is called by the agent.
|
||||
*
|
||||
* It polls for task completion and auto-updates when the podcast is ready.
|
||||
*/
|
||||
export const GeneratePodcastToolUI = makeAssistantToolUI<
|
||||
GeneratePodcastArgs,
|
||||
GeneratePodcastResult
|
||||
>({
|
||||
toolName: "generate_podcast",
|
||||
render: function GeneratePodcastUI({ args, result, status }) {
|
||||
const title = args.podcast_title || "SurfSense Podcast";
|
||||
|
||||
// Loading state - tool is still running (agent processing)
|
||||
if (status.type === "running" || status.type === "requires-action") {
|
||||
return <PodcastGeneratingState title={title} />;
|
||||
}
|
||||
|
||||
// Incomplete/cancelled state
|
||||
if (status.type === "incomplete") {
|
||||
if (status.reason === "cancelled") {
|
||||
return (
|
||||
<div className="my-4 rounded-xl border border-muted p-4 text-muted-foreground">
|
||||
<p className="flex items-center gap-2">
|
||||
<MicIcon className="size-4" />
|
||||
<span className="line-through">Podcast generation cancelled</span>
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
if (status.reason === "error") {
|
||||
return (
|
||||
<PodcastErrorState
|
||||
title={title}
|
||||
error={typeof status.error === "string" ? status.error : "An error occurred"}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// No result yet
|
||||
if (!result) {
|
||||
return <PodcastGeneratingState title={title} />;
|
||||
}
|
||||
|
||||
// Error result
|
||||
if (result.status === "error") {
|
||||
return <PodcastErrorState title={title} error={result.error || "Unknown error"} />;
|
||||
}
|
||||
|
||||
// Already generating - show simple warning, don't create another poller
|
||||
// The FIRST tool call will display the podcast when ready
|
||||
if (result.status === "already_generating") {
|
||||
return (
|
||||
<div className="my-4 overflow-hidden rounded-xl border border-amber-500/20 bg-amber-500/5 p-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex size-10 shrink-0 items-center justify-center rounded-full bg-amber-500/20">
|
||||
<MicIcon className="size-5 text-amber-500" />
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-amber-600 dark:text-amber-400 text-sm font-medium">
|
||||
Podcast already in progress
|
||||
</p>
|
||||
<p className="text-muted-foreground text-xs mt-0.5">
|
||||
Please wait for the current podcast to complete.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Processing - poll for completion
|
||||
if (result.status === "processing" && result.task_id) {
|
||||
return <PodcastTaskPoller taskId={result.task_id} title={result.title || title} />;
|
||||
}
|
||||
|
||||
// Success with podcast_id (direct result, not via polling)
|
||||
if (result.status === "success" && result.podcast_id) {
|
||||
return (
|
||||
<PodcastPlayer
|
||||
podcastId={result.podcast_id}
|
||||
title={result.title || title}
|
||||
description={
|
||||
result.transcript_entries
|
||||
? `${result.transcript_entries} dialogue entries`
|
||||
: "SurfSense AI-generated podcast"
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback - missing required data
|
||||
return <PodcastErrorState title={title} error="Missing task ID or podcast ID" />;
|
||||
},
|
||||
});
|
||||
11
surfsense_web/components/tool-ui/index.ts
Normal file
11
surfsense_web/components/tool-ui/index.ts
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
/**
|
||||
* Tool UI Components
|
||||
*
|
||||
* This module exports custom UI components for assistant tools.
|
||||
* These components are registered with assistant-ui to render
|
||||
* rich UI when specific tools are called by the agent.
|
||||
*/
|
||||
|
||||
export { Audio } from "./audio";
|
||||
export { GeneratePodcastToolUI } from "./generate-podcast";
|
||||
|
||||
|
|
@ -1,12 +1,17 @@
|
|||
import { z } from "zod";
|
||||
import { paginationQueryParams } from ".";
|
||||
|
||||
export const podcastTranscriptEntry = z.object({
|
||||
speaker_id: z.number(),
|
||||
dialog: z.string(),
|
||||
});
|
||||
|
||||
export const podcast = z.object({
|
||||
id: z.number(),
|
||||
title: z.string(),
|
||||
created_at: z.string(),
|
||||
file_location: z.string(),
|
||||
podcast_transcript: z.array(z.any()),
|
||||
podcast_transcript: z.array(podcastTranscriptEntry),
|
||||
search_space_id: z.number(),
|
||||
chat_state_version: z.number().nullable(),
|
||||
});
|
||||
|
|
@ -41,6 +46,7 @@ export const getPodcastsRequest = z.object({
|
|||
queryParams: paginationQueryParams.nullish(),
|
||||
});
|
||||
|
||||
export type PodcastTranscriptEntry = z.infer<typeof podcastTranscriptEntry>;
|
||||
export type GeneratePodcastRequest = z.infer<typeof generatePodcastRequest>;
|
||||
export type GetPodcastByChatIdRequest = z.infer<typeof getPodcastByChatIdRequest>;
|
||||
export type GetPodcastByChatIdResponse = z.infer<typeof getPodcastByChaIdResponse>;
|
||||
|
|
|
|||
|
|
@ -62,6 +62,13 @@ class PodcastsApiService {
|
|||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get a podcast by its ID (includes full transcript)
|
||||
*/
|
||||
getPodcastById = async (podcastId: number) => {
|
||||
return baseApiService.get(`/api/v1/podcasts/${podcastId}`, podcast);
|
||||
};
|
||||
|
||||
generatePodcast = async (request: GeneratePodcastRequest) => {
|
||||
// Validate the request
|
||||
const parsedRequest = generatePodcastRequest.safeParse(request);
|
||||
|
|
|
|||
|
|
@ -4,7 +4,13 @@
|
|||
*/
|
||||
|
||||
import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react";
|
||||
import { toast } from "sonner";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import {
|
||||
isPodcastGenerating,
|
||||
looksLikePodcastRequest,
|
||||
setActivePodcastTaskId,
|
||||
} from "@/lib/chat/podcast-state";
|
||||
|
||||
interface NewChatAdapterConfig {
|
||||
searchSpaceId: number;
|
||||
|
|
@ -40,6 +46,22 @@ function convertMessagesToBackendFormat(
|
|||
.filter((m) => m.content.length > 0); // Filter out empty messages
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents an in-progress or completed tool call
|
||||
*/
|
||||
interface ToolCallState {
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tools that should render custom UI in the chat.
|
||||
* Other tools (like search_knowledge_base) will be hidden from the UI.
|
||||
*/
|
||||
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
|
||||
|
||||
/**
|
||||
* Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint.
|
||||
*
|
||||
|
|
@ -72,6 +94,21 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
|
|||
throw new Error("User query cannot be empty");
|
||||
}
|
||||
|
||||
// Check if user is requesting a podcast while one is already generating
|
||||
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
|
||||
toast.warning("A podcast is already being generated. Please wait for it to complete.");
|
||||
// Return a message telling the user to wait
|
||||
yield {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "A podcast is already being generated. Please wait for it to complete before requesting another one.",
|
||||
},
|
||||
],
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
throw new Error("Not authenticated. Please log in again.");
|
||||
|
|
@ -110,6 +147,41 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
|
|||
let buffer = "";
|
||||
let accumulatedText = "";
|
||||
|
||||
// Track tool calls by their ID
|
||||
const toolCalls = new Map<string, ToolCallState>();
|
||||
|
||||
/**
|
||||
* Build the content array with text and tool calls.
|
||||
* Only includes tools that have custom UI (defined in TOOLS_WITH_UI).
|
||||
*/
|
||||
function buildContent() {
|
||||
const content: Array<
|
||||
| { type: "text"; text: string }
|
||||
| { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown }
|
||||
> = [];
|
||||
|
||||
// Add text content if any
|
||||
if (accumulatedText) {
|
||||
content.push({ type: "text" as const, text: accumulatedText });
|
||||
}
|
||||
|
||||
// Only add tool calls that have custom UI registered
|
||||
// Other tools (like search_knowledge_base) are hidden from the UI
|
||||
for (const toolCall of toolCalls.values()) {
|
||||
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
|
||||
content.push({
|
||||
type: "tool-call" as const,
|
||||
toolCallId: toolCall.toolCallId,
|
||||
toolName: toolCall.toolName,
|
||||
args: toolCall.args,
|
||||
result: toolCall.result,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
|
@ -146,16 +218,70 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
|
|||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
accumulatedText += parsed.delta;
|
||||
yield {
|
||||
content: [{ type: "text" as const, text: accumulatedText }],
|
||||
};
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
|
||||
case "tool-input-start": {
|
||||
// Tool call is starting - create a new tool call entry
|
||||
const { toolCallId, toolName } = parsed;
|
||||
toolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args: {},
|
||||
});
|
||||
// Yield to show tool is starting (running state)
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-input-available": {
|
||||
// Tool input is complete - update the args
|
||||
const { toolCallId, toolName, input } = parsed;
|
||||
const existing = toolCalls.get(toolCallId);
|
||||
if (existing) {
|
||||
existing.args = input || {};
|
||||
} else {
|
||||
// Create new entry if we missed tool-input-start
|
||||
toolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args: input || {},
|
||||
});
|
||||
}
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool-output-available": {
|
||||
// Tool execution is complete - add the result
|
||||
const { toolCallId, output } = parsed;
|
||||
const existing = toolCalls.get(toolCallId);
|
||||
if (existing) {
|
||||
existing.result = output;
|
||||
|
||||
// If this is a podcast tool with status="processing", set the state immediately
|
||||
// This ensures subsequent podcast requests are intercepted
|
||||
if (
|
||||
existing.toolName === "generate_podcast" &&
|
||||
output &&
|
||||
typeof output === "object" &&
|
||||
"status" in output &&
|
||||
output.status === "processing" &&
|
||||
"task_id" in output &&
|
||||
typeof output.task_id === "string"
|
||||
) {
|
||||
setActivePodcastTaskId(output.task_id);
|
||||
}
|
||||
}
|
||||
yield { content: buildContent() };
|
||||
break;
|
||||
}
|
||||
|
||||
case "error":
|
||||
throw new Error(parsed.errorText || "Unknown error from server");
|
||||
|
||||
// Other types like text-start, text-end, tool-*, etc.
|
||||
// are handled implicitly - we just accumulate text deltas
|
||||
// Other types like text-start, text-end, start-step, finish-step, etc.
|
||||
// are handled implicitly
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -181,9 +307,27 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
|
|||
const parsed = JSON.parse(data);
|
||||
if (parsed.type === "text-delta") {
|
||||
accumulatedText += parsed.delta;
|
||||
yield {
|
||||
content: [{ type: "text" as const, text: accumulatedText }],
|
||||
};
|
||||
yield { content: buildContent() };
|
||||
} else if (parsed.type === "tool-output-available") {
|
||||
const { toolCallId, output } = parsed;
|
||||
const existing = toolCalls.get(toolCallId);
|
||||
if (existing) {
|
||||
existing.result = output;
|
||||
|
||||
// Set podcast state if processing
|
||||
if (
|
||||
existing.toolName === "generate_podcast" &&
|
||||
output &&
|
||||
typeof output === "object" &&
|
||||
"status" in output &&
|
||||
output.status === "processing" &&
|
||||
"task_id" in output &&
|
||||
typeof output.task_id === "string"
|
||||
) {
|
||||
setActivePodcastTaskId(output.task_id);
|
||||
}
|
||||
}
|
||||
yield { content: buildContent() };
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
|
|
|
|||
74
surfsense_web/lib/chat/podcast-state.ts
Normal file
74
surfsense_web/lib/chat/podcast-state.ts
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Module-level state for tracking active podcast generation.
|
||||
* Used by the new-chat adapter to prevent duplicate podcast requests.
|
||||
*/
|
||||
|
||||
type PodcastStateListener = (isGenerating: boolean) => void;
|
||||
|
||||
let _activePodcastTaskId: string | null = null;
|
||||
const _listeners: Set<PodcastStateListener> = new Set();
|
||||
|
||||
/**
|
||||
* Check if a podcast is currently being generated
|
||||
*/
|
||||
export function isPodcastGenerating(): boolean {
|
||||
return _activePodcastTaskId !== null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the active podcast task ID
|
||||
*/
|
||||
export function getActivePodcastTaskId(): string | null {
|
||||
return _activePodcastTaskId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the active podcast task ID (when podcast generation starts)
|
||||
*/
|
||||
export function setActivePodcastTaskId(taskId: string): void {
|
||||
_activePodcastTaskId = taskId;
|
||||
notifyListeners();
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the active podcast task ID (when podcast generation completes or errors)
|
||||
*/
|
||||
export function clearActivePodcastTaskId(): void {
|
||||
_activePodcastTaskId = null;
|
||||
notifyListeners();
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to podcast state changes
|
||||
*/
|
||||
export function subscribeToPodcastState(listener: PodcastStateListener): () => void {
|
||||
_listeners.add(listener);
|
||||
return () => {
|
||||
_listeners.delete(listener);
|
||||
};
|
||||
}
|
||||
|
||||
function notifyListeners(): void {
|
||||
const isGenerating = _activePodcastTaskId !== null;
|
||||
for (const listener of _listeners) {
|
||||
listener(isGenerating);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message looks like a podcast request
|
||||
*/
|
||||
export function looksLikePodcastRequest(message: string): boolean {
|
||||
const podcastPatterns = [
|
||||
/\bpodcast\b/i,
|
||||
/\bcreate.*podcast\b/i,
|
||||
/\bgenerate.*podcast\b/i,
|
||||
/\bmake.*podcast\b/i,
|
||||
/\bturn.*into.*podcast\b/i,
|
||||
/\bpodcast.*about\b/i,
|
||||
/\bgive.*podcast\b/i,
|
||||
];
|
||||
|
||||
return podcastPatterns.some((pattern) => pattern.test(message));
|
||||
}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue