mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 09:46:25 +02:00
fix: Fix rerank_documents node in sub_section_writer & qna_agent
This commit is contained in:
parent
3e1db2ac6b
commit
671984acbd
6 changed files with 2708 additions and 1923 deletions
|
|
@ -1,7 +1,18 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
section_title: str = Field(..., description="The title of the section")
|
||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||
|
||||
class AnswerOutline(BaseModel):
|
||||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
|
|
@ -9,6 +20,40 @@ class DocumentTokenInfo(NamedTuple):
|
|||
document: Dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
"YOUTUBE_VIDEO": "📹",
|
||||
"EXTENSION": "🧩",
|
||||
"CRAWLED_URL": "🌐",
|
||||
"FILE": "📄",
|
||||
"SLACK_CONNECTOR": "💬",
|
||||
"NOTION_CONNECTOR": "📘",
|
||||
"GITHUB_CONNECTOR": "🐙",
|
||||
"LINEAR_CONNECTOR": "📊",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗"
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
|
||||
def get_connector_friendly_name(connector_name: str) -> str:
|
||||
"""Convert technical connector IDs to user-friendly names."""
|
||||
connector_friendly_names = {
|
||||
"YOUTUBE_VIDEO": "YouTube",
|
||||
"EXTENSION": "Browser Extension",
|
||||
"CRAWLED_URL": "Web Pages",
|
||||
"FILE": "Files",
|
||||
"SLACK_CONNECTOR": "Slack",
|
||||
"NOTION_CONNECTOR": "Notion",
|
||||
"GITHUB_CONNECTOR": "GitHub",
|
||||
"LINEAR_CONNECTOR": "Linear",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search"
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue