fix: create mps account on migrate to v2

This commit is contained in:
Abhishek Kumar 2026-06-12 14:53:36 +05:30
parent 8f241b89d2
commit 724e1d456b
14 changed files with 666 additions and 61 deletions

View file

@ -5,7 +5,11 @@ from loguru import logger
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
from api.constants import (
DEFAULT_CAMPAIGN_RETRY_CONFIG,
DEFAULT_ORG_CONCURRENCY_LIMIT,
DEPLOYMENT_MODE,
)
from api.db import db_client
from api.db.models import UserModel
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
@ -55,6 +59,7 @@ from api.services.configuration.registry import (
ServiceProviders,
ServiceType,
)
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.organization_context import (
OrganizationContextResponse,
get_organization_context,
@ -359,6 +364,23 @@ async def migrate_model_configuration_v2(
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
if DEPLOYMENT_MODE != "oss":
try:
await ensure_hosted_mps_billing_account_v2(
organization_id,
created_by=str(user.provider_id),
)
except Exception as exc:
logger.error(
"Failed to initialize MPS billing v2 account for organization {}: {}",
organization_id,
exc,
)
raise HTTPException(
status_code=502,
detail="Failed to initialize MPS billing v2 account",
)
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,

View file

@ -74,6 +74,10 @@ class MPSBillingCreditsResponse(BaseModel):
total_quota: float = 0.0
account: Optional[MPSBillingAccountResponse] = None
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
total_count: int = 0
page: int = 1
limit: int = 50
total_pages: int = 0
def _optional_int(value: Any) -> Optional[int]:
@ -224,10 +228,11 @@ async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResp
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
async def get_billing_credits(
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
user: UserModel = Depends(get_user),
):
"""Return legacy MPS credits or v2 billing ledger details for the org."""
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
try:
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
return await _legacy_mps_credits_response(user)
@ -239,11 +244,18 @@ async def get_billing_credits(
ledger = await mps_service_key_client.get_credit_ledger(
organization_id=organization_id,
page=page,
limit=limit,
created_by=str(user.provider_id),
)
account = ledger.get("account") or {}
ledger_entries = ledger.get("ledger_entries") or []
total_count = int(ledger.get("total_count") or len(ledger_entries))
response_limit = int(ledger.get("limit") or limit)
total_pages = int(
ledger.get("total_pages")
or ((total_count + response_limit - 1) // response_limit)
)
workflow_ids_by_run_id: dict[int, int] = {}
workflow_run_ids = {
workflow_run_id
@ -266,6 +278,8 @@ async def get_billing_credits(
for entry in ledger_entries
if float(entry.get("credits_delta") or 0.0) < 0
)
if ledger.get("total_debits_credits") is not None:
total_debits = float(ledger["total_debits_credits"])
return MPSBillingCreditsResponse(
billing_version="v2",
@ -308,6 +322,10 @@ async def get_billing_credits(
)
for entry in ledger_entries
],
total_count=total_count,
page=int(ledger.get("page") or page),
limit=response_limit,
total_pages=total_pages,
)
except HTTPException:
raise

View file

@ -12,6 +12,7 @@ from api.enums import PostHogEvent
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.posthog_client import capture_event
from api.utils.auth import decode_jwt_token
@ -110,6 +111,19 @@ async def get_user(
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
try:
await ensure_hosted_mps_billing_account_v2(
organization.id,
created_by=str(stack_user["id"]),
)
except Exception:
logger.warning(
"Failed to initialize hosted MPS billing account for "
"organization {}",
organization.id,
exc_info=True,
)
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users

View file

@ -0,0 +1,23 @@
from typing import Optional
from api.constants import DEPLOYMENT_MODE
from api.services.mps_service_key_client import mps_service_key_client
async def ensure_hosted_mps_billing_account_v2(
organization_id: int,
*,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Ensure hosted orgs have an MPS billing v2 account.
OSS deployments use legacy per-key quota accounting and do not create MPS
billing accounts.
"""
if DEPLOYMENT_MODE == "oss":
return None
return await mps_service_key_client.ensure_billing_account_v2(
organization_id=organization_id,
created_by=created_by,
)

View file

@ -394,6 +394,7 @@ class MPSServiceKeyClient:
async def get_credit_ledger(
self,
organization_id: int,
page: int = 1,
limit: int = 50,
created_by: Optional[str] = None,
) -> dict:
@ -401,7 +402,7 @@ class MPSServiceKeyClient:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
params={"limit": limit},
params={"page": page, "limit": limit},
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
@ -449,6 +450,34 @@ class MPSServiceKeyClient:
response=response,
)
async def ensure_billing_account_v2(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Create or return the MPS v2 billing account for an organization."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to ensure MPS billing account v2: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to ensure MPS billing account v2: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,

View file

@ -1,9 +1,13 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
@ -358,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
@pytest.mark.asyncio
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
monkeypatch,
):
from api.routes import organization as organization_routes
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
)
expected_response = OrganizationAIModelConfigurationResponse(
configuration={"version": 2, "mode": "dograh"},
effective_configuration={},
source="organization_v2",
)
class FakeValidator:
async def validate(self, *args, **kwargs):
return {"status": [{"model": "all", "message": "ok"}]}
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
upsert = AsyncMock()
migrate_workflows = AsyncMock()
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_routes,
"get_organization_ai_model_configuration_v2",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
organization_routes.db_client,
"get_user_configurations",
AsyncMock(return_value=legacy),
)
monkeypatch.setattr(
organization_routes,
"UserConfigurationValidator",
lambda: FakeValidator(),
)
monkeypatch.setattr(
organization_routes,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
monkeypatch.setattr(
organization_routes,
"upsert_organization_ai_model_configuration_v2",
upsert,
)
monkeypatch.setattr(
organization_routes,
"migrate_workflow_model_configurations_to_v2",
migrate_workflows,
)
monkeypatch.setattr(
organization_routes,
"_model_configuration_v2_response",
AsyncMock(return_value=expected_response),
)
user = SimpleNamespace(
id=7,
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_routes.migrate_model_configuration_v2(
force=False,
user=user,
)
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
upsert.assert_awaited_once()
migrate_workflows.assert_awaited_once_with(
organization_id=42,
fallback_user_config=legacy,
)
assert response == expected_response

View file

@ -0,0 +1,68 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.auth import depends as auth_depends
@pytest.mark.asyncio
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
stack_user = {
"id": "stack-user-1",
"selected_team_id": "team-1",
"primary_email_verified": False,
}
user = SimpleNamespace(
id=7,
email=None,
provider_id="stack-user-1",
selected_organization_id=None,
)
organization = SimpleNamespace(id=42)
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
monkeypatch.setattr(
auth_depends.stackauth,
"get_user",
AsyncMock(return_value=stack_user),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_user_by_provider_id",
AsyncMock(return_value=(user, False)),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_organization_by_provider_id",
AsyncMock(return_value=(organization, True)),
)
monkeypatch.setattr(
auth_depends.db_client,
"add_user_to_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"update_user_selected_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_user_configurations",
AsyncMock(return_value=existing_config),
)
monkeypatch.setattr(
auth_depends,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
result = await auth_depends.get_user(authorization="Bearer token")
assert result is user
assert result.selected_organization_id == 42
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")

View file

@ -175,6 +175,130 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(
200,
{
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.ensure_billing_account_v2(
organization_id=42,
created_by="provider-123",
) == {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/balance",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, params, headers):
calls.append(("GET", url, params, headers))
return _Response(
200,
{
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_credit_ledger(
organization_id=42,
page=3,
limit=25,
) == {
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
{"page": 3, "limit": 25},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []

View file

@ -31,3 +31,69 @@ async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch)
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_usage,
"_get_mps_billing_account_status",
AsyncMock(return_value={"billing_mode": "v2"}),
)
get_ledger = AsyncMock(
return_value={
"account": {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": 250,
"currency": "USD",
},
"ledger_entries": [
{
"id": 99,
"entry_type": "grant",
"origin": "account_creation",
"credits_delta": 250,
"balance_after": 250,
"created_at": "2026-06-12T00:00:00Z",
}
],
"total_debits_credits": 75,
"total_count": 101,
"page": 3,
"limit": 25,
"total_pages": 5,
}
)
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_credit_ledger",
get_ledger,
)
user = SimpleNamespace(
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_usage.get_billing_credits(
page=3,
limit=25,
user=user,
)
get_ledger.assert_awaited_once_with(
organization_id=42,
page=3,
limit=25,
created_by="provider-123",
)
assert response.billing_version == "v2"
assert response.total_credits_used == 75
assert response.total_count == 101
assert response.page == 3
assert response.limit == 25
assert response.total_pages == 5
assert response.ledger_entries[0].id == 99

View file

@ -1,7 +1,14 @@
"use client";
import { CircleDollarSign, CreditCard, RefreshCw } from "lucide-react";
import {
ChevronLeft,
ChevronRight,
CircleDollarSign,
CreditCard,
RefreshCw,
} from "lucide-react";
import Link from "next/link";
import { useRouter, useSearchParams } from "next/navigation";
import { useCallback, useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
@ -23,6 +30,8 @@ import {
import { useAppConfig } from "@/context/AppConfigContext";
import { useAuth } from "@/lib/auth";
const LEDGER_PAGE_SIZE = 50;
const formatCredits = (value: number | null | undefined) => (
(value ?? 0).toLocaleString(undefined, {
maximumFractionDigits: 2,
@ -93,13 +102,26 @@ const getRunHref = (entry: MpsCreditLedgerEntryResponse) => {
return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`;
};
const getPageFromSearchParams = (
searchParams: { get: (name: string) => string | null },
) => {
const pageParam = searchParams.get("page");
const page = pageParam ? Number.parseInt(pageParam, 10) : 1;
return Number.isFinite(page) && page > 0 ? page : 1;
};
export default function BillingPage() {
const router = useRouter();
const searchParams = useSearchParams();
const auth = useAuth();
const { config } = useAppConfig();
const [credits, setCredits] = useState<MpsBillingCreditsResponse | null>(null);
const [loading, setLoading] = useState(true);
const [refreshing, setRefreshing] = useState(false);
const [purchasing, setPurchasing] = useState(false);
const [currentPage, setCurrentPage] = useState(
() => getPageFromSearchParams(searchParams),
);
const isBillingV2 = credits?.billing_version === "v2";
const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss";
@ -109,8 +131,14 @@ export default function BillingPage() {
const usagePercent = totalQuota > 0 ? Math.min(100, Math.round((usedCredits / totalQuota) * 100)) : 0;
const ledgerEntries = useMemo(() => credits?.ledger_entries ?? [], [credits?.ledger_entries]);
const ledgerPage = credits?.page ?? currentPage;
const ledgerTotalCount = credits?.total_count ?? ledgerEntries.length;
const ledgerTotalPages = credits?.total_pages ?? 0;
const fetchCredits = useCallback(async ({ silent = false }: { silent?: boolean } = {}) => {
const fetchCredits = useCallback(async (
page: number,
{ silent = false }: { silent?: boolean } = {},
) => {
if (auth.loading) {
return;
}
@ -128,7 +156,7 @@ export default function BillingPage() {
try {
const response = await getBillingCreditsApiV1OrganizationsBillingCreditsGet({
query: { limit: 50 },
query: { page, limit: LEDGER_PAGE_SIZE },
});
if (response.error) {
@ -146,11 +174,36 @@ export default function BillingPage() {
}, [auth.isAuthenticated, auth.loading]);
useEffect(() => {
fetchCredits();
}, [fetchCredits]);
const nextPage = getPageFromSearchParams(searchParams);
setCurrentPage((previousPage) => (
previousPage === nextPage ? previousPage : nextPage
));
}, [searchParams]);
useEffect(() => {
fetchCredits(currentPage);
}, [currentPage, fetchCredits]);
const handleRefresh = () => {
fetchCredits({ silent: true });
fetchCredits(currentPage, { silent: true });
};
const updateUrlPage = useCallback((page: number) => {
const newParams = new URLSearchParams(searchParams.toString());
if (page > 1) {
newParams.set("page", page.toString());
} else {
newParams.delete("page");
}
const queryString = newParams.toString();
router.push(queryString ? `/billing?${queryString}` : "/billing");
}, [router, searchParams]);
const handlePageChange = (page: number) => {
const nextPage = Math.max(1, page);
setCurrentPage(nextPage);
updateUrlPage(nextPage);
};
const handlePurchaseCredits = async () => {
@ -233,7 +286,7 @@ export default function BillingPage() {
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">
{isBillingV2 ? "Recent ledger debit total" : "Current allocation usage"}
{isBillingV2 ? "Total ledger debits" : "Current allocation usage"}
</p>
</CardContent>
</Card>
@ -315,6 +368,33 @@ export default function BillingPage() {
No ledger entries yet
</div>
)}
{ledgerTotalPages > 1 && (
<div className="flex items-center justify-between mt-6">
<p className="text-sm text-muted-foreground">
Page {ledgerPage} of {ledgerTotalPages} ({ledgerTotalCount} total entries)
</p>
<div className="flex gap-2">
<Button
variant="outline"
size="sm"
onClick={() => handlePageChange(ledgerPage - 1)}
disabled={ledgerPage <= 1 || loading || refreshing}
>
<ChevronLeft className="h-4 w-4" />
Previous
</Button>
<Button
variant="outline"
size="sm"
onClick={() => handlePageChange(ledgerPage + 1)}
disabled={ledgerPage >= ledgerTotalPages || loading || refreshing}
>
Next
<ChevronRight className="h-4 w-4" />
</Button>
</div>
</div>
)}
</CardContent>
</Card>
) : (

View file

@ -1256,7 +1256,7 @@ export const getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet = <ThrowOnError e
/**
* Get Billing Credits
*
* Return legacy MPS credits or v2 billing ledger details for the org.
* Return legacy MPS credits or paginated v2 billing ledger details for the org.
*/
export const getBillingCreditsApiV1OrganizationsBillingCreditsGet = <ThrowOnError extends boolean = false>(options?: Options<GetBillingCreditsApiV1OrganizationsBillingCreditsGetData, ThrowOnError>) => (options?.client ?? client).get<GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses, GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors, ThrowOnError>({ url: '/api/v1/organizations/billing/credits', ...options });

View file

@ -3138,6 +3138,22 @@ export type MpsBillingCreditsResponse = {
* Ledger Entries
*/
ledger_entries?: Array<MpsCreditLedgerEntryResponse>;
/**
* Total Count
*/
total_count?: number;
/**
* Page
*/
page?: number;
/**
* Limit
*/
limit?: number;
/**
* Total Pages
*/
total_pages?: number;
};
/**
@ -11482,6 +11498,10 @@ export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetData = {
};
path?: never;
query?: {
/**
* Page
*/
page?: number;
/**
* Limit
*/

View file

@ -17,7 +17,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
type ModelMode = "dograh" | "byok";
type ModelMode = "realtime" | "dograh" | "byok";
interface DograhDefaults {
voices: string[];
@ -125,24 +125,35 @@ function effectiveConfigToLegacyShape(config: Record<string, unknown> | null): R
};
}
function emptyByokInitialConfig(): Record<string, unknown> {
function emptyByokInitialConfig(isRealtime: boolean): Record<string, unknown> {
return {
is_realtime: false,
is_realtime: isRealtime,
};
}
// The v2 editor surfaces realtime ("Speech to Speech") and pipeline (BYOK) as
// separate tabs, so each tab gets its own initial config. A tab is pre-filled
// only when the saved (or effective) configuration matches that tab's mode;
// otherwise it starts empty so the other tab's data does not leak across.
function getByokInitialConfig(
configuration: Record<string, unknown> | null,
effectiveConfiguration: Record<string, unknown> | null,
wantRealtime: boolean,
): Record<string, unknown> {
const byokConfiguration = byokConfigToLegacyShape(configuration);
if (byokConfiguration) return byokConfiguration;
const matchesTab = (config: Record<string, unknown> | null) =>
config ? Boolean(config.is_realtime) === wantRealtime : false;
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
return emptyByokInitialConfig();
const byokConfiguration = byokConfigToLegacyShape(configuration);
if (byokConfiguration) {
return matchesTab(byokConfiguration) ? byokConfiguration : emptyByokInitialConfig(wantRealtime);
}
return effectiveConfigToLegacyShape(effectiveConfiguration) || emptyByokInitialConfig();
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
return emptyByokInitialConfig(wantRealtime);
}
const effective = effectiveConfigToLegacyShape(effectiveConfiguration);
return matchesTab(effective) ? (effective as Record<string, unknown>) : emptyByokInitialConfig(wantRealtime);
}
function buildDograhState(
@ -185,10 +196,12 @@ function preferredMode(
configuration: Record<string, unknown> | null,
effectiveConfiguration: Record<string, unknown> | null,
): ModelMode {
if (configuration?.mode === "dograh" || configuration?.mode === "byok") {
return configuration.mode;
if (configuration?.mode === "dograh") return "dograh";
if (configuration?.mode === "byok") {
return asRecord(configuration.byok)?.mode === "realtime" ? "realtime" : "byok";
}
return isDograhEffectiveConfig(effectiveConfiguration) ? "dograh" : "byok";
if (isDograhEffectiveConfig(effectiveConfiguration)) return "dograh";
return Boolean(effectiveConfiguration?.is_realtime) ? "realtime" : "byok";
}
function hasRequiredApiKey(
@ -249,7 +262,8 @@ export function AIModelConfigurationV2Editor({
speed: defaults.dograh.defaults.speed,
language: defaults.dograh.defaults.language,
}));
const [byokInitialConfig, setByokInitialConfig] = useState<Record<string, unknown> | null>(null);
const [realtimeInitialConfig, setRealtimeInitialConfig] = useState<Record<string, unknown> | null>(null);
const [pipelineInitialConfig, setPipelineInitialConfig] = useState<Record<string, unknown> | null>(null);
const [isSavingDograh, setIsSavingDograh] = useState(false);
const [error, setError] = useState<string | null>(null);
@ -258,7 +272,8 @@ export function AIModelConfigurationV2Editor({
const rawEffectiveConfiguration = asRecord(effectiveConfiguration);
setMode(preferredMode(rawConfiguration, rawEffectiveConfiguration));
setDograh(buildDograhState(defaults, rawConfiguration, rawEffectiveConfiguration));
setByokInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration));
setRealtimeInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, true));
setPipelineInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, false));
}, [configuration, defaults, effectiveConfiguration]);
const saveDograhConfiguration = async () => {
@ -322,28 +337,30 @@ export function AIModelConfigurationV2Editor({
)}
<Tabs value={mode} onValueChange={(value) => setMode(value as ModelMode)} className="space-y-6">
<TabsList className="grid w-full grid-cols-2">
<TabsList className="grid w-full grid-cols-3">
<TabsTrigger value="realtime">Speech to Speech</TabsTrigger>
<TabsTrigger value="dograh">Dograh</TabsTrigger>
<TabsTrigger value="byok">BYOK</TabsTrigger>
</TabsList>
<TabsContent value="realtime" className="mt-0">
<p className="mb-4 text-sm text-muted-foreground">
A single speech-to-speech model handles the conversation in realtime (no separate transcriber or voice). An LLM is still required for variable extraction and QA.
</p>
<ServiceConfigurationForm
key={`realtime-${JSON.stringify(realtimeInitialConfig)}`}
mode="global"
forceRealtime
configurationDefaults={defaultsForByok}
initialConfig={realtimeInitialConfig}
submitLabel={submitLabel}
onSave={saveByokConfiguration}
/>
</TabsContent>
<TabsContent value="dograh" className="mt-0">
<div className="rounded-lg border p-5">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2 sm:col-span-2">
<Label htmlFor="dograh-api-key">API Key</Label>
<div className="relative">
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input
id="dograh-api-key"
className="pl-9"
value={dograh.api_key}
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
placeholder="Enter API key"
/>
</div>
</div>
<div className="space-y-2">
<Label>Voice</Label>
<Select value={dograh.voice} onValueChange={(voice) => setDograh({ ...dograh, voice })}>
@ -394,6 +411,20 @@ export function AIModelConfigurationV2Editor({
</SelectContent>
</Select>
</div>
<div className="space-y-2 sm:col-span-2">
<Label htmlFor="dograh-api-key">API Key</Label>
<div className="relative">
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input
id="dograh-api-key"
className="pl-9"
value={dograh.api_key}
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
placeholder="Enter API key"
/>
</div>
</div>
</div>
<Button type="button" className="mt-6 w-full" onClick={saveDograhConfiguration} disabled={isSavingDograh}>
@ -405,10 +436,11 @@ export function AIModelConfigurationV2Editor({
<TabsContent value="byok" className="mt-0">
<ServiceConfigurationForm
key={JSON.stringify(byokInitialConfig)}
key={`byok-${JSON.stringify(pipelineInitialConfig)}`}
mode="global"
forceRealtime={false}
configurationDefaults={defaultsForByok}
initialConfig={byokInitialConfig}
initialConfig={pipelineInitialConfig}
submitLabel={submitLabel}
onSave={saveByokConfiguration}
/>

View file

@ -101,6 +101,13 @@ export interface ServiceConfigurationFormProps {
submitLabel?: string;
configurationDefaults?: ServiceConfigurationDefaults | null;
initialConfig?: Record<string, unknown> | null;
/**
* When set, locks the realtime/pipeline mode to this value and hides the
* in-form toggle. The v2 editor uses this to surface realtime
* ("Speech to Speech") and pipeline (BYOK) as separate top-level tabs.
* Leave undefined to keep the user-controllable toggle (legacy + overrides).
*/
forceRealtime?: boolean;
}
function getProviderDisplayName(
@ -130,10 +137,11 @@ export function ServiceConfigurationForm({
submitLabel,
configurationDefaults,
initialConfig,
forceRealtime,
}: ServiceConfigurationFormProps) {
const [apiError, setApiError] = useState<string | null>(null);
const [isSaving, setIsSaving] = useState(false);
const [isRealtime, setIsRealtime] = useState(false);
const [isRealtime, setIsRealtime] = useState(forceRealtime ?? false);
const { userConfig } = useUserConfig();
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
llm: {},
@ -227,9 +235,9 @@ export function ServiceConfigurationForm({
realtime: realtimeSchemas,
});
// Restore realtime toggle
// Restore realtime toggle (skip when the parent locks the mode)
const configData = configSource as Record<string, unknown> | null;
if (configData?.is_realtime) {
if (forceRealtime === undefined && configData?.is_realtime) {
setIsRealtime(true);
}
@ -867,22 +875,24 @@ export function ServiceConfigurationForm({
return (
<form onSubmit={handleSubmit(onSubmit)}>
{/* Realtime toggle */}
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
<div>
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
Realtime Mode
</Label>
<p className="text-xs text-muted-foreground mt-0.5">
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
</p>
{/* Realtime toggle — hidden when the parent locks the mode (v2 tabs) */}
{forceRealtime === undefined && (
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
<div>
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
Realtime Mode
</Label>
<p className="text-xs text-muted-foreground mt-0.5">
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
</p>
</div>
<Switch
id="realtime-toggle"
checked={isRealtime}
onCheckedChange={setIsRealtime}
/>
</div>
<Switch
id="realtime-toggle"
checked={isRealtime}
onCheckedChange={setIsRealtime}
/>
</div>
)}
<Card>
<CardContent className="pt-6">