mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
Add Sarvam LLM, update Sarvam STT models, expose usage_info on run detail (#351)
* Add Sarvam LLM provider, update Sarvam STT models, expose usage_info on run detail. Depends on pipecat PR dograh-hq/pipecat#43 for STT string language support. Submodule bump will follow after that merges. * test: cover Sarvam STT language mapping; link Sarvam docs --------- Co-authored-by: Sabiha Khan <sabihak89@gmail.com>
This commit is contained in:
parent
13b30dee9e
commit
98d2b24cba
10 changed files with 272 additions and 9 deletions
23
api/tests/test_run_usage_response.py
Normal file
23
api/tests/test_run_usage_response.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"SarvamLLMService#0|||sarvam-30b": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 42},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 12.4,
|
||||
}
|
||||
|
||||
result = format_public_usage_info(usage_info)
|
||||
|
||||
assert result["llm"]["SarvamLLMService#0|||sarvam-30b"]["prompt_tokens"] == 100
|
||||
assert result["tts"]["ElevenLabsTTSService#0|||eleven_flash_v2_5"] == 42
|
||||
assert result["stt"] == {}
|
||||
assert result["call_duration_seconds"] == 12.4
|
||||
114
api/tests/test_sarvam_service_factory.py
Normal file
114
api/tests/test_sarvam_service_factory.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pipecat.services.sarvam.llm import SarvamLLMService as RealSarvamLLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
SarvamLLMConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_stt_service,
|
||||
)
|
||||
|
||||
|
||||
class TestSarvamLLMConfiguration:
|
||||
def test_default_values(self):
|
||||
config = SarvamLLMConfiguration(api_key="test-key")
|
||||
assert config.provider == ServiceProviders.SARVAM
|
||||
assert config.model == "sarvam-30b"
|
||||
assert config.temperature == 0.5
|
||||
|
||||
def test_custom_model(self):
|
||||
config = SarvamLLMConfiguration(api_key="test-key", model="sarvam-105b")
|
||||
assert config.model == "sarvam-105b"
|
||||
|
||||
|
||||
class TestSarvamLLMServiceFactory:
|
||||
def test_create_sarvam_llm_service(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["settings"].model == "sarvam-30b"
|
||||
assert kwargs["settings"].temperature == 0.5
|
||||
|
||||
def test_create_sarvam_llm_service_passes_user_temperature(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].temperature == 0.8
|
||||
|
||||
def test_create_llm_service_extracts_sarvam_temperature(self):
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service(user_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].temperature == 0.7
|
||||
|
||||
|
||||
class TestSarvamSTTServiceFactory:
|
||||
@pytest.mark.parametrize(
|
||||
"input_language,expected_language",
|
||||
[
|
||||
("unknown", None),
|
||||
(None, None),
|
||||
("hi-IN", Language.HI_IN),
|
||||
("ne-IN", "ne-IN"),
|
||||
],
|
||||
)
|
||||
def test_stt_language_mapping(self, input_language, expected_language):
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="saaras:v3",
|
||||
api_key="test-key",
|
||||
language=input_language,
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].language == expected_language
|
||||
Loading…
Add table
Add a link
Reference in a new issue