mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
test(backend): add deterministic chat fake LLM
This commit is contained in:
parent
8536bac29a
commit
55c33ca1c8
3 changed files with 170 additions and 7 deletions
99
surfsense_backend/tests/e2e/fakes/chat_llm.py
Normal file
99
surfsense_backend/tests/e2e/fakes/chat_llm.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any, Self
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
DRIVE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DRIVE_001"
|
||||||
|
NO_RELEVANT_CONTENT_SENTINEL = "No relevant indexed content found."
|
||||||
|
NO_RELEVANT_CONTENT_QUERY = "E2E_NO_RELEVANT_CONTENT_SMOKE"
|
||||||
|
|
||||||
|
|
||||||
|
def _content_to_text(content: Any) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
return " ".join(_content_to_text(item) for item in content)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
text = content.get("text") or content.get("content")
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
return json.dumps(content, sort_keys=True)
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_text(messages: list[BaseMessage]) -> str:
|
||||||
|
return "\n".join(_content_to_text(message.content) for message in messages)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatLLM(BaseChatModel):
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "e2e-fake-chat"
|
||||||
|
|
||||||
|
def bind_tools(self, tools: Any, **kwargs: Any) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _response_for(self, messages: list[BaseMessage]) -> str:
|
||||||
|
latest_human = next(
|
||||||
|
(
|
||||||
|
_content_to_text(message.content)
|
||||||
|
for message in reversed(messages)
|
||||||
|
if message.type == "human"
|
||||||
|
),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
if NO_RELEVANT_CONTENT_QUERY in latest_human:
|
||||||
|
return NO_RELEVANT_CONTENT_SENTINEL
|
||||||
|
|
||||||
|
prompt_text = _messages_to_text(messages)
|
||||||
|
if (
|
||||||
|
"e2e-canary" in prompt_text
|
||||||
|
or "fake-file-canary" in prompt_text
|
||||||
|
or DRIVE_CANARY_TOKEN in prompt_text
|
||||||
|
):
|
||||||
|
return f"Drive content found: {DRIVE_CANARY_TOKEN}"
|
||||||
|
return NO_RELEVANT_CONTENT_SENTINEL
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
del stop, run_manager, kwargs
|
||||||
|
message = AIMessage(content=self._response_for(messages), tool_calls=[])
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
del stop, run_manager, kwargs
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=self._response_for(messages))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fake_create_chat_litellm_from_agent_config(*args: Any, **kwargs: Any) -> FakeChatLLM:
|
||||||
|
del args, kwargs
|
||||||
|
return FakeChatLLM()
|
||||||
|
|
||||||
|
|
||||||
|
def fake_create_chat_litellm_from_config(*args: Any, **kwargs: Any) -> FakeChatLLM:
|
||||||
|
del args, kwargs
|
||||||
|
return FakeChatLLM()
|
||||||
|
|
@ -80,13 +80,17 @@ from unittest.mock import patch # noqa: E402
|
||||||
|
|
||||||
from app.app import app # noqa: E402
|
from app.app import app # noqa: E402
|
||||||
from tests.e2e.fakes import embeddings as _fake_embeddings # noqa: E402
|
from tests.e2e.fakes import embeddings as _fake_embeddings # noqa: E402
|
||||||
|
from tests.e2e.fakes.chat_llm import ( # noqa: E402
|
||||||
|
fake_create_chat_litellm_from_agent_config,
|
||||||
|
fake_create_chat_litellm_from_config,
|
||||||
|
)
|
||||||
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
|
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
|
||||||
|
|
||||||
_active_patches: list = []
|
_active_patches: list = []
|
||||||
|
|
||||||
|
|
||||||
def _patch_llm_bindings() -> None:
|
def _patch_llm_bindings() -> None:
|
||||||
"""Replace get_user_long_context_llm at every known binding site."""
|
"""Replace LLM factories at every known binding site."""
|
||||||
targets = [
|
targets = [
|
||||||
"app.services.llm_service.get_user_long_context_llm",
|
"app.services.llm_service.get_user_long_context_llm",
|
||||||
"app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm",
|
"app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm",
|
||||||
|
|
@ -111,6 +115,33 @@ def _patch_llm_bindings() -> None:
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chat_targets = [
|
||||||
|
(
|
||||||
|
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
|
||||||
|
fake_create_chat_litellm_from_agent_config,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
|
||||||
|
fake_create_chat_litellm_from_config,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
|
||||||
|
fake_create_chat_litellm_from_agent_config,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
|
||||||
|
fake_create_chat_litellm_from_config,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for target, replacement in chat_targets:
|
||||||
|
try:
|
||||||
|
p = patch(target, replacement)
|
||||||
|
p.start()
|
||||||
|
_active_patches.append(p)
|
||||||
|
logger.info("[fake-chat-llm] patched %s", target)
|
||||||
|
except (ModuleNotFoundError, AttributeError) as exc:
|
||||||
|
logger.warning("[fake-chat-llm] could not patch %s: %s.", target, exc)
|
||||||
|
|
||||||
|
|
||||||
_patch_llm_bindings()
|
_patch_llm_bindings()
|
||||||
_fake_embeddings.install(_active_patches)
|
_fake_embeddings.install(_active_patches)
|
||||||
|
|
@ -129,8 +160,9 @@ app.add_middleware(ScenarioMiddleware)
|
||||||
# 6) Start uvicorn, mirroring main.py's behaviour.
|
# 6) Start uvicorn, mirroring main.py's behaviour.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
import asyncio
|
import asyncio # noqa: E402
|
||||||
import uvicorn
|
|
||||||
|
import uvicorn # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def _main() -> None:
|
def _main() -> None:
|
||||||
|
|
|
||||||
|
|
@ -58,16 +58,17 @@ logger.warning(
|
||||||
# 3) Import the production celery_app. All task modules load here.
|
# 3) Import the production celery_app. All task modules load here.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from app.celery_app import celery_app # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 4) Patch LLM + embedding bindings inside the worker process.
|
# 4) Patch LLM + embedding bindings inside the worker process.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from unittest.mock import patch # noqa: E402
|
from unittest.mock import patch # noqa: E402
|
||||||
|
|
||||||
|
from app.celery_app import celery_app # noqa: E402
|
||||||
from tests.e2e.fakes import embeddings as _fake_embeddings # noqa: E402
|
from tests.e2e.fakes import embeddings as _fake_embeddings # noqa: E402
|
||||||
|
from tests.e2e.fakes.chat_llm import ( # noqa: E402
|
||||||
|
fake_create_chat_litellm_from_agent_config,
|
||||||
|
fake_create_chat_litellm_from_config,
|
||||||
|
)
|
||||||
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
|
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
|
||||||
|
|
||||||
_active_patches: list = []
|
_active_patches: list = []
|
||||||
|
|
@ -94,6 +95,37 @@ def _patch_llm_bindings() -> None:
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chat_targets = [
|
||||||
|
(
|
||||||
|
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
|
||||||
|
fake_create_chat_litellm_from_agent_config,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
|
||||||
|
fake_create_chat_litellm_from_config,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
|
||||||
|
fake_create_chat_litellm_from_agent_config,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
|
||||||
|
fake_create_chat_litellm_from_config,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for target, replacement in chat_targets:
|
||||||
|
try:
|
||||||
|
p = patch(target, replacement)
|
||||||
|
p.start()
|
||||||
|
_active_patches.append(p)
|
||||||
|
logger.info("[fake-chat-llm] patched %s in celery worker", target)
|
||||||
|
except (ModuleNotFoundError, AttributeError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"[fake-chat-llm] could not patch %s in celery worker: %s.",
|
||||||
|
target,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_patch_llm_bindings()
|
_patch_llm_bindings()
|
||||||
_fake_embeddings.install(_active_patches)
|
_fake_embeddings.install(_active_patches)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue