Merge pull request #471 from MODSetter/dev

feat: added top_k in chat Interface.
This commit is contained in:
Rohan Verma 2025-11-06 13:33:47 -08:00 committed by GitHub
commit a9c9e3fe86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 217 additions and 5 deletions

View file

@ -27,6 +27,7 @@ class Configuration:
search_mode: SearchMode
document_ids_to_add_in_context: list[int]
language: str | None = None
top_k: int = 10
@classmethod
def from_runnable_config(

View file

@ -1366,8 +1366,8 @@ async def handle_qna_workflow(
}
)
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
top_k = 5 if configuration.search_mode == SearchMode.DOCUMENTS else 20
# Use the top_k value from configuration
top_k = configuration.top_k
relevant_documents = []
user_selected_documents = []

View file

@ -24,6 +24,7 @@ from app.utils.validators import (
validate_research_mode,
validate_search_mode,
validate_search_space_id,
validate_top_k,
)
router = APIRouter()
@ -54,6 +55,7 @@ async def handle_chat_data(
request_data.get("document_ids_to_add_in_context")
)
search_mode_str = validate_search_mode(request_data.get("search_mode"))
top_k = validate_top_k(request_data.get("top_k"))
# print("RESQUEST DATA:", request_data)
# print("SELECTED CONNECTORS:", selected_connectors)
@ -123,6 +125,7 @@ async def handle_chat_data(
search_mode_str,
document_ids_to_add_in_context,
language,
top_k,
)
)

View file

@ -21,6 +21,7 @@ async def stream_connector_search_results(
search_mode_str: str,
document_ids_to_add_in_context: list[int],
language: str | None = None,
top_k: int = 10,
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
@ -56,6 +57,7 @@ async def stream_connector_search_results(
"search_mode": search_mode,
"document_ids_to_add_in_context": document_ids_to_add_in_context,
"language": language, # Add language to the configuration
"top_k": top_k, # Add top_k to the configuration
}
}
# print(f"Researcher configuration: {config['configurable']}") # Debug print

View file

@ -241,6 +241,60 @@ def validate_search_mode(search_mode: Any) -> str:
return normalized_mode
def validate_top_k(top_k: Any) -> int:
"""
Validate and convert top_k to integer.
Args:
top_k: The top_k value to validate
Returns:
int: Validated top_k value (defaults to 10 if None)
Raises:
HTTPException: If validation fails
"""
if top_k is None:
return 10 # Default value
if isinstance(top_k, bool):
raise HTTPException(
status_code=400, detail="top_k must be an integer, not a boolean"
)
if isinstance(top_k, int):
if top_k <= 0:
raise HTTPException(
status_code=400, detail="top_k must be a positive integer"
)
if top_k > 100:
raise HTTPException(status_code=400, detail="top_k must not exceed 100")
return top_k
if isinstance(top_k, str):
if not top_k.strip():
raise HTTPException(status_code=400, detail="top_k cannot be empty")
if not re.match(r"^[1-9]\d*$", top_k.strip()):
raise HTTPException(
status_code=400, detail="top_k must be a valid positive integer"
)
value = int(top_k.strip())
if value <= 0:
raise HTTPException(
status_code=400, detail="top_k must be a positive integer"
)
if value > 100:
raise HTTPException(status_code=400, detail="top_k must not exceed 100")
return value
raise HTTPException(
status_code=400,
detail="top_k must be an integer or string representation of an integer",
)
def validate_messages(messages: Any) -> list[dict]:
"""
Validate messages structure.

View file

@ -27,6 +27,8 @@ export default function ResearcherPage() {
setSelectedConnectors,
selectedDocuments,
setSelectedDocuments,
topK,
setTopK,
} = useChatState({
search_space_id: search_space_id as string,
chat_id: chatIdParam,
@ -66,6 +68,7 @@ export default function ResearcherPage() {
selectedConnectors: string[];
searchMode: "DOCUMENTS" | "CHUNKS";
researchMode: "QNA"; // Always QNA mode
topK: number;
}
const getChatStateStorageKey = (searchSpaceId: string, chatId: string) =>
@ -105,6 +108,7 @@ export default function ResearcherPage() {
research_mode: researchMode,
search_mode: searchMode,
document_ids_to_add_in_context: documentIds,
top_k: topK,
},
},
onError: (error) => {
@ -124,6 +128,7 @@ export default function ResearcherPage() {
selectedConnectors,
searchMode,
researchMode,
topK,
});
router.replace(`/dashboard/${search_space_id}/researcher/${newChatId}`);
}
@ -145,10 +150,18 @@ export default function ResearcherPage() {
setSelectedDocuments(restoredState.selectedDocuments);
setSelectedConnectors(restoredState.selectedConnectors);
setSearchMode(restoredState.searchMode);
setTopK(restoredState.topK);
// researchMode is always "QNA", no need to restore
}
}
}, [chatIdParam, search_space_id, setSelectedDocuments, setSelectedConnectors, setSearchMode]);
}, [
chatIdParam,
search_space_id,
setSelectedDocuments,
setSelectedConnectors,
setSearchMode,
setTopK,
]);
// Set all sources as default for new chats
useEffect(() => {
@ -234,6 +247,8 @@ export default function ResearcherPage() {
selectedConnectors={selectedConnectors}
searchMode={searchMode}
onSearchModeChange={setSearchMode}
topK={topK}
onTopKChange={setTopK}
/>
);
}

View file

@ -1,7 +1,7 @@
"use client";
import { ChatInput } from "@llamaindex/chat-ui";
import { Brain, Check, FolderOpen, Zap } from "lucide-react";
import { Brain, Check, FolderOpen, Minus, Plus, Zap } from "lucide-react";
import { useParams } from "next/navigation";
import React, { Suspense, useCallback, useState } from "react";
import { DocumentsDataTable } from "@/components/chat/DocumentsDataTable";
@ -15,6 +15,7 @@ import {
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
@ -22,6 +23,7 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import { useDocumentTypes } from "@/hooks/use-document-types";
import type { Document } from "@/hooks/use-documents";
@ -447,6 +449,119 @@ const SearchModeSelector = React.memo(
SearchModeSelector.displayName = "SearchModeSelector";
const TopKSelector = React.memo(
({ topK = 10, onTopKChange }: { topK?: number; onTopKChange?: (topK: number) => void }) => {
const MIN_VALUE = 1;
const MAX_VALUE = 100;
const handleIncrement = React.useCallback(() => {
if (topK < MAX_VALUE) {
onTopKChange?.(topK + 1);
}
}, [topK, onTopKChange]);
const handleDecrement = React.useCallback(() => {
if (topK > MIN_VALUE) {
onTopKChange?.(topK - 1);
}
}, [topK, onTopKChange]);
const handleInputChange = React.useCallback(
(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
// Allow empty input for editing
if (value === "") {
return;
}
const numValue = parseInt(value, 10);
if (!isNaN(numValue) && numValue >= MIN_VALUE && numValue <= MAX_VALUE) {
onTopKChange?.(numValue);
}
},
[onTopKChange]
);
const handleInputBlur = React.useCallback(
(e: React.FocusEvent<HTMLInputElement>) => {
const value = e.target.value;
if (value === "") {
// Reset to default if empty
onTopKChange?.(10);
return;
}
const numValue = parseInt(value, 10);
if (isNaN(numValue) || numValue < MIN_VALUE) {
onTopKChange?.(MIN_VALUE);
} else if (numValue > MAX_VALUE) {
onTopKChange?.(MAX_VALUE);
}
},
[onTopKChange]
);
return (
<TooltipProvider>
<Tooltip delayDuration={200}>
<TooltipTrigger asChild>
<div className="flex items-center h-8 border rounded-md bg-background hover:bg-accent/50 transition-colors">
<Button
type="button"
variant="ghost"
size="icon"
className="h-full w-7 rounded-l-md rounded-r-none hover:bg-accent border-r"
onClick={handleDecrement}
disabled={topK <= MIN_VALUE}
>
<Minus className="h-3.5 w-3.5" />
</Button>
<div className="flex flex-col items-center justify-center px-2 min-w-[60px]">
<Input
type="number"
value={topK}
onChange={handleInputChange}
onBlur={handleInputBlur}
min={MIN_VALUE}
max={MAX_VALUE}
className="h-5 w-full px-1 text-center text-sm font-semibold border-0 bg-transparent focus-visible:ring-0 focus-visible:ring-offset-0 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
/>
<span className="text-[10px] text-muted-foreground leading-none">Results</span>
</div>
<Button
type="button"
variant="ghost"
size="icon"
className="h-full w-7 rounded-r-md rounded-l-none hover:bg-accent border-l"
onClick={handleIncrement}
disabled={topK >= MAX_VALUE}
>
<Plus className="h-3.5 w-3.5" />
</Button>
</div>
</TooltipTrigger>
<TooltipContent side="top" className="max-w-xs">
<div className="space-y-2">
<p className="text-sm font-semibold">Results per Source</p>
<p className="text-xs text-muted-foreground leading-relaxed">
Control how many results to fetch from each data source. Set a higher number to get
more information, or a lower number for faster, more focused results.
</p>
<div className="flex items-center gap-2 text-xs text-muted-foreground pt-1 border-t">
<span>Recommended: 5-20</span>
<span></span>
<span>
Range: {MIN_VALUE}-{MAX_VALUE}
</span>
</div>
</div>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}
);
TopKSelector.displayName = "TopKSelector";
const LLMSelector = React.memo(() => {
const { search_space_id } = useParams();
const searchSpaceId = Number(search_space_id);
@ -604,6 +719,8 @@ const CustomChatInputOptions = React.memo(
selectedConnectors,
searchMode,
onSearchModeChange,
topK,
onTopKChange,
}: {
onDocumentSelectionChange?: (documents: Document[]) => void;
selectedDocuments?: Document[];
@ -611,6 +728,8 @@ const CustomChatInputOptions = React.memo(
selectedConnectors?: string[];
searchMode?: "DOCUMENTS" | "CHUNKS";
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
topK?: number;
onTopKChange?: (topK: number) => void;
}) => {
// Memoize the loading fallback to prevent recreation
const loadingFallback = React.useMemo(
@ -637,6 +756,8 @@ const CustomChatInputOptions = React.memo(
<div className="h-4 w-px bg-border hidden sm:block" />
<SearchModeSelector searchMode={searchMode} onSearchModeChange={onSearchModeChange} />
<div className="h-4 w-px bg-border hidden sm:block" />
<TopKSelector topK={topK} onTopKChange={onTopKChange} />
<div className="h-4 w-px bg-border hidden sm:block" />
<LLMSelector />
</div>
);
@ -653,6 +774,8 @@ export const ChatInputUI = React.memo(
selectedConnectors,
searchMode,
onSearchModeChange,
topK,
onTopKChange,
}: {
onDocumentSelectionChange?: (documents: Document[]) => void;
selectedDocuments?: Document[];
@ -660,6 +783,8 @@ export const ChatInputUI = React.memo(
selectedConnectors?: string[];
searchMode?: "DOCUMENTS" | "CHUNKS";
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
topK?: number;
onTopKChange?: (topK: number) => void;
}) => {
return (
<ChatInput>
@ -674,6 +799,8 @@ export const ChatInputUI = React.memo(
selectedConnectors={selectedConnectors}
searchMode={searchMode}
onSearchModeChange={onSearchModeChange}
topK={topK}
onTopKChange={onTopKChange}
/>
</ChatInput>
);

View file

@ -13,6 +13,8 @@ interface ChatInterfaceProps {
selectedConnectors?: string[];
searchMode?: "DOCUMENTS" | "CHUNKS";
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
topK?: number;
onTopKChange?: (topK: number) => void;
}
export default function ChatInterface({
@ -23,6 +25,8 @@ export default function ChatInterface({
selectedConnectors = [],
searchMode,
onSearchModeChange,
topK = 10,
onTopKChange,
}: ChatInterfaceProps) {
return (
<LlamaIndexChatSection handler={handler} className="flex h-full">
@ -36,6 +40,8 @@ export default function ChatInterface({
selectedConnectors={selectedConnectors}
searchMode={searchMode}
onSearchModeChange={onSearchModeChange}
topK={topK}
onTopKChange={onTopKChange}
/>
</div>
</div>

View file

@ -14,10 +14,11 @@ export function useChatState({ chat_id }: UseChatStateProps) {
const [currentChatId, setCurrentChatId] = useState<string | null>(chat_id || null);
// Chat configuration state
const [searchMode, setSearchMode] = useState<"DOCUMENTS" | "CHUNKS">("CHUNKS");
const [searchMode, setSearchMode] = useState<"DOCUMENTS" | "CHUNKS">("DOCUMENTS");
const [researchMode, setResearchMode] = useState<ResearchMode>("QNA");
const [selectedConnectors, setSelectedConnectors] = useState<string[]>([]);
const [selectedDocuments, setSelectedDocuments] = useState<Document[]>([]);
const [topK, setTopK] = useState<number>(5);
useEffect(() => {
const bearerToken = localStorage.getItem("surfsense_bearer_token");
@ -39,6 +40,8 @@ export function useChatState({ chat_id }: UseChatStateProps) {
setSelectedConnectors,
selectedDocuments,
setSelectedDocuments,
topK,
setTopK,
};
}

View file

@ -9,6 +9,7 @@
"build": "next build",
"start": "next start",
"lint": "next lint",
"format": "biome check --write ./ --max-diagnostics 500",
"debug": "cross-env NODE_OPTIONS=--inspect next dev --turbopack",
"debug:browser": "cross-env NODE_OPTIONS=--inspect next dev --turbopack",
"debug:server": "cross-env NODE_OPTIONS=--inspect=0.0.0.0:9229 next dev --turbopack",