mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
fix: fix review comments
This commit is contained in:
parent
dfee942f9a
commit
c7e0d06a2b
13 changed files with 477 additions and 253 deletions
|
|
@ -419,8 +419,9 @@ class TestStartGreeting:
|
|||
"""When a node has no greeting, the engine should queue initial LLM generation."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
|
|
@ -430,8 +431,9 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
calculate_workflow_run_cost,
|
||||
)
|
||||
|
|
@ -85,3 +86,96 @@ async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrappe
|
|||
assert saved_kwargs["run_id"] == workflow_run.id
|
||||
assert "cost_breakdown" in saved_kwargs["cost_info"]
|
||||
update_usage.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.cost_info = {"call_id": "preserve-me"}
|
||||
|
||||
usage_delta_one = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 1_000,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 1_100,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 3,
|
||||
}
|
||||
usage_delta_two = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 2_000,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 2_050,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 4,
|
||||
}
|
||||
merged_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 3_000,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 3_150,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
}
|
||||
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
|
||||
second_delta = await apply_usage_delta_to_organization(
|
||||
workflow_run, usage_delta_two
|
||||
)
|
||||
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
|
||||
total_workflow_run.usage_info = merged_usage
|
||||
total_cost = await build_workflow_run_cost_info(total_workflow_run)
|
||||
|
||||
assert first_delta is not None
|
||||
assert second_delta is not None
|
||||
assert total_cost is not None
|
||||
assert update_usage.await_count == 2
|
||||
assert update_usage.await_args_list[0].args == (
|
||||
42,
|
||||
first_delta["dograh_token_usage"],
|
||||
3.0,
|
||||
first_delta["charge_usd"],
|
||||
)
|
||||
assert update_usage.await_args_list[1].args == (
|
||||
42,
|
||||
second_delta["dograh_token_usage"],
|
||||
4.0,
|
||||
second_delta["charge_usd"],
|
||||
)
|
||||
assert (
|
||||
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
|
||||
) == pytest.approx(total_cost["dograh_token_usage"])
|
||||
assert (
|
||||
first_delta["charge_usd"] + second_delta["charge_usd"]
|
||||
== total_cost["charge_usd"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -968,3 +969,226 @@ async def test_text_chat_session_is_not_accessible_from_another_org(
|
|||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}"
|
||||
)
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_requires_selected_org_scope(
|
||||
db_session,
|
||||
async_session,
|
||||
test_client_factory,
|
||||
):
|
||||
workflow_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "You are a helpful assistant.",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
"add_global_prompt": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
org_a = OrganizationModel(provider_id="textchat-scope-a")
|
||||
org_b = OrganizationModel(provider_id="textchat-scope-b")
|
||||
async_session.add_all([org_a, org_b])
|
||||
await async_session.flush()
|
||||
|
||||
user = UserModel(
|
||||
provider_id="textchat-scope-user",
|
||||
selected_organization_id=org_a.id,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.flush()
|
||||
|
||||
await db_session.update_user_configuration(
|
||||
user_id=user.id,
|
||||
configuration=UserConfiguration.model_validate(USER_CONFIGURATION),
|
||||
)
|
||||
|
||||
workflow = await db_session.create_workflow(
|
||||
name="Cross-org workflow",
|
||||
workflow_definition=workflow_definition,
|
||||
user_id=user.id,
|
||||
organization_id=org_b.id,
|
||||
)
|
||||
|
||||
llm = MockLLMService(
|
||||
mock_steps=[MockLLMService.create_text_chunks("Should never run.")],
|
||||
chunk_delay=0.001,
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
return_value=llm,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
json={},
|
||||
)
|
||||
|
||||
assert create_response.status_code == 404
|
||||
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
|
||||
workflow.id,
|
||||
organization_id=org_b.id,
|
||||
)
|
||||
assert total_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_rejects_quota_before_creating_run(
|
||||
db_session,
|
||||
async_session,
|
||||
test_client_factory,
|
||||
):
|
||||
workflow_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "You are a helpful assistant.",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
"add_global_prompt": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
user, workflow = await _create_user_and_workflow(
|
||||
db_session,
|
||||
async_session,
|
||||
workflow_definition=workflow_definition,
|
||||
suffix="quota-create",
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.routes.workflow_text_chat.check_dograh_quota",
|
||||
new=AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
has_quota=False,
|
||||
error_message="Quota exceeded",
|
||||
)
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
json={},
|
||||
)
|
||||
|
||||
assert create_response.status_code == 402
|
||||
assert create_response.json()["detail"] == "Quota exceeded"
|
||||
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
|
||||
workflow.id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
assert total_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_append_rejects_quota_without_mutating_session(
|
||||
db_session,
|
||||
async_session,
|
||||
test_client_factory,
|
||||
):
|
||||
workflow_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "You are a helpful assistant.",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
"add_global_prompt": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
user, workflow = await _create_user_and_workflow(
|
||||
db_session,
|
||||
async_session,
|
||||
workflow_definition=workflow_definition,
|
||||
suffix="quota-append",
|
||||
)
|
||||
|
||||
llm = MockLLMService(
|
||||
mock_steps=[
|
||||
MockLLMService.create_text_chunks("Hello from the workflow tester.")
|
||||
],
|
||||
chunk_delay=0.001,
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
with (
|
||||
patch(
|
||||
"api.routes.workflow_text_chat.check_dograh_quota",
|
||||
new=AsyncMock(
|
||||
side_effect=[
|
||||
SimpleNamespace(has_quota=True, error_message=""),
|
||||
SimpleNamespace(
|
||||
has_quota=False,
|
||||
error_message="Quota exceeded on append",
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.create_llm_service",
|
||||
return_value=llm,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
create_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
|
||||
json={},
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
created = create_response.json()
|
||||
|
||||
append_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}/messages",
|
||||
json={
|
||||
"text": "This should be rejected",
|
||||
"expected_revision": created["revision"],
|
||||
},
|
||||
)
|
||||
assert append_response.status_code == 402
|
||||
|
||||
session_response = await client.get(
|
||||
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}"
|
||||
)
|
||||
assert session_response.status_code == 200
|
||||
|
||||
session_payload = session_response.json()
|
||||
assert append_response.json()["detail"] == "Quota exceeded on append"
|
||||
assert session_payload["revision"] == created["revision"]
|
||||
assert session_payload["session_data"]["turns"] == created["session_data"]["turns"]
|
||||
assert (
|
||||
session_payload["session_data"]["status"] == created["session_data"]["status"]
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue