mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: add AWS Bedrock support
This commit is contained in:
parent
1604e306ec
commit
fe84f086ba
30 changed files with 546 additions and 195 deletions
|
|
@ -4,19 +4,16 @@ import json
|
|||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.gen_ai.json_parser import parse_llm_json
|
||||
from api.services.pipecat.service_factory import create_llm_service_from_provider
|
||||
from api.services.workflow.qa.conversation import (
|
||||
build_conversation_structure,
|
||||
format_transcript,
|
||||
split_events_by_node,
|
||||
)
|
||||
from api.services.workflow.qa.llm_config import (
|
||||
accumulate_token_usage,
|
||||
resolve_llm_config,
|
||||
)
|
||||
from api.services.workflow.qa.llm_config import resolve_llm_config
|
||||
from api.services.workflow.qa.metrics import compute_call_metrics
|
||||
from api.services.workflow.qa.node_summary import (
|
||||
CONVERSATION_SUMMARY_SYSTEM_PROMPT,
|
||||
|
|
@ -28,15 +25,22 @@ from api.services.workflow.qa.tracing import (
|
|||
setup_langfuse_parent_context,
|
||||
)
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
async def _run_llm_inference(llm, messages: list[dict]) -> str | None:
|
||||
"""Run a one-shot LLM inference using the pipecat service."""
|
||||
context = LLMContext()
|
||||
context.set_messages(messages)
|
||||
return await llm.run_inference(context)
|
||||
|
||||
|
||||
async def _generate_conversation_summary(
|
||||
client: AsyncOpenAI,
|
||||
llm,
|
||||
model: str,
|
||||
transcript: str,
|
||||
parent_ctx,
|
||||
node_name: str,
|
||||
total_token_usage: dict,
|
||||
) -> str:
|
||||
"""Generate a summary of the conversation so far (before the current node).
|
||||
|
||||
|
|
@ -48,13 +52,7 @@ async def _generate_conversation_summary(
|
|||
]
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
summary = response.choices[0].message.content or ""
|
||||
accumulate_token_usage(total_token_usage, response)
|
||||
summary = await _run_llm_inference(llm, messages) or ""
|
||||
|
||||
span_name = f"conversation-summary-before-{node_name}"
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name)
|
||||
|
|
@ -82,7 +80,7 @@ async def run_per_node_qa_analysis(
|
|||
Falls back to whole-call QA if events lack node_id.
|
||||
|
||||
Returns:
|
||||
Dict with node_results, token_usage, model
|
||||
Dict with node_results, model
|
||||
"""
|
||||
logs = workflow_run.logs or {}
|
||||
rtf_events = logs.get("realtime_feedback_events", [])
|
||||
|
|
@ -107,7 +105,9 @@ async def run_per_node_qa_analysis(
|
|||
return {"error": "no_system_prompt", "node_results": {}}
|
||||
|
||||
# Resolve LLM config
|
||||
model, api_key, base_url = await resolve_llm_config(qa_node_data, workflow_run)
|
||||
provider, model, api_key, service_kwargs = await resolve_llm_config(
|
||||
qa_node_data, workflow_run
|
||||
)
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
f"No LLM API key configured for QA analysis on run {workflow_run_id}"
|
||||
|
|
@ -122,13 +122,9 @@ async def run_per_node_qa_analysis(
|
|||
# Set up Langfuse tracing
|
||||
parent_ctx = setup_langfuse_parent_context(workflow_run)
|
||||
|
||||
# Build LLM client
|
||||
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
client = AsyncOpenAI(**client_kwargs)
|
||||
# Build LLM service
|
||||
llm = create_llm_service_from_provider(provider, model, api_key, **service_kwargs)
|
||||
|
||||
total_token_usage: dict[str, int] = {}
|
||||
node_results: dict[str, Any] = {}
|
||||
prior_conversation: list[dict] = [] # Running accumulation of all prior nodes
|
||||
|
||||
|
|
@ -150,12 +146,11 @@ async def run_per_node_qa_analysis(
|
|||
if idx > 0 and prior_conversation:
|
||||
prior_transcript = format_transcript(prior_conversation)
|
||||
previous_conversation_summary = await _generate_conversation_summary(
|
||||
client,
|
||||
llm,
|
||||
model,
|
||||
prior_transcript,
|
||||
parent_ctx,
|
||||
node_name,
|
||||
total_token_usage,
|
||||
)
|
||||
|
||||
# Substitute placeholders in the user's system prompt
|
||||
|
|
@ -174,14 +169,7 @@ async def run_per_node_qa_analysis(
|
|||
|
||||
# Call QA LLM
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
extra_body={"stream": False},
|
||||
)
|
||||
raw_response = response.choices[0].message.content
|
||||
accumulate_token_usage(total_token_usage, response)
|
||||
raw_response = await _run_llm_inference(llm, messages)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"QA LLM call failed for node '{node_name}' on run {workflow_run_id}: {e}"
|
||||
|
|
@ -221,13 +209,10 @@ async def run_per_node_qa_analysis(
|
|||
# Append this node's conversation to running total
|
||||
prior_conversation.extend(node_conversation)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
return {
|
||||
"node_results": node_results,
|
||||
"model": model,
|
||||
}
|
||||
if total_token_usage:
|
||||
result["token_usage"] = total_token_usage
|
||||
return result
|
||||
|
||||
|
||||
async def _run_whole_call_qa_analysis(
|
||||
|
|
@ -262,7 +247,9 @@ async def _run_whole_call_qa_analysis(
|
|||
logger.warning("No system prompt defined for QA Node")
|
||||
return {"error": "no_system_prompt", "node_results": {}}
|
||||
|
||||
model, api_key, base_url = await resolve_llm_config(qa_node_data, workflow_run)
|
||||
provider, model, api_key, service_kwargs = await resolve_llm_config(
|
||||
qa_node_data, workflow_run
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
|
|
@ -284,27 +271,14 @@ async def _run_whole_call_qa_analysis(
|
|||
]
|
||||
|
||||
# Call LLM
|
||||
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
client = AsyncOpenAI(**client_kwargs)
|
||||
llm = create_llm_service_from_provider(provider, model, api_key, **service_kwargs)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
raw_response = response.choices[0].message.content
|
||||
raw_response = await _run_llm_inference(llm, messages)
|
||||
except Exception as e:
|
||||
logger.error(f"QA LLM call failed for run {workflow_run_id}: {e}")
|
||||
return {"error": str(e), "node_results": {}}
|
||||
|
||||
# Extract token usage
|
||||
token_usage: dict[str, int] = {}
|
||||
accumulate_token_usage(token_usage, response)
|
||||
|
||||
# Parse response
|
||||
node_result: dict[str, Any] = {
|
||||
"node_name": "whole_call",
|
||||
|
|
@ -325,10 +299,7 @@ async def _run_whole_call_qa_analysis(
|
|||
parent_ctx = setup_langfuse_parent_context(workflow_run)
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, raw_response, "qa-analysis")
|
||||
|
||||
result: dict[str, Any] = {
|
||||
return {
|
||||
"node_results": {"whole_call": node_result},
|
||||
"model": model,
|
||||
}
|
||||
if token_usage:
|
||||
result["token_usage"] = token_usage
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue