feat: SurfSense v0.0.6 init

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-03-14 18:53:14 -07:00
parent 18fc19e8d9
commit da23012970
58 changed files with 8284 additions and 2076 deletions

View file

@ -0,0 +1,12 @@
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import User
# Helper function to check user ownership
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
item = item.scalars().first()
if not item:
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
return item

View file

@ -0,0 +1,385 @@
import json
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType
from tavily import TavilyClient
class ConnectorService:
def __init__(self, session: AsyncSession):
self.session = session
self.retriever = ChucksHybridSearchRetriever(session)
self.source_id_counter = 1
async def search_crawled_urls(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for crawled URLs and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
crawled_urls_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="CRAWLED_URL"
)
# Map crawled_urls_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(crawled_urls_chunks):
#Fix for UI
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry
source = {
"id": self.source_id_counter,
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
}
self.source_id_counter += 1
# Use a unique identifier for tracking unique sources
source_key = source.get("url") or source.get("title")
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 1,
"name": "Crawled URLs",
"type": "CRAWLED_URL",
"sources": sources_list,
}
return result_object, crawled_urls_chunks
async def search_files(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for files and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
files_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="FILE"
)
# Map crawled_urls_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(files_chunks):
#Fix for UI
files_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry
source = {
"id": self.source_id_counter,
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
}
self.source_id_counter += 1
# Use a unique identifier for tracking unique sources
source_key = source.get("url") or source.get("title")
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 2,
"name": "Files",
"type": "FILE",
"sources": sources_list,
}
return result_object, files_chunks
async def get_connector_by_type(self, user_id: int, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
"""
Get a connector by type for a specific user
Args:
user_id: The user's ID
connector_type: The connector type to retrieve
Returns:
Optional[SearchSourceConnector]: The connector if found, None otherwise
"""
result = await self.session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type
)
)
return result.scalars().first()
async def search_tavily(self, user_query: str, user_id: int, top_k: int = 20) -> tuple:
"""
Search using Tavily API and return both the source information and documents
Args:
user_query: The user's query
user_id: The user's ID
top_k: Maximum number of results to return
Returns:
tuple: (sources_info, documents)
"""
# Get Tavily connector configuration
tavily_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.TAVILY_API)
if not tavily_connector:
# Return empty results if no Tavily connector is configured
return {
"id": 3,
"name": "Tavily Search",
"type": "TAVILY_API",
"sources": [],
}, []
# Initialize Tavily client with API key from connector config
tavily_api_key = tavily_connector.config.get("TAVILY_API_KEY")
tavily_client = TavilyClient(api_key=tavily_api_key)
# Perform search with Tavily
try:
response = tavily_client.search(
query=user_query,
max_results=top_k,
search_depth="advanced" # Use advanced search for better results
)
# Extract results from Tavily response
tavily_results = response.get("results", [])
# Map Tavily results to the required format
sources_list = []
documents = []
# Start IDs from 1000 to avoid conflicts with other connectors
base_id = 100
for i, result in enumerate(tavily_results):
# Create a source entry
source = {
"id": self.source_id_counter,
"title": result.get("title", "Tavily Result"),
"description": result.get("content", "")[:100],
"url": result.get("url", "")
}
sources_list.append(source)
# Create a document entry
document = {
"chunk_id": f"tavily_chunk_{i}",
"content": result.get("content", ""),
"score": result.get("score", 0.0),
"document": {
"id": self.source_id_counter,
"title": result.get("title", "Tavily Result"),
"document_type": "TAVILY_API",
"metadata": {
"url": result.get("url", ""),
"published_date": result.get("published_date", ""),
"source": "TAVILY_API"
}
}
}
documents.append(document)
self.source_id_counter += 1
# Create result object
result_object = {
"id": 3,
"name": "Tavily Search",
"type": "TAVILY_API",
"sources": sources_list,
}
return result_object, documents
except Exception as e:
# Log the error and return empty results
print(f"Error searching with Tavily: {str(e)}")
return {
"id": 3,
"name": "Tavily Search",
"type": "TAVILY_API",
"sources": [],
}, []
async def search_slack(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for slack and return both the source information and langchain documents
Returns:
tuple: (sources_info, langchain_documents)
"""
slack_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="SLACK_CONNECTOR"
)
# Map slack_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(slack_chunks):
#Fix for UI
slack_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry with Slack-specific metadata
channel_name = metadata.get('channel_name', 'Unknown Channel')
channel_id = metadata.get('channel_id', '')
message_date = metadata.get('start_date', '')
# Create a more descriptive title for Slack messages
title = f"Slack: {channel_name}"
if message_date:
title += f" ({message_date})"
# Create a more descriptive description for Slack messages
description = chunk.get('content', '')[:100]
if len(description) == 100:
description += "..."
# For URL, we can use a placeholder or construct a URL to the Slack channel if available
url = ""
if channel_id:
url = f"https://slack.com/app_redirect?channel={channel_id}"
source = {
"id": self.source_id_counter,
"title": title,
"description": description,
"url": url,
}
self.source_id_counter += 1
# Use channel_id and content as a unique identifier for tracking unique sources
source_key = f"{channel_id}_{chunk.get('chunk_id', i)}"
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 4,
"name": "Slack",
"type": "SLACK_CONNECTOR",
"sources": sources_list,
}
return result_object, slack_chunks
async def search_notion(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
"""
Search for Notion pages and return both the source information and langchain documents
Args:
user_query: The user's query
user_id: The user's ID
search_space_id: The search space ID to search in
top_k: Maximum number of results to return
Returns:
tuple: (sources_info, langchain_documents)
"""
notion_chunks = await self.retriever.hybrid_search(
query_text=user_query,
top_k=top_k,
user_id=user_id,
search_space_id=search_space_id,
document_type="NOTION_CONNECTOR"
)
# Map notion_chunks to the required format
mapped_sources = {}
for i, chunk in enumerate(notion_chunks):
# Fix for UI
notion_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a mapped source entry with Notion-specific metadata
page_title = metadata.get('page_title', 'Untitled Page')
page_id = metadata.get('page_id', '')
indexed_at = metadata.get('indexed_at', '')
# Create a more descriptive title for Notion pages
title = f"Notion: {page_title}"
if indexed_at:
title += f" (indexed: {indexed_at})"
# Create a more descriptive description for Notion pages
description = chunk.get('content', '')[:100]
if len(description) == 100:
description += "..."
# For URL, we can use a placeholder or construct a URL to the Notion page if available
url = ""
if page_id:
# Notion page URLs follow this format
url = f"https://notion.so/{page_id.replace('-', '')}"
source = {
"id": self.source_id_counter,
"title": title,
"description": description,
"url": url,
}
self.source_id_counter += 1
# Use page_id and content as a unique identifier for tracking unique sources
source_key = f"{page_id}_{chunk.get('chunk_id', i)}"
if source_key and source_key not in mapped_sources:
mapped_sources[source_key] = source
# Convert to list of sources
sources_list = list(mapped_sources.values())
# Create result object
result_object = {
"id": 5,
"name": "Notion",
"type": "NOTION_CONNECTOR",
"sources": sources_list,
}
return result_object, notion_chunks

View file

@ -0,0 +1,136 @@
async def convert_element_to_markdown(element) -> str:
"""
Convert an Unstructured element to markdown format based on its category.
Args:
element: The Unstructured API element object
Returns:
str: Markdown formatted string
"""
element_category = element.metadata["category"]
content = element.page_content
if not content:
return ""
markdown_mapping = {
"Formula": lambda x: f"```math\n{x}\n```",
"FigureCaption": lambda x: f"*Figure: {x}*",
"NarrativeText": lambda x: f"{x}\n\n",
"ListItem": lambda x: f"- {x}\n",
"Title": lambda x: f"# {x}\n\n",
"Address": lambda x: f"> {x}\n\n",
"EmailAddress": lambda x: f"`{x}`",
"Image": lambda x: f"![{x}]({x})",
"PageBreak": lambda x: "\n---\n",
"Table": lambda x: f"```html\n{element.metadata['text_as_html']}\n```",
"Header": lambda x: f"## {x}\n\n",
"Footer": lambda x: f"*{x}*\n\n",
"CodeSnippet": lambda x: f"```\n{x}\n```",
"PageNumber": lambda x: f"*Page {x}*\n\n",
"UncategorizedText": lambda x: f"{x}\n\n"
}
converter = markdown_mapping.get(element_category, lambda x: x)
return converter(content)
async def convert_document_to_markdown(elements):
"""
Convert all document elements to markdown.
Args:
elements: List of Unstructured API elements
Returns:
str: Complete markdown document
"""
markdown_parts = []
for element in elements:
markdown_text = await convert_element_to_markdown(element)
if markdown_text:
markdown_parts.append(markdown_text)
return "".join(markdown_parts)
def convert_chunks_to_langchain_documents(chunks):
"""
Convert chunks from hybrid search results to LangChain Document objects.
Args:
chunks: List of chunk dictionaries from hybrid search results
Returns:
List of LangChain Document objects
"""
try:
from langchain_core.documents import Document as LangChainDocument
except ImportError:
raise ImportError(
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
)
langchain_docs = []
for chunk in chunks:
# Extract content from the chunk
content = chunk.get("content", "")
# Create metadata dictionary
metadata = {
"chunk_id": chunk.get("chunk_id"),
"score": chunk.get("score"),
"rank": chunk.get("rank") if "rank" in chunk else None,
}
# Add document information to metadata
if "document" in chunk:
doc = chunk["document"]
metadata.update({
"document_id": doc.get("id"),
"document_title": doc.get("title"),
"document_type": doc.get("document_type"),
})
# Add document metadata if available
if "metadata" in doc:
# Prefix document metadata keys to avoid conflicts
doc_metadata = {f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()}
metadata.update(doc_metadata)
# Add source URL if available in metadata
if "url" in doc.get("metadata", {}):
metadata["source"] = doc["metadata"]["url"]
elif "sourceURL" in doc.get("metadata", {}):
metadata["source"] = doc["metadata"]["sourceURL"]
# Ensure source_id is set for citation purposes
# Use document_id as the source_id if available
if "document_id" in metadata:
metadata["source_id"] = metadata["document_id"]
# Update content for citation mode - format as XML with explicit source_id
new_content = f"""
<document>
<metadata>
<source_id>{metadata.get("source_id", metadata.get("document_id", "unknown"))}</source_id>
</metadata>
<content>
<text>
{content}
</text>
</content>
</document>
"""
# Create LangChain Document
langchain_doc = LangChainDocument(
page_content=new_content,
metadata=metadata
)
langchain_docs.append(langchain_doc)
return langchain_docs

View file

@ -0,0 +1,95 @@
import logging
from typing import List, Dict, Any, Optional
from rerankers import Document as RerankerDocument
class RerankerService:
"""
Service for reranking documents using a configured reranker
"""
def __init__(self, reranker_instance=None):
"""
Initialize the reranker service
Args:
reranker_instance: The reranker instance to use for reranking
"""
self.reranker_instance = reranker_instance
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Rerank documents using the configured reranker
Args:
query_text: The query text to use for reranking
documents: List of document dictionaries to rerank
Returns:
List[Dict[str, Any]]: Reranked documents
"""
if not self.reranker_instance or not documents:
return documents
try:
# Create Document objects for the rerankers library
reranker_docs = []
for i, doc in enumerate(documents):
chunk_id = doc.get("chunk_id", f"chunk_{i}")
content = doc.get("content", "")
score = doc.get("score", 0.0)
document_info = doc.get("document", {})
reranker_docs.append(
RerankerDocument(
text=content,
doc_id=chunk_id,
metadata={
'document_id': document_info.get("id", ""),
'document_title': document_info.get("title", ""),
'document_type': document_info.get("document_type", ""),
'rrf_score': score
}
)
)
# Rerank using the configured reranker
reranking_results = self.reranker_instance.rank(
query=query_text,
docs=reranker_docs
)
# Process the results from the reranker
# Convert to serializable dictionaries
serialized_results = []
for result in reranking_results.results:
# Find the original document by id
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
if original_doc:
# Create a new document with the reranked score
reranked_doc = original_doc.copy()
reranked_doc["score"] = float(result.score)
reranked_doc["rank"] = result.rank
serialized_results.append(reranked_doc)
return serialized_results
except Exception as e:
# Log the error
logging.error(f"Error during reranking: {str(e)}")
# Fall back to original documents without reranking
return documents
@staticmethod
def get_reranker_instance(config=None) -> Optional['RerankerService']:
"""
Get a reranker service instance based on configuration
Args:
config: Configuration object that may contain a reranker_instance
Returns:
Optional[RerankerService]: A reranker service instance or None
"""
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
return RerankerService(config.reranker_instance)
return None

View file

@ -0,0 +1,211 @@
import asyncio
import re
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
from langchain.schema import Document
from gpt_researcher.agent import GPTResearcher
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
from dotenv import load_dotenv
load_dotenv()
class ResearchService:
@staticmethod
async def create_custom_prompt(user_query: str) -> str:
citation_prompt = f"""
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
<instructions>
1. Carefully analyze all provided documents in the <document> section's.
2. Extract relevant information that addresses the user's query.
3. Synthesize a comprehensive, well-structured answer using information from these documents.
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
5. Make sure ALL factual statements from the documents have proper citations.
6. If multiple documents support the same point, include all relevant citations [X], [Y].
7. Present information in a logical, coherent flow.
8. Use your own words to connect ideas, but cite ALL information from the documents.
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
10. Do not make up or include information not found in the provided documents.
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
14. CRITICAL: Do not return citations as clickable links.
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
</instructions>
<format>
- Write in clear, professional language suitable for academic or technical audiences
- Organize your response with appropriate paragraphs, headings, and structure
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
- Citations should appear at the end of the sentence containing the information they support
- Multiple citations should be separated by commas: [X], [Y], [Z]
- No need to return references section. Just citation numbers in answer.
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
</format>
<input_example>
<document>
<metadata>
<source_id>1</source_id>
</metadata>
<content>
<text>
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
</text>
</content>
</document>
<document>
<metadata>
<source_id>13</source_id>
</metadata>
<content>
<text>
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
</text>
</content>
</document>
<document>
<metadata>
<source_id>21</source_id>
</metadata>
<content>
<text>
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
</text>
</content>
</document>
</input_example>
<output_example>
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
</output_example>
<incorrect_citation_formats>
DO NOT use any of these incorrect citation formats:
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
- Using parentheses around brackets: ([1])
- Using hyperlinked text: [link to source 1](https://example.com)
- Using footnote style: ... reef system¹
- Making up citation numbers when source_id is unknown
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
</incorrect_citation_formats>
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
Now, please research the following query:
<user_query_to_research>
{user_query}
</user_query_to_research>
"""
return citation_prompt
@staticmethod
async def stream_research(
user_query: str,
documents: List[Document] = None,
on_progress: Optional[Callable] = None,
research_mode: str = "GENERAL"
) -> str:
"""
Stream the research process using GPTResearcher
Args:
user_query: The user's query
documents: List of Document objects to use for research
on_progress: Optional callback for progress updates
research_mode: Research mode to use
Returns:
str: The final research report
"""
# Create a custom websocket-like object to capture streaming output
class StreamingWebsocket:
async def send_json(self, data):
if on_progress:
try:
# Filter out excessive logging of the prompt
if data.get("type") == "logs":
output = data.get("output", "")
# Check if this is a verbose prompt log
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
# Replace with a shorter message
data["output"] = f"Processing research for query: {user_query}"
result = await on_progress(data)
return result
except Exception as e:
print(f"Error in on_progress callback: {e}")
return None
streaming_websocket = StreamingWebsocket()
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
if(research_mode == "GENERAL"):
research_report_type = ReportType.CustomReport.value
elif(research_mode == "DEEP"):
research_report_type = ReportType.ResearchReport.value
elif(research_mode == "DEEPER"):
research_report_type = ReportType.DetailedReport.value
# elif(research_mode == "DEEPEST"):
# research_report_type = ReportType.DeepResearch.value
# Initialize GPTResearcher with the streaming websocket
researcher = GPTResearcher(
query=custom_prompt_for_ieee_citations,
report_type=research_report_type,
report_format="IEEE",
report_source=ReportSource.LangChainDocuments.value,
tone=Tone.Formal,
documents=documents,
verbose=True,
websocket=streaming_websocket
)
# Conduct research
await researcher.conduct_research()
# Generate report with streaming
report = await researcher.write_report()
# Fix citation format
report = ResearchService.fix_citation_format(report)
return report
@staticmethod
def fix_citation_format(text: str) -> str:
"""
Fix any incorrectly formatted citations in the text.
Args:
text: The text to fix
Returns:
str: The text with fixed citations
"""
if not text:
return text
# More specific pattern to match only numeric citations in markdown-style links
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
# Replace with just [X] where X is the number
text = re.sub(pattern, r'[\1]', text)
# Also match other incorrect formats like ([1]) and convert to [1]
# Only match if the content inside brackets is a number
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
return text

View file

@ -0,0 +1,99 @@
import json
from typing import List, Dict, Any, Generator
class StreamingService:
def __init__(self):
self.terminal_idx = 1
self.message_annotations = [
{
"type": "TERMINAL_INFO",
"content": []
},
{
"type": "SOURCES",
"content": []
},
{
"type": "ANSWER",
"content": []
}
]
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
"""
Add a terminal message to the annotations and return the formatted response
Args:
text: The message text
message_type: The message type (info, success, error)
Returns:
str: The formatted response string
"""
self.message_annotations[0]["content"].append({
"id": self.terminal_idx,
"text": text,
"type": message_type
})
self.terminal_idx += 1
return self._format_annotations()
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
"""
Update the sources in the annotations and return the formatted response
Args:
sources: List of source objects
Returns:
str: The formatted response string
"""
self.message_annotations[1]["content"] = sources
return self._format_annotations()
def update_answer(self, answer_content: List[str]) -> str:
"""
Update the answer in the annotations and return the formatted response
Args:
answer_content: The answer content as a list of strings
Returns:
str: The formatted response string
"""
self.message_annotations[2] = {
"type": "ANSWER",
"content": answer_content
}
return self._format_annotations()
def _format_annotations(self) -> str:
"""
Format the annotations as a string
Returns:
str: The formatted annotations string
"""
return f'8:{json.dumps(self.message_annotations)}\n'
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
"""
Format a completion message
Args:
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
Returns:
str: The formatted completion string
"""
total_tokens = prompt_tokens + completion_tokens
completion_data = {
"finishReason": "stop",
"usage": {
"promptTokens": prompt_tokens,
"completionTokens": completion_tokens,
"totalTokens": total_tokens
}
}
return f'd:{json.dumps(completion_data)}\n'