mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-30 19:36:25 +02:00
inject user instruction in the podcast generation task
This commit is contained in:
parent
9c959baadd
commit
2902fd6d28
7 changed files with 50 additions and 13 deletions
|
|
@ -18,6 +18,7 @@ class Configuration:
|
||||||
podcast_title: str
|
podcast_title: str
|
||||||
user_id: str
|
user_id: str
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
user_prompt: str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ async def create_podcast_transcript(
|
||||||
configuration = Configuration.from_runnable_config(config)
|
configuration = Configuration.from_runnable_config(config)
|
||||||
user_id = configuration.user_id
|
user_id = configuration.user_id
|
||||||
search_space_id = configuration.search_space_id
|
search_space_id = configuration.search_space_id
|
||||||
|
user_prompt = configuration.user_prompt
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
|
llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
|
||||||
|
|
@ -38,7 +39,7 @@ async def create_podcast_transcript(
|
||||||
raise RuntimeError(error_message)
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
# Get the prompt
|
# Get the prompt
|
||||||
prompt = get_podcast_generation_prompt()
|
prompt = get_podcast_generation_prompt(user_prompt)
|
||||||
|
|
||||||
# Create the messages
|
# Create the messages
|
||||||
messages = [
|
messages = [
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,23 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
def get_podcast_generation_prompt():
|
def get_podcast_generation_prompt(user_prompt: str | None = None):
|
||||||
return f"""
|
return f"""
|
||||||
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||||
<podcast_generation_system>
|
<podcast_generation_system>
|
||||||
You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery.
|
You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery.
|
||||||
|
|
||||||
|
{
|
||||||
|
f'''
|
||||||
|
You **MUST** strictly adhere to the following user instruction while generating the podcast script:
|
||||||
|
<user_instruction>
|
||||||
|
{user_prompt}
|
||||||
|
</user_instruction>
|
||||||
|
'''
|
||||||
|
if user_prompt
|
||||||
|
else ""
|
||||||
|
}
|
||||||
|
|
||||||
<input>
|
<input>
|
||||||
- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue.
|
- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue.
|
||||||
</input>
|
</input>
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,11 @@ async def delete_podcast(
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_podcast_with_new_session(
|
async def generate_chat_podcast_with_new_session(
|
||||||
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
|
chat_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: int,
|
||||||
|
podcast_title: str | None = None,
|
||||||
|
user_prompt: str | None = None,
|
||||||
):
|
):
|
||||||
"""Create a new session and process chat podcast generation."""
|
"""Create a new session and process chat podcast generation."""
|
||||||
from app.db import async_session_maker
|
from app.db import async_session_maker
|
||||||
|
|
@ -163,7 +167,7 @@ async def generate_chat_podcast_with_new_session(
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
try:
|
try:
|
||||||
await generate_chat_podcast(
|
await generate_chat_podcast(
|
||||||
session, chat_id, search_space_id, podcast_title, user_id
|
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -211,7 +215,11 @@ async def generate_podcast(
|
||||||
# Add Celery tasks for each chat ID
|
# Add Celery tasks for each chat ID
|
||||||
for chat_id in valid_chat_ids:
|
for chat_id in valid_chat_ids:
|
||||||
generate_chat_podcast_task.delay(
|
generate_chat_podcast_task.delay(
|
||||||
chat_id, request.search_space_id, request.podcast_title, user.id
|
chat_id,
|
||||||
|
request.search_space_id,
|
||||||
|
user.id,
|
||||||
|
request.podcast_title,
|
||||||
|
request.user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -29,4 +29,5 @@ class PodcastGenerateRequest(BaseModel):
|
||||||
type: Literal["DOCUMENT", "CHAT"]
|
type: Literal["DOCUMENT", "CHAT"]
|
||||||
ids: list[int]
|
ids: list[int]
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
podcast_title: str = "SurfSense Podcast"
|
podcast_title: str | None = None
|
||||||
|
user_prompt: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,12 @@ def get_celery_session_maker():
|
||||||
|
|
||||||
@celery_app.task(name="generate_chat_podcast", bind=True)
|
@celery_app.task(name="generate_chat_podcast", bind=True)
|
||||||
def generate_chat_podcast_task(
|
def generate_chat_podcast_task(
|
||||||
self, chat_id: int, search_space_id: int, podcast_title: str, user_id: int
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: int,
|
||||||
|
podcast_title: str | None = None,
|
||||||
|
user_prompt: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Celery task to generate podcast from chat.
|
Celery task to generate podcast from chat.
|
||||||
|
|
@ -46,15 +51,18 @@ def generate_chat_podcast_task(
|
||||||
Args:
|
Args:
|
||||||
chat_id: ID of the chat to generate podcast from
|
chat_id: ID of the chat to generate podcast from
|
||||||
search_space_id: ID of the search space
|
search_space_id: ID of the search space
|
||||||
|
user_id: ID of the user,
|
||||||
podcast_title: Title for the podcast
|
podcast_title: Title for the podcast
|
||||||
user_id: ID of the user
|
user_prompt: Optional prompt from the user to guide the podcast generation
|
||||||
"""
|
"""
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
loop.run_until_complete(
|
||||||
_generate_chat_podcast(chat_id, search_space_id, podcast_title, user_id)
|
_generate_chat_podcast(
|
||||||
|
chat_id, search_space_id, user_id, podcast_title, user_prompt
|
||||||
|
)
|
||||||
)
|
)
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -63,13 +71,17 @@ def generate_chat_podcast_task(
|
||||||
|
|
||||||
|
|
||||||
async def _generate_chat_podcast(
|
async def _generate_chat_podcast(
|
||||||
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
|
chat_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: int,
|
||||||
|
podcast_title: str | None = None,
|
||||||
|
user_prompt: str | None = None,
|
||||||
):
|
):
|
||||||
"""Generate chat podcast with new session."""
|
"""Generate chat podcast with new session."""
|
||||||
async with get_celery_session_maker()() as session:
|
async with get_celery_session_maker()() as session:
|
||||||
try:
|
try:
|
||||||
await generate_chat_podcast(
|
await generate_chat_podcast(
|
||||||
session, chat_id, search_space_id, podcast_title, user_id
|
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating podcast from chat: {e!s}")
|
logger.error(f"Error generating podcast from chat: {e!s}")
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,9 @@ async def generate_chat_podcast(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
podcast_title: str,
|
|
||||||
user_id: int,
|
user_id: int,
|
||||||
|
podcast_title: str | None = None,
|
||||||
|
user_prompt: str | None = None,
|
||||||
):
|
):
|
||||||
task_logger = TaskLoggingService(session, search_space_id)
|
task_logger = TaskLoggingService(session, search_space_id)
|
||||||
|
|
||||||
|
|
@ -34,6 +35,7 @@ async def generate_chat_podcast(
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
"podcast_title": podcast_title,
|
"podcast_title": podcast_title,
|
||||||
"user_id": str(user_id),
|
"user_id": str(user_id),
|
||||||
|
"user_prompt": user_prompt,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -96,9 +98,10 @@ async def generate_chat_podcast(
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"podcast_title": "SurfSense",
|
"podcast_title": podcast_title or "SurfSense Podcast",
|
||||||
"user_id": str(user_id),
|
"user_id": str(user_id),
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Initialize state with database session and streaming service
|
# Initialize state with database session and streaming service
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue