mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: add support for self hosted llm models
This commit is contained in:
parent
31e075d114
commit
ac0731a374
17 changed files with 179 additions and 48 deletions
|
|
@ -8,7 +8,11 @@ from api.services.configuration.registry import ServiceProviders
|
|||
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings, GenerationConfig
|
||||
from pipecat.services.cartesia.tts import (
|
||||
CartesiaTTSService,
|
||||
CartesiaTTSSettings,
|
||||
GenerationConfig,
|
||||
)
|
||||
from pipecat.services.deepgram.flux.stt import (
|
||||
DeepgramFluxSTTService,
|
||||
DeepgramFluxSTTSettings,
|
||||
|
|
@ -212,13 +216,19 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
generation_config = GenerationConfig(speed=speed) if speed and speed != 1.0 else None
|
||||
generation_config = (
|
||||
GenerationConfig(speed=speed) if speed and speed != 1.0 else None
|
||||
)
|
||||
return CartesiaTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
settings=CartesiaTTSSettings(
|
||||
voice=user_config.tts.voice,
|
||||
model=user_config.tts.model,
|
||||
**({"generation_config": generation_config} if generation_config else {}),
|
||||
**(
|
||||
{"generation_config": generation_config}
|
||||
if generation_config
|
||||
else {}
|
||||
),
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
|
|
@ -353,6 +363,12 @@ def create_llm_service_from_provider(
|
|||
aws_region=aws_region,
|
||||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.SELF_HOSTED.value:
|
||||
return OpenAILLMService(
|
||||
base_url=base_url or "http://localhost:11434/v1",
|
||||
api_key=api_key or "none",
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
|
||||
|
||||
|
|
@ -368,6 +384,8 @@ def create_llm_service(user_config):
|
|||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.SELF_HOSTED.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue