Merge pull request #608 from AnishSarkar22/feature/podcast-agent

Added Podcast agent within chat
This commit is contained in:
Rohan Verma 2025-12-21 14:52:29 -08:00 committed by GitHub
commit f115980d2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1488 additions and 23 deletions

View file

@ -2,7 +2,7 @@
SurfSense deep agent implementation. SurfSense deep agent implementation.
This module provides the factory function for creating SurfSense deep agents This module provides the factory function for creating SurfSense deep agents
with knowledge base search capability. with knowledge base search and podcast generation capabilities.
""" """
from collections.abc import Sequence from collections.abc import Sequence
@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.knowledge_base import create_search_knowledge_base_tool from app.agents.new_chat.knowledge_base import create_search_knowledge_base_tool
from app.agents.new_chat.podcast import create_generate_podcast_tool
from app.agents.new_chat.system_prompt import build_surfsense_system_prompt from app.agents.new_chat.system_prompt import build_surfsense_system_prompt
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
@ -29,12 +30,14 @@ def create_surfsense_deep_agent(
db_session: AsyncSession, db_session: AsyncSession,
connector_service: ConnectorService, connector_service: ConnectorService,
checkpointer: Checkpointer, checkpointer: Checkpointer,
user_id: str | None = None,
user_instructions: str | None = None, user_instructions: str | None = None,
enable_citations: bool = True, enable_citations: bool = True,
enable_podcast: bool = True,
additional_tools: Sequence[BaseTool] | None = None, additional_tools: Sequence[BaseTool] | None = None,
): ):
""" """
Create a SurfSense deep agent with knowledge base search capability. Create a SurfSense deep agent with knowledge base search and podcast generation capabilities.
Args: Args:
llm: ChatLiteLLM instance llm: ChatLiteLLM instance
@ -43,10 +46,13 @@ def create_surfsense_deep_agent(
connector_service: Initialized connector service connector_service: Initialized connector service
checkpointer: LangGraph checkpointer for conversation state persistence. checkpointer: LangGraph checkpointer for conversation state persistence.
Use AsyncPostgresSaver for production or MemorySaver for testing. Use AsyncPostgresSaver for production or MemorySaver for testing.
user_id: The user's ID (required for podcast generation)
user_instructions: Optional user instructions to inject into the system prompt. user_instructions: Optional user instructions to inject into the system prompt.
These will be added to the system prompt to customize agent behavior. These will be added to the system prompt to customize agent behavior.
enable_citations: Whether to include citation instructions in the system prompt (default: True). enable_citations: Whether to include citation instructions in the system prompt (default: True).
When False, the agent will not be instructed to add citations to responses. When False, the agent will not be instructed to add citations to responses.
enable_podcast: Whether to include the podcast generation tool (default: True).
When True and user_id is provided, the agent can generate podcasts.
additional_tools: Optional sequence of additional tools to inject into the agent. additional_tools: Optional sequence of additional tools to inject into the agent.
The search_knowledge_base tool will always be included. The search_knowledge_base tool will always be included.
@ -62,6 +68,16 @@ def create_surfsense_deep_agent(
# Combine search tool with any additional tools # Combine search tool with any additional tools
tools = [search_tool] tools = [search_tool]
# Add podcast tool if enabled and user_id is provided
if enable_podcast and user_id:
podcast_tool = create_generate_podcast_tool(
search_space_id=search_space_id,
db_session=db_session,
user_id=str(user_id),
)
tools.append(podcast_tool)
if additional_tools: if additional_tools:
tools.extend(additional_tools) tools.extend(additional_tools)

View file

@ -0,0 +1,174 @@
"""
Podcast generation tool for the new chat agent.
This module provides a factory function for creating the generate_podcast tool
that submits a Celery task for background podcast generation. The frontend
polls for completion and auto-updates when the podcast is ready.
Duplicate request prevention:
- Only one podcast can be generated at a time per search space
- Uses Redis to track active podcast tasks
- Returns a friendly message if a podcast is already being generated
"""
import os
from typing import Any
import redis
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
# Redis connection for tracking active podcast tasks
# Uses the same Redis instance as Celery
REDIS_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
_redis_client: redis.Redis | None = None
def get_redis_client() -> redis.Redis:
"""Get or create Redis client for podcast task tracking."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
return _redis_client
def get_active_podcast_key(search_space_id: int) -> str:
"""Generate Redis key for tracking active podcast task."""
return f"podcast:active:{search_space_id}"
def get_active_podcast_task(search_space_id: int) -> str | None:
"""Check if there's an active podcast task for this search space."""
try:
client = get_redis_client()
return client.get(get_active_podcast_key(search_space_id))
except Exception:
# If Redis is unavailable, allow the request (fail open)
return None
def set_active_podcast_task(search_space_id: int, task_id: str) -> None:
"""Mark a podcast task as active for this search space."""
try:
client = get_redis_client()
# Set with 30-minute expiry as safety net (podcast should complete before this)
client.setex(get_active_podcast_key(search_space_id), 1800, task_id)
except Exception as e:
print(f"[generate_podcast] Warning: Could not set active task in Redis: {e}")
def clear_active_podcast_task(search_space_id: int) -> None:
"""Clear the active podcast task for this search space."""
try:
client = get_redis_client()
client.delete(get_active_podcast_key(search_space_id))
except Exception as e:
print(f"[generate_podcast] Warning: Could not clear active task in Redis: {e}")
def create_generate_podcast_tool(
search_space_id: int,
db_session: AsyncSession,
user_id: str,
):
"""
Factory function to create the generate_podcast tool with injected dependencies.
Args:
search_space_id: The user's search space ID
db_session: Database session (not used - Celery creates its own)
user_id: The user's ID (as string)
Returns:
A configured tool function for generating podcasts
"""
@tool
async def generate_podcast(
source_content: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict[str, Any]:
"""
Generate a podcast from the provided content.
Use this tool when the user asks to create, generate, or make a podcast.
Common triggers include phrases like:
- "Give me a podcast about this"
- "Create a podcast from this conversation"
- "Generate a podcast summary"
- "Make a podcast about..."
- "Turn this into a podcast"
The tool will start generating a podcast in the background.
The podcast will be available once generation completes.
IMPORTANT: Only one podcast can be generated at a time. If a podcast
is already being generated, this tool will return a message asking
the user to wait.
Args:
source_content: The text content to convert into a podcast.
This can be a summary, research findings, or any text
the user wants transformed into an audio podcast.
podcast_title: Title for the podcast (default: "SurfSense Podcast")
user_prompt: Optional instructions for podcast style, tone, or format.
For example: "Make it casual and fun" or "Focus on the key insights"
Returns:
A dictionary containing:
- status: "processing" (task submitted), "already_generating", or "error"
- task_id: The Celery task ID for polling status (if processing)
- title: The podcast title
- message: Status message for the user
"""
try:
# Check if a podcast is already being generated for this search space
active_task_id = get_active_podcast_task(search_space_id)
if active_task_id:
print(f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}")
return {
"status": "already_generating",
"task_id": active_task_id,
"title": podcast_title,
"message": "A podcast is already being generated. Please wait for it to complete before requesting another one.",
}
# Import Celery task here to avoid circular imports
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
# Submit Celery task for background processing
task = generate_content_podcast_task.delay(
source_content=source_content,
search_space_id=search_space_id,
user_id=str(user_id),
podcast_title=podcast_title,
user_prompt=user_prompt,
)
# Mark this task as active
set_active_podcast_task(search_space_id, task.id)
print(f"[generate_podcast] Submitted Celery task: {task.id}")
# Return immediately with task_id for polling
return {
"status": "processing",
"task_id": task.id,
"title": podcast_title,
"message": "Podcast generation started. This may take a few minutes.",
}
except Exception as e:
error_message = str(e)
print(f"[generate_podcast] Error submitting task: {error_message}")
return {
"status": "error",
"error": error_message,
"title": podcast_title,
"task_id": None,
}
return generate_podcast

View file

@ -121,7 +121,8 @@ Today's date (UTC): {resolved_today}
</system_instruction>{user_section} </system_instruction>{user_section}
<tools> <tools>
You have access to the following tools: You have access to the following tools:
- search_knowledge_base: Search the user's personal knowledge base for relevant information.
1. search_knowledge_base: Search the user's personal knowledge base for relevant information.
- Args: - Args:
- query: The search query - be specific and include key terms - query: The search query - be specific and include key terms
- top_k: Number of results to retrieve (default: 10) - top_k: Number of results to retrieve (default: 10)
@ -129,6 +130,21 @@ You have access to the following tools:
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00") - end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all. - connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
- Returns: Formatted string with relevant documents and their content - Returns: Formatted string with relevant documents and their content
2. generate_podcast: Generate an audio podcast from provided content.
- Use this when the user asks to create, generate, or make a podcast.
- Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast"
- Args:
- source_content: The text content to convert into a podcast. This MUST be comprehensive and include:
* If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses)
* If based on knowledge base search: Include the key findings and insights from the search results
* You can combine both: conversation context + search results for richer podcasts
* The more detailed the source_content, the better the podcast quality
- podcast_title: Optional title for the podcast (default: "SurfSense Podcast")
- user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun")
- Returns: A task_id for tracking. The podcast will be generated in the background.
- IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating".
- After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes).
</tools> </tools>
<tool_call_examples> <tool_call_examples>
- User: "Fetch all my notes and what's in them?" - User: "Fetch all my notes and what's in them?"
@ -136,6 +152,16 @@ You have access to the following tools:
- User: "What did I discuss on Slack last week about the React migration?" - User: "What did I discuss on Slack last week about the React migration?"
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")` - Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
- User: "Give me a podcast about AI trends based on what we discussed"
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
- User: "Create a podcast summary of this conversation"
- Call: `generate_podcast(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
- User: "Make a podcast about quantum computing"
- First search: `search_knowledge_base(query="quantum computing")`
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\n\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
</tool_call_examples>{citation_section} </tool_call_examples>{citation_section}
""" """

View file

@ -444,3 +444,66 @@ async def get_podcast_by_chat_id(
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error fetching podcast: {e!s}" status_code=500, detail=f"Error fetching podcast: {e!s}"
) from e ) from e
@router.get("/podcasts/task/{task_id}/status")
async def get_podcast_task_status(
task_id: str,
user: User = Depends(current_active_user),
):
"""
Get the status of a podcast generation task.
Used by new-chat frontend to poll for completion.
Returns:
- status: "processing" | "success" | "error"
- podcast_id: (only if status == "success")
- title: (only if status == "success")
- error: (only if status == "error")
"""
try:
from celery.result import AsyncResult
from app.celery_app import celery_app
result = AsyncResult(task_id, app=celery_app)
if result.ready():
# Task completed
if result.successful():
task_result = result.result
if isinstance(task_result, dict):
if task_result.get("status") == "success":
return {
"status": "success",
"podcast_id": task_result.get("podcast_id"),
"title": task_result.get("title"),
"transcript_entries": task_result.get("transcript_entries"),
}
else:
return {
"status": "error",
"error": task_result.get("error", "Unknown error"),
}
else:
return {
"status": "error",
"error": "Unexpected task result format",
}
else:
# Task failed
return {
"status": "error",
"error": str(result.result) if result.result else "Task failed",
}
else:
# Task still processing
return {
"status": "processing",
"state": result.state,
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error checking task status: {e!s}"
) from e

View file

@ -11,6 +11,11 @@ from app.celery_app import celery_app
from app.config import config from app.config import config
from app.tasks.podcast_tasks import generate_chat_podcast from app.tasks.podcast_tasks import generate_chat_podcast
# Import for content-based podcast (new-chat)
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.db import Podcast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
@ -86,3 +91,149 @@ async def _generate_chat_podcast(
except Exception as e: except Exception as e:
logger.error(f"Error generating podcast from chat: {e!s}") logger.error(f"Error generating podcast from chat: {e!s}")
raise raise
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
"""Clear the active podcast task key from Redis when task completes."""
import os
import redis
try:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
client = redis.from_url(redis_url, decode_responses=True)
key = f"podcast:active:{search_space_id}"
client.delete(key)
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
except Exception as e:
logger.warning(f"Could not clear active podcast key: {e}")
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
source_content: str,
search_space_id: int,
user_id: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict:
"""
Celery task to generate podcast from source content (for new-chat).
Unlike generate_chat_podcast which requires a chat_id, this task
generates a podcast directly from provided content.
Args:
source_content: The text content to convert into a podcast
search_space_id: ID of the search space
user_id: ID of the user (as string)
podcast_title: Title for the podcast
user_prompt: Optional instructions for podcast style/tone
Returns:
dict with podcast_id on success, or error info on failure
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_content_podcast(
source_content,
search_space_id,
user_id,
podcast_title,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
return {"status": "error", "error": str(e)}
finally:
# Always clear the active podcast key when task completes (success or failure)
_clear_active_podcast_redis_key(search_space_id)
asyncio.set_event_loop(None)
loop.close()
async def _generate_content_podcast(
source_content: str,
search_space_id: int,
user_id: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict:
"""Generate content-based podcast with new session."""
async with get_celery_session_maker()() as session:
try:
# Configure the podcaster graph
graph_config = {
"configurable": {
"podcast_title": podcast_title,
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize the podcaster state with the source content
initial_state = PodcasterState(
source_content=source_content,
db_session=session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
# Convert transcript to serializable format
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
serializable_transcript.append({
"speaker_id": entry.speaker_id,
"dialog": entry.dialog
})
else:
serializable_transcript.append({
"speaker_id": entry.get("speaker_id", 0),
"dialog": entry.get("dialog", "")
})
# Save podcast to database
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
chat_id=None, # No chat_id for new-chat podcasts
chat_state_version=None,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
logger.info(f"Successfully generated content podcast: {podcast.id}")
return {
"status": "success",
"podcast_id": podcast.id,
"title": podcast_title,
"transcript_entries": len(serializable_transcript),
}
except Exception as e:
logger.error(f"Error in _generate_content_podcast: {e!s}")
await session.rollback()
raise

View file

@ -5,6 +5,7 @@ This module streams responses from the deep agent using the Vercel AI SDK
Data Stream Protocol (SSE format). Data Stream Protocol (SSE format).
""" """
import json
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from uuid import UUID from uuid import UUID
@ -78,13 +79,15 @@ async def stream_new_chat(
# Get the PostgreSQL checkpointer for persistent conversation memory # Get the PostgreSQL checkpointer for persistent conversation memory
checkpointer = await get_checkpointer() checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer # Create the deep agent with checkpointer with podcast capability
agent = create_surfsense_deep_agent( agent = create_surfsense_deep_agent(
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,
connector_service=connector_service, connector_service=connector_service,
checkpointer=checkpointer, checkpointer=checkpointer,
user_id=str(user_id),
enable_podcast=True,
) )
# Build input with message history from frontend # Build input with message history from frontend
@ -182,22 +185,72 @@ async def stream_new_chat(
f"Searching knowledge base: {query[:100]}{'...' if len(query) > 100 else ''}", f"Searching knowledge base: {query[:100]}{'...' if len(query) > 100 else ''}",
"info", "info",
) )
elif tool_name == "generate_podcast":
title = (
tool_input.get("podcast_title", "SurfSense Podcast")
if isinstance(tool_input, dict)
else "SurfSense Podcast"
)
yield streaming_service.format_terminal_info(
f"Generating podcast: {title}",
"info",
)
elif event_type == "on_tool_end": elif event_type == "on_tool_end":
run_id = event.get("run_id", "") run_id = event.get("run_id", "")
tool_output = event.get("data", {}).get("output", "") tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
# Extract content from ToolMessage if needed
# LangGraph may return a ToolMessage object instead of raw dict
if hasattr(raw_output, "content"):
# It's a ToolMessage object - extract the content
content = raw_output.content
# If content is a string that looks like JSON, try to parse it
if isinstance(content, str):
try:
tool_output = json.loads(content)
except (json.JSONDecodeError, TypeError):
tool_output = {"result": content}
elif isinstance(content, dict):
tool_output = content
else:
tool_output = {"result": str(content)}
elif isinstance(raw_output, dict):
tool_output = raw_output
else:
tool_output = {"result": str(raw_output) if raw_output else "completed"}
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
# Don't stream the full output (can be very large), just acknowledge # Handle different tool outputs
yield streaming_service.format_tool_output_available( if tool_name == "generate_podcast":
tool_call_id, # Stream the full podcast result so frontend can render the audio player
{"status": "completed", "result_length": len(str(tool_output))}, yield streaming_service.format_tool_output_available(
) tool_call_id,
tool_output if isinstance(tool_output, dict) else {"result": tool_output},
yield streaming_service.format_terminal_info( )
"Knowledge base search completed", "success" # Send appropriate terminal message based on status
) if isinstance(tool_output, dict) and tool_output.get("status") == "success":
yield streaming_service.format_terminal_info(
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
"success",
)
else:
error_msg = tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) else "Unknown error"
yield streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}",
"error",
)
else:
# Don't stream the full output for other tools (can be very large), just acknowledge
yield streaming_service.format_tool_output_available(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
yield streaming_service.format_terminal_info(
"Knowledge base search completed", "success"
)
# Handle chain/agent end to close any open text blocks # Handle chain/agent end to close any open text blocks
elif event_type in ("on_chain_end", "on_agent_end"): elif event_type in ("on_chain_end", "on_agent_end"):

View file

@ -4,6 +4,7 @@ import { AssistantRuntimeProvider, useLocalRuntime } from "@assistant-ui/react";
import { useParams } from "next/navigation"; import { useParams } from "next/navigation";
import { useMemo } from "react"; import { useMemo } from "react";
import { Thread } from "@/components/assistant-ui/thread"; import { Thread } from "@/components/assistant-ui/thread";
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
import { createNewChatAdapter } from "@/lib/chat/new-chat-transport"; import { createNewChatAdapter } from "@/lib/chat/new-chat-transport";
export default function NewChatPage() { export default function NewChatPage() {
@ -38,6 +39,8 @@ export default function NewChatPage() {
return ( return (
<AssistantRuntimeProvider runtime={runtime}> <AssistantRuntimeProvider runtime={runtime}>
{/* Register tool UI components */}
<GeneratePodcastToolUI />
<div className="h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden"> <div className="h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden">
<Thread /> <Thread />
</div> </div>

View file

@ -0,0 +1,310 @@
"use client";
import { DownloadIcon, PauseIcon, PlayIcon, Volume2Icon, VolumeXIcon } from "lucide-react";
import Image from "next/image";
import { useCallback, useEffect, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import { Slider } from "@/components/ui/slider";
import { cn } from "@/lib/utils";
interface AudioProps {
id: string;
assetId?: string;
src: string;
title: string;
description?: string;
artwork?: string;
durationMs?: number;
className?: string;
}
function formatTime(seconds: number): string {
if (!Number.isFinite(seconds) || seconds < 0) return "0:00";
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${mins}:${secs.toString().padStart(2, "0")}`;
}
export function Audio({
id,
src,
title,
description,
artwork,
durationMs,
className,
}: AudioProps) {
const audioRef = useRef<HTMLAudioElement>(null);
const [isPlaying, setIsPlaying] = useState(false);
const [currentTime, setCurrentTime] = useState(0);
const [duration, setDuration] = useState(durationMs ? durationMs / 1000 : 0);
const [volume, setVolume] = useState(1);
const [isMuted, setIsMuted] = useState(false);
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
// Handle play/pause
const togglePlayPause = useCallback(() => {
const audio = audioRef.current;
if (!audio) return;
if (isPlaying) {
audio.pause();
} else {
audio.play().catch((err) => {
console.error("Error playing audio:", err);
setError("Failed to play audio");
});
}
}, [isPlaying]);
// Handle seek
const handleSeek = useCallback((value: number[]) => {
const audio = audioRef.current;
if (!audio || !Number.isFinite(value[0])) return;
audio.currentTime = value[0];
setCurrentTime(value[0]);
}, []);
// Handle volume change
const handleVolumeChange = useCallback((value: number[]) => {
const audio = audioRef.current;
if (!audio || !Number.isFinite(value[0])) return;
const newVolume = value[0];
audio.volume = newVolume;
setVolume(newVolume);
setIsMuted(newVolume === 0);
}, []);
// Toggle mute
const toggleMute = useCallback(() => {
const audio = audioRef.current;
if (!audio) return;
if (isMuted) {
audio.volume = volume || 1;
setIsMuted(false);
} else {
audio.volume = 0;
setIsMuted(true);
}
}, [isMuted, volume]);
// Handle download
const handleDownload = useCallback(async () => {
try {
const response = await fetch(src);
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `${title.replace(/[^a-zA-Z0-9]/g, "_")}.mp3`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
} catch (err) {
console.error("Error downloading audio:", err);
}
}, [src, title]);
// Set up audio event listeners
useEffect(() => {
const audio = audioRef.current;
if (!audio) return;
const handleLoadedMetadata = () => {
setDuration(audio.duration);
setIsLoading(false);
};
const handleTimeUpdate = () => {
setCurrentTime(audio.currentTime);
};
const handlePlay = () => setIsPlaying(true);
const handlePause = () => setIsPlaying(false);
const handleEnded = () => {
setIsPlaying(false);
setCurrentTime(0);
};
const handleError = () => {
setError("Failed to load audio");
setIsLoading(false);
};
const handleCanPlay = () => setIsLoading(false);
audio.addEventListener("loadedmetadata", handleLoadedMetadata);
audio.addEventListener("timeupdate", handleTimeUpdate);
audio.addEventListener("play", handlePlay);
audio.addEventListener("pause", handlePause);
audio.addEventListener("ended", handleEnded);
audio.addEventListener("error", handleError);
audio.addEventListener("canplay", handleCanPlay);
return () => {
audio.removeEventListener("loadedmetadata", handleLoadedMetadata);
audio.removeEventListener("timeupdate", handleTimeUpdate);
audio.removeEventListener("play", handlePlay);
audio.removeEventListener("pause", handlePause);
audio.removeEventListener("ended", handleEnded);
audio.removeEventListener("error", handleError);
audio.removeEventListener("canplay", handleCanPlay);
};
}, []);
if (error) {
return (
<div
className={cn(
"flex items-center gap-4 rounded-xl border border-destructive/20 bg-destructive/5 p-4",
className,
)}
>
<div className="flex size-16 items-center justify-center rounded-lg bg-destructive/10">
<Volume2Icon className="size-8 text-destructive" />
</div>
<div className="flex-1">
<p className="font-medium text-destructive">{title}</p>
<p className="text-destructive/70 text-sm">{error}</p>
</div>
</div>
);
}
return (
<div
id={id}
className={cn(
"group relative overflow-hidden rounded-xl border bg-gradient-to-br from-background to-muted/30 p-4 shadow-sm transition-all hover:shadow-md",
className,
)}
>
{/* Hidden audio element */}
<audio ref={audioRef} src={src} preload="metadata">
<track kind="captions" srcLang="en" label="English captions" default />
</audio>
<div className="flex gap-4">
{/* Artwork */}
<div className="relative shrink-0">
<div className="relative size-20 overflow-hidden rounded-lg bg-gradient-to-br from-primary/20 to-primary/5 shadow-inner">
{artwork ? (
<Image
src={artwork}
alt={title}
fill
className="object-cover"
unoptimized
/>
) : (
<div className="flex size-full items-center justify-center">
<Volume2Icon className="size-8 text-primary/50" />
</div>
)}
{/* Play overlay on artwork */}
<button
type="button"
onClick={togglePlayPause}
className="absolute inset-0 flex items-center justify-center bg-black/0 opacity-0 transition-all group-hover:bg-black/30 group-hover:opacity-100"
aria-label={isPlaying ? "Pause" : "Play"}
>
{isPlaying ? (
<PauseIcon className="size-8 text-white drop-shadow-lg" />
) : (
<PlayIcon className="size-8 text-white drop-shadow-lg" />
)}
</button>
</div>
</div>
{/* Content */}
<div className="flex min-w-0 flex-1 flex-col justify-between">
{/* Title and description */}
<div className="min-w-0">
<h3 className="truncate font-semibold text-foreground">{title}</h3>
{description && (
<p className="mt-0.5 line-clamp-1 text-muted-foreground text-sm">
{description}
</p>
)}
</div>
{/* Progress bar */}
<div className="mt-2 space-y-1">
<Slider
value={[currentTime]}
max={duration || 100}
step={0.1}
onValueChange={handleSeek}
className="cursor-pointer"
disabled={isLoading}
/>
<div className="flex justify-between text-muted-foreground text-xs">
<span>{formatTime(currentTime)}</span>
<span>{formatTime(duration)}</span>
</div>
</div>
</div>
</div>
{/* Controls */}
<div className="mt-3 flex items-center justify-between border-t pt-3">
<div className="flex items-center gap-2">
{/* Play/Pause button */}
<Button
variant="default"
size="sm"
onClick={togglePlayPause}
disabled={isLoading}
className="gap-2"
>
{isLoading ? (
<div className="size-4 animate-spin rounded-full border-2 border-current border-t-transparent" />
) : isPlaying ? (
<PauseIcon className="size-4" />
) : (
<PlayIcon className="size-4" />
)}
{isPlaying ? "Pause" : "Play"}
</Button>
{/* Volume control */}
<div className="flex items-center gap-2">
<Button
variant="ghost"
size="icon"
onClick={toggleMute}
className="size-8"
>
{isMuted ? (
<VolumeXIcon className="size-4" />
) : (
<Volume2Icon className="size-4" />
)}
</Button>
<Slider
value={[isMuted ? 0 : volume]}
max={1}
step={0.01}
onValueChange={handleVolumeChange}
className="w-20"
/>
</div>
</div>
{/* Download button */}
<Button
variant="outline"
size="sm"
onClick={handleDownload}
className="gap-2"
>
<DownloadIcon className="size-4" />
Download
</Button>
</div>
</div>
);
}

View file

@ -0,0 +1,427 @@
"use client";
import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import { Audio } from "@/components/tool-ui/audio";
import { baseApiService } from "@/lib/apis/base-api.service";
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
import {
clearActivePodcastTaskId,
setActivePodcastTaskId,
} from "@/lib/chat/podcast-state";
import type { PodcastTranscriptEntry } from "@/contracts/types/podcast.types";
/**
* Type definitions for the generate_podcast tool
*/
interface GeneratePodcastArgs {
source_content: string;
podcast_title?: string;
user_prompt?: string;
}
interface GeneratePodcastResult {
status: "processing" | "already_generating" | "success" | "error";
task_id?: string;
podcast_id?: number;
title?: string;
transcript_entries?: number;
message?: string;
error?: string;
}
interface TaskStatusResponse {
status: "processing" | "success" | "error";
podcast_id?: number;
title?: string;
transcript_entries?: number;
state?: string;
error?: string;
}
/**
* Loading state component shown while podcast is being generated
*/
function PodcastGeneratingState({ title }: { title: string }) {
return (
<div className="my-4 overflow-hidden rounded-xl border border-primary/20 bg-gradient-to-br from-primary/5 to-primary/10 p-6">
<div className="flex items-center gap-4">
<div className="relative">
<div className="flex size-16 items-center justify-center rounded-full bg-primary/20">
<MicIcon className="size-8 text-primary" />
</div>
{/* Animated rings */}
<div className="absolute inset-1 animate-ping rounded-full bg-primary/20" />
</div>
<div className="flex-1">
<h3 className="font-semibold text-foreground text-lg">{title}</h3>
<div className="mt-2 flex items-center gap-2 text-muted-foreground">
<Loader2Icon className="size-4 animate-spin" />
<span className="text-sm">Generating podcast... This may take a few minutes</span>
</div>
<div className="mt-3">
<div className="h-1.5 w-full overflow-hidden rounded-full bg-primary/10">
<div className="h-full w-1/3 animate-pulse rounded-full bg-primary" />
</div>
</div>
</div>
</div>
</div>
);
}
/**
* Error state component shown when podcast generation fails
*/
function PodcastErrorState({ title, error }: { title: string; error: string }) {
return (
<div className="my-4 overflow-hidden rounded-xl border border-destructive/20 bg-destructive/5 p-6">
<div className="flex items-center gap-4">
<div className="flex size-16 shrink-0 items-center justify-center rounded-full bg-destructive/10">
<AlertCircleIcon className="size-8 text-destructive" />
</div>
<div className="flex-1">
<h3 className="font-semibold text-foreground">{title}</h3>
<p className="mt-1 text-destructive text-sm">Failed to generate podcast</p>
<p className="mt-2 text-muted-foreground text-sm">{error}</p>
</div>
</div>
</div>
);
}
/**
* Audio loading state component
*/
function AudioLoadingState({ title }: { title: string }) {
return (
<div className="my-4 overflow-hidden rounded-xl border bg-muted/30 p-6">
<div className="flex items-center gap-4">
<div className="flex size-16 items-center justify-center rounded-full bg-primary/10">
<MicIcon className="size-8 text-primary/50" />
</div>
<div className="flex-1">
<h3 className="font-semibold text-foreground">{title}</h3>
<div className="mt-2 flex items-center gap-2 text-muted-foreground">
<Loader2Icon className="size-4 animate-spin" />
<span className="text-sm">Loading audio...</span>
</div>
</div>
</div>
</div>
);
}
/**
* Podcast Player Component - Fetches audio and transcript with authentication
*/
function PodcastPlayer({
podcastId,
title,
description,
durationMs,
}: {
podcastId: number;
title: string;
description: string;
durationMs?: number;
}) {
const [audioSrc, setAudioSrc] = useState<string | null>(null);
const [transcript, setTranscript] = useState<PodcastTranscriptEntry[] | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const objectUrlRef = useRef<string | null>(null);
// Cleanup object URL on unmount
useEffect(() => {
return () => {
if (objectUrlRef.current) {
URL.revokeObjectURL(objectUrlRef.current);
}
};
}, []);
// Fetch audio and podcast details (including transcript)
const loadPodcast = useCallback(async () => {
setIsLoading(true);
setError(null);
try {
// Revoke previous object URL if exists
if (objectUrlRef.current) {
URL.revokeObjectURL(objectUrlRef.current);
objectUrlRef.current = null;
}
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 60000); // 60s timeout
try {
// Fetch audio blob and podcast details in parallel
const [audioBlob, podcastDetails] = await Promise.all([
podcastsApiService.loadPodcast({
request: { id: podcastId },
controller,
}),
podcastsApiService.getPodcastById(podcastId),
]);
// Create object URL from blob
const objectUrl = URL.createObjectURL(audioBlob);
objectUrlRef.current = objectUrl;
setAudioSrc(objectUrl);
// Set transcript from podcast details
if (podcastDetails?.podcast_transcript) {
setTranscript(podcastDetails.podcast_transcript);
}
} finally {
clearTimeout(timeoutId);
}
} catch (err) {
console.error("Error loading podcast:", err);
if (err instanceof DOMException && err.name === "AbortError") {
setError("Request timed out. Please try again.");
} else {
setError(err instanceof Error ? err.message : "Failed to load podcast");
}
} finally {
setIsLoading(false);
}
}, [podcastId]);
// Load podcast when component mounts
useEffect(() => {
loadPodcast();
}, [loadPodcast]);
if (isLoading) {
return <AudioLoadingState title={title} />;
}
if (error || !audioSrc) {
return <PodcastErrorState title={title} error={error || "Failed to load audio"} />;
}
return (
<div className="my-4">
<Audio
id={`podcast-${podcastId}`}
src={audioSrc}
title={title}
description={description}
durationMs={durationMs}
className="w-full"
/>
{/* Transcript section */}
{transcript && transcript.length > 0 && (
<details className="mt-3 rounded-lg border bg-muted/30 p-3">
<summary className="cursor-pointer font-medium text-muted-foreground text-sm hover:text-foreground">
View transcript ({transcript.length} entries)
</summary>
<div className="mt-3 space-y-3 max-h-96 overflow-y-auto">
{transcript.map((entry, idx) => (
<div key={`${idx}-${entry.speaker_id}`} className="text-sm">
<span className="font-medium text-primary">
Speaker {entry.speaker_id + 1}:
</span>{" "}
<span className="text-muted-foreground">{entry.dialog}</span>
</div>
))}
</div>
</details>
)}
</div>
);
}
/**
* Polling component that checks task status and shows player when complete
*/
function PodcastTaskPoller({
taskId,
title,
}: {
taskId: string;
title: string;
}) {
const [taskStatus, setTaskStatus] = useState<TaskStatusResponse>({ status: "processing" });
const pollingRef = useRef<NodeJS.Timeout | null>(null);
// Set active podcast state when this component mounts
useEffect(() => {
setActivePodcastTaskId(taskId);
// Clear when component unmounts
return () => {
// Only clear if this task is still the active one
clearActivePodcastTaskId();
};
}, [taskId]);
// Poll for task status
useEffect(() => {
const pollStatus = async () => {
try {
const response = await baseApiService.get<TaskStatusResponse>(
`/api/v1/podcasts/task/${taskId}/status`
);
setTaskStatus(response);
// Stop polling if task is complete or errored
if (response.status !== "processing") {
if (pollingRef.current) {
clearInterval(pollingRef.current);
pollingRef.current = null;
}
// Clear the active podcast state when task completes
clearActivePodcastTaskId();
}
} catch (err) {
console.error("Error polling task status:", err);
// Don't stop polling on network errors, continue polling
}
};
// Initial poll
pollStatus();
// Poll every 5 seconds
pollingRef.current = setInterval(pollStatus, 5000);
return () => {
if (pollingRef.current) {
clearInterval(pollingRef.current);
}
};
}, [taskId]);
// Show loading state while processing
if (taskStatus.status === "processing") {
return <PodcastGeneratingState title={title} />;
}
// Show error state
if (taskStatus.status === "error") {
return <PodcastErrorState title={title} error={taskStatus.error || "Generation failed"} />;
}
// Show player when complete
if (taskStatus.status === "success" && taskStatus.podcast_id) {
return (
<PodcastPlayer
podcastId={taskStatus.podcast_id}
title={taskStatus.title || title}
description={
taskStatus.transcript_entries
? `${taskStatus.transcript_entries} dialogue entries`
: "SurfSense AI-generated podcast"
}
/>
);
}
// Fallback
return <PodcastErrorState title={title} error="Unexpected state" />;
}
/**
* Generate Podcast Tool UI Component
*
* This component is registered with assistant-ui to render custom UI
* when the generate_podcast tool is called by the agent.
*
* It polls for task completion and auto-updates when the podcast is ready.
*/
export const GeneratePodcastToolUI = makeAssistantToolUI<
GeneratePodcastArgs,
GeneratePodcastResult
>({
toolName: "generate_podcast",
render: function GeneratePodcastUI({ args, result, status }) {
const title = args.podcast_title || "SurfSense Podcast";
// Loading state - tool is still running (agent processing)
if (status.type === "running" || status.type === "requires-action") {
return <PodcastGeneratingState title={title} />;
}
// Incomplete/cancelled state
if (status.type === "incomplete") {
if (status.reason === "cancelled") {
return (
<div className="my-4 rounded-xl border border-muted p-4 text-muted-foreground">
<p className="flex items-center gap-2">
<MicIcon className="size-4" />
<span className="line-through">Podcast generation cancelled</span>
</p>
</div>
);
}
if (status.reason === "error") {
return (
<PodcastErrorState
title={title}
error={typeof status.error === "string" ? status.error : "An error occurred"}
/>
);
}
}
// No result yet
if (!result) {
return <PodcastGeneratingState title={title} />;
}
// Error result
if (result.status === "error") {
return <PodcastErrorState title={title} error={result.error || "Unknown error"} />;
}
// Already generating - show simple warning, don't create another poller
// The FIRST tool call will display the podcast when ready
if (result.status === "already_generating") {
return (
<div className="my-4 overflow-hidden rounded-xl border border-amber-500/20 bg-amber-500/5 p-4">
<div className="flex items-center gap-3">
<div className="flex size-10 shrink-0 items-center justify-center rounded-full bg-amber-500/20">
<MicIcon className="size-5 text-amber-500" />
</div>
<div>
<p className="text-amber-600 dark:text-amber-400 text-sm font-medium">
Podcast already in progress
</p>
<p className="text-muted-foreground text-xs mt-0.5">
Please wait for the current podcast to complete.
</p>
</div>
</div>
</div>
);
}
// Processing - poll for completion
if (result.status === "processing" && result.task_id) {
return <PodcastTaskPoller taskId={result.task_id} title={result.title || title} />;
}
// Success with podcast_id (direct result, not via polling)
if (result.status === "success" && result.podcast_id) {
return (
<PodcastPlayer
podcastId={result.podcast_id}
title={result.title || title}
description={
result.transcript_entries
? `${result.transcript_entries} dialogue entries`
: "SurfSense AI-generated podcast"
}
/>
);
}
// Fallback - missing required data
return <PodcastErrorState title={title} error="Missing task ID or podcast ID" />;
},
});

View file

@ -0,0 +1,11 @@
/**
* Tool UI Components
*
* This module exports custom UI components for assistant tools.
* These components are registered with assistant-ui to render
* rich UI when specific tools are called by the agent.
*/
export { Audio } from "./audio";
export { GeneratePodcastToolUI } from "./generate-podcast";

View file

@ -1,12 +1,17 @@
import { z } from "zod"; import { z } from "zod";
import { paginationQueryParams } from "."; import { paginationQueryParams } from ".";
export const podcastTranscriptEntry = z.object({
speaker_id: z.number(),
dialog: z.string(),
});
export const podcast = z.object({ export const podcast = z.object({
id: z.number(), id: z.number(),
title: z.string(), title: z.string(),
created_at: z.string(), created_at: z.string(),
file_location: z.string(), file_location: z.string(),
podcast_transcript: z.array(z.any()), podcast_transcript: z.array(podcastTranscriptEntry),
search_space_id: z.number(), search_space_id: z.number(),
chat_state_version: z.number().nullable(), chat_state_version: z.number().nullable(),
}); });
@ -41,6 +46,7 @@ export const getPodcastsRequest = z.object({
queryParams: paginationQueryParams.nullish(), queryParams: paginationQueryParams.nullish(),
}); });
export type PodcastTranscriptEntry = z.infer<typeof podcastTranscriptEntry>;
export type GeneratePodcastRequest = z.infer<typeof generatePodcastRequest>; export type GeneratePodcastRequest = z.infer<typeof generatePodcastRequest>;
export type GetPodcastByChatIdRequest = z.infer<typeof getPodcastByChatIdRequest>; export type GetPodcastByChatIdRequest = z.infer<typeof getPodcastByChatIdRequest>;
export type GetPodcastByChatIdResponse = z.infer<typeof getPodcastByChaIdResponse>; export type GetPodcastByChatIdResponse = z.infer<typeof getPodcastByChaIdResponse>;

View file

@ -62,6 +62,13 @@ class PodcastsApiService {
); );
}; };
/**
* Get a podcast by its ID (includes full transcript)
*/
getPodcastById = async (podcastId: number) => {
return baseApiService.get(`/api/v1/podcasts/${podcastId}`, podcast);
};
generatePodcast = async (request: GeneratePodcastRequest) => { generatePodcast = async (request: GeneratePodcastRequest) => {
// Validate the request // Validate the request
const parsedRequest = generatePodcastRequest.safeParse(request); const parsedRequest = generatePodcastRequest.safeParse(request);

View file

@ -4,7 +4,13 @@
*/ */
import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react"; import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react";
import { toast } from "sonner";
import { getBearerToken } from "@/lib/auth-utils"; import { getBearerToken } from "@/lib/auth-utils";
import {
isPodcastGenerating,
looksLikePodcastRequest,
setActivePodcastTaskId,
} from "@/lib/chat/podcast-state";
interface NewChatAdapterConfig { interface NewChatAdapterConfig {
searchSpaceId: number; searchSpaceId: number;
@ -40,6 +46,22 @@ function convertMessagesToBackendFormat(
.filter((m) => m.content.length > 0); // Filter out empty messages .filter((m) => m.content.length > 0); // Filter out empty messages
} }
/**
* Represents an in-progress or completed tool call
*/
interface ToolCallState {
toolCallId: string;
toolName: string;
args: Record<string, unknown>;
result?: unknown;
}
/**
* Tools that should render custom UI in the chat.
* Other tools (like search_knowledge_base) will be hidden from the UI.
*/
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
/** /**
* Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint. * Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint.
* *
@ -72,6 +94,21 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
throw new Error("User query cannot be empty"); throw new Error("User query cannot be empty");
} }
// Check if user is requesting a podcast while one is already generating
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
toast.warning("A podcast is already being generated. Please wait for it to complete.");
// Return a message telling the user to wait
yield {
content: [
{
type: "text",
text: "A podcast is already being generated. Please wait for it to complete before requesting another one.",
},
],
};
return;
}
const token = getBearerToken(); const token = getBearerToken();
if (!token) { if (!token) {
throw new Error("Not authenticated. Please log in again."); throw new Error("Not authenticated. Please log in again.");
@ -110,6 +147,41 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
let buffer = ""; let buffer = "";
let accumulatedText = ""; let accumulatedText = "";
// Track tool calls by their ID
const toolCalls = new Map<string, ToolCallState>();
/**
* Build the content array with text and tool calls.
* Only includes tools that have custom UI (defined in TOOLS_WITH_UI).
*/
function buildContent() {
const content: Array<
| { type: "text"; text: string }
| { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown }
> = [];
// Add text content if any
if (accumulatedText) {
content.push({ type: "text" as const, text: accumulatedText });
}
// Only add tool calls that have custom UI registered
// Other tools (like search_knowledge_base) are hidden from the UI
for (const toolCall of toolCalls.values()) {
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
content.push({
type: "tool-call" as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
args: toolCall.args,
result: toolCall.result,
});
}
}
return content;
}
try { try {
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
@ -146,16 +218,70 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
switch (parsed.type) { switch (parsed.type) {
case "text-delta": case "text-delta":
accumulatedText += parsed.delta; accumulatedText += parsed.delta;
yield { yield { content: buildContent() };
content: [{ type: "text" as const, text: accumulatedText }],
};
break; break;
case "tool-input-start": {
// Tool call is starting - create a new tool call entry
const { toolCallId, toolName } = parsed;
toolCalls.set(toolCallId, {
toolCallId,
toolName,
args: {},
});
// Yield to show tool is starting (running state)
yield { content: buildContent() };
break;
}
case "tool-input-available": {
// Tool input is complete - update the args
const { toolCallId, toolName, input } = parsed;
const existing = toolCalls.get(toolCallId);
if (existing) {
existing.args = input || {};
} else {
// Create new entry if we missed tool-input-start
toolCalls.set(toolCallId, {
toolCallId,
toolName,
args: input || {},
});
}
yield { content: buildContent() };
break;
}
case "tool-output-available": {
// Tool execution is complete - add the result
const { toolCallId, output } = parsed;
const existing = toolCalls.get(toolCallId);
if (existing) {
existing.result = output;
// If this is a podcast tool with status="processing", set the state immediately
// This ensures subsequent podcast requests are intercepted
if (
existing.toolName === "generate_podcast" &&
output &&
typeof output === "object" &&
"status" in output &&
output.status === "processing" &&
"task_id" in output &&
typeof output.task_id === "string"
) {
setActivePodcastTaskId(output.task_id);
}
}
yield { content: buildContent() };
break;
}
case "error": case "error":
throw new Error(parsed.errorText || "Unknown error from server"); throw new Error(parsed.errorText || "Unknown error from server");
// Other types like text-start, text-end, tool-*, etc. // Other types like text-start, text-end, start-step, finish-step, etc.
// are handled implicitly - we just accumulate text deltas // are handled implicitly
default: default:
break; break;
} }
@ -181,9 +307,27 @@ export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAda
const parsed = JSON.parse(data); const parsed = JSON.parse(data);
if (parsed.type === "text-delta") { if (parsed.type === "text-delta") {
accumulatedText += parsed.delta; accumulatedText += parsed.delta;
yield { yield { content: buildContent() };
content: [{ type: "text" as const, text: accumulatedText }], } else if (parsed.type === "tool-output-available") {
}; const { toolCallId, output } = parsed;
const existing = toolCalls.get(toolCallId);
if (existing) {
existing.result = output;
// Set podcast state if processing
if (
existing.toolName === "generate_podcast" &&
output &&
typeof output === "object" &&
"status" in output &&
output.status === "processing" &&
"task_id" in output &&
typeof output.task_id === "string"
) {
setActivePodcastTaskId(output.task_id);
}
}
yield { content: buildContent() };
} }
} catch { } catch {
// Ignore parse errors // Ignore parse errors

View file

@ -0,0 +1,74 @@
/**
* Module-level state for tracking active podcast generation.
* Used by the new-chat adapter to prevent duplicate podcast requests.
*/
type PodcastStateListener = (isGenerating: boolean) => void;
let _activePodcastTaskId: string | null = null;
const _listeners: Set<PodcastStateListener> = new Set();
/**
* Check if a podcast is currently being generated
*/
export function isPodcastGenerating(): boolean {
return _activePodcastTaskId !== null;
}
/**
* Get the active podcast task ID
*/
export function getActivePodcastTaskId(): string | null {
return _activePodcastTaskId;
}
/**
* Set the active podcast task ID (when podcast generation starts)
*/
export function setActivePodcastTaskId(taskId: string): void {
_activePodcastTaskId = taskId;
notifyListeners();
}
/**
* Clear the active podcast task ID (when podcast generation completes or errors)
*/
export function clearActivePodcastTaskId(): void {
_activePodcastTaskId = null;
notifyListeners();
}
/**
* Subscribe to podcast state changes
*/
export function subscribeToPodcastState(listener: PodcastStateListener): () => void {
_listeners.add(listener);
return () => {
_listeners.delete(listener);
};
}
function notifyListeners(): void {
const isGenerating = _activePodcastTaskId !== null;
for (const listener of _listeners) {
listener(isGenerating);
}
}
/**
* Check if a message looks like a podcast request
*/
export function looksLikePodcastRequest(message: string): boolean {
const podcastPatterns = [
/\bpodcast\b/i,
/\bcreate.*podcast\b/i,
/\bgenerate.*podcast\b/i,
/\bmake.*podcast\b/i,
/\bturn.*into.*podcast\b/i,
/\bpodcast.*about\b/i,
/\bgive.*podcast\b/i,
];
return podcastPatterns.some((pattern) => pattern.test(message));
}