Merge pull request #448 from MODSetter/dev

chore: update configuration for rerankers
This commit is contained in:
Rohan Verma 2025-10-29 23:53:28 -07:00 committed by GitHub
commit b79befdef6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 62 additions and 51 deletions

View file

@ -3,6 +3,7 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
#Celery Config #Celery Config
CELERY_BROKER_URL=redis://localhost:6379/0 CELERY_BROKER_URL=redis://localhost:6379/0
CELERY_RESULT_BACKEND=redis://localhost:6379/0 CELERY_RESULT_BACKEND=redis://localhost:6379/0
# Periodic task interval # Periodic task interval
# # Run every minute (default) # # Run every minute (default)
# SCHEDULE_CHECKER_INTERVAL=1m # SCHEDULE_CHECKER_INTERVAL=1m
@ -18,7 +19,6 @@ CELERY_RESULT_BACKEND=redis://localhost:6379/0
# # Run every 2 hours # # Run every 2 hours
# SCHEDULE_CHECKER_INTERVAL=2h # SCHEDULE_CHECKER_INTERVAL=2h
SCHEDULE_CHECKER_INTERVAL=5m SCHEDULE_CHECKER_INTERVAL=5m
SECRET_KEY=SECRET SECRET_KEY=SECRET
@ -26,14 +26,16 @@ NEXT_FRONTEND_URL=http://localhost:3000
# Auth # Auth
AUTH_TYPE=GOOGLE or LOCAL AUTH_TYPE=GOOGLE or LOCAL
REGISTRATION_ENABLED= TRUE or FALSE REGISTRATION_ENABLED=TRUE or FALSE
# For Google Auth Only # For Google Auth Only
GOOGLE_OAUTH_CLIENT_ID=924507538m GOOGLE_OAUTH_CLIENT_ID=924507538m
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
# Connector Specific Configs
GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
# Airtable OAuth # Airtable OAuth for Aitable Connector
AIRTABLE_CLIENT_ID=your_airtable_client_id AIRTABLE_CLIENT_ID=your_airtable_client_id
AIRTABLE_CLIENT_SECRET=your_airtable_client_secret AIRTABLE_CLIENT_SECRET=your_airtable_client_secret
AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
@ -51,20 +53,21 @@ AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callb
# # Get Cohere embeddings # # Get Cohere embeddings
# embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...") # embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
# Rerankers Config
RERANKERS_ENABLED=TRUE or FALSE(Default: FALSE)
RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2 RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
RERANKERS_MODEL_TYPE=flashrank RERANKERS_MODEL_TYPE=flashrank
# TTS_SERVICE=local/kokoro for local Kokoro TTS or # TTS_SERVICE=local/kokoro for local Kokoro TTS or
# LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers # LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
TTS_SERVICE=openai/tts-1 TTS_SERVICE=local/kokoro
# Respective TTS Service API # Respective TTS Service API
TTS_SERVICE_API_KEY= # TTS_SERVICE_API_KEY=
# OPTIONAL: TTS Provider API Base # OPTIONAL: TTS Provider API Base
TTS_SERVICE_API_BASE= # TTS_SERVICE_API_BASE=
# STT Service Configuration # STT Service Configuration
# For local Faster-Whisper: local/MODEL_SIZE (tiny, base, small, medium, large-v3) # For local Faster-Whisper: local/MODEL_SIZE (tiny, base, small, medium, large-v3)

View file

@ -24,6 +24,8 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
reranks them using the reranker service based on the user's query, reranks them using the reranker service based on the user's query,
and updates the state with the reranked documents. and updates the state with the reranked documents.
If reranking is disabled, returns the original documents without processing.
Returns: Returns:
Dict containing the reranked documents. Dict containing the reranked documents.
""" """
@ -40,10 +42,12 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
# Get reranker service from app config # Get reranker service from app config
reranker_service = RerankerService.get_reranker_instance() reranker_service = RerankerService.get_reranker_instance()
# Use documents as is if no reranker service is available # If reranking is not enabled, return original documents without processing
reranked_docs = documents if not reranker_service:
print("Reranking is disabled. Using original document order.")
return {"reranked_documents": documents}
if reranker_service: # Perform reranking
try: try:
# Convert documents to format expected by reranker if needed # Convert documents to format expected by reranker if needed
reranker_input_docs = [ reranker_input_docs = [
@ -54,9 +58,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
"document": { "document": {
"id": doc.get("document", {}).get("id", ""), "id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""), "title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get( "document_type": doc.get("document", {}).get("document_type", ""),
"document_type", ""
),
"metadata": doc.get("document", {}).get("metadata", {}), "metadata": doc.get("document", {}).get("metadata", {}),
}, },
} }
@ -71,15 +73,15 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
# Sort by score in descending order # Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print( print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
)
except Exception as e:
print(f"Error during reranking: {e!s}")
# Use original docs if reranking fails
return {"reranked_documents": reranked_docs} return {"reranked_documents": reranked_docs}
except Exception as e:
print(f"Error during reranking: {e!s}")
# Fall back to original documents if reranking fails
return {"reranked_documents": documents}
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]: async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
""" """

View file

@ -92,12 +92,16 @@ class Config:
) )
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage # Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
RERANKERS_ENABLED = os.getenv("RERANKERS_ENABLED", "FALSE").upper() == "TRUE"
if RERANKERS_ENABLED:
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME") RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE") RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
reranker_instance = Reranker( reranker_instance = Reranker(
model_name=RERANKERS_MODEL_NAME, model_name=RERANKERS_MODEL_NAME,
model_type=RERANKERS_MODEL_TYPE, model_type=RERANKERS_MODEL_TYPE,
) )
else:
reranker_instance = None
# OAuth JWT # OAuth JWT
SECRET_KEY = os.getenv("SECRET_KEY") SECRET_KEY = os.getenv("SECRET_KEY")

View file

@ -88,8 +88,9 @@ Before you begin, ensure you have:
| GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) | | GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
| GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) | | GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
| EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) | | EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) |
| RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) | | RERANKERS_ENABLED | (Optional) Enable or disable document reranking for improved search results (e.g., `TRUE` or `FALSE`, default: `FALSE`) |
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) | | RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) (required if RERANKERS_ENABLED=TRUE) |
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) (required if RERANKERS_ENABLED=TRUE) |
| TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) | | TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) |
| TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service | | TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service |
| TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service | | TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service |

View file

@ -73,8 +73,9 @@ Edit the `.env` file and set the following variables:
| GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) | | GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
| GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) | | GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
| EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) | | EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) |
| RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) | | RERANKERS_ENABLED | (Optional) Enable or disable document reranking for improved search results (e.g., `TRUE` or `FALSE`, default: `FALSE`) |
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) | | RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) (required if RERANKERS_ENABLED=TRUE) |
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) (required if RERANKERS_ENABLED=TRUE) |
| TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) | | TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) |
| TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service | | TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service |
| TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service | | TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service |