diff --git a/surfsense_backend/tests/e2e/fakes/chat_llm.py b/surfsense_backend/tests/e2e/fakes/chat_llm.py new file mode 100644 index 000000000..eef4e61ad --- /dev/null +++ b/surfsense_backend/tests/e2e/fakes/chat_llm.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Any, Self + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + +DRIVE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DRIVE_001" +NO_RELEVANT_CONTENT_SENTINEL = "No relevant indexed content found." +NO_RELEVANT_CONTENT_QUERY = "E2E_NO_RELEVANT_CONTENT_SMOKE" + + +def _content_to_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + return " ".join(_content_to_text(item) for item in content) + if isinstance(content, dict): + text = content.get("text") or content.get("content") + if isinstance(text, str): + return text + return json.dumps(content, sort_keys=True) + if content is None: + return "" + return str(content) + + +def _messages_to_text(messages: list[BaseMessage]) -> str: + return "\n".join(_content_to_text(message.content) for message in messages) + + +class FakeChatLLM(BaseChatModel): + @property + def _llm_type(self) -> str: + return "e2e-fake-chat" + + def bind_tools(self, tools: Any, **kwargs: Any) -> Self: + return self + + def _response_for(self, messages: list[BaseMessage]) -> str: + latest_human = next( + ( + _content_to_text(message.content) + for message in reversed(messages) + if message.type == "human" + ), + "", + ) + if NO_RELEVANT_CONTENT_QUERY in latest_human: + return NO_RELEVANT_CONTENT_SENTINEL + + prompt_text = _messages_to_text(messages) + if ( + "e2e-canary" in prompt_text + or "fake-file-canary" in prompt_text + or DRIVE_CANARY_TOKEN in prompt_text + ): + return f"Drive content found: {DRIVE_CANARY_TOKEN}" + return NO_RELEVANT_CONTENT_SENTINEL + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + del stop, run_manager, kwargs + message = AIMessage(content=self._response_for(messages), tool_calls=[]) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _astream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + del stop, run_manager, kwargs + yield ChatGenerationChunk( + message=AIMessageChunk(content=self._response_for(messages)) + ) + + +def fake_create_chat_litellm_from_agent_config(*args: Any, **kwargs: Any) -> FakeChatLLM: + del args, kwargs + return FakeChatLLM() + + +def fake_create_chat_litellm_from_config(*args: Any, **kwargs: Any) -> FakeChatLLM: + del args, kwargs + return FakeChatLLM() diff --git a/surfsense_backend/tests/e2e/run_backend.py b/surfsense_backend/tests/e2e/run_backend.py index ba2737d66..5a0e109ff 100644 --- a/surfsense_backend/tests/e2e/run_backend.py +++ b/surfsense_backend/tests/e2e/run_backend.py @@ -80,13 +80,17 @@ from unittest.mock import patch # noqa: E402 from app.app import app # noqa: E402 from tests.e2e.fakes import embeddings as _fake_embeddings # noqa: E402 +from tests.e2e.fakes.chat_llm import ( # noqa: E402 + fake_create_chat_litellm_from_agent_config, + fake_create_chat_litellm_from_config, +) from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402 _active_patches: list = [] def _patch_llm_bindings() -> None: - """Replace get_user_long_context_llm at every known binding site.""" + """Replace LLM factories at every known binding site.""" targets = [ "app.services.llm_service.get_user_long_context_llm", "app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm", @@ -111,6 +115,33 @@ def _patch_llm_bindings() -> None: exc, ) + chat_targets = [ + ( + "app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config", + fake_create_chat_litellm_from_agent_config, + ), + ( + "app.agents.new_chat.llm_config.create_chat_litellm_from_config", + fake_create_chat_litellm_from_config, + ), + ( + "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config", + fake_create_chat_litellm_from_agent_config, + ), + ( + "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", + fake_create_chat_litellm_from_config, + ), + ] + for target, replacement in chat_targets: + try: + p = patch(target, replacement) + p.start() + _active_patches.append(p) + logger.info("[fake-chat-llm] patched %s", target) + except (ModuleNotFoundError, AttributeError) as exc: + logger.warning("[fake-chat-llm] could not patch %s: %s.", target, exc) + _patch_llm_bindings() _fake_embeddings.install(_active_patches) @@ -129,8 +160,9 @@ app.add_middleware(ScenarioMiddleware) # 6) Start uvicorn, mirroring main.py's behaviour. # --------------------------------------------------------------------------- -import asyncio -import uvicorn +import asyncio # noqa: E402 + +import uvicorn # noqa: E402 def _main() -> None: diff --git a/surfsense_backend/tests/e2e/run_celery.py b/surfsense_backend/tests/e2e/run_celery.py index c6e451a56..88c61dba5 100644 --- a/surfsense_backend/tests/e2e/run_celery.py +++ b/surfsense_backend/tests/e2e/run_celery.py @@ -58,16 +58,17 @@ logger.warning( # 3) Import the production celery_app. All task modules load here. # --------------------------------------------------------------------------- -from app.celery_app import celery_app # noqa: E402 - - # --------------------------------------------------------------------------- # 4) Patch LLM + embedding bindings inside the worker process. # --------------------------------------------------------------------------- - from unittest.mock import patch # noqa: E402 +from app.celery_app import celery_app # noqa: E402 from tests.e2e.fakes import embeddings as _fake_embeddings # noqa: E402 +from tests.e2e.fakes.chat_llm import ( # noqa: E402 + fake_create_chat_litellm_from_agent_config, + fake_create_chat_litellm_from_config, +) from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402 _active_patches: list = [] @@ -94,6 +95,37 @@ def _patch_llm_bindings() -> None: exc, ) + chat_targets = [ + ( + "app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config", + fake_create_chat_litellm_from_agent_config, + ), + ( + "app.agents.new_chat.llm_config.create_chat_litellm_from_config", + fake_create_chat_litellm_from_config, + ), + ( + "app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config", + fake_create_chat_litellm_from_agent_config, + ), + ( + "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", + fake_create_chat_litellm_from_config, + ), + ] + for target, replacement in chat_targets: + try: + p = patch(target, replacement) + p.start() + _active_patches.append(p) + logger.info("[fake-chat-llm] patched %s in celery worker", target) + except (ModuleNotFoundError, AttributeError) as exc: + logger.warning( + "[fake-chat-llm] could not patch %s in celery worker: %s.", + target, + exc, + ) + _patch_llm_bindings() _fake_embeddings.install(_active_patches)