mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
import api.services.workflow.text_chat_session_service as text_chat_session_service
|
|
from api.db.models import WorkflowRunTextSessionModel
|
|
from api.services.workflow.text_chat_session_service import (
|
|
TextChatSessionExecutionError,
|
|
TextChatTurnNotFoundError,
|
|
_reload_text_chat_session,
|
|
build_pending_text_chat_turn,
|
|
truncate_text_chat_future_turns,
|
|
validate_text_chat_turn_cursor,
|
|
)
|
|
|
|
|
|
def test_build_pending_text_chat_turn_sets_pending_shape():
|
|
turn = build_pending_text_chat_turn(user_text="Hello")
|
|
|
|
assert turn["id"].startswith("turn_")
|
|
assert turn["status"] == "pending"
|
|
assert turn["user_message"]["text"] == "Hello"
|
|
assert turn["assistant_message"] is None
|
|
assert turn["events"] == []
|
|
assert turn["usage"] == {}
|
|
|
|
|
|
def test_truncate_text_chat_future_turns_moves_rewound_branch_to_discarded_future():
|
|
session_data = {
|
|
"cursor_turn_id": "turn-2",
|
|
"turns": [
|
|
{"id": "turn-1"},
|
|
{"id": "turn-2"},
|
|
{"id": "turn-3"},
|
|
],
|
|
"discarded_future": [],
|
|
}
|
|
|
|
active_turns, discarded_future = truncate_text_chat_future_turns(session_data)
|
|
|
|
assert active_turns == [{"id": "turn-1"}, {"id": "turn-2"}]
|
|
assert discarded_future[0]["rewound_from_turn_id"] == "turn-2"
|
|
assert discarded_future[0]["turns"] == [{"id": "turn-3"}]
|
|
|
|
|
|
def test_validate_text_chat_turn_cursor_raises_for_missing_turn():
|
|
with pytest.raises(TextChatTurnNotFoundError):
|
|
validate_text_chat_turn_cursor(
|
|
{"turns": [{"id": "turn-1"}]},
|
|
"turn-404",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reload_text_chat_session_uses_run_id_to_resolve_organization(
|
|
monkeypatch,
|
|
):
|
|
reloaded_session = WorkflowRunTextSessionModel(workflow_run_id=123)
|
|
get_org_id = AsyncMock(return_value=77)
|
|
get_text_session = AsyncMock(return_value=reloaded_session)
|
|
|
|
monkeypatch.setattr(
|
|
text_chat_session_service.db_client,
|
|
"get_organization_id_by_workflow_run_id",
|
|
get_org_id,
|
|
)
|
|
monkeypatch.setattr(
|
|
text_chat_session_service.db_client,
|
|
"get_workflow_run_text_session",
|
|
get_text_session,
|
|
)
|
|
|
|
result = await _reload_text_chat_session(123)
|
|
|
|
assert result is reloaded_session
|
|
get_org_id.assert_awaited_once_with(123)
|
|
get_text_session.assert_awaited_once_with(123, organization_id=77)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reload_text_chat_session_raises_when_run_organization_is_missing(
|
|
monkeypatch,
|
|
):
|
|
monkeypatch.setattr(
|
|
text_chat_session_service.db_client,
|
|
"get_organization_id_by_workflow_run_id",
|
|
AsyncMock(return_value=None),
|
|
)
|
|
|
|
with pytest.raises(TextChatSessionExecutionError, match="organization not found"):
|
|
await _reload_text_chat_session(123)
|