mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: knowledge base functionality for the voice agent (#120)
* feat: upload file and store embedding * feat: add documents in nodes * feat: add openai embedding service
This commit is contained in:
parent
e2fa4bbb98
commit
ef5b9e40a9
52 changed files with 4551 additions and 114 deletions
|
|
@ -48,6 +48,12 @@ class UserConfigurationValidator:
|
|||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
@ -55,11 +61,16 @@ class UserConfigurationValidator:
|
|||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
def _validate_service(
|
||||
self, service_config: Optional[ServiceConfig], service_name: str
|
||||
self,
|
||||
service_config: Optional[ServiceConfig],
|
||||
service_name: str,
|
||||
required: bool = True,
|
||||
) -> list[APIKeyStatus]:
|
||||
"""Validate a service configuration and return any error statuses."""
|
||||
if not service_config:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
if required:
|
||||
return [{"model": service_name, "message": "API key is missing"}]
|
||||
return [] # Optional service not configured is OK
|
||||
|
||||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ left as ``None``.
|
|||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
ServiceProviders,
|
||||
)
|
||||
|
|
@ -22,6 +23,7 @@ _DEFAULTS = {
|
|||
"llm": (ServiceProviders.OPENAI, OpenAILLMService),
|
||||
"tts": (ServiceProviders.ELEVENLABS, ElevenlabsTTSConfiguration),
|
||||
"stt": (ServiceProviders.DEEPGRAM, DeepgramSTTConfiguration),
|
||||
"embeddings": (ServiceProviders.OPENAI, OpenAIEmbeddingsConfiguration),
|
||||
}
|
||||
|
||||
# Public mapping of service name -> default provider
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"llm": _mask_service(config.llm),
|
||||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"embeddings": _mask_service(config.embeddings),
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict
|
|||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt")
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class ServiceType(Enum):
|
|||
LLM = auto()
|
||||
TTS = auto()
|
||||
STT = auto()
|
||||
EMBEDDINGS = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
|
|
@ -50,11 +51,16 @@ class BaseSTTConfiguration(BaseServiceConfiguration):
|
|||
model: str
|
||||
|
||||
|
||||
class BaseEmbeddingsConfiguration(BaseServiceConfiguration):
|
||||
model: str
|
||||
|
||||
|
||||
# Unified registry for all service types
|
||||
REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
||||
ServiceType.LLM: {},
|
||||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
ServiceType.EMBEDDINGS: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
|
@ -93,6 +99,10 @@ def register_stt(cls: Type[BaseSTTConfiguration]):
|
|||
return register_service(ServiceType.STT)(cls)
|
||||
|
||||
|
||||
def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
|
||||
return register_service(ServiceType.EMBEDDINGS)(cls)
|
||||
|
||||
|
||||
###################################################### LLM ########################################################################
|
||||
|
||||
# Suggested models for each provider (used for UI dropdown)
|
||||
|
|
@ -436,6 +446,27 @@ STTConfig = Annotated[
|
|||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig], Field(discriminator="provider")
|
||||
###################################################### EMBEDDINGS ########################################################################
|
||||
|
||||
OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
json_schema_extra={"examples": OPENAI_EMBEDDING_MODELS},
|
||||
)
|
||||
api_key: str
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[OpenAIEmbeddingsConfiguration],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue