Fixed all ruff lint and formatting errors

This commit is contained in:
Utkarsh-Patel-13 2025-07-24 14:43:48 -07:00
parent 0a03c42cc5
commit d359a59f6d
85 changed files with 5520 additions and 3870 deletions

View file

@ -1,27 +1,37 @@
from typing import List, Dict, Any, Tuple, NamedTuple
from typing import Any, NamedTuple
from langchain_core.messages import BaseMessage
from litellm import get_model_info, token_counter
from pydantic import BaseModel, Field
from litellm import token_counter, get_model_info
class Section(BaseModel):
"""A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section")
section_title: str = Field(..., description="The title of the section")
questions: List[str] = Field(..., description="Questions to research for this section")
questions: list[str] = Field(
..., description="Questions to research for this section"
)
class AnswerOutline(BaseModel):
"""The complete answer outline with all sections."""
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
answer_outline: list[Section] = Field(
..., description="List of sections in the answer outline"
)
class DocumentTokenInfo(NamedTuple):
"""Information about a document and its token cost."""
index: int
document: Dict[str, Any]
document: dict[str, Any]
formatted_content: str
token_count: int
def get_connector_emoji(connector_name: str) -> str:
"""Get an appropriate emoji for a connector type."""
connector_emojis = {
@ -34,7 +44,7 @@ def get_connector_emoji(connector_name: str) -> str:
"GITHUB_CONNECTOR": "🐙",
"LINEAR_CONNECTOR": "📊",
"TAVILY_API": "🔍",
"LINKUP_API": "🔗"
"LINKUP_API": "🔗",
}
return connector_emojis.get(connector_name, "🔎")
@ -51,31 +61,26 @@ def get_connector_friendly_name(connector_name: str) -> str:
"GITHUB_CONNECTOR": "GitHub",
"LINEAR_CONNECTOR": "Linear",
"TAVILY_API": "Tavily Search",
"LINKUP_API": "Linkup Search"
"LINKUP_API": "Linkup Search",
}
return connector_friendly_names.get(connector_name, connector_name)
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
def convert_langchain_messages_to_dict(
messages: list[BaseMessage],
) -> list[dict[str, str]]:
"""Convert LangChain messages to format expected by token_counter."""
role_mapping = {
'system': 'system',
'human': 'user',
'ai': 'assistant'
}
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
converted_messages = []
for msg in messages:
role = role_mapping.get(getattr(msg, 'type', None), 'user')
converted_messages.append({
"role": role,
"content": str(msg.content)
})
role = role_mapping.get(getattr(msg, "type", None), "user")
converted_messages.append({"role": role, "content": str(msg.content)})
return converted_messages
def format_document_for_citation(document: Dict[str, Any]) -> str:
def format_document_for_citation(document: dict[str, Any]) -> str:
"""Format a single document for citation in the standard XML format."""
content = document.get("content", "")
doc_info = document.get("document", {})
@ -93,7 +98,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
</document>"""
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
def format_documents_section(
documents: list[dict[str, Any]], section_title: str = "Source material"
) -> str:
"""Format multiple documents into a complete documents section."""
if not documents:
return ""
@ -106,7 +113,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
</documents>"""
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
def calculate_document_token_costs(
documents: list[dict[str, Any]], model: str
) -> list[DocumentTokenInfo]:
"""Pre-calculate token costs for each document."""
document_token_info = []
@ -115,24 +124,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
# Calculate token count for this document
token_count = token_counter(
messages=[{"role": "user", "content": formatted_doc}],
model=model
messages=[{"role": "user", "content": formatted_doc}], model=model
)
document_token_info.append(DocumentTokenInfo(
index=i,
document=doc,
formatted_content=formatted_doc,
token_count=token_count
))
document_token_info.append(
DocumentTokenInfo(
index=i,
document=doc,
formatted_content=formatted_doc,
token_count=token_count,
)
)
return document_token_info
def find_optimal_documents_with_binary_search(
document_tokens: List[DocumentTokenInfo],
available_tokens: int
) -> List[DocumentTokenInfo]:
document_tokens: list[DocumentTokenInfo], available_tokens: int
) -> list[DocumentTokenInfo]:
"""Use binary search to find the maximum number of documents that fit within token limit."""
if not document_tokens or available_tokens <= 0:
return []
@ -143,8 +152,7 @@ def find_optimal_documents_with_binary_search(
while left <= right:
mid = (left + right) // 2
current_docs = document_tokens[:mid]
current_token_sum = sum(
doc_info.token_count for doc_info in current_docs)
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
if current_token_sum <= available_tokens:
optimal_docs = current_docs
@ -159,20 +167,18 @@ def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
try:
model_info = get_model_info(model_name)
context_window = model_info.get(
'max_input_tokens', 4096) # Default fallback
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
return context_window
except Exception as e:
print(
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
)
return 4096 # Conservative fallback
def optimize_documents_for_token_limit(
documents: List[Dict[str, Any]],
base_messages: List[BaseMessage],
model_name: str
) -> Tuple[List[Dict[str, Any]], bool]:
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
) -> tuple[list[dict[str, Any]], bool]:
"""
Optimize documents to fit within token limits using binary search.
@ -197,7 +203,8 @@ def optimize_documents_for_token_limit(
available_tokens_for_docs = context_window - base_tokens
print(
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
)
if available_tokens_for_docs <= 0:
print("No tokens available for documents after base content and output buffer")
@ -208,8 +215,7 @@ def optimize_documents_for_token_limit(
# Find optimal number of documents using binary search
optimal_doc_info = find_optimal_documents_with_binary_search(
document_token_info,
available_tokens_for_docs
document_token_info, available_tokens_for_docs
)
# Extract the original document objects
@ -217,12 +223,13 @@ def optimize_documents_for_token_limit(
has_documents_remaining = len(optimized_documents) > 0
print(
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
)
return optimized_documents, has_documents_remaining
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
"""Calculate token count for a list of LangChain messages."""
model = model_name
messages_dict = convert_langchain_messages_to_dict(messages)