mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 17:56:25 +02:00
feat: added configable summary calculation and various improvements
- Replaced direct embedding calls with a utility function across various components to streamline embedding logic. - Added enable_summary flag to several models and routes to control summary generation behavior.
This commit is contained in:
parent
dc33a4a68f
commit
e9892c8fe9
50 changed files with 380 additions and 298 deletions
|
|
@ -1,11 +1,59 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from litellm import get_model_info, token_counter
|
||||
|
||||
from app.config import config
|
||||
from app.db import Chunk, DocumentType
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_embedding_max_tokens() -> int:
|
||||
"""Get the max token limit for the configured embedding model.
|
||||
|
||||
Checks model properties in order: max_seq_length, _max_tokens.
|
||||
Falls back to 8192 (OpenAI embedding default).
|
||||
"""
|
||||
model = config.embedding_model_instance
|
||||
for attr in ("max_seq_length", "_max_tokens"):
|
||||
val = getattr(model, attr, None)
|
||||
if isinstance(val, int) and val > 0:
|
||||
return val
|
||||
return 8192
|
||||
|
||||
|
||||
def truncate_for_embedding(text: str) -> str:
|
||||
"""Truncate text to fit within the embedding model's context window.
|
||||
|
||||
Uses the embedding model's own tokenizer for accurate token counting,
|
||||
so the result is model-agnostic regardless of the underlying provider.
|
||||
"""
|
||||
max_tokens = _get_embedding_max_tokens()
|
||||
if len(text) // 3 <= max_tokens:
|
||||
return text
|
||||
|
||||
tokenizer = config.embedding_model_instance.get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
|
||||
warnings.warn(
|
||||
f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return tokenizer.decode(tokens[:max_tokens])
|
||||
|
||||
|
||||
def embed_text(text: str) -> np.ndarray:
|
||||
"""Truncate text to fit and embed it. Drop-in replacement for
|
||||
``config.embedding_model_instance.embed(text)`` that never exceeds the
|
||||
model's context window."""
|
||||
return config.embedding_model_instance.embed(truncate_for_embedding(text))
|
||||
|
||||
|
||||
def get_model_context_window(model_name: str) -> int:
|
||||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
|
|
@ -146,7 +194,7 @@ async def generate_document_summary(
|
|||
else:
|
||||
enhanced_summary_content = summary_content
|
||||
|
||||
summary_embedding = config.embedding_model_instance.embed(enhanced_summary_content)
|
||||
summary_embedding = embed_text(enhanced_summary_content)
|
||||
|
||||
return enhanced_summary_content, summary_embedding
|
||||
|
||||
|
|
@ -164,7 +212,7 @@ async def create_document_chunks(content: str) -> list[Chunk]:
|
|||
return [
|
||||
Chunk(
|
||||
content=chunk.text,
|
||||
embedding=config.embedding_model_instance.embed(chunk.text),
|
||||
embedding=embed_text(chunk.text),
|
||||
)
|
||||
for chunk in config.chunker_instance.chunk(content)
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue