feat: pass thread_id through podcast generation chain

This commit is contained in:
CREDO23 2026-01-26 15:56:34 +02:00
parent 062998738a
commit 7017a14107
5 changed files with 13 additions and 1 deletions

View file

@ -35,6 +35,7 @@ async def create_surfsense_deep_agent(
connector_service: ConnectorService,
checkpointer: Checkpointer,
user_id: str | None = None,
thread_id: int | None = None,
agent_config: AgentConfig | None = None,
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
@ -123,6 +124,7 @@ async def create_surfsense_deep_agent(
"connector_service": connector_service,
"firecrawl_api_key": firecrawl_api_key,
"user_id": user_id, # Required for memory tools
"thread_id": thread_id, # For podcast tool
}
# Build tools using the async registry (includes MCP tools)

View file

@ -69,6 +69,7 @@ def clear_active_podcast_task(search_space_id: int) -> None:
def create_generate_podcast_tool(
search_space_id: int,
db_session: AsyncSession,
thread_id: int | None = None,
):
"""
Factory function to create the generate_podcast tool with injected dependencies.
@ -76,6 +77,7 @@ def create_generate_podcast_tool(
Args:
search_space_id: The user's search space ID
db_session: Database session (not used - Celery creates its own)
thread_id: The chat thread ID for associating the podcast
Returns:
A configured tool function for generating podcasts
@ -145,6 +147,7 @@ def create_generate_podcast_tool(
search_space_id=search_space_id,
podcast_title=podcast_title,
user_prompt=user_prompt,
thread_id=thread_id,
)
# Mark this task as active

View file

@ -102,8 +102,9 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
factory=lambda deps: create_generate_podcast_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
thread_id=deps["thread_id"],
),
requires=["search_space_id", "db_session"],
requires=["search_space_id", "db_session", "thread_id"],
),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(

View file

@ -67,6 +67,7 @@ def generate_content_podcast_task(
search_space_id: int,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
thread_id: int | None = None,
) -> dict:
"""
Celery task to generate podcast from source content (for new-chat).
@ -78,6 +79,7 @@ def generate_content_podcast_task(
search_space_id: ID of the search space
podcast_title: Title for the podcast
user_prompt: Optional instructions for podcast style/tone
thread_id: Optional ID of the chat thread that generated this podcast
Returns:
dict with podcast_id on success, or error info on failure
@ -92,6 +94,7 @@ def generate_content_podcast_task(
search_space_id,
podcast_title,
user_prompt,
thread_id,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
@ -111,6 +114,7 @@ async def _generate_content_podcast(
search_space_id: int,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
thread_id: int | None = None,
) -> dict:
"""Generate content-based podcast with new session."""
async with get_celery_session_maker()() as session:
@ -158,6 +162,7 @@ async def _generate_content_podcast(
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(podcast)
await session.commit()

View file

@ -255,6 +255,7 @@ async def stream_new_chat(
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id, # Pass user ID for memory tools
thread_id=chat_id, # Pass chat ID for podcast association
agent_config=agent_config, # Pass prompt configuration
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
)