mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: add huggingface inferece provider endpoint
This commit is contained in:
parent
ef266daa6e
commit
dd3f2e7323
7 changed files with 315 additions and 2 deletions
|
|
@ -49,6 +49,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
|
||||
ServiceProviders.HUGGINGFACE.value: self._check_huggingface_api_key,
|
||||
ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key,
|
||||
|
|
@ -360,6 +361,14 @@ class UserConfigurationValidator:
|
|||
raise ValueError("base_url is required for Speaches services")
|
||||
return True
|
||||
|
||||
def _check_huggingface_api_key(self, model: str, api_key: str) -> bool:
|
||||
if not api_key.startswith("hf_"):
|
||||
raise ValueError(
|
||||
"Invalid Hugging Face API token format. Use a token that starts with "
|
||||
"'hf_' and has Inference Providers permission."
|
||||
)
|
||||
return True
|
||||
|
||||
def _check_google_vertex_realtime_api_key(self, model: str, service_config) -> bool:
|
||||
if not getattr(service_config, "project_id", None):
|
||||
raise ValueError("project_id is required for Google Vertex Realtime")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class ServiceProviders(str, Enum):
|
|||
CAMB = "camb"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
SPEACHES = "speaches"
|
||||
HUGGINGFACE = "huggingface"
|
||||
ASSEMBLYAI = "assemblyai"
|
||||
GLADIA = "gladia"
|
||||
RIME = "rime"
|
||||
|
|
@ -94,6 +95,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SPEACHES,
|
||||
ServiceProviders.HUGGINGFACE,
|
||||
ServiceProviders.ASSEMBLYAI,
|
||||
ServiceProviders.GLADIA,
|
||||
ServiceProviders.RIME,
|
||||
|
|
@ -255,6 +257,11 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config(
|
|||
),
|
||||
provider_docs_url="https://github.com/speaches-ai/speaches",
|
||||
)
|
||||
HUGGINGFACE_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Hugging Face",
|
||||
description="Hosted Hugging Face Inference Providers API for usage-based inference.",
|
||||
provider_docs_url="https://huggingface.co/docs/inference-providers/en/index",
|
||||
)
|
||||
AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure Speech Services",
|
||||
description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.",
|
||||
|
|
@ -471,6 +478,35 @@ class SpeachesLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
HUGGINGFACE_LLM_MODELS = [
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"deepseek-ai/DeepSeek-R1:fastest",
|
||||
"Qwen/Qwen3-Coder-480B-A35B-Instruct:fastest",
|
||||
]
|
||||
|
||||
|
||||
@register_llm
|
||||
class HuggingFaceLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = HUGGINGFACE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.HUGGINGFACE] = ServiceProviders.HUGGINGFACE
|
||||
model: str = Field(
|
||||
default="openai/gpt-oss-120b:cerebras",
|
||||
description="Hugging Face chat-completion model identifier, optionally with provider suffix.",
|
||||
json_schema_extra={
|
||||
"examples": HUGGINGFACE_LLM_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://router.huggingface.co/v1",
|
||||
description="Hugging Face OpenAI-compatible chat-completions router base URL.",
|
||||
)
|
||||
bill_to: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Hugging Face organization or user to bill using X-HF-Bill-To.",
|
||||
)
|
||||
|
||||
|
||||
MINIMAX_MODELS = [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
|
|
@ -741,6 +777,7 @@ LLMConfig = Annotated[
|
|||
DograhLLMService,
|
||||
AWSBedrockLLMConfiguration,
|
||||
SpeachesLLMConfiguration,
|
||||
HuggingFaceLLMConfiguration,
|
||||
MiniMaxLLMConfiguration,
|
||||
SarvamLLMConfiguration,
|
||||
],
|
||||
|
|
@ -1334,6 +1371,38 @@ class SpeachesSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
HUGGINGFACE_STT_MODELS = [
|
||||
"openai/whisper-large-v3-turbo",
|
||||
"openai/whisper-large-v3",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class HuggingFaceSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = HUGGINGFACE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.HUGGINGFACE] = ServiceProviders.HUGGINGFACE
|
||||
model: str = Field(
|
||||
default="openai/whisper-large-v3-turbo",
|
||||
description="Hugging Face ASR model identifier served through Inference Providers.",
|
||||
json_schema_extra={
|
||||
"examples": HUGGINGFACE_STT_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://router.huggingface.co/hf-inference",
|
||||
description="Hugging Face Inference Providers router base URL.",
|
||||
)
|
||||
bill_to: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Hugging Face organization or user to bill using X-HF-Bill-To.",
|
||||
)
|
||||
return_timestamps: bool = Field(
|
||||
default=False,
|
||||
description="Request timestamp chunks when supported by the selected provider/model.",
|
||||
)
|
||||
|
||||
|
||||
ASSEMBLYAI_STT_MODELS = ["u3-rt-pro"]
|
||||
ASSEMBLYAI_STT_LANGUAGES = ["en", "es", "de", "fr", "pt", "it"]
|
||||
|
||||
|
|
@ -1406,6 +1475,7 @@ STTConfig = Annotated[
|
|||
SpeechmaticsSTTConfiguration,
|
||||
SarvamSTTConfiguration,
|
||||
SpeachesSTTConfiguration,
|
||||
HuggingFaceSTTConfiguration,
|
||||
AssemblyAISTTConfiguration,
|
||||
GladiaSTTConfiguration,
|
||||
AzureSpeechSTTConfiguration,
|
||||
|
|
|
|||
|
|
@ -39,8 +39,17 @@ from pipecat.services.google.vertex.llm import (
|
|||
GoogleVertexLLMSettings,
|
||||
)
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.huggingface.llm import (
|
||||
HuggingFaceLLMService,
|
||||
HuggingFaceLLMSettings,
|
||||
)
|
||||
from pipecat.services.huggingface.stt import (
|
||||
HuggingFaceSTTService,
|
||||
HuggingFaceSTTSettings,
|
||||
)
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxTTSSettings
|
||||
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import (
|
||||
|
|
@ -218,6 +227,22 @@ def create_stt_service(
|
|||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.HUGGINGFACE.value:
|
||||
base_url = (
|
||||
getattr(user_config.stt, "base_url", None)
|
||||
or "https://router.huggingface.co/hf-inference"
|
||||
)
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return HuggingFaceSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
base_url=base_url,
|
||||
bill_to=getattr(user_config.stt, "bill_to", None),
|
||||
settings=HuggingFaceSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
return_timestamps=getattr(user_config.stt, "return_timestamps", False),
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.ASSEMBLYAI.value:
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
settings_kwargs = {"model": user_config.stt.model, "language": language}
|
||||
|
|
@ -320,6 +345,7 @@ def create_tts_service(
|
|||
kwargs["base_url"] = base_url
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
sample_rate=OPENAI_SAMPLE_RATE,
|
||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
|
|
@ -581,6 +607,7 @@ def create_llm_service_from_provider(
|
|||
location: str | None = None,
|
||||
credentials: str | None = None,
|
||||
temperature: float | None = None,
|
||||
bill_to: str | None = None,
|
||||
):
|
||||
"""Create an LLM service from explicit provider/model/api_key.
|
||||
|
||||
|
|
@ -663,6 +690,15 @@ def create_llm_service_from_provider(
|
|||
api_key=api_key or "none",
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.HUGGINGFACE.value:
|
||||
base_url = base_url or "https://router.huggingface.co/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return HuggingFaceLLMService(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
bill_to=bill_to,
|
||||
settings=HuggingFaceLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
base_url = base_url or "https://api.minimax.io/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
|
|
@ -875,6 +911,9 @@ def create_llm_service(user_config, correlation_id: str | None = None):
|
|||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.HUGGINGFACE.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
kwargs["bill_to"] = user_config.llm.bill_to
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
|
|
|
|||
131
api/tests/test_huggingface_stt_service_factory.py
Normal file
131
api/tests/test_huggingface_stt_service_factory.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
HuggingFaceLLMConfiguration,
|
||||
HuggingFaceSTTConfiguration,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_stt_configuration_defaults_and_registry():
|
||||
config = HuggingFaceSTTConfiguration(api_key="hf_test")
|
||||
|
||||
assert config.provider == ServiceProviders.HUGGINGFACE
|
||||
assert config.model == "openai/whisper-large-v3-turbo"
|
||||
assert config.base_url == "https://router.huggingface.co/hf-inference"
|
||||
assert config.return_timestamps is False
|
||||
assert (
|
||||
REGISTRY[ServiceType.STT][ServiceProviders.HUGGINGFACE]
|
||||
is HuggingFaceSTTConfiguration
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_llm_configuration_defaults_and_registry():
|
||||
config = HuggingFaceLLMConfiguration(api_key="hf_test")
|
||||
|
||||
assert config.provider == ServiceProviders.HUGGINGFACE
|
||||
assert config.model == "openai/gpt-oss-120b:cerebras"
|
||||
assert config.base_url == "https://router.huggingface.co/v1"
|
||||
assert (
|
||||
REGISTRY[ServiceType.LLM][ServiceProviders.HUGGINGFACE]
|
||||
is HuggingFaceLLMConfiguration
|
||||
)
|
||||
|
||||
|
||||
def test_create_huggingface_llm_service_uses_openai_compatible_router():
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.HUGGINGFACE.value,
|
||||
api_key="hf_test",
|
||||
model="deepseek-ai/DeepSeek-R1:fastest",
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
bill_to="demo-org",
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.HuggingFaceLLMService"
|
||||
) as mock_service:
|
||||
create_llm_service(user_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "hf_test"
|
||||
assert kwargs["base_url"] == "https://router.huggingface.co/v1"
|
||||
assert kwargs["bill_to"] == "demo-org"
|
||||
assert kwargs["settings"].model == "deepseek-ai/DeepSeek-R1:fastest"
|
||||
assert kwargs["settings"].temperature == 0.1
|
||||
|
||||
|
||||
def test_create_huggingface_stt_service_uses_hosted_defaults():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.HUGGINGFACE.value,
|
||||
api_key="hf_test",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
base_url="https://router.huggingface.co/hf-inference",
|
||||
bill_to="demo-org",
|
||||
return_timestamps=True,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.HuggingFaceSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "hf_test"
|
||||
assert kwargs["base_url"] == "https://router.huggingface.co/hf-inference"
|
||||
assert kwargs["bill_to"] == "demo-org"
|
||||
assert kwargs["sample_rate"] == 16000
|
||||
assert kwargs["settings"].model == "openai/whisper-large-v3-turbo"
|
||||
assert kwargs["settings"].return_timestamps is True
|
||||
|
||||
|
||||
def test_validator_accepts_huggingface_stt_token_format():
|
||||
validator = UserConfigurationValidator()
|
||||
|
||||
assert (
|
||||
validator._validate_service(
|
||||
HuggingFaceSTTConfiguration(api_key="hf_test"),
|
||||
"stt",
|
||||
)
|
||||
== []
|
||||
)
|
||||
assert (
|
||||
validator._validate_service(
|
||||
HuggingFaceLLMConfiguration(api_key="hf_test"),
|
||||
"llm",
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
def test_validator_rejects_non_huggingface_token_format():
|
||||
validator = UserConfigurationValidator()
|
||||
|
||||
errors = validator._validate_service(
|
||||
HuggingFaceSTTConfiguration(api_key="not-hf-token"),
|
||||
"stt",
|
||||
)
|
||||
|
||||
assert errors == [
|
||||
{
|
||||
"model": "stt",
|
||||
"message": (
|
||||
"Invalid Hugging Face API token format. Use a token that starts with "
|
||||
"'hf_' and has Inference Providers permission."
|
||||
),
|
||||
}
|
||||
]
|
||||
31
api/tests/test_openai_tts_service_factory.py
Normal file
31
api/tests/test_openai_tts_service_factory.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.service_factory import create_tts_service
|
||||
|
||||
|
||||
def test_create_openai_tts_service_uses_openai_pcm_sample_rate():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
api_key="test-key",
|
||||
model="gpt-4o-mini-tts",
|
||||
voice="alloy",
|
||||
base_url=None,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(
|
||||
transport_out_sample_rate=16000,
|
||||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.OpenAITTSService") as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["sample_rate"] == OPENAI_SAMPLE_RATE
|
||||
assert kwargs["settings"].model == "gpt-4o-mini-tts"
|
||||
Loading…
Add table
Add a link
Reference in a new issue