mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Add Sarvam LLM, update Sarvam STT models, expose usage_info on run detail (#351)
* Add Sarvam LLM provider, update Sarvam STT models, expose usage_info on run detail. Depends on pipecat PR dograh-hq/pipecat#43 for STT string language support. Submodule bump will follow after that merges. * test: cover Sarvam STT language mapping; link Sarvam docs --------- Co-authored-by: Sabiha Khan <sabihak89@gmail.com>
This commit is contained in:
parent
13b30dee9e
commit
98d2b24cba
10 changed files with 272 additions and 9 deletions
|
|
@ -32,6 +32,7 @@ from api.services.configuration.resolve import (
|
|||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.reports import generate_workflow_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
|
|
@ -1186,6 +1187,7 @@ async def get_workflow_run(
|
|||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
"created_at": run.created_at,
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class WorkflowRunResponseSchema(BaseModel):
|
|||
recording_public_url: str | None = None
|
||||
public_access_token: str | None = None
|
||||
cost_info: Dict[str, Any] | None
|
||||
usage_info: Dict[str, Any] | None = None
|
||||
definition_id: int | None # This is for backward compatibility
|
||||
initial_context: dict | None = None
|
||||
gathered_context: dict | None = None
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ from .google import (
|
|||
)
|
||||
from .sarvam import (
|
||||
SARVAM_LANGUAGES,
|
||||
SARVAM_LLM_MODELS,
|
||||
SARVAM_STT_LANGUAGES_V25,
|
||||
SARVAM_STT_LANGUAGES_V3,
|
||||
SARVAM_STT_MODELS,
|
||||
SARVAM_TTS_MODELS,
|
||||
SARVAM_V2_VOICES,
|
||||
|
|
@ -41,6 +44,9 @@ __all__ = [
|
|||
"GOOGLE_VERTEX_REALTIME_MODELS",
|
||||
"GOOGLE_VERTEX_REALTIME_VOICES",
|
||||
"SARVAM_LANGUAGES",
|
||||
"SARVAM_LLM_MODELS",
|
||||
"SARVAM_STT_LANGUAGES_V25",
|
||||
"SARVAM_STT_LANGUAGES_V3",
|
||||
"SARVAM_STT_MODELS",
|
||||
"SARVAM_TTS_MODELS",
|
||||
"SARVAM_V2_VOICES",
|
||||
|
|
|
|||
|
|
@ -63,4 +63,38 @@ SARVAM_LANGUAGES = (
|
|||
"te-IN",
|
||||
"as-IN",
|
||||
)
|
||||
SARVAM_STT_MODELS = ("saarika:v2.5", "saaras:v2")
|
||||
SARVAM_STT_MODELS = ("saarika:v2.5", "saaras:v3")
|
||||
# saarika:v2.5 language codes (unknown = auto-detect)
|
||||
SARVAM_STT_LANGUAGES_V25 = (
|
||||
"unknown",
|
||||
"hi-IN",
|
||||
"bn-IN",
|
||||
"gu-IN",
|
||||
"kn-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"od-IN",
|
||||
"pa-IN",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"en-IN",
|
||||
)
|
||||
# saaras:v3 adds these regional languages on top of the v2.5 set. Full list: https://docs.sarvam.ai/api-reference-docs/speech-to-text/transcribe
|
||||
SARVAM_STT_LANGUAGES_V3 = SARVAM_STT_LANGUAGES_V25 + (
|
||||
"as-IN",
|
||||
"ur-IN",
|
||||
"ne-IN",
|
||||
"kok-IN",
|
||||
"ks-IN",
|
||||
"sd-IN",
|
||||
"sa-IN",
|
||||
"sat-IN",
|
||||
"mni-IN",
|
||||
"brx-IN",
|
||||
"mai-IN",
|
||||
"doi-IN",
|
||||
)
|
||||
SARVAM_LLM_MODELS = (
|
||||
"sarvam-30b",
|
||||
"sarvam-105b",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ from api.services.configuration.options import (
|
|||
GOOGLE_VERTEX_REALTIME_MODELS,
|
||||
GOOGLE_VERTEX_REALTIME_VOICES,
|
||||
SARVAM_LANGUAGES,
|
||||
SARVAM_LLM_MODELS,
|
||||
SARVAM_STT_LANGUAGES_V25,
|
||||
SARVAM_STT_LANGUAGES_V3,
|
||||
SARVAM_STT_MODELS,
|
||||
SARVAM_TTS_MODELS,
|
||||
SARVAM_V2_VOICES,
|
||||
|
|
@ -89,7 +92,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.ULTRAVOX_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
ServiceProviders.SARVAM,
|
||||
]
|
||||
api_key: str | list[str]
|
||||
|
||||
|
|
@ -472,6 +475,29 @@ class MiniMaxLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
@register_llm
|
||||
class SarvamLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = SARVAM_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SARVAM] = ServiceProviders.SARVAM
|
||||
model: str = Field(
|
||||
default="sarvam-30b",
|
||||
description=(
|
||||
"Sarvam chat model. Use sarvam-30b for low-latency voice agents; "
|
||||
"sarvam-105b for complex multi-step reasoning."
|
||||
),
|
||||
json_schema_extra={"examples": SARVAM_LLM_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
temperature: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description=(
|
||||
"Sampling temperature. Sarvam recommends 0.5 for balanced "
|
||||
"conversational responses."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
OPENAI_REALTIME_MODELS = ["gpt-realtime-2"]
|
||||
OPENAI_REALTIME_VOICES = [
|
||||
"alloy",
|
||||
|
|
@ -661,6 +687,7 @@ LLMConfig = Annotated[
|
|||
AWSBedrockLLMConfiguration,
|
||||
SpeachesLLMConfiguration,
|
||||
MiniMaxLLMConfiguration,
|
||||
SarvamLLMConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -1129,13 +1156,24 @@ class SarvamSTTConfiguration(BaseSTTConfiguration):
|
|||
provider: Literal[ServiceProviders.SARVAM] = ServiceProviders.SARVAM
|
||||
model: str = Field(
|
||||
default="saarika:v2.5",
|
||||
description="Sarvam STT model.",
|
||||
description=(
|
||||
"Sarvam STT model. saarika:v2.5 transcribes in the spoken language; "
|
||||
"saaras:v3 is the recommended model with flexible output modes."
|
||||
),
|
||||
json_schema_extra={"examples": SARVAM_STT_MODELS},
|
||||
)
|
||||
language: str = Field(
|
||||
default="hi-IN",
|
||||
description="BCP-47 Indian-language code.",
|
||||
json_schema_extra={"examples": SARVAM_LANGUAGES},
|
||||
default="unknown",
|
||||
description=(
|
||||
"BCP-47 language code. Use unknown for automatic language detection."
|
||||
),
|
||||
json_schema_extra={
|
||||
"examples": SARVAM_STT_LANGUAGES_V25,
|
||||
"model_options": {
|
||||
"saarika:v2.5": SARVAM_STT_LANGUAGES_V25,
|
||||
"saaras:v3": SARVAM_STT_LANGUAGES_V3,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ from pipecat.services.openai.stt import (
|
|||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
|
||||
from pipecat.services.sarvam.llm import SarvamLLMService, SarvamLLMSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
|
|
@ -158,7 +159,7 @@ def create_stt_service(
|
|||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
language_mapping = {
|
||||
"bn-IN": Language.BN_IN,
|
||||
"gu-IN": Language.GU_IN,
|
||||
|
|
@ -172,9 +173,18 @@ def create_stt_service(
|
|||
"od-IN": Language.OR_IN,
|
||||
"en-IN": Language.EN_IN,
|
||||
"as-IN": Language.AS_IN,
|
||||
"ur-IN": Language.UR_IN,
|
||||
"kok-IN": Language.KOK_IN,
|
||||
"mai-IN": Language.MAI_IN,
|
||||
"sd-IN": Language.SD_IN,
|
||||
}
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
pipecat_language = language_mapping.get(language, Language.HI_IN)
|
||||
if not language or language == "unknown":
|
||||
pipecat_language = None
|
||||
elif language in language_mapping:
|
||||
pipecat_language = language_mapping[language]
|
||||
else:
|
||||
# Unmapped BCP-47 codes pass through; Sarvam accepts them per https://docs.sarvam.ai/api-reference-docs/speech-to-text/transcribe
|
||||
pipecat_language = language
|
||||
return SarvamSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=SarvamSTTSettings(
|
||||
|
|
@ -604,6 +614,14 @@ def create_llm_service_from_provider(
|
|||
temperature=temperature if temperature is not None else 1.0,
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.SARVAM.value:
|
||||
return SarvamLLMService(
|
||||
api_key=api_key,
|
||||
settings=SarvamLLMSettings(
|
||||
model=model,
|
||||
temperature=temperature if temperature is not None else 0.5,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
|
||||
|
||||
|
|
@ -756,5 +774,7 @@ def create_llm_service(user_config):
|
|||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
elif provider == ServiceProviders.SARVAM.value:
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
|
|
|
|||
13
api/services/pricing/run_usage_response.py
Normal file
13
api/services/pricing/run_usage_response.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
23
api/tests/test_run_usage_response.py
Normal file
23
api/tests/test_run_usage_response.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"SarvamLLMService#0|||sarvam-30b": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 42},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 12.4,
|
||||
}
|
||||
|
||||
result = format_public_usage_info(usage_info)
|
||||
|
||||
assert result["llm"]["SarvamLLMService#0|||sarvam-30b"]["prompt_tokens"] == 100
|
||||
assert result["tts"]["ElevenLabsTTSService#0|||eleven_flash_v2_5"] == 42
|
||||
assert result["stt"] == {}
|
||||
assert result["call_duration_seconds"] == 12.4
|
||||
114
api/tests/test_sarvam_service_factory.py
Normal file
114
api/tests/test_sarvam_service_factory.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pipecat.services.sarvam.llm import SarvamLLMService as RealSarvamLLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
SarvamLLMConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_stt_service,
|
||||
)
|
||||
|
||||
|
||||
class TestSarvamLLMConfiguration:
|
||||
def test_default_values(self):
|
||||
config = SarvamLLMConfiguration(api_key="test-key")
|
||||
assert config.provider == ServiceProviders.SARVAM
|
||||
assert config.model == "sarvam-30b"
|
||||
assert config.temperature == 0.5
|
||||
|
||||
def test_custom_model(self):
|
||||
config = SarvamLLMConfiguration(api_key="test-key", model="sarvam-105b")
|
||||
assert config.model == "sarvam-105b"
|
||||
|
||||
|
||||
class TestSarvamLLMServiceFactory:
|
||||
def test_create_sarvam_llm_service(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["settings"].model == "sarvam-30b"
|
||||
assert kwargs["settings"].temperature == 0.5
|
||||
|
||||
def test_create_sarvam_llm_service_passes_user_temperature(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].temperature == 0.8
|
||||
|
||||
def test_create_llm_service_extracts_sarvam_temperature(self):
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service(user_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].temperature == 0.7
|
||||
|
||||
|
||||
class TestSarvamSTTServiceFactory:
|
||||
@pytest.mark.parametrize(
|
||||
"input_language,expected_language",
|
||||
[
|
||||
("unknown", None),
|
||||
(None, None),
|
||||
("hi-IN", Language.HI_IN),
|
||||
("ne-IN", "ne-IN"),
|
||||
],
|
||||
)
|
||||
def test_stt_language_mapping(self, input_language, expected_language):
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="saaras:v3",
|
||||
api_key="test-key",
|
||||
language=input_language,
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].language == expected_language
|
||||
|
|
@ -84,6 +84,7 @@ export const LANGUAGE_DISPLAY_NAMES: Record<string, string> = {
|
|||
"zh-CN": "Chinese (Simplified)",
|
||||
"zh-TW": "Chinese (Traditional)",
|
||||
// Sarvam Indian languages
|
||||
"unknown": "Auto-detect",
|
||||
"bn-IN": "Bengali",
|
||||
"gu-IN": "Gujarati",
|
||||
"hi-IN": "Hindi",
|
||||
|
|
@ -95,4 +96,15 @@ export const LANGUAGE_DISPLAY_NAMES: Record<string, string> = {
|
|||
"ta-IN": "Tamil",
|
||||
"te-IN": "Telugu",
|
||||
"as-IN": "Assamese",
|
||||
"ur-IN": "Urdu",
|
||||
"ne-IN": "Nepali",
|
||||
"kok-IN": "Konkani",
|
||||
"ks-IN": "Kashmiri",
|
||||
"sd-IN": "Sindhi",
|
||||
"sa-IN": "Sanskrit",
|
||||
"sat-IN": "Santali",
|
||||
"mni-IN": "Manipuri",
|
||||
"brx-IN": "Bodo",
|
||||
"mai-IN": "Maithili",
|
||||
"doi-IN": "Dogri",
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue