mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 22:02:39 +02:00
Merge pull request #448 from MODSetter/dev
chore: update configuration for rerankers
This commit is contained in:
commit
b79befdef6
5 changed files with 62 additions and 51 deletions
|
|
@ -3,6 +3,7 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
|
||||||
#Celery Config
|
#Celery Config
|
||||||
CELERY_BROKER_URL=redis://localhost:6379/0
|
CELERY_BROKER_URL=redis://localhost:6379/0
|
||||||
CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||||
|
|
||||||
# Periodic task interval
|
# Periodic task interval
|
||||||
# # Run every minute (default)
|
# # Run every minute (default)
|
||||||
# SCHEDULE_CHECKER_INTERVAL=1m
|
# SCHEDULE_CHECKER_INTERVAL=1m
|
||||||
|
|
@ -18,7 +19,6 @@ CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||||
|
|
||||||
# # Run every 2 hours
|
# # Run every 2 hours
|
||||||
# SCHEDULE_CHECKER_INTERVAL=2h
|
# SCHEDULE_CHECKER_INTERVAL=2h
|
||||||
|
|
||||||
SCHEDULE_CHECKER_INTERVAL=5m
|
SCHEDULE_CHECKER_INTERVAL=5m
|
||||||
|
|
||||||
SECRET_KEY=SECRET
|
SECRET_KEY=SECRET
|
||||||
|
|
@ -30,10 +30,12 @@ REGISTRATION_ENABLED= TRUE or FALSE
|
||||||
# For Google Auth Only
|
# For Google Auth Only
|
||||||
GOOGLE_OAUTH_CLIENT_ID=924507538m
|
GOOGLE_OAUTH_CLIENT_ID=924507538m
|
||||||
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
|
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
|
||||||
|
|
||||||
|
# Connector Specific Configs
|
||||||
GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
|
GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback
|
||||||
GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
|
GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback
|
||||||
|
|
||||||
# Airtable OAuth
|
# Airtable OAuth for Aitable Connector
|
||||||
AIRTABLE_CLIENT_ID=your_airtable_client_id
|
AIRTABLE_CLIENT_ID=your_airtable_client_id
|
||||||
AIRTABLE_CLIENT_SECRET=your_airtable_client_secret
|
AIRTABLE_CLIENT_SECRET=your_airtable_client_secret
|
||||||
AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
|
AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback
|
||||||
|
|
@ -51,20 +53,21 @@ AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callb
|
||||||
|
|
||||||
# # Get Cohere embeddings
|
# # Get Cohere embeddings
|
||||||
# embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")
|
# embeddings = AutoEmbeddings.get_embeddings("cohere://embed-english-light-v3.0", api_key="...")
|
||||||
|
|
||||||
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
|
||||||
|
# Rerankers Config
|
||||||
|
RERANKERS_ENABLED=TRUE or FALSE(Default: FALSE)
|
||||||
RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
|
RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
|
||||||
RERANKERS_MODEL_TYPE=flashrank
|
RERANKERS_MODEL_TYPE=flashrank
|
||||||
|
|
||||||
|
|
||||||
# TTS_SERVICE=local/kokoro for local Kokoro TTS or
|
# TTS_SERVICE=local/kokoro for local Kokoro TTS or
|
||||||
# LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
|
# LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
|
||||||
TTS_SERVICE=openai/tts-1
|
TTS_SERVICE=local/kokoro
|
||||||
# Respective TTS Service API
|
# Respective TTS Service API
|
||||||
TTS_SERVICE_API_KEY=
|
# TTS_SERVICE_API_KEY=
|
||||||
# OPTIONAL: TTS Provider API Base
|
# OPTIONAL: TTS Provider API Base
|
||||||
TTS_SERVICE_API_BASE=
|
# TTS_SERVICE_API_BASE=
|
||||||
|
|
||||||
# STT Service Configuration
|
# STT Service Configuration
|
||||||
# For local Faster-Whisper: local/MODEL_SIZE (tiny, base, small, medium, large-v3)
|
# For local Faster-Whisper: local/MODEL_SIZE (tiny, base, small, medium, large-v3)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
||||||
reranks them using the reranker service based on the user's query,
|
reranks them using the reranker service based on the user's query,
|
||||||
and updates the state with the reranked documents.
|
and updates the state with the reranked documents.
|
||||||
|
|
||||||
|
If reranking is disabled, returns the original documents without processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing the reranked documents.
|
Dict containing the reranked documents.
|
||||||
"""
|
"""
|
||||||
|
|
@ -40,10 +42,12 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
||||||
# Get reranker service from app config
|
# Get reranker service from app config
|
||||||
reranker_service = RerankerService.get_reranker_instance()
|
reranker_service = RerankerService.get_reranker_instance()
|
||||||
|
|
||||||
# Use documents as is if no reranker service is available
|
# If reranking is not enabled, return original documents without processing
|
||||||
reranked_docs = documents
|
if not reranker_service:
|
||||||
|
print("Reranking is disabled. Using original document order.")
|
||||||
|
return {"reranked_documents": documents}
|
||||||
|
|
||||||
if reranker_service:
|
# Perform reranking
|
||||||
try:
|
try:
|
||||||
# Convert documents to format expected by reranker if needed
|
# Convert documents to format expected by reranker if needed
|
||||||
reranker_input_docs = [
|
reranker_input_docs = [
|
||||||
|
|
@ -54,9 +58,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
||||||
"document": {
|
"document": {
|
||||||
"id": doc.get("document", {}).get("id", ""),
|
"id": doc.get("document", {}).get("id", ""),
|
||||||
"title": doc.get("document", {}).get("title", ""),
|
"title": doc.get("document", {}).get("title", ""),
|
||||||
"document_type": doc.get("document", {}).get(
|
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||||
"document_type", ""
|
|
||||||
),
|
|
||||||
"metadata": doc.get("document", {}).get("metadata", {}),
|
"metadata": doc.get("document", {}).get("metadata", {}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -71,15 +73,15 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
||||||
# Sort by score in descending order
|
# Sort by score in descending order
|
||||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||||
|
|
||||||
print(
|
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
|
||||||
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during reranking: {e!s}")
|
|
||||||
# Use original docs if reranking fails
|
|
||||||
|
|
||||||
return {"reranked_documents": reranked_docs}
|
return {"reranked_documents": reranked_docs}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during reranking: {e!s}")
|
||||||
|
# Fall back to original documents if reranking fails
|
||||||
|
return {"reranked_documents": documents}
|
||||||
|
|
||||||
|
|
||||||
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -92,12 +92,16 @@ class Config:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
||||||
|
RERANKERS_ENABLED = os.getenv("RERANKERS_ENABLED", "FALSE").upper() == "TRUE"
|
||||||
|
if RERANKERS_ENABLED:
|
||||||
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
|
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
|
||||||
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
|
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
|
||||||
reranker_instance = Reranker(
|
reranker_instance = Reranker(
|
||||||
model_name=RERANKERS_MODEL_NAME,
|
model_name=RERANKERS_MODEL_NAME,
|
||||||
model_type=RERANKERS_MODEL_TYPE,
|
model_type=RERANKERS_MODEL_TYPE,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
reranker_instance = None
|
||||||
|
|
||||||
# OAuth JWT
|
# OAuth JWT
|
||||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||||
|
|
|
||||||
|
|
@ -88,8 +88,9 @@ Before you begin, ensure you have:
|
||||||
| GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
| GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
||||||
| GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
| GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
||||||
| EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) |
|
| EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) |
|
||||||
| RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) |
|
| RERANKERS_ENABLED | (Optional) Enable or disable document reranking for improved search results (e.g., `TRUE` or `FALSE`, default: `FALSE`) |
|
||||||
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) |
|
| RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) (required if RERANKERS_ENABLED=TRUE) |
|
||||||
|
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) (required if RERANKERS_ENABLED=TRUE) |
|
||||||
| TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) |
|
| TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) |
|
||||||
| TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service |
|
| TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service |
|
||||||
| TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service |
|
| TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service |
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,9 @@ Edit the `.env` file and set the following variables:
|
||||||
| GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
| GOOGLE_OAUTH_CLIENT_ID | (Optional) Client ID from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
||||||
| GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
| GOOGLE_OAUTH_CLIENT_SECRET | (Optional) Client secret from Google Cloud Console (required if AUTH_TYPE=GOOGLE) |
|
||||||
| EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) |
|
| EMBEDDING_MODEL | Name of the embedding model (e.g., `sentence-transformers/all-MiniLM-L6-v2`, `openai://text-embedding-ada-002`) |
|
||||||
| RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) |
|
| RERANKERS_ENABLED | (Optional) Enable or disable document reranking for improved search results (e.g., `TRUE` or `FALSE`, default: `FALSE`) |
|
||||||
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) |
|
| RERANKERS_MODEL_NAME | Name of the reranker model (e.g., `ms-marco-MiniLM-L-12-v2`) (required if RERANKERS_ENABLED=TRUE) |
|
||||||
|
| RERANKERS_MODEL_TYPE | Type of reranker model (e.g., `flashrank`) (required if RERANKERS_ENABLED=TRUE) |
|
||||||
| TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) |
|
| TTS_SERVICE | Text-to-Speech API provider for Podcasts (e.g., `local/kokoro`, `openai/tts-1`). See [supported providers](https://docs.litellm.ai/docs/text_to_speech#supported-providers) |
|
||||||
| TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service |
|
| TTS_SERVICE_API_KEY | (Optional if local) API key for the Text-to-Speech service |
|
||||||
| TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service |
|
| TTS_SERVICE_API_BASE | (Optional) Custom API base URL for the Text-to-Speech service |
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue