mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: enhance Azure OpenAI embeddings and add voice support for Azure provider
- Introduced a fixed parameter order for AzureOpenAIEmbeddings to resolve compatibility issues. - Updated the voice selection logic to include Azure voices in the podcaster utility. - Modified the page limit service to use a more efficient method for retrieving users.
This commit is contained in:
parent
34353078fe
commit
a2fb9faad6
4 changed files with 73 additions and 9 deletions
|
|
@ -63,6 +63,17 @@ def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
|
|||
},
|
||||
}
|
||||
return vertex_voices.get(speaker_id, vertex_voices[0])
|
||||
elif provider_type == "azure":
|
||||
# OpenAI voice mapping - simple string values
|
||||
azure_voices = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
2: "fable", # Second speaker
|
||||
3: "onyx", # Third speaker
|
||||
4: "nova", # Fourth speaker
|
||||
5: "shimmer", # Fifth speaker
|
||||
}
|
||||
return azure_voices.get(speaker_id, "alloy")
|
||||
|
||||
else:
|
||||
# Default fallback to OpenAI format for unknown providers
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||
from chonkie.embeddings.azure_openai import AzureOpenAIEmbeddings
|
||||
|
|
@ -8,8 +9,46 @@ from chonkie.embeddings.registry import EmbeddingsRegistry
|
|||
from dotenv import load_dotenv
|
||||
from rerankers import Reranker
|
||||
|
||||
|
||||
# Monkey patch AzureOpenAIEmbeddings to fix parameter order issue
|
||||
# This is a temporary workaround until the upstream chonkie library is fixed
|
||||
class FixedAzureOpenAIEmbeddings(AzureOpenAIEmbeddings):
|
||||
"""Wrapper around AzureOpenAIEmbeddings with fixed parameter order."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "text-embedding-3-small",
|
||||
azure_endpoint: str | None = None,
|
||||
tokenizer: Any | None = None,
|
||||
dimension: int | None = None,
|
||||
azure_api_key: str | None = None,
|
||||
api_version: str = "2024-10-21",
|
||||
deployment: str | None = None,
|
||||
max_retries: int = 3,
|
||||
timeout: float = 60.0,
|
||||
batch_size: int = 128,
|
||||
**kwargs: dict[str, Any],
|
||||
):
|
||||
"""Initialize with model as first parameter to avoid conflicts."""
|
||||
# Call parent's __init__ by explicitly passing azure_endpoint as first arg
|
||||
# to maintain compatibility with the original signature
|
||||
super().__init__(
|
||||
azure_endpoint=azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT", ""),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
dimension=dimension,
|
||||
azure_api_key=azure_api_key,
|
||||
api_version=api_version,
|
||||
deployment=deployment,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Fix this in chonkie upstream
|
||||
# Register Azure OpenAI embeddings with pattern
|
||||
# Register our fixed Azure OpenAI embeddings with pattern
|
||||
# This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
||||
# - `api_key` from `AZURE_OPENAI_API_KEY`
|
||||
# - `organization` from `OPENAI_ORG_ID`
|
||||
|
|
@ -17,11 +56,11 @@ from rerankers import Reranker
|
|||
# - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
||||
# - `api_version` from `OPENAI_API_VERSION`
|
||||
# - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
||||
EmbeddingsRegistry.register_provider("azure_openai", AzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_pattern(r"^text-embedding-", AzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_model("text-embedding-ada-002", AzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_model("text-embedding-3-small", AzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_model("text-embedding-3-large", AzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_provider("azure_openai", FixedAzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_pattern(r"^text-embedding-", FixedAzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_model("text-embedding-ada-002", FixedAzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_model("text-embedding-3-small", FixedAzureOpenAIEmbeddings)
|
||||
EmbeddingsRegistry.register_model("text-embedding-3-large", FixedAzureOpenAIEmbeddings)
|
||||
|
||||
|
||||
# Get the base directory of the project
|
||||
|
|
@ -83,7 +122,21 @@ class Config:
|
|||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
||||
# Azure OpenAI credentials from environment variables
|
||||
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
|
||||
|
||||
# Pass Azure credentials to embeddings when using Azure OpenAI
|
||||
embedding_kwargs = {}
|
||||
if AZURE_OPENAI_ENDPOINT:
|
||||
embedding_kwargs["azure_endpoint"] = AZURE_OPENAI_ENDPOINT
|
||||
if AZURE_OPENAI_API_KEY:
|
||||
embedding_kwargs["azure_api_key"] = AZURE_OPENAI_API_KEY
|
||||
|
||||
embedding_model_instance = AutoEmbeddings.get_embeddings(
|
||||
EMBEDDING_MODEL,
|
||||
**embedding_kwargs,
|
||||
)
|
||||
chunker_instance = RecursiveChunker(
|
||||
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class PageLimitService:
|
|||
|
||||
# Get user
|
||||
result = await self.session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = result.unique().scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ backend_pid=$!
|
|||
sleep 5
|
||||
|
||||
echo "Starting Celery Worker..."
|
||||
celery -A app.celery_app worker --loglevel=info --concurrency=1 --pool=solo &
|
||||
celery -A app.celery_app worker --loglevel=info &
|
||||
celery_worker_pid=$!
|
||||
|
||||
# Wait a bit for worker to initialize
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue