mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-04 22:02:16 +02:00
Merge upstream/main into fix/chat-citations
This commit is contained in:
commit
1cb7633920
8 changed files with 191 additions and 19 deletions
|
|
@ -130,7 +130,9 @@ async def load_llm_bundle(
|
||||||
billing_tier="free",
|
billing_tier="free",
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
SanitizedChatLiteLLM(
|
||||||
|
model=model_string, **{**litellm_kwargs, "streaming": True}
|
||||||
|
),
|
||||||
agent_config,
|
agent_config,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
@ -174,7 +176,9 @@ async def load_llm_bundle(
|
||||||
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
SanitizedChatLiteLLM(
|
||||||
|
model=model_string, **{**litellm_kwargs, "streaming": True}
|
||||||
|
),
|
||||||
agent_config,
|
agent_config,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
"""Contracts for chat LLM construction in streaming flows.
|
||||||
|
|
||||||
|
``stream_new_chat`` / ``stream_resume_chat`` depend on LangChain receiving
|
||||||
|
token chunks from ``ChatLiteLLM``. ``langchain-litellm`` defaults
|
||||||
|
``streaming`` to ``False``, so the shared bundle loader must opt in
|
||||||
|
explicitly for both DB-backed and global model paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import app.tasks.chat.streaming.flows.shared.llm_bundle as llm_bundle
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _CapturedChatLiteLLM:
|
||||||
|
calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.__class__.calls.append(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Keep these tests focused on the LLM constructor contract."""
|
||||||
|
|
||||||
|
_CapturedChatLiteLLM.calls = []
|
||||||
|
|
||||||
|
async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(id=42, user_id="user-1")
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space)
|
||||||
|
monkeypatch.setattr(llm_bundle, "SanitizedChatLiteLLM", _CapturedChatLiteLLM)
|
||||||
|
monkeypatch.setattr(llm_bundle, "register_model_usage_metadata", lambda **_kw: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
llm_bundle,
|
||||||
|
"has_capability",
|
||||||
|
lambda _model, capability: capability in {"chat", "vision"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_load_llm_bundle_enables_streaming_for_db_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
connection = SimpleNamespace(
|
||||||
|
provider="openai",
|
||||||
|
api_key="sk-test",
|
||||||
|
base_url=None,
|
||||||
|
extra={"litellm_params": {"temperature": 0.1}},
|
||||||
|
)
|
||||||
|
model = SimpleNamespace(
|
||||||
|
id=7,
|
||||||
|
model_id="gpt-4o-mini",
|
||||||
|
display_name="GPT 4o Mini",
|
||||||
|
connection=connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_db_model(_session: Any, *, model_id: int, search_space: Any) -> Any:
|
||||||
|
assert model_id == 7
|
||||||
|
assert search_space.id == 42
|
||||||
|
return model
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_bundle, "_load_db_model", _fake_db_model)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
llm_bundle,
|
||||||
|
"to_litellm",
|
||||||
|
lambda _conn, _model_id: (
|
||||||
|
"openai/gpt-4o-mini",
|
||||||
|
{"api_key": "sk-test", "temperature": 0.1},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm, agent_config, error = await llm_bundle.load_llm_bundle(
|
||||||
|
object(),
|
||||||
|
config_id=7,
|
||||||
|
search_space_id=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
assert llm is not None
|
||||||
|
assert agent_config is not None
|
||||||
|
assert _CapturedChatLiteLLM.calls == [
|
||||||
|
{
|
||||||
|
"model": "openai/gpt-4o-mini",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"temperature": 0.1,
|
||||||
|
"streaming": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_load_llm_bundle_enables_streaming_for_global_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
global_model = {
|
||||||
|
"id": -11,
|
||||||
|
"connection_id": -101,
|
||||||
|
"model_id": "claude-sonnet-4-5",
|
||||||
|
"display_name": "Claude Sonnet",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
}
|
||||||
|
global_connection = {
|
||||||
|
"id": -101,
|
||||||
|
"provider": "anthropic",
|
||||||
|
"api_key": "sk-ant-test",
|
||||||
|
"base_url": None,
|
||||||
|
"extra": {"litellm_params": {"temperature": 0.2}},
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
llm_bundle.config,
|
||||||
|
"GLOBAL_MODELS",
|
||||||
|
[global_model],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
llm_bundle.config,
|
||||||
|
"GLOBAL_CONNECTIONS",
|
||||||
|
[global_connection],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
llm_bundle,
|
||||||
|
"to_litellm",
|
||||||
|
lambda _conn, _model_id: (
|
||||||
|
"anthropic/claude-sonnet-4-5",
|
||||||
|
{"api_key": "sk-ant-test", "temperature": 0.2},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm, agent_config, error = await llm_bundle.load_llm_bundle(
|
||||||
|
object(),
|
||||||
|
config_id=-11,
|
||||||
|
search_space_id=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
assert llm is not None
|
||||||
|
assert agent_config is not None
|
||||||
|
assert _CapturedChatLiteLLM.calls == [
|
||||||
|
{
|
||||||
|
"model": "anthropic/claude-sonnet-4-5",
|
||||||
|
"api_key": "sk-ant-test",
|
||||||
|
"temperature": 0.2,
|
||||||
|
"streaming": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
@ -27,8 +27,8 @@ export interface ChatViewportProps {
|
||||||
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
||||||
<ThreadPrimitive.Viewport
|
<ThreadPrimitive.Viewport
|
||||||
turnAnchor="top"
|
turnAnchor="top"
|
||||||
autoScroll={false}
|
autoScroll
|
||||||
scrollToBottomOnRunStart={false}
|
scrollToBottomOnRunStart
|
||||||
scrollToBottomOnInitialize
|
scrollToBottomOnInitialize
|
||||||
scrollToBottomOnThreadSwitch
|
scrollToBottomOnThreadSwitch
|
||||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ const MarkdownTextImpl = () => {
|
||||||
return (
|
return (
|
||||||
<CitationUrlMapContext.Provider value={urlMapRef}>
|
<CitationUrlMapContext.Provider value={urlMapRef}>
|
||||||
<MarkdownTextPrimitive
|
<MarkdownTextPrimitive
|
||||||
smooth={false}
|
smooth
|
||||||
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
||||||
rehypePlugins={[rehypeKatex]}
|
rehypePlugins={[rehypeKatex]}
|
||||||
className="aui-md"
|
className="aui-md"
|
||||||
|
|
|
||||||
|
|
@ -1577,7 +1577,7 @@ const ComposerAction: FC<ComposerActionProps> = ({
|
||||||
<span>Select a model</span>
|
<span>Select a model</span>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div className="flex items-center gap-2">
|
<div className="ml-auto flex min-w-0 shrink-0 items-center gap-2">
|
||||||
<ChatHeader
|
<ChatHeader
|
||||||
searchSpaceId={searchSpaceId}
|
searchSpaceId={searchSpaceId}
|
||||||
className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3"
|
className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3"
|
||||||
|
|
@ -1600,7 +1600,7 @@ const ComposerAction: FC<ComposerActionProps> = ({
|
||||||
variant="default"
|
variant="default"
|
||||||
size="icon"
|
size="icon"
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-composer-send size-9 rounded-full",
|
"aui-composer-send size-9 shrink-0 rounded-full",
|
||||||
isSendDisabled && "cursor-not-allowed opacity-50"
|
isSendDisabled && "cursor-not-allowed opacity-50"
|
||||||
)}
|
)}
|
||||||
aria-label="Send message"
|
aria-label="Send message"
|
||||||
|
|
@ -1617,7 +1617,7 @@ const ComposerAction: FC<ComposerActionProps> = ({
|
||||||
type="button"
|
type="button"
|
||||||
variant="default"
|
variant="default"
|
||||||
size="icon"
|
size="icon"
|
||||||
className="aui-composer-cancel size-9 rounded-full"
|
className="aui-composer-cancel size-9 shrink-0 rounded-full"
|
||||||
aria-label="Stop generating"
|
aria-label="Stop generating"
|
||||||
>
|
>
|
||||||
<SquareIcon className="aui-composer-cancel-icon size-3.5 fill-current" />
|
<SquareIcon className="aui-composer-cancel-icon size-3.5 fill-current" />
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,13 @@ interface ChatHeaderProps {
|
||||||
|
|
||||||
export function ChatHeader({ searchSpaceId, className, onChatModelSelected }: ChatHeaderProps) {
|
export function ChatHeader({ searchSpaceId, className, onChatModelSelected }: ChatHeaderProps) {
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex min-w-0 shrink-0 items-center gap-2">
|
||||||
<ModelSelector
|
<ModelSelector
|
||||||
searchSpaceId={searchSpaceId}
|
searchSpaceId={searchSpaceId}
|
||||||
className={className}
|
className={className}
|
||||||
onChatModelSelected={onChatModelSelected}
|
onChatModelSelected={onChatModelSelected}
|
||||||
/>
|
/>
|
||||||
<ImageModelSelector searchSpaceId={searchSpaceId} className={className} />
|
<ImageModelSelector searchSpaceId={searchSpaceId} className={className} mobileIconOnly />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ import { providerDisplay } from "../settings/model-connections/provider-metadata
|
||||||
interface ImageModelSelectorProps {
|
interface ImageModelSelectorProps {
|
||||||
searchSpaceId: number;
|
searchSpaceId: number;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
mobileIconOnly?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageModel = ModelRead & {
|
type ImageModel = ModelRead & {
|
||||||
|
|
@ -95,7 +96,11 @@ function groupedModels(models: ImageModel[]) {
|
||||||
}, {});
|
}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ImageModelSelector({ searchSpaceId, className }: ImageModelSelectorProps) {
|
export function ImageModelSelector({
|
||||||
|
searchSpaceId,
|
||||||
|
className,
|
||||||
|
mobileIconOnly = false,
|
||||||
|
}: ImageModelSelectorProps) {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const isMobile = useIsMobile();
|
const isMobile = useIsMobile();
|
||||||
const [open, setOpen] = useState(false);
|
const [open, setOpen] = useState(false);
|
||||||
|
|
@ -126,6 +131,7 @@ export function ImageModelSelector({ searchSpaceId, className }: ImageModelSelec
|
||||||
const groups = useMemo(() => groupedModels(visibleImageModels), [visibleImageModels]);
|
const groups = useMemo(() => groupedModels(visibleImageModels), [visibleImageModels]);
|
||||||
const loading = globalLoading || connectionsLoading;
|
const loading = globalLoading || connectionsLoading;
|
||||||
const hasSearchQuery = search.trim().length > 0;
|
const hasSearchQuery = search.trim().length > 0;
|
||||||
|
const showIconOnlyTrigger = isMobile && mobileIconOnly;
|
||||||
|
|
||||||
function handleOpenChange(nextOpen: boolean) {
|
function handleOpenChange(nextOpen: boolean) {
|
||||||
if (!nextOpen) setSearch("");
|
if (!nextOpen) setSearch("");
|
||||||
|
|
@ -252,12 +258,14 @@ export function ImageModelSelector({ searchSpaceId, className }: ImageModelSelec
|
||||||
type="button"
|
type="button"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="sm"
|
size="sm"
|
||||||
|
aria-label="Select image model"
|
||||||
className={cn(
|
className={cn(
|
||||||
"h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors",
|
"h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors",
|
||||||
"select-none",
|
"select-none",
|
||||||
"hover:bg-foreground/10 hover:text-foreground",
|
"hover:bg-foreground/10 hover:text-foreground",
|
||||||
"data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground",
|
"data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground",
|
||||||
className
|
className,
|
||||||
|
showIconOnlyTrigger && "h-9 w-auto shrink-0 justify-center gap-1 px-2"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{selected ? (
|
{selected ? (
|
||||||
|
|
@ -265,9 +273,11 @@ export function ImageModelSelector({ searchSpaceId, className }: ImageModelSelec
|
||||||
) : (
|
) : (
|
||||||
<ImagePlus className="size-4 shrink-0" />
|
<ImagePlus className="size-4 shrink-0" />
|
||||||
)}
|
)}
|
||||||
<span className="min-w-0 flex-1 truncate text-sm">
|
{showIconOnlyTrigger ? null : (
|
||||||
{selected ? modelName(selected) : "Auto"}
|
<span className="min-w-0 flex-1 truncate text-sm">
|
||||||
</span>
|
{selected ? modelName(selected) : "Auto"}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
<ChevronDown className="h-3.5 w-3.5 shrink-0" />
|
<ChevronDown className="h-3.5 w-3.5 shrink-0" />
|
||||||
</Button>
|
</Button>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,7 @@ export function ModelSelector({
|
||||||
const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]);
|
const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]);
|
||||||
const loading = globalLoading || connectionsLoading;
|
const loading = globalLoading || connectionsLoading;
|
||||||
const hasSearchQuery = search.trim().length > 0;
|
const hasSearchQuery = search.trim().length > 0;
|
||||||
|
const showIconOnlyTrigger = isMobile;
|
||||||
|
|
||||||
function handleOpenChange(nextOpen: boolean) {
|
function handleOpenChange(nextOpen: boolean) {
|
||||||
if (!nextOpen) setSearch("");
|
if (!nextOpen) setSearch("");
|
||||||
|
|
@ -276,15 +277,18 @@ export function ModelSelector({
|
||||||
"select-none",
|
"select-none",
|
||||||
"hover:bg-foreground/10 hover:text-foreground",
|
"hover:bg-foreground/10 hover:text-foreground",
|
||||||
"data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground",
|
"data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground",
|
||||||
className
|
className,
|
||||||
|
showIconOnlyTrigger && "h-9 w-auto shrink-0 justify-center gap-1 px-2"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{selected
|
{selected
|
||||||
? getProviderIcon(selected.provider, { className: "size-4 shrink-0" })
|
? getProviderIcon(selected.provider, { className: "size-4 shrink-0" })
|
||||||
: getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })}
|
: getProviderIcon(AUTO_PROVIDER_ICON_KEY, { className: "size-4 shrink-0" })}
|
||||||
<span className="min-w-0 flex-1 truncate text-sm">
|
{showIconOnlyTrigger ? null : (
|
||||||
{selected ? modelName(selected) : "Auto"}
|
<span className="min-w-0 flex-1 truncate text-sm">
|
||||||
</span>
|
{selected ? modelName(selected) : "Auto"}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
<ChevronDown className="h-3.5 w-3.5 shrink-0" />
|
<ChevronDown className="h-3.5 w-3.5 shrink-0" />
|
||||||
</Button>
|
</Button>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue