mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 03:46:25 +02:00
Merge pull request #471 from MODSetter/dev
feat: added top_k in chat Interface.
This commit is contained in:
commit
a9c9e3fe86
10 changed files with 217 additions and 5 deletions
|
|
@ -27,6 +27,7 @@ class Configuration:
|
|||
search_mode: SearchMode
|
||||
document_ids_to_add_in_context: list[int]
|
||||
language: str | None = None
|
||||
top_k: int = 10
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
|
|
|||
|
|
@ -1366,8 +1366,8 @@ async def handle_qna_workflow(
|
|||
}
|
||||
)
|
||||
|
||||
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
|
||||
top_k = 5 if configuration.search_mode == SearchMode.DOCUMENTS else 20
|
||||
# Use the top_k value from configuration
|
||||
top_k = configuration.top_k
|
||||
|
||||
relevant_documents = []
|
||||
user_selected_documents = []
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from app.utils.validators import (
|
|||
validate_research_mode,
|
||||
validate_search_mode,
|
||||
validate_search_space_id,
|
||||
validate_top_k,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -54,6 +55,7 @@ async def handle_chat_data(
|
|||
request_data.get("document_ids_to_add_in_context")
|
||||
)
|
||||
search_mode_str = validate_search_mode(request_data.get("search_mode"))
|
||||
top_k = validate_top_k(request_data.get("top_k"))
|
||||
# print("RESQUEST DATA:", request_data)
|
||||
# print("SELECTED CONNECTORS:", selected_connectors)
|
||||
|
||||
|
|
@ -123,6 +125,7 @@ async def handle_chat_data(
|
|||
search_mode_str,
|
||||
document_ids_to_add_in_context,
|
||||
language,
|
||||
top_k,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ async def stream_connector_search_results(
|
|||
search_mode_str: str,
|
||||
document_ids_to_add_in_context: list[int],
|
||||
language: str | None = None,
|
||||
top_k: int = 10,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
|
@ -56,6 +57,7 @@ async def stream_connector_search_results(
|
|||
"search_mode": search_mode,
|
||||
"document_ids_to_add_in_context": document_ids_to_add_in_context,
|
||||
"language": language, # Add language to the configuration
|
||||
"top_k": top_k, # Add top_k to the configuration
|
||||
}
|
||||
}
|
||||
# print(f"Researcher configuration: {config['configurable']}") # Debug print
|
||||
|
|
|
|||
|
|
@ -241,6 +241,60 @@ def validate_search_mode(search_mode: Any) -> str:
|
|||
return normalized_mode
|
||||
|
||||
|
||||
def validate_top_k(top_k: Any) -> int:
|
||||
"""
|
||||
Validate and convert top_k to integer.
|
||||
|
||||
Args:
|
||||
top_k: The top_k value to validate
|
||||
|
||||
Returns:
|
||||
int: Validated top_k value (defaults to 10 if None)
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if top_k is None:
|
||||
return 10 # Default value
|
||||
|
||||
if isinstance(top_k, bool):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="top_k must be an integer, not a boolean"
|
||||
)
|
||||
|
||||
if isinstance(top_k, int):
|
||||
if top_k <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="top_k must be a positive integer"
|
||||
)
|
||||
if top_k > 100:
|
||||
raise HTTPException(status_code=400, detail="top_k must not exceed 100")
|
||||
return top_k
|
||||
|
||||
if isinstance(top_k, str):
|
||||
if not top_k.strip():
|
||||
raise HTTPException(status_code=400, detail="top_k cannot be empty")
|
||||
|
||||
if not re.match(r"^[1-9]\d*$", top_k.strip()):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="top_k must be a valid positive integer"
|
||||
)
|
||||
|
||||
value = int(top_k.strip())
|
||||
if value <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="top_k must be a positive integer"
|
||||
)
|
||||
if value > 100:
|
||||
raise HTTPException(status_code=400, detail="top_k must not exceed 100")
|
||||
return value
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="top_k must be an integer or string representation of an integer",
|
||||
)
|
||||
|
||||
|
||||
def validate_messages(messages: Any) -> list[dict]:
|
||||
"""
|
||||
Validate messages structure.
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ export default function ResearcherPage() {
|
|||
setSelectedConnectors,
|
||||
selectedDocuments,
|
||||
setSelectedDocuments,
|
||||
topK,
|
||||
setTopK,
|
||||
} = useChatState({
|
||||
search_space_id: search_space_id as string,
|
||||
chat_id: chatIdParam,
|
||||
|
|
@ -66,6 +68,7 @@ export default function ResearcherPage() {
|
|||
selectedConnectors: string[];
|
||||
searchMode: "DOCUMENTS" | "CHUNKS";
|
||||
researchMode: "QNA"; // Always QNA mode
|
||||
topK: number;
|
||||
}
|
||||
|
||||
const getChatStateStorageKey = (searchSpaceId: string, chatId: string) =>
|
||||
|
|
@ -105,6 +108,7 @@ export default function ResearcherPage() {
|
|||
research_mode: researchMode,
|
||||
search_mode: searchMode,
|
||||
document_ids_to_add_in_context: documentIds,
|
||||
top_k: topK,
|
||||
},
|
||||
},
|
||||
onError: (error) => {
|
||||
|
|
@ -124,6 +128,7 @@ export default function ResearcherPage() {
|
|||
selectedConnectors,
|
||||
searchMode,
|
||||
researchMode,
|
||||
topK,
|
||||
});
|
||||
router.replace(`/dashboard/${search_space_id}/researcher/${newChatId}`);
|
||||
}
|
||||
|
|
@ -145,10 +150,18 @@ export default function ResearcherPage() {
|
|||
setSelectedDocuments(restoredState.selectedDocuments);
|
||||
setSelectedConnectors(restoredState.selectedConnectors);
|
||||
setSearchMode(restoredState.searchMode);
|
||||
setTopK(restoredState.topK);
|
||||
// researchMode is always "QNA", no need to restore
|
||||
}
|
||||
}
|
||||
}, [chatIdParam, search_space_id, setSelectedDocuments, setSelectedConnectors, setSearchMode]);
|
||||
}, [
|
||||
chatIdParam,
|
||||
search_space_id,
|
||||
setSelectedDocuments,
|
||||
setSelectedConnectors,
|
||||
setSearchMode,
|
||||
setTopK,
|
||||
]);
|
||||
|
||||
// Set all sources as default for new chats
|
||||
useEffect(() => {
|
||||
|
|
@ -234,6 +247,8 @@ export default function ResearcherPage() {
|
|||
selectedConnectors={selectedConnectors}
|
||||
searchMode={searchMode}
|
||||
onSearchModeChange={setSearchMode}
|
||||
topK={topK}
|
||||
onTopKChange={setTopK}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { ChatInput } from "@llamaindex/chat-ui";
|
||||
import { Brain, Check, FolderOpen, Zap } from "lucide-react";
|
||||
import { Brain, Check, FolderOpen, Minus, Plus, Zap } from "lucide-react";
|
||||
import { useParams } from "next/navigation";
|
||||
import React, { Suspense, useCallback, useState } from "react";
|
||||
import { DocumentsDataTable } from "@/components/chat/DocumentsDataTable";
|
||||
|
|
@ -15,6 +15,7 @@ import {
|
|||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
|
|
@ -22,6 +23,7 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import { useDocumentTypes } from "@/hooks/use-document-types";
|
||||
import type { Document } from "@/hooks/use-documents";
|
||||
|
|
@ -447,6 +449,119 @@ const SearchModeSelector = React.memo(
|
|||
|
||||
SearchModeSelector.displayName = "SearchModeSelector";
|
||||
|
||||
const TopKSelector = React.memo(
|
||||
({ topK = 10, onTopKChange }: { topK?: number; onTopKChange?: (topK: number) => void }) => {
|
||||
const MIN_VALUE = 1;
|
||||
const MAX_VALUE = 100;
|
||||
|
||||
const handleIncrement = React.useCallback(() => {
|
||||
if (topK < MAX_VALUE) {
|
||||
onTopKChange?.(topK + 1);
|
||||
}
|
||||
}, [topK, onTopKChange]);
|
||||
|
||||
const handleDecrement = React.useCallback(() => {
|
||||
if (topK > MIN_VALUE) {
|
||||
onTopKChange?.(topK - 1);
|
||||
}
|
||||
}, [topK, onTopKChange]);
|
||||
|
||||
const handleInputChange = React.useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
// Allow empty input for editing
|
||||
if (value === "") {
|
||||
return;
|
||||
}
|
||||
const numValue = parseInt(value, 10);
|
||||
if (!isNaN(numValue) && numValue >= MIN_VALUE && numValue <= MAX_VALUE) {
|
||||
onTopKChange?.(numValue);
|
||||
}
|
||||
},
|
||||
[onTopKChange]
|
||||
);
|
||||
|
||||
const handleInputBlur = React.useCallback(
|
||||
(e: React.FocusEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
if (value === "") {
|
||||
// Reset to default if empty
|
||||
onTopKChange?.(10);
|
||||
return;
|
||||
}
|
||||
const numValue = parseInt(value, 10);
|
||||
if (isNaN(numValue) || numValue < MIN_VALUE) {
|
||||
onTopKChange?.(MIN_VALUE);
|
||||
} else if (numValue > MAX_VALUE) {
|
||||
onTopKChange?.(MAX_VALUE);
|
||||
}
|
||||
},
|
||||
[onTopKChange]
|
||||
);
|
||||
|
||||
return (
|
||||
<TooltipProvider>
|
||||
<Tooltip delayDuration={200}>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="flex items-center h-8 border rounded-md bg-background hover:bg-accent/50 transition-colors">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-full w-7 rounded-l-md rounded-r-none hover:bg-accent border-r"
|
||||
onClick={handleDecrement}
|
||||
disabled={topK <= MIN_VALUE}
|
||||
>
|
||||
<Minus className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
<div className="flex flex-col items-center justify-center px-2 min-w-[60px]">
|
||||
<Input
|
||||
type="number"
|
||||
value={topK}
|
||||
onChange={handleInputChange}
|
||||
onBlur={handleInputBlur}
|
||||
min={MIN_VALUE}
|
||||
max={MAX_VALUE}
|
||||
className="h-5 w-full px-1 text-center text-sm font-semibold border-0 bg-transparent focus-visible:ring-0 focus-visible:ring-offset-0 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
|
||||
/>
|
||||
<span className="text-[10px] text-muted-foreground leading-none">Results</span>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-full w-7 rounded-r-md rounded-l-none hover:bg-accent border-l"
|
||||
onClick={handleIncrement}
|
||||
disabled={topK >= MAX_VALUE}
|
||||
>
|
||||
<Plus className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" className="max-w-xs">
|
||||
<div className="space-y-2">
|
||||
<p className="text-sm font-semibold">Results per Source</p>
|
||||
<p className="text-xs text-muted-foreground leading-relaxed">
|
||||
Control how many results to fetch from each data source. Set a higher number to get
|
||||
more information, or a lower number for faster, more focused results.
|
||||
</p>
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground pt-1 border-t">
|
||||
<span>Recommended: 5-20</span>
|
||||
<span>•</span>
|
||||
<span>
|
||||
Range: {MIN_VALUE}-{MAX_VALUE}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
TopKSelector.displayName = "TopKSelector";
|
||||
|
||||
const LLMSelector = React.memo(() => {
|
||||
const { search_space_id } = useParams();
|
||||
const searchSpaceId = Number(search_space_id);
|
||||
|
|
@ -604,6 +719,8 @@ const CustomChatInputOptions = React.memo(
|
|||
selectedConnectors,
|
||||
searchMode,
|
||||
onSearchModeChange,
|
||||
topK,
|
||||
onTopKChange,
|
||||
}: {
|
||||
onDocumentSelectionChange?: (documents: Document[]) => void;
|
||||
selectedDocuments?: Document[];
|
||||
|
|
@ -611,6 +728,8 @@ const CustomChatInputOptions = React.memo(
|
|||
selectedConnectors?: string[];
|
||||
searchMode?: "DOCUMENTS" | "CHUNKS";
|
||||
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
|
||||
topK?: number;
|
||||
onTopKChange?: (topK: number) => void;
|
||||
}) => {
|
||||
// Memoize the loading fallback to prevent recreation
|
||||
const loadingFallback = React.useMemo(
|
||||
|
|
@ -637,6 +756,8 @@ const CustomChatInputOptions = React.memo(
|
|||
<div className="h-4 w-px bg-border hidden sm:block" />
|
||||
<SearchModeSelector searchMode={searchMode} onSearchModeChange={onSearchModeChange} />
|
||||
<div className="h-4 w-px bg-border hidden sm:block" />
|
||||
<TopKSelector topK={topK} onTopKChange={onTopKChange} />
|
||||
<div className="h-4 w-px bg-border hidden sm:block" />
|
||||
<LLMSelector />
|
||||
</div>
|
||||
);
|
||||
|
|
@ -653,6 +774,8 @@ export const ChatInputUI = React.memo(
|
|||
selectedConnectors,
|
||||
searchMode,
|
||||
onSearchModeChange,
|
||||
topK,
|
||||
onTopKChange,
|
||||
}: {
|
||||
onDocumentSelectionChange?: (documents: Document[]) => void;
|
||||
selectedDocuments?: Document[];
|
||||
|
|
@ -660,6 +783,8 @@ export const ChatInputUI = React.memo(
|
|||
selectedConnectors?: string[];
|
||||
searchMode?: "DOCUMENTS" | "CHUNKS";
|
||||
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
|
||||
topK?: number;
|
||||
onTopKChange?: (topK: number) => void;
|
||||
}) => {
|
||||
return (
|
||||
<ChatInput>
|
||||
|
|
@ -674,6 +799,8 @@ export const ChatInputUI = React.memo(
|
|||
selectedConnectors={selectedConnectors}
|
||||
searchMode={searchMode}
|
||||
onSearchModeChange={onSearchModeChange}
|
||||
topK={topK}
|
||||
onTopKChange={onTopKChange}
|
||||
/>
|
||||
</ChatInput>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ interface ChatInterfaceProps {
|
|||
selectedConnectors?: string[];
|
||||
searchMode?: "DOCUMENTS" | "CHUNKS";
|
||||
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
|
||||
topK?: number;
|
||||
onTopKChange?: (topK: number) => void;
|
||||
}
|
||||
|
||||
export default function ChatInterface({
|
||||
|
|
@ -23,6 +25,8 @@ export default function ChatInterface({
|
|||
selectedConnectors = [],
|
||||
searchMode,
|
||||
onSearchModeChange,
|
||||
topK = 10,
|
||||
onTopKChange,
|
||||
}: ChatInterfaceProps) {
|
||||
return (
|
||||
<LlamaIndexChatSection handler={handler} className="flex h-full">
|
||||
|
|
@ -36,6 +40,8 @@ export default function ChatInterface({
|
|||
selectedConnectors={selectedConnectors}
|
||||
searchMode={searchMode}
|
||||
onSearchModeChange={onSearchModeChange}
|
||||
topK={topK}
|
||||
onTopKChange={onTopKChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@ export function useChatState({ chat_id }: UseChatStateProps) {
|
|||
const [currentChatId, setCurrentChatId] = useState<string | null>(chat_id || null);
|
||||
|
||||
// Chat configuration state
|
||||
const [searchMode, setSearchMode] = useState<"DOCUMENTS" | "CHUNKS">("CHUNKS");
|
||||
const [searchMode, setSearchMode] = useState<"DOCUMENTS" | "CHUNKS">("DOCUMENTS");
|
||||
const [researchMode, setResearchMode] = useState<ResearchMode>("QNA");
|
||||
const [selectedConnectors, setSelectedConnectors] = useState<string[]>([]);
|
||||
const [selectedDocuments, setSelectedDocuments] = useState<Document[]>([]);
|
||||
const [topK, setTopK] = useState<number>(5);
|
||||
|
||||
useEffect(() => {
|
||||
const bearerToken = localStorage.getItem("surfsense_bearer_token");
|
||||
|
|
@ -39,6 +40,8 @@ export function useChatState({ chat_id }: UseChatStateProps) {
|
|||
setSelectedConnectors,
|
||||
selectedDocuments,
|
||||
setSelectedDocuments,
|
||||
topK,
|
||||
setTopK,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
"format": "biome check --write ./ --max-diagnostics 500",
|
||||
"debug": "cross-env NODE_OPTIONS=--inspect next dev --turbopack",
|
||||
"debug:browser": "cross-env NODE_OPTIONS=--inspect next dev --turbopack",
|
||||
"debug:server": "cross-env NODE_OPTIONS=--inspect=0.0.0.0:9229 next dev --turbopack",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue