import hashlib import logging import warnings import numpy as np from litellm import get_model_info, token_counter from app.config import config from app.db import Chunk, DocumentType from app.prompts import SUMMARY_PROMPT_TEMPLATE logger = logging.getLogger(__name__) def _get_embedding_max_tokens() -> int: """Get the max token limit for the configured embedding model. Checks model properties in order: max_seq_length, _max_tokens. Falls back to 8192 (OpenAI embedding default). """ model = config.embedding_model_instance for attr in ("max_seq_length", "_max_tokens"): val = getattr(model, attr, None) if isinstance(val, int) and val > 0: return val return 8192 def truncate_for_embedding(text: str) -> str: """Truncate text to fit within the embedding model's context window. Uses the embedding model's own tokenizer for accurate token counting, so the result is model-agnostic regardless of the underlying provider. """ max_tokens = _get_embedding_max_tokens() if len(text) // 3 <= max_tokens: return text tokenizer = config.embedding_model_instance.get_tokenizer() tokens = tokenizer.encode(text) if len(tokens) <= max_tokens: return text warnings.warn( f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.", stacklevel=2, ) return tokenizer.decode(tokens[:max_tokens]) def embed_text(text: str) -> np.ndarray: """Truncate text to fit and embed it. Drop-in replacement for ``config.embedding_model_instance.embed(text)`` that never exceeds the model's context window.""" return config.embedding_model_instance.embed(truncate_for_embedding(text)) def embed_texts(texts: list[str]) -> list[np.ndarray]: """Batch-embed multiple texts in a single call. Each text is truncated to fit the model's context window before embedding. For API-based models (``://`` in the model string) this uses ``embed_batch`` to collapse many network round-trips into one. For local models (SentenceTransformers) it falls back to sequential ``embed`` calls to avoid padding overhead. """ if not texts: return [] truncated = [truncate_for_embedding(t) for t in texts] if config.is_local_embedding_model: return [config.embedding_model_instance.embed(t) for t in truncated] return config.embedding_model_instance.embed_batch(truncated) def get_model_context_window(model_name: str) -> int: """Get the total context window size for a model (input + output tokens).""" try: model_info = get_model_info(model_name) context_window = model_info.get("max_input_tokens") # Handle case where key exists but value is None if context_window is None: print( f"Warning: max_input_tokens is None for {model_name}, using default 4096 tokens." ) return 4096 # Conservative fallback return context_window except Exception as e: print( f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}" ) return 4096 # Conservative fallback def optimize_content_for_context_window( content: str, document_metadata: dict | None, model_name: str ) -> str: """ Optimize content length to fit within model context window using binary search. Args: content: Original document content document_metadata: Optional metadata dictionary model_name: Model name for token counting Returns: Optimized content that fits within context window """ if not content: return content # Get model context window context_window = get_model_context_window(model_name) # Reserve tokens for: system prompt, metadata, template overhead, and output # Conservative estimate: 2000 tokens for prompt + metadata + output buffer # TODO: Calculate Summary System Prompt Token Count Here reserved_tokens = 2000 # Add metadata token cost if present if document_metadata: metadata_text = ( f"\n\n{document_metadata}\n\n" ) metadata_tokens = token_counter( messages=[{"role": "user", "content": metadata_text}], model=model_name ) reserved_tokens += metadata_tokens available_tokens = context_window - reserved_tokens if available_tokens <= 100: # Minimum viable content print(f"Warning: Very limited tokens available for content: {available_tokens}") return content[:500] # Fallback to first 500 chars # Binary search to find optimal content length left, right = 0, len(content) optimal_length = 0 while left <= right: mid = (left + right) // 2 test_content = content[:mid] # Test token count for this content length test_document = f"\n\n{test_content}\n\n" test_tokens = token_counter( messages=[{"role": "user", "content": test_document}], model=model_name ) if test_tokens <= available_tokens: optimal_length = mid left = mid + 1 else: right = mid - 1 optimized_content = ( content[:optimal_length] if optimal_length > 0 else content[:500] ) if optimal_length < len(content): print( f"Content optimized: {len(content)} -> {optimal_length} chars " f"to fit in {available_tokens} available tokens" ) return optimized_content async def generate_document_summary( content: str, user_llm, document_metadata: dict | None = None, ) -> tuple[str, list[float]]: """ Generate summary and embedding for document content with metadata. Args: content: Document content user_llm: User's LLM instance document_metadata: Optional metadata dictionary to include in summary Returns: Tuple of (enhanced_summary_content, summary_embedding) """ # Get model name from user_llm for token counting model_name = getattr(user_llm, "model", "gpt-3.5-turbo") # Fallback to default # Optimize content to fit within context window optimized_content = optimize_content_for_context_window( content, document_metadata, model_name ) summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm content_with_metadata = f"\n\n{document_metadata}\n\n\n\n\n\n{optimized_content}\n\n" summary_result = await summary_chain.ainvoke({"document": content_with_metadata}) summary_content = summary_result.content # Combine summary with metadata if provided if document_metadata: metadata_parts = [] metadata_parts.append("# DOCUMENT METADATA") for key, value in document_metadata.items(): if value: # Only include non-empty values formatted_key = key.replace("_", " ").title() metadata_parts.append(f"**{formatted_key}:** {value}") metadata_section = "\n".join(metadata_parts) enhanced_summary_content = ( f"{metadata_section}\n\n# DOCUMENT SUMMARY\n\n{summary_content}" ) else: enhanced_summary_content = summary_content summary_embedding = embed_text(enhanced_summary_content) return enhanced_summary_content, summary_embedding async def create_document_chunks(content: str) -> list[Chunk]: """ Create chunks from document content. Args: content: Document content to chunk Returns: List of Chunk objects with embeddings """ chunk_texts = [c.text for c in config.chunker_instance.chunk(content)] chunk_embeddings = embed_texts(chunk_texts) return [ Chunk(content=text, embedding=emb) for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) ] async def convert_element_to_markdown(element) -> str: """ Convert an Unstructured element to markdown format based on its category. Args: element: The Unstructured API element object Returns: str: Markdown formatted string """ element_category = element.metadata["category"] content = element.page_content if not content: return "" markdown_mapping = { "Formula": lambda x: f"```math\n{x}\n```", "FigureCaption": lambda x: f"*Figure: {x}*", "NarrativeText": lambda x: f"{x}\n\n", "ListItem": lambda x: f"- {x}\n", "Title": lambda x: f"# {x}\n\n", "Address": lambda x: f"> {x}\n\n", "EmailAddress": lambda x: f"`{x}`", "Image": lambda x: f"![{x}]({x})", "PageBreak": lambda x: "\n---\n", "Table": lambda x: f"```html\n{element.metadata['text_as_html']}\n```", "Header": lambda x: f"## {x}\n\n", "Footer": lambda x: f"*{x}*\n\n", "CodeSnippet": lambda x: f"```\n{x}\n```", "PageNumber": lambda x: f"*Page {x}*\n\n", "UncategorizedText": lambda x: f"{x}\n\n", } converter = markdown_mapping.get(element_category, lambda x: x) return converter(content) async def convert_document_to_markdown(elements): """ Convert all document elements to markdown. Args: elements: List of Unstructured API elements Returns: str: Complete markdown document """ markdown_parts = [] for element in elements: markdown_text = await convert_element_to_markdown(element) if markdown_text: markdown_parts.append(markdown_text) return "".join(markdown_parts) def generate_content_hash(content: str, search_space_id: int) -> str: """Generate SHA-256 hash for the given content combined with search space ID.""" combined_data = f"{search_space_id}:{content}" return hashlib.sha256(combined_data.encode("utf-8")).hexdigest() def generate_unique_identifier_hash( document_type: DocumentType, unique_identifier: str | int | float, search_space_id: int, ) -> str: """ Generate SHA-256 hash for a unique document identifier from connector sources. This function creates a consistent hash based on the document type, its unique identifier from the source system, and the search space ID. This helps prevent duplicate documents when syncing from various connectors like Slack, Notion, Jira, etc. Args: document_type: The type of document (e.g., SLACK_CONNECTOR, NOTION_CONNECTOR) unique_identifier: The unique ID from the source system (e.g., message ID, page ID) search_space_id: The search space this document belongs to Returns: str: SHA-256 hash string representing the unique document identifier Example: >>> generate_unique_identifier_hash( ... DocumentType.SLACK_CONNECTOR, ... "1234567890.123456", ... 42 ... ) 'a1b2c3d4e5f6...' """ # Convert unique_identifier to string to handle different types identifier_str = str(unique_identifier) # Combine document type value, unique identifier, and search space ID combined_data = f"{document_type.value}:{identifier_str}:{search_space_id}" return hashlib.sha256(combined_data.encode("utf-8")).hexdigest()