Merge pull request #430 from CREDO23/feat/chat-pannel

[Feature] Add the chat panel
This commit is contained in:
Rohan Verma 2025-11-11 17:04:39 -08:00 committed by GitHub
commit 0835a192a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 1219 additions and 72 deletions

View file

@ -18,6 +18,7 @@ class Configuration:
podcast_title: str
user_id: str
search_space_id: int
user_prompt: str | None = None
@classmethod
def from_runnable_config(

View file

@ -29,6 +29,7 @@ async def create_podcast_transcript(
configuration = Configuration.from_runnable_config(config)
user_id = configuration.user_id
search_space_id = configuration.search_space_id
user_prompt = configuration.user_prompt
# Get user's long context LLM
llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
@ -38,7 +39,7 @@ async def create_podcast_transcript(
raise RuntimeError(error_message)
# Get the prompt
prompt = get_podcast_generation_prompt()
prompt = get_podcast_generation_prompt(user_prompt)
# Create the messages
messages = [

View file

@ -1,12 +1,23 @@
import datetime
def get_podcast_generation_prompt():
def get_podcast_generation_prompt(user_prompt: str | None = None):
return f"""
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
<podcast_generation_system>
You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery.
{
f'''
You **MUST** strictly adhere to the following user instruction while generating the podcast script:
<user_instruction>
{user_prompt}
</user_instruction>
'''
if user_prompt
else ""
}
<input>
- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue.
</input>

View file

@ -9,6 +9,7 @@ from sqlalchemy import (
ARRAY,
JSON,
TIMESTAMP,
BigInteger,
Boolean,
Column,
Enum as SQLAlchemyEnum,
@ -157,6 +158,7 @@ class Chat(BaseModel, TimestampMixin):
title = Column(String, nullable=False, index=True)
initial_connectors = Column(ARRAY(String), nullable=True)
messages = Column(JSON, nullable=False)
state_version = Column(BigInteger, nullable=False, default=1)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
@ -203,6 +205,10 @@ class Podcast(BaseModel, TimestampMixin):
title = Column(String, nullable=False, index=True)
podcast_transcript = Column(JSON, nullable=False, default={})
file_location = Column(String(500), nullable=False, default="")
chat_id = Column(
Integer, ForeignKey("chats.id", ondelete="CASCADE"), nullable=True
) # If generated from a chat, this will be the chat id, else null ( can be from a document or a chat )
chat_state_version = Column(BigInteger, nullable=True)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False

View file

@ -199,6 +199,7 @@ async def read_chats(
Chat.initial_connectors,
Chat.search_space_id,
Chat.created_at,
Chat.state_version,
)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
@ -261,7 +262,10 @@ async def update_chat(
db_chat = await read_chat(chat_id, session, user)
update_data = chat_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
if key == "messages":
db_chat.state_version = len(update_data["messages"])
setattr(db_chat, key, value)
await session.commit()
await session.refresh(db_chat)
return db_chat

View file

@ -155,7 +155,11 @@ async def delete_podcast(
async def generate_chat_podcast_with_new_session(
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
"""Create a new session and process chat podcast generation."""
from app.db import async_session_maker
@ -163,7 +167,7 @@ async def generate_chat_podcast_with_new_session(
async with async_session_maker() as session:
try:
await generate_chat_podcast(
session, chat_id, search_space_id, podcast_title, user_id
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
)
except Exception as e:
import logging
@ -211,7 +215,11 @@ async def generate_podcast(
# Add Celery tasks for each chat ID
for chat_id in valid_chat_ids:
generate_chat_podcast_task.delay(
chat_id, request.search_space_id, request.podcast_title, user.id
chat_id,
request.search_space_id,
user.id,
request.podcast_title,
request.user_prompt,
)
return {
@ -287,3 +295,27 @@ async def stream_podcast(
raise HTTPException(
status_code=500, detail=f"Error streaming podcast: {e!s}"
) from e
@router.get("/podcasts/by-chat/{chat_id}", response_model=PodcastRead | None)
async def get_podcast_by_chat_id(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
# Get the podcast and check if user has access
result = await session.execute(
select(Podcast)
.join(SearchSpace)
.filter(Podcast.chat_id == chat_id, SearchSpace.user_id == user.id)
)
podcast = result.scalars().first()
return podcast
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error fetching podcast: {e!s}"
) from e

View file

@ -13,12 +13,14 @@ class ChatBase(BaseModel):
initial_connectors: list[str] | None = None
messages: list[Any]
search_space_id: int
state_version: int = 1
class ChatBaseWithoutMessages(BaseModel):
type: ChatType
title: str
search_space_id: int
state_version: int = 1
class ClientAttachment(BaseModel):

View file

@ -10,6 +10,7 @@ class PodcastBase(BaseModel):
podcast_transcript: list[Any]
file_location: str = ""
search_space_id: int
chat_state_version: int | None = None
class PodcastCreate(PodcastBase):
@ -28,4 +29,5 @@ class PodcastGenerateRequest(BaseModel):
type: Literal["DOCUMENT", "CHAT"]
ids: list[int]
search_space_id: int
podcast_title: str = "SurfSense Podcast"
podcast_title: str | None = None
user_prompt: str | None = None

View file

@ -38,7 +38,12 @@ def get_celery_session_maker():
@celery_app.task(name="generate_chat_podcast", bind=True)
def generate_chat_podcast_task(
self, chat_id: int, search_space_id: int, podcast_title: str, user_id: int
self,
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
"""
Celery task to generate podcast from chat.
@ -46,15 +51,18 @@ def generate_chat_podcast_task(
Args:
chat_id: ID of the chat to generate podcast from
search_space_id: ID of the search space
user_id: ID of the user,
podcast_title: Title for the podcast
user_id: ID of the user
user_prompt: Optional prompt from the user to guide the podcast generation
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_generate_chat_podcast(chat_id, search_space_id, podcast_title, user_id)
_generate_chat_podcast(
chat_id, search_space_id, user_id, podcast_title, user_prompt
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
@ -63,13 +71,17 @@ def generate_chat_podcast_task(
async def _generate_chat_podcast(
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
"""Generate chat podcast with new session."""
async with get_celery_session_maker()() as session:
try:
await generate_chat_podcast(
session, chat_id, search_space_id, podcast_title, user_id
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
)
except Exception as e:
logger.error(f"Error generating podcast from chat: {e!s}")

View file

@ -19,8 +19,9 @@ async def generate_chat_podcast(
session: AsyncSession,
chat_id: int,
search_space_id: int,
podcast_title: str,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
task_logger = TaskLoggingService(session, search_space_id)
@ -34,6 +35,7 @@ async def generate_chat_podcast(
"search_space_id": search_space_id,
"podcast_title": podcast_title,
"user_id": str(user_id),
"user_prompt": user_prompt,
},
)
@ -96,9 +98,10 @@ async def generate_chat_podcast(
config = {
"configurable": {
"podcast_title": "SurfSense",
"podcast_title": podcast_title or "SurfSense Podcast",
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize state with database session and streaming service
@ -139,33 +142,49 @@ async def generate_chat_podcast(
},
)
podcast = Podcast(
title=f"{podcast_title}",
podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"],
search_space_id=search_space_id,
# check if podcast already exists for this chat (re-generation)
existing_podcast = await session.execute(
select(Podcast).filter(Podcast.chat_id == chat_id)
)
existing_podcast = existing_podcast.scalars().first()
# Add to session and commit
session.add(podcast)
await session.commit()
await session.refresh(podcast)
if existing_podcast:
existing_podcast.podcast_transcript = serializable_transcript
existing_podcast.file_location = result["final_podcast_file_path"]
existing_podcast.chat_state_version = chat.state_version
await session.commit()
await session.refresh(existing_podcast)
return existing_podcast
else:
podcast = Podcast(
title=f"{podcast_title}",
podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"],
search_space_id=search_space_id,
chat_state_version=chat.state_version,
chat_id=chat.id,
)
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully generated podcast for chat {chat_id}",
{
"podcast_id": podcast.id,
"podcast_title": podcast_title,
"transcript_entries": len(serializable_transcript),
"file_location": result.get("final_podcast_file_path"),
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
# Add to session and commit
session.add(podcast)
await session.commit()
await session.refresh(podcast)
return podcast
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully generated podcast for chat {chat_id}",
{
"podcast_id": podcast.id,
"podcast_title": podcast_title,
"transcript_entries": len(serializable_transcript),
"file_location": result.get("final_podcast_file_path"),
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
return podcast
except ValueError as ve:
# ValueError is already logged above for chat not found