mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-18 21:15:16 +02:00
Merge pull request #1515 from AnishSarkar22/hotfix/streaming
hotfix(chat): Chat answer streaming and smooth markdown rendering
This commit is contained in:
commit
03e57bdf7e
4 changed files with 163 additions and 5 deletions
|
|
@ -130,7 +130,9 @@ async def load_llm_bundle(
|
|||
billing_tier="free",
|
||||
)
|
||||
return (
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
SanitizedChatLiteLLM(
|
||||
model=model_string, **{**litellm_kwargs, "streaming": True}
|
||||
),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
|
@ -174,7 +176,9 @@ async def load_llm_bundle(
|
|||
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
||||
)
|
||||
return (
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
SanitizedChatLiteLLM(
|
||||
model=model_string, **{**litellm_kwargs, "streaming": True}
|
||||
),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,154 @@
|
|||
"""Contracts for chat LLM construction in streaming flows.
|
||||
|
||||
``stream_new_chat`` / ``stream_resume_chat`` depend on LangChain receiving
|
||||
token chunks from ``ChatLiteLLM``. ``langchain-litellm`` defaults
|
||||
``streaming`` to ``False``, so the shared bundle loader must opt in
|
||||
explicitly for both DB-backed and global model paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.chat.streaming.flows.shared.llm_bundle as llm_bundle
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _CapturedChatLiteLLM:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.calls.append(kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Keep these tests focused on the LLM constructor contract."""
|
||||
|
||||
_CapturedChatLiteLLM.calls = []
|
||||
|
||||
async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace:
|
||||
return SimpleNamespace(id=42, user_id="user-1")
|
||||
|
||||
monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space)
|
||||
monkeypatch.setattr(llm_bundle, "SanitizedChatLiteLLM", _CapturedChatLiteLLM)
|
||||
monkeypatch.setattr(llm_bundle, "register_model_usage_metadata", lambda **_kw: None)
|
||||
monkeypatch.setattr(
|
||||
llm_bundle,
|
||||
"has_capability",
|
||||
lambda _model, capability: capability in {"chat", "vision"},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def test_load_llm_bundle_enables_streaming_for_db_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
connection = SimpleNamespace(
|
||||
provider="openai",
|
||||
api_key="sk-test",
|
||||
base_url=None,
|
||||
extra={"litellm_params": {"temperature": 0.1}},
|
||||
)
|
||||
model = SimpleNamespace(
|
||||
id=7,
|
||||
model_id="gpt-4o-mini",
|
||||
display_name="GPT 4o Mini",
|
||||
connection=connection,
|
||||
)
|
||||
|
||||
async def _fake_db_model(_session: Any, *, model_id: int, search_space: Any) -> Any:
|
||||
assert model_id == 7
|
||||
assert search_space.id == 42
|
||||
return model
|
||||
|
||||
monkeypatch.setattr(llm_bundle, "_load_db_model", _fake_db_model)
|
||||
monkeypatch.setattr(
|
||||
llm_bundle,
|
||||
"to_litellm",
|
||||
lambda _conn, _model_id: (
|
||||
"openai/gpt-4o-mini",
|
||||
{"api_key": "sk-test", "temperature": 0.1},
|
||||
),
|
||||
)
|
||||
|
||||
llm, agent_config, error = await llm_bundle.load_llm_bundle(
|
||||
object(),
|
||||
config_id=7,
|
||||
search_space_id=42,
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert llm is not None
|
||||
assert agent_config is not None
|
||||
assert _CapturedChatLiteLLM.calls == [
|
||||
{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"api_key": "sk-test",
|
||||
"temperature": 0.1,
|
||||
"streaming": True,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_load_llm_bundle_enables_streaming_for_global_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
global_model = {
|
||||
"id": -11,
|
||||
"connection_id": -101,
|
||||
"model_id": "claude-sonnet-4-5",
|
||||
"display_name": "Claude Sonnet",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
global_connection = {
|
||||
"id": -101,
|
||||
"provider": "anthropic",
|
||||
"api_key": "sk-ant-test",
|
||||
"base_url": None,
|
||||
"extra": {"litellm_params": {"temperature": 0.2}},
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
llm_bundle.config,
|
||||
"GLOBAL_MODELS",
|
||||
[global_model],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
llm_bundle.config,
|
||||
"GLOBAL_CONNECTIONS",
|
||||
[global_connection],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
llm_bundle,
|
||||
"to_litellm",
|
||||
lambda _conn, _model_id: (
|
||||
"anthropic/claude-sonnet-4-5",
|
||||
{"api_key": "sk-ant-test", "temperature": 0.2},
|
||||
),
|
||||
)
|
||||
|
||||
llm, agent_config, error = await llm_bundle.load_llm_bundle(
|
||||
object(),
|
||||
config_id=-11,
|
||||
search_space_id=42,
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert llm is not None
|
||||
assert agent_config is not None
|
||||
assert _CapturedChatLiteLLM.calls == [
|
||||
{
|
||||
"model": "anthropic/claude-sonnet-4-5",
|
||||
"api_key": "sk-ant-test",
|
||||
"temperature": 0.2,
|
||||
"streaming": True,
|
||||
}
|
||||
]
|
||||
|
|
@ -27,8 +27,8 @@ export interface ChatViewportProps {
|
|||
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
autoScroll={false}
|
||||
scrollToBottomOnRunStart={false}
|
||||
autoScroll
|
||||
scrollToBottomOnRunStart
|
||||
scrollToBottomOnInitialize
|
||||
scrollToBottomOnThreadSwitch
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ const MarkdownTextImpl = () => {
|
|||
return (
|
||||
<CitationUrlMapContext.Provider value={urlMapRef}>
|
||||
<MarkdownTextPrimitive
|
||||
smooth={false}
|
||||
smooth
|
||||
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
||||
rehypePlugins={[rehypeKatex]}
|
||||
className="aui-md"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue