mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: pass thread_id through podcast generation chain
This commit is contained in:
parent
062998738a
commit
7017a14107
5 changed files with 13 additions and 1 deletions
|
|
@ -35,6 +35,7 @@ async def create_surfsense_deep_agent(
|
|||
connector_service: ConnectorService,
|
||||
checkpointer: Checkpointer,
|
||||
user_id: str | None = None,
|
||||
thread_id: int | None = None,
|
||||
agent_config: AgentConfig | None = None,
|
||||
enabled_tools: list[str] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
|
|
@ -123,6 +124,7 @@ async def create_surfsense_deep_agent(
|
|||
"connector_service": connector_service,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"user_id": user_id, # Required for memory tools
|
||||
"thread_id": thread_id, # For podcast tool
|
||||
}
|
||||
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ def clear_active_podcast_task(search_space_id: int) -> None:
|
|||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_podcast tool with injected dependencies.
|
||||
|
|
@ -76,6 +77,7 @@ def create_generate_podcast_tool(
|
|||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session (not used - Celery creates its own)
|
||||
thread_id: The chat thread ID for associating the podcast
|
||||
|
||||
Returns:
|
||||
A configured tool function for generating podcasts
|
||||
|
|
@ -145,6 +147,7 @@ def create_generate_podcast_tool(
|
|||
search_space_id=search_space_id,
|
||||
podcast_title=podcast_title,
|
||||
user_prompt=user_prompt,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Mark this task as active
|
||||
|
|
|
|||
|
|
@ -102,8 +102,9 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
factory=lambda deps: create_generate_podcast_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
thread_id=deps["thread_id"],
|
||||
),
|
||||
requires=["search_space_id", "db_session"],
|
||||
requires=["search_space_id", "db_session", "thread_id"],
|
||||
),
|
||||
# Link preview tool - fetches Open Graph metadata for URLs
|
||||
ToolDefinition(
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ def generate_content_podcast_task(
|
|||
search_space_id: int,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Celery task to generate podcast from source content (for new-chat).
|
||||
|
|
@ -78,6 +79,7 @@ def generate_content_podcast_task(
|
|||
search_space_id: ID of the search space
|
||||
podcast_title: Title for the podcast
|
||||
user_prompt: Optional instructions for podcast style/tone
|
||||
thread_id: Optional ID of the chat thread that generated this podcast
|
||||
|
||||
Returns:
|
||||
dict with podcast_id on success, or error info on failure
|
||||
|
|
@ -92,6 +94,7 @@ def generate_content_podcast_task(
|
|||
search_space_id,
|
||||
podcast_title,
|
||||
user_prompt,
|
||||
thread_id,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
@ -111,6 +114,7 @@ async def _generate_content_podcast(
|
|||
search_space_id: int,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Generate content-based podcast with new session."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
|
|
@ -158,6 +162,7 @@ async def _generate_content_podcast(
|
|||
podcast_transcript=serializable_transcript,
|
||||
file_location=file_path,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -255,6 +255,7 @@ async def stream_new_chat(
|
|||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id, # Pass user ID for memory tools
|
||||
thread_id=chat_id, # Pass chat ID for podcast association
|
||||
agent_config=agent_config, # Pass prompt configuration
|
||||
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue