mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
feat: add huggingface inferece provider endpoint
This commit is contained in:
parent
ef266daa6e
commit
dd3f2e7323
7 changed files with 315 additions and 2 deletions
|
|
@ -49,6 +49,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
|
||||
ServiceProviders.HUGGINGFACE.value: self._check_huggingface_api_key,
|
||||
ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key,
|
||||
|
|
@ -360,6 +361,14 @@ class UserConfigurationValidator:
|
|||
raise ValueError("base_url is required for Speaches services")
|
||||
return True
|
||||
|
||||
def _check_huggingface_api_key(self, model: str, api_key: str) -> bool:
|
||||
if not api_key.startswith("hf_"):
|
||||
raise ValueError(
|
||||
"Invalid Hugging Face API token format. Use a token that starts with "
|
||||
"'hf_' and has Inference Providers permission."
|
||||
)
|
||||
return True
|
||||
|
||||
def _check_google_vertex_realtime_api_key(self, model: str, service_config) -> bool:
|
||||
if not getattr(service_config, "project_id", None):
|
||||
raise ValueError("project_id is required for Google Vertex Realtime")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class ServiceProviders(str, Enum):
|
|||
CAMB = "camb"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
SPEACHES = "speaches"
|
||||
HUGGINGFACE = "huggingface"
|
||||
ASSEMBLYAI = "assemblyai"
|
||||
GLADIA = "gladia"
|
||||
RIME = "rime"
|
||||
|
|
@ -94,6 +95,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SPEACHES,
|
||||
ServiceProviders.HUGGINGFACE,
|
||||
ServiceProviders.ASSEMBLYAI,
|
||||
ServiceProviders.GLADIA,
|
||||
ServiceProviders.RIME,
|
||||
|
|
@ -255,6 +257,11 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config(
|
|||
),
|
||||
provider_docs_url="https://github.com/speaches-ai/speaches",
|
||||
)
|
||||
HUGGINGFACE_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Hugging Face",
|
||||
description="Hosted Hugging Face Inference Providers API for usage-based inference.",
|
||||
provider_docs_url="https://huggingface.co/docs/inference-providers/en/index",
|
||||
)
|
||||
AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure Speech Services",
|
||||
description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.",
|
||||
|
|
@ -471,6 +478,35 @@ class SpeachesLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
HUGGINGFACE_LLM_MODELS = [
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"deepseek-ai/DeepSeek-R1:fastest",
|
||||
"Qwen/Qwen3-Coder-480B-A35B-Instruct:fastest",
|
||||
]
|
||||
|
||||
|
||||
@register_llm
|
||||
class HuggingFaceLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = HUGGINGFACE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.HUGGINGFACE] = ServiceProviders.HUGGINGFACE
|
||||
model: str = Field(
|
||||
default="openai/gpt-oss-120b:cerebras",
|
||||
description="Hugging Face chat-completion model identifier, optionally with provider suffix.",
|
||||
json_schema_extra={
|
||||
"examples": HUGGINGFACE_LLM_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://router.huggingface.co/v1",
|
||||
description="Hugging Face OpenAI-compatible chat-completions router base URL.",
|
||||
)
|
||||
bill_to: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Hugging Face organization or user to bill using X-HF-Bill-To.",
|
||||
)
|
||||
|
||||
|
||||
MINIMAX_MODELS = [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
|
|
@ -741,6 +777,7 @@ LLMConfig = Annotated[
|
|||
DograhLLMService,
|
||||
AWSBedrockLLMConfiguration,
|
||||
SpeachesLLMConfiguration,
|
||||
HuggingFaceLLMConfiguration,
|
||||
MiniMaxLLMConfiguration,
|
||||
SarvamLLMConfiguration,
|
||||
],
|
||||
|
|
@ -1334,6 +1371,38 @@ class SpeachesSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
HUGGINGFACE_STT_MODELS = [
|
||||
"openai/whisper-large-v3-turbo",
|
||||
"openai/whisper-large-v3",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class HuggingFaceSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = HUGGINGFACE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.HUGGINGFACE] = ServiceProviders.HUGGINGFACE
|
||||
model: str = Field(
|
||||
default="openai/whisper-large-v3-turbo",
|
||||
description="Hugging Face ASR model identifier served through Inference Providers.",
|
||||
json_schema_extra={
|
||||
"examples": HUGGINGFACE_STT_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://router.huggingface.co/hf-inference",
|
||||
description="Hugging Face Inference Providers router base URL.",
|
||||
)
|
||||
bill_to: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Hugging Face organization or user to bill using X-HF-Bill-To.",
|
||||
)
|
||||
return_timestamps: bool = Field(
|
||||
default=False,
|
||||
description="Request timestamp chunks when supported by the selected provider/model.",
|
||||
)
|
||||
|
||||
|
||||
ASSEMBLYAI_STT_MODELS = ["u3-rt-pro"]
|
||||
ASSEMBLYAI_STT_LANGUAGES = ["en", "es", "de", "fr", "pt", "it"]
|
||||
|
||||
|
|
@ -1406,6 +1475,7 @@ STTConfig = Annotated[
|
|||
SpeechmaticsSTTConfiguration,
|
||||
SarvamSTTConfiguration,
|
||||
SpeachesSTTConfiguration,
|
||||
HuggingFaceSTTConfiguration,
|
||||
AssemblyAISTTConfiguration,
|
||||
GladiaSTTConfiguration,
|
||||
AzureSpeechSTTConfiguration,
|
||||
|
|
|
|||
|
|
@ -39,8 +39,17 @@ from pipecat.services.google.vertex.llm import (
|
|||
GoogleVertexLLMSettings,
|
||||
)
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.huggingface.llm import (
|
||||
HuggingFaceLLMService,
|
||||
HuggingFaceLLMSettings,
|
||||
)
|
||||
from pipecat.services.huggingface.stt import (
|
||||
HuggingFaceSTTService,
|
||||
HuggingFaceSTTSettings,
|
||||
)
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxTTSSettings
|
||||
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import (
|
||||
|
|
@ -218,6 +227,22 @@ def create_stt_service(
|
|||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.HUGGINGFACE.value:
|
||||
base_url = (
|
||||
getattr(user_config.stt, "base_url", None)
|
||||
or "https://router.huggingface.co/hf-inference"
|
||||
)
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return HuggingFaceSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
base_url=base_url,
|
||||
bill_to=getattr(user_config.stt, "bill_to", None),
|
||||
settings=HuggingFaceSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
return_timestamps=getattr(user_config.stt, "return_timestamps", False),
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.ASSEMBLYAI.value:
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
settings_kwargs = {"model": user_config.stt.model, "language": language}
|
||||
|
|
@ -320,6 +345,7 @@ def create_tts_service(
|
|||
kwargs["base_url"] = base_url
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
sample_rate=OPENAI_SAMPLE_RATE,
|
||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
|
|
@ -581,6 +607,7 @@ def create_llm_service_from_provider(
|
|||
location: str | None = None,
|
||||
credentials: str | None = None,
|
||||
temperature: float | None = None,
|
||||
bill_to: str | None = None,
|
||||
):
|
||||
"""Create an LLM service from explicit provider/model/api_key.
|
||||
|
||||
|
|
@ -663,6 +690,15 @@ def create_llm_service_from_provider(
|
|||
api_key=api_key or "none",
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.HUGGINGFACE.value:
|
||||
base_url = base_url or "https://router.huggingface.co/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return HuggingFaceLLMService(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
bill_to=bill_to,
|
||||
settings=HuggingFaceLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
base_url = base_url or "https://api.minimax.io/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
|
|
@ -875,6 +911,9 @@ def create_llm_service(user_config, correlation_id: str | None = None):
|
|||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.HUGGINGFACE.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
kwargs["bill_to"] = user_config.llm.bill_to
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue