mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
feat: enhance Azure OpenAI embeddings and add voice support for Azure provider
- Introduced a fixed parameter order for AzureOpenAIEmbeddings to resolve compatibility issues. - Updated the voice selection logic to include Azure voices in the podcaster utility. - Modified the page limit service to use a more efficient method for retrieving users.
This commit is contained in:
parent
34353078fe
commit
a2fb9faad6
4 changed files with 73 additions and 9 deletions
|
|
@ -63,6 +63,17 @@ def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return vertex_voices.get(speaker_id, vertex_voices[0])
|
return vertex_voices.get(speaker_id, vertex_voices[0])
|
||||||
|
elif provider_type == "azure":
|
||||||
|
# OpenAI voice mapping - simple string values
|
||||||
|
azure_voices = {
|
||||||
|
0: "alloy", # Default/intro voice
|
||||||
|
1: "echo", # First speaker
|
||||||
|
2: "fable", # Second speaker
|
||||||
|
3: "onyx", # Third speaker
|
||||||
|
4: "nova", # Fourth speaker
|
||||||
|
5: "shimmer", # Fifth speaker
|
||||||
|
}
|
||||||
|
return azure_voices.get(speaker_id, "alloy")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Default fallback to OpenAI format for unknown providers
|
# Default fallback to OpenAI format for unknown providers
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||||
from chonkie.embeddings.azure_openai import AzureOpenAIEmbeddings
|
from chonkie.embeddings.azure_openai import AzureOpenAIEmbeddings
|
||||||
|
|
@ -8,8 +9,46 @@ from chonkie.embeddings.registry import EmbeddingsRegistry
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from rerankers import Reranker
|
from rerankers import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
# Monkey patch AzureOpenAIEmbeddings to fix parameter order issue
|
||||||
|
# This is a temporary workaround until the upstream chonkie library is fixed
|
||||||
|
class FixedAzureOpenAIEmbeddings(AzureOpenAIEmbeddings):
|
||||||
|
"""Wrapper around AzureOpenAIEmbeddings with fixed parameter order."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
azure_endpoint: str | None = None,
|
||||||
|
tokenizer: Any | None = None,
|
||||||
|
dimension: int | None = None,
|
||||||
|
azure_api_key: str | None = None,
|
||||||
|
api_version: str = "2024-10-21",
|
||||||
|
deployment: str | None = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
batch_size: int = 128,
|
||||||
|
**kwargs: dict[str, Any],
|
||||||
|
):
|
||||||
|
"""Initialize with model as first parameter to avoid conflicts."""
|
||||||
|
# Call parent's __init__ by explicitly passing azure_endpoint as first arg
|
||||||
|
# to maintain compatibility with the original signature
|
||||||
|
super().__init__(
|
||||||
|
azure_endpoint=azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT", ""),
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dimension=dimension,
|
||||||
|
azure_api_key=azure_api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
deployment=deployment,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
batch_size=batch_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix this in chonkie upstream
|
# TODO: Fix this in chonkie upstream
|
||||||
# Register Azure OpenAI embeddings with pattern
|
# Register our fixed Azure OpenAI embeddings with pattern
|
||||||
# This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
# This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
||||||
# - `api_key` from `AZURE_OPENAI_API_KEY`
|
# - `api_key` from `AZURE_OPENAI_API_KEY`
|
||||||
# - `organization` from `OPENAI_ORG_ID`
|
# - `organization` from `OPENAI_ORG_ID`
|
||||||
|
|
@ -17,11 +56,11 @@ from rerankers import Reranker
|
||||||
# - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
# - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
||||||
# - `api_version` from `OPENAI_API_VERSION`
|
# - `api_version` from `OPENAI_API_VERSION`
|
||||||
# - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
# - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
||||||
EmbeddingsRegistry.register_provider("azure_openai", AzureOpenAIEmbeddings)
|
EmbeddingsRegistry.register_provider("azure_openai", FixedAzureOpenAIEmbeddings)
|
||||||
EmbeddingsRegistry.register_pattern(r"^text-embedding-", AzureOpenAIEmbeddings)
|
EmbeddingsRegistry.register_pattern(r"^text-embedding-", FixedAzureOpenAIEmbeddings)
|
||||||
EmbeddingsRegistry.register_model("text-embedding-ada-002", AzureOpenAIEmbeddings)
|
EmbeddingsRegistry.register_model("text-embedding-ada-002", FixedAzureOpenAIEmbeddings)
|
||||||
EmbeddingsRegistry.register_model("text-embedding-3-small", AzureOpenAIEmbeddings)
|
EmbeddingsRegistry.register_model("text-embedding-3-small", FixedAzureOpenAIEmbeddings)
|
||||||
EmbeddingsRegistry.register_model("text-embedding-3-large", AzureOpenAIEmbeddings)
|
EmbeddingsRegistry.register_model("text-embedding-3-large", FixedAzureOpenAIEmbeddings)
|
||||||
|
|
||||||
|
|
||||||
# Get the base directory of the project
|
# Get the base directory of the project
|
||||||
|
|
@ -83,7 +122,21 @@ class Config:
|
||||||
|
|
||||||
# Chonkie Configuration | Edit this to your needs
|
# Chonkie Configuration | Edit this to your needs
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
# Azure OpenAI credentials from environment variables
|
||||||
|
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||||
|
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
|
||||||
|
|
||||||
|
# Pass Azure credentials to embeddings when using Azure OpenAI
|
||||||
|
embedding_kwargs = {}
|
||||||
|
if AZURE_OPENAI_ENDPOINT:
|
||||||
|
embedding_kwargs["azure_endpoint"] = AZURE_OPENAI_ENDPOINT
|
||||||
|
if AZURE_OPENAI_API_KEY:
|
||||||
|
embedding_kwargs["azure_api_key"] = AZURE_OPENAI_API_KEY
|
||||||
|
|
||||||
|
embedding_model_instance = AutoEmbeddings.get_embeddings(
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
**embedding_kwargs,
|
||||||
|
)
|
||||||
chunker_instance = RecursiveChunker(
|
chunker_instance = RecursiveChunker(
|
||||||
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ class PageLimitService:
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
result = await self.session.execute(select(User).where(User.id == user_id))
|
result = await self.session.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one_or_none()
|
user = result.unique().scalar_one_or_none()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise ValueError(f"User with ID {user_id} not found")
|
raise ValueError(f"User with ID {user_id} not found")
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ backend_pid=$!
|
||||||
sleep 5
|
sleep 5
|
||||||
|
|
||||||
echo "Starting Celery Worker..."
|
echo "Starting Celery Worker..."
|
||||||
celery -A app.celery_app worker --loglevel=info --concurrency=1 --pool=solo &
|
celery -A app.celery_app worker --loglevel=info &
|
||||||
celery_worker_pid=$!
|
celery_worker_pid=$!
|
||||||
|
|
||||||
# Wait a bit for worker to initialize
|
# Wait a bit for worker to initialize
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue