mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 09:12:40 +02:00
Fixed all ruff lint and formatting errors
This commit is contained in:
parent
0a03c42cc5
commit
d359a59f6d
85 changed files with 5520 additions and 3870 deletions
|
|
@ -1,27 +1,37 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from litellm import get_model_info, token_counter
|
||||
from pydantic import BaseModel, Field
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
section_title: str = Field(..., description="The title of the section")
|
||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||
questions: list[str] = Field(
|
||||
..., description="Questions to research for this section"
|
||||
)
|
||||
|
||||
|
||||
class AnswerOutline(BaseModel):
|
||||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
answer_outline: list[Section] = Field(
|
||||
..., description="List of sections in the answer outline"
|
||||
)
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
|
||||
index: int
|
||||
document: Dict[str, Any]
|
||||
document: dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
|
|
@ -34,7 +44,7 @@ def get_connector_emoji(connector_name: str) -> str:
|
|||
"GITHUB_CONNECTOR": "🐙",
|
||||
"LINEAR_CONNECTOR": "📊",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗"
|
||||
"LINKUP_API": "🔗",
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
|
|
@ -51,31 +61,26 @@ def get_connector_friendly_name(connector_name: str) -> str:
|
|||
"GITHUB_CONNECTOR": "GitHub",
|
||||
"LINEAR_CONNECTOR": "Linear",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search"
|
||||
"LINKUP_API": "Linkup Search",
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
def convert_langchain_messages_to_dict(
|
||||
messages: list[BaseMessage],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert LangChain messages to format expected by token_counter."""
|
||||
role_mapping = {
|
||||
'system': 'system',
|
||||
'human': 'user',
|
||||
'ai': 'assistant'
|
||||
}
|
||||
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
|
||||
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = role_mapping.get(getattr(msg, 'type', None), 'user')
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": str(msg.content)
|
||||
})
|
||||
role = role_mapping.get(getattr(msg, "type", None), "user")
|
||||
converted_messages.append({"role": role, "content": str(msg.content)})
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def format_document_for_citation(document: Dict[str, Any]) -> str:
|
||||
def format_document_for_citation(document: dict[str, Any]) -> str:
|
||||
"""Format a single document for citation in the standard XML format."""
|
||||
content = document.get("content", "")
|
||||
doc_info = document.get("document", {})
|
||||
|
|
@ -93,7 +98,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
|
|||
</document>"""
|
||||
|
||||
|
||||
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
|
||||
def format_documents_section(
|
||||
documents: list[dict[str, Any]], section_title: str = "Source material"
|
||||
) -> str:
|
||||
"""Format multiple documents into a complete documents section."""
|
||||
if not documents:
|
||||
return ""
|
||||
|
|
@ -106,7 +113,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
|
|||
</documents>"""
|
||||
|
||||
|
||||
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
|
||||
def calculate_document_token_costs(
|
||||
documents: list[dict[str, Any]], model: str
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Pre-calculate token costs for each document."""
|
||||
document_token_info = []
|
||||
|
||||
|
|
@ -115,24 +124,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
|
|||
|
||||
# Calculate token count for this document
|
||||
token_count = token_counter(
|
||||
messages=[{"role": "user", "content": formatted_doc}],
|
||||
model=model
|
||||
messages=[{"role": "user", "content": formatted_doc}], model=model
|
||||
)
|
||||
|
||||
document_token_info.append(DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count
|
||||
))
|
||||
document_token_info.append(
|
||||
DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
return document_token_info
|
||||
|
||||
|
||||
def find_optimal_documents_with_binary_search(
|
||||
document_tokens: List[DocumentTokenInfo],
|
||||
available_tokens: int
|
||||
) -> List[DocumentTokenInfo]:
|
||||
document_tokens: list[DocumentTokenInfo], available_tokens: int
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
||||
if not document_tokens or available_tokens <= 0:
|
||||
return []
|
||||
|
|
@ -143,8 +152,7 @@ def find_optimal_documents_with_binary_search(
|
|||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
current_docs = document_tokens[:mid]
|
||||
current_token_sum = sum(
|
||||
doc_info.token_count for doc_info in current_docs)
|
||||
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
|
||||
|
||||
if current_token_sum <= available_tokens:
|
||||
optimal_docs = current_docs
|
||||
|
|
@ -159,20 +167,18 @@ def get_model_context_window(model_name: str) -> int:
|
|||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
try:
|
||||
model_info = get_model_info(model_name)
|
||||
context_window = model_info.get(
|
||||
'max_input_tokens', 4096) # Default fallback
|
||||
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
||||
return context_window
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
||||
)
|
||||
return 4096 # Conservative fallback
|
||||
|
||||
|
||||
def optimize_documents_for_token_limit(
|
||||
documents: List[Dict[str, Any]],
|
||||
base_messages: List[BaseMessage],
|
||||
model_name: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
|
||||
) -> tuple[list[dict[str, Any]], bool]:
|
||||
"""
|
||||
Optimize documents to fit within token limits using binary search.
|
||||
|
||||
|
|
@ -197,7 +203,8 @@ def optimize_documents_for_token_limit(
|
|||
available_tokens_for_docs = context_window - base_tokens
|
||||
|
||||
print(
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
|
||||
)
|
||||
|
||||
if available_tokens_for_docs <= 0:
|
||||
print("No tokens available for documents after base content and output buffer")
|
||||
|
|
@ -208,8 +215,7 @@ def optimize_documents_for_token_limit(
|
|||
|
||||
# Find optimal number of documents using binary search
|
||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
||||
document_token_info,
|
||||
available_tokens_for_docs
|
||||
document_token_info, available_tokens_for_docs
|
||||
)
|
||||
|
||||
# Extract the original document objects
|
||||
|
|
@ -217,12 +223,13 @@ def optimize_documents_for_token_limit(
|
|||
has_documents_remaining = len(optimized_documents) > 0
|
||||
|
||||
print(
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
|
||||
)
|
||||
|
||||
return optimized_documents, has_documents_remaining
|
||||
|
||||
|
||||
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
|
||||
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
|
||||
"""Calculate token count for a list of LangChain messages."""
|
||||
model = model_name
|
||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue