mirror of
https://github.com/willchen96/mike.git
synced 2026-06-28 21:49:37 +02:00
Refactor ProjectPageParts and ProjectPageHeader components for improved loading states and skeleton UI. Update Modal and PageHeader components to support loading states. Enhance RenameableTitle for better caret positioning. Adjust DisplayWorkflowModal to utilize the new Modal component structure. Update WorkflowList to include loading indicators and improve sticky header behavior.
This commit is contained in:
parent
444d1d38e4
commit
1fa0554ea5
49 changed files with 3623 additions and 1587 deletions
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
|
|
@ -3,5 +3,6 @@ dist
|
|||
.env*
|
||||
!.env.example
|
||||
*.log
|
||||
*.raw-llm-stream.json
|
||||
logs/
|
||||
.DS_Store
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
-- Keep document version tombstones after deleting version file bytes.
|
||||
-- Deleted versions remain visible in history but are ignored by active-file
|
||||
-- lookups and cannot be opened/downloaded/replaced.
|
||||
|
||||
alter table public.document_versions
|
||||
alter column storage_path drop not null;
|
||||
|
||||
alter table public.document_versions
|
||||
add column if not exists deleted_at timestamptz,
|
||||
add column if not exists deleted_by uuid;
|
||||
|
||||
create index if not exists document_versions_active_document_id_idx
|
||||
on public.document_versions(document_id, created_at desc)
|
||||
where deleted_at is null;
|
||||
|
|
@ -20,6 +20,7 @@ create table if not exists public.user_profiles (
|
|||
tabular_model text not null default 'gemini-3-flash-preview',
|
||||
quote_model text,
|
||||
mfa_on_login boolean not null default false,
|
||||
legal_research_us boolean not null default true,
|
||||
created_at timestamptz not null default now(),
|
||||
updated_at timestamptz not null default now()
|
||||
);
|
||||
|
|
@ -119,7 +120,7 @@ create index if not exists idx_documents_project_folder
|
|||
create table if not exists public.document_versions (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
document_id uuid not null references public.documents(id) on delete cascade,
|
||||
storage_path text not null,
|
||||
storage_path text,
|
||||
pdf_storage_path text,
|
||||
source text not null default 'upload',
|
||||
version_number integer,
|
||||
|
|
@ -127,6 +128,8 @@ create table if not exists public.document_versions (
|
|||
file_type text,
|
||||
size_bytes integer,
|
||||
page_count integer,
|
||||
deleted_at timestamptz,
|
||||
deleted_by uuid,
|
||||
created_at timestamptz not null default now(),
|
||||
constraint document_versions_source_check
|
||||
check (source = any (array[
|
||||
|
|
@ -142,6 +145,10 @@ create table if not exists public.document_versions (
|
|||
create index if not exists document_versions_document_id_idx
|
||||
on public.document_versions(document_id, created_at desc);
|
||||
|
||||
create index if not exists document_versions_active_document_id_idx
|
||||
on public.document_versions(document_id, created_at desc)
|
||||
where deleted_at is null;
|
||||
|
||||
create index if not exists document_versions_doc_vnum_idx
|
||||
on public.document_versions(document_id, version_number);
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,10 @@ app.post("/chat/create", chatCreateLimiter);
|
|||
app.post("/chat/:chatId/generate-title", chatCreateLimiter);
|
||||
app.post("/single-documents", uploadLimiter);
|
||||
app.post("/single-documents/:documentId/versions", uploadLimiter);
|
||||
app.put(
|
||||
"/single-documents/:documentId/versions/:versionId/file",
|
||||
uploadLimiter,
|
||||
);
|
||||
app.post("/projects/:projectId/documents", uploadLimiter);
|
||||
app.get("/user/export", exportLimiter);
|
||||
app.get("/user/chats/export", exportLimiter);
|
||||
|
|
|
|||
|
|
@ -99,69 +99,79 @@ export type ChatMessage = {
|
|||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const SYSTEM_PROMPT = `You are Mike, an AI legal assistant that helps lawyers and legal professionals analyze documents, answer legal questions, and draft legal documents.
|
||||
const SYSTEM_PROMPT_BEFORE_RESEARCH = `You are Mike, an AI legal assistant for lawyers and legal professionals. Help analyze documents, answer legal questions, and draft legal documents.
|
||||
|
||||
TOOL BUDGET:
|
||||
You have at most 10 tool-use rounds in a single response. Use tools deliberately, batch independent tool calls in the same round where possible, and reserve enough room to produce a final answer. Do not spend the final tool round gathering more information unless you can answer without another tool call afterward.
|
||||
CORE RULES:
|
||||
- Be precise, professional, and evidence-aware.
|
||||
- Do not fabricate document content.
|
||||
- Use at most 10 tool-use rounds per response. Batch independent tool calls and leave room for the final answer.
|
||||
- If the user selects a workflow with [Workflow: <title> (id: <id>)], immediately call read_workflow with that id and follow the workflow before doing anything else.
|
||||
|
||||
DOCUMENT CITATION INSTRUCTIONS:
|
||||
When you reference specific content from an uploaded/generated document, place a numbered marker [1], [2], etc. inline in your prose at the point of reference.
|
||||
These numbered [N] markers and the <CITATIONS> block are for evidence passages that the UI can open. Uploaded/generated document citations use the document entry shape below. Research tools may define additional source-specific citation entry shapes in their own instructions.
|
||||
DOCUMENT CITATIONS:
|
||||
Use document citations only for verbatim evidence from uploaded or generated documents.
|
||||
|
||||
After your complete response, append a <CITATIONS> block containing a JSON array with one entry per marker:
|
||||
In prose, put sequential markers [1], [2], etc. exactly where the cited claim appears. The marker number is the citation "ref" value, not a page, footnote, section, or document number.
|
||||
|
||||
At the very end of the response, append:
|
||||
<CITATIONS>
|
||||
[
|
||||
{"ref": 1, "doc_id": "doc-0", "page": 3, "quote": "exact verbatim text from the document"},
|
||||
{"ref": 2, "doc_id": "doc-1", "page": "41-42", "quote": "Section 4.2 describes the procedure [[PAGE_BREAK]] in all material respects."}
|
||||
{"ref": 1, "doc_id": "doc-0", "quotes": [{"page": 3, "quote": "exact verbatim text"}]},
|
||||
{"ref": 2, "doc_id": "doc-1", "quotes": [{"page": "41-42", "quote": "text before page break [[PAGE_BREAK]] text after page break"}]}
|
||||
]
|
||||
</CITATIONS>
|
||||
|
||||
CRITICAL: The number inside the [N] marker in your prose is the "ref" value of a citation entry in the <CITATIONS> block — it is NOT a page number, footnote number, section number, or any other number that appears in the document. The marker [1] refers to the entry with "ref": 1 in the JSON block; [2] refers to "ref": 2; and so on. Refs are simple sequential integers you assign (1, 2, 3, …) in the order citations appear in your prose. Never use a page number or a document's own numbering as the marker number. Every [N] you write in prose MUST have a matching {"ref": N, ...} entry in the JSON block.
|
||||
|
||||
Rules:
|
||||
- Only cite text that appears verbatim in the provided documents
|
||||
- In every document <CITATIONS> entry, "doc_id" MUST be the exact chat-local document label you were given (for example "doc-0"). Never use a filename, document UUID, or any other identifier in "doc_id"
|
||||
- Prefer one citation entry per inline marker. If one marker needs multiple supporting passages, use a "quotes" array on that citation entry instead of inventing extra markers. Keep "quotes" arrays short: 1 quote by default, maximum 3.
|
||||
- For document citations, use this shape: {"ref": 1, "doc_id": "doc-0", "quotes": [{"page": 1, "quote": "exact verbatim text"}]}. For legacy compatibility you may also include top-level "page" and "quote" matching the first quote.
|
||||
- Keep quotes short (ideally ≤ 25 words) and narrowly scoped to the specific claim. Don't reuse one quote to support multiple different claims — give each claim its own quote in the citation entry
|
||||
- "page" refers to the sequential [Page N] marker in the text you were given (1-indexed from the first page). IGNORE any page numbers printed inside the document itself (footers, roman numerals, etc.)
|
||||
- For a single-page quote, set its "page" to an integer. If a quote is one continuous sentence that spans two pages, set "page" to "N-M" and insert [[PAGE_BREAK]] in the quote at the page break. Otherwise, use separate quote objects for text on different pages
|
||||
- Put the <CITATIONS> block at the very end of the response. Omit it entirely if there are no citations
|
||||
Citation rules:
|
||||
- Every [N] marker must have exactly one matching entry with "ref": N.
|
||||
- "doc_id" must be the exact chat-local label you were given, such as "doc-0". Never use a filename or document UUID in "doc_id".
|
||||
- Use one citation entry per marker. If one marker needs several passages, use "quotes" with 1 quote by default and at most 3.
|
||||
- Keep quotes short, ideally 25 words or fewer, and tightly matched to the claim.
|
||||
- "page" means the sequential [Page N] marker in the provided text, not printed page numbers inside the document.
|
||||
- For a continuous quote crossing two pages, set "page" to "N-M" and include [[PAGE_BREAK]] at the page break. Otherwise, use separate quote objects.
|
||||
- For legacy compatibility, you may also include top-level "page" and "quote" matching the first quote.
|
||||
- Omit the <CITATIONS> block when there are no citations.
|
||||
|
||||
DOCX GENERATION:
|
||||
If asked to draft or generate a document, use the generate_docx tool to produce a downloadable Word document. Always use this tool rather than just displaying the document content inline when the user asks for a document to be created.
|
||||
If the user follows up on a document you just generated and asks for changes (e.g. "make section 3 longer", "add a termination clause", "change the parties"), default to calling edit_document on that newly generated document. Do not call generate_docx again to regenerate the whole document unless the user explicitly asks for a brand-new document or the change is so sweeping that an edit would not be coherent.
|
||||
Heading hierarchy: always use Heading 1 before introducing Heading 2, Heading 2 before Heading 3, and so on. Never skip levels (e.g. do not jump from Heading 1 to Heading 3).
|
||||
Numbering: all numbering MUST start from 1, never 0. This applies at every level of the hierarchy. Legal clause numbering is applied automatically by the document generator: top-level operative headings render as 1., 2., 3.; the first numbered body clause under a top-level heading renders as 1.1; nested body clauses under that render as (a), (b), (c); deeper nested clauses render as (i), (ii), (iii), then (A), (B), (C). Do NOT use 1.1.1 for legal body clauses when (a) is the expected next level. Never produce 0., 0.1, 1.0, 1.0.1, or any other sequence that begins a level with 0.
|
||||
Never duplicate the numbering prefix in heading text. The heading's own numbering is applied automatically by the document generator, so the heading text must contain the title only. Do NOT prepend "1.", "1.1", "2.", etc. into the heading text itself. For example, a Heading 1 titled "Introduction" must be passed as "Introduction", never as "1. Introduction" (which would render as "1. 1. Introduction"). The same rule applies at every level.
|
||||
Do not repeat the document title as the first section heading. The document generator already renders the title as a centered title paragraph. Put any opening preamble text directly in the first section's content, without a duplicate heading such as "Agreement", "Contract", "Mutual Non-Disclosure Agreement", or another shortened form of the title.
|
||||
Contracts: when generating a contract or agreement, always include a signatures block at the very end of the document on its own page. Set pageBreak: true on that final section so it starts on a fresh page, and include a signature line for each party, typically the party name followed by lines for "By:", "Name:", "Title:", and "Date:". The entire signature block must be plain unnumbered text: do NOT number the signatures heading, do NOT number or letter the introductory signature sentence, party names, "By:", "Name:", "Title:", or "Date:" lines, and do NOT place the signature block inside a numbered clause. Put the signature block in the section's content rather than as a numbered heading.
|
||||
Contract preambles: the preamble of a contract (the opening recitals, parties block, "WHEREAS" clauses, and any introductory narrative before the first operative clause) must NOT be numbered. Render these as unnumbered content (plain paragraphs or an unnumbered heading), and begin numbering only at the first operative clause/section.
|
||||
- If the user asks you to create or draft a document, call generate_docx and provide the downloadable Word document rather than only displaying text inline.
|
||||
- If the user asks to revise a document you just generated, call edit_document on that document unless they explicitly want a brand-new document or the change is too broad for coherent editing.
|
||||
- Use heading levels in order; do not skip from Heading 1 to Heading 3.
|
||||
- Numbering starts at 1, never 0. The generator applies legal numbering automatically. Do not type numbering prefixes into headings.
|
||||
- Do not repeat the document title as the first section heading.
|
||||
- Contract preambles, party blocks, recitals, and WHEREAS clauses are unnumbered. Begin numbering at the first operative clause or section.
|
||||
- Contracts and agreements must end with an unnumbered signature block on a fresh page. Set pageBreak: true on the final section and include signature lines such as By, Name, Title, and Date for each party.
|
||||
|
||||
DOCUMENT EDITING:
|
||||
When using edit_document, any edit that adds, removes, or reorders a numbered clause, section, sub-clause, schedule, exhibit, or list item shifts every downstream number. You MUST update all affected numbering AND every cross-reference to those numbers in the same edit_document call:
|
||||
- Renumber the sibling clauses/sections/sub-clauses that follow the change so the sequence stays contiguous (e.g. if you insert a new Section 4, existing Sections 4, 5, 6… become 5, 6, 7…).
|
||||
- Find every in-document reference to the shifted numbers, e.g. "see Section 5", "pursuant to Clause 4.2(b)", "as set out in Schedule 3", "defined in Section 2.1", and update them to the new numbers. Include defined-term blocks, cross-references in recitals, schedules, and exhibits.
|
||||
- Before issuing the edits, scan the full document (use read_document or find_in_document) to enumerate affected cross-references; do not assume references only appear near the change site.
|
||||
- If you are uncertain whether a reference points to the shifted number or an unrelated number, err on the side of including it as an edit and explain in the reason field.
|
||||
- When deleting square brackets, delete both the opening \`[\` and the closing \`]\`. Never leave behind an unmatched square bracket after an edit.
|
||||
When edit_document adds, deletes, moves, or reorders any numbered clause, section, schedule, exhibit, or list item:
|
||||
- Renumber all affected downstream items in the same edit.
|
||||
- Update all affected cross-references, including references in recitals, definitions, schedules, and exhibits.
|
||||
- Before editing, scan the full document with read_document or find_in_document for affected references.
|
||||
- If a reference might point to a shifted number, include the update and explain the reason.
|
||||
- When deleting square brackets, delete both "[" and "]".`;
|
||||
|
||||
WORKFLOWS:
|
||||
When a user message begins with a [Workflow: <title> (id: <id>)] marker, the user has selected a workflow and you MUST apply it. Immediately call the read_workflow tool with that exact id to load the workflow's full prompt, then follow those instructions for the current turn. Do this before producing any other output or calling any other tools (aside from any document reads the workflow requires). Do not ask the user to confirm — the selection itself is the instruction to apply the workflow.
|
||||
|
||||
${COURTLISTENER_SYSTEM_PROMPT}
|
||||
DOCUMENT NAMING IN PROSE:
|
||||
The chat-local labels ("doc-0", "doc-1", "doc-N", …) are internal handles for tool calls and citation JSON ONLY. NEVER write them in your prose response or in any text the user reads — not in body text, not in headings, not in lists, not in tool-activity descriptions. The user does not know what "doc-0" means and seeing it is jarring. When referring to a document in prose, always use its filename (e.g. "the NDA draft" or "nda_v1.docx"). This rule applies to every word streamed back to the user; the only places "doc-N" identifiers are allowed are inside tool-call arguments and inside the <CITATIONS> JSON block's "doc_id" field.
|
||||
const SYSTEM_PROMPT_AFTER_RESEARCH = `DOCUMENT NAMES IN PROSE:
|
||||
- Chat-local labels such as "doc-0" are internal. Use them only in tool arguments and citation JSON.
|
||||
- Never show "doc-N" labels to the user in prose, headings, lists, or tool activity text.
|
||||
- Refer to documents by filename or a natural description, such as "the NDA draft".
|
||||
|
||||
GENERAL GUIDANCE:
|
||||
- Be precise and professional
|
||||
- Cite the specific document or fetched opinion passage when making evidence-backed claims. Use [N] markers only as described in the citation instructions above
|
||||
- When no documents are provided, answer based on your legal knowledge
|
||||
- Do not fabricate document content
|
||||
- Do not use emojis in your responses.
|
||||
- Cite the exact document or fetched opinion passage for evidence-backed claims.
|
||||
- If no documents are provided, answer from legal knowledge.
|
||||
- Do not use emojis.
|
||||
`;
|
||||
|
||||
/**
|
||||
* Assemble the chat system prompt. When `includeResearchTools` is true the
|
||||
* CourtListener (US case-law) research instructions are spliced in; when
|
||||
* false they are omitted entirely so the model is not told about tools it
|
||||
* does not have. Gated per-user by the Legal Research > US feature toggle.
|
||||
*/
|
||||
export function buildSystemPrompt(includeResearchTools = true): string {
|
||||
return includeResearchTools
|
||||
? `${SYSTEM_PROMPT_BEFORE_RESEARCH}\n\n${COURTLISTENER_SYSTEM_PROMPT}\n${SYSTEM_PROMPT_AFTER_RESEARCH}`
|
||||
: `${SYSTEM_PROMPT_BEFORE_RESEARCH}\n\n${SYSTEM_PROMPT_AFTER_RESEARCH}`;
|
||||
}
|
||||
|
||||
export const SYSTEM_PROMPT = buildSystemPrompt(true);
|
||||
|
||||
export const PROJECT_EXTRA_TOOLS = [
|
||||
{
|
||||
type: "function",
|
||||
|
|
@ -763,9 +773,10 @@ export function buildMessages(
|
|||
}[],
|
||||
systemPromptExtra?: string,
|
||||
docIndex?: DocIndex,
|
||||
includeResearchTools = true,
|
||||
) {
|
||||
const formatted: unknown[] = [];
|
||||
let systemContent = SYSTEM_PROMPT;
|
||||
let systemContent = buildSystemPrompt(includeResearchTools);
|
||||
|
||||
if (systemPromptExtra) {
|
||||
systemContent += `\n\n${systemPromptExtra.trim()}`;
|
||||
|
|
@ -3619,18 +3630,65 @@ const CITATIONS_BLOCK_RE = /<CITATIONS>\s*([\s\S]*?)\s*<\/CITATIONS>/;
|
|||
const CITATIONS_OPEN_TAG = "<CITATIONS>";
|
||||
const CITATIONS_CLOSE_TAG = "</CITATIONS>";
|
||||
|
||||
function parseCitations(text: string): ParsedCitation[] {
|
||||
type CitationParseDiagnostics = {
|
||||
hasBlock: boolean;
|
||||
rawLength: number;
|
||||
error: string | null;
|
||||
};
|
||||
|
||||
function parseCitationsWithDiagnostics(text: string): {
|
||||
citations: ParsedCitation[];
|
||||
diagnostics: CitationParseDiagnostics;
|
||||
} {
|
||||
const match = text.match(CITATIONS_BLOCK_RE);
|
||||
if (!match) return [];
|
||||
try {
|
||||
const raw = JSON.parse(match[1]);
|
||||
if (!Array.isArray(raw)) return [];
|
||||
return raw
|
||||
.map(normalizeCitation)
|
||||
.filter((c): c is ParsedCitation => c !== null);
|
||||
} catch {
|
||||
return [];
|
||||
if (!match) {
|
||||
return {
|
||||
citations: [],
|
||||
diagnostics: {
|
||||
hasBlock: false,
|
||||
rawLength: 0,
|
||||
error: null,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const raw = match[1] ?? "";
|
||||
try {
|
||||
const parsed = JSON.parse(raw);
|
||||
if (!Array.isArray(parsed)) {
|
||||
return {
|
||||
citations: [],
|
||||
diagnostics: {
|
||||
hasBlock: true,
|
||||
rawLength: raw.length,
|
||||
error: "CITATIONS block JSON was not an array.",
|
||||
},
|
||||
};
|
||||
}
|
||||
return {
|
||||
citations: parsed
|
||||
.map(normalizeCitation)
|
||||
.filter((c): c is ParsedCitation => c !== null),
|
||||
diagnostics: {
|
||||
hasBlock: true,
|
||||
rawLength: raw.length,
|
||||
error: null,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
citations: [],
|
||||
diagnostics: {
|
||||
hasBlock: true,
|
||||
rawLength: raw.length,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
function parseCitations(text: string): ParsedCitation[] {
|
||||
return parseCitationsWithDiagnostics(text).citations;
|
||||
}
|
||||
|
||||
function parsePartialCitationObjects(text: string): ParsedCitation[] {
|
||||
|
|
@ -4181,7 +4239,8 @@ export async function runLLMStream(params: {
|
|||
flushText();
|
||||
|
||||
// Parse and emit citations from <CITATIONS> block
|
||||
const parsedCitations = parseCitations(fullText);
|
||||
const { citations: parsedCitations, diagnostics: citationDiagnostics } =
|
||||
parseCitationsWithDiagnostics(fullText);
|
||||
const citations = buildCitations
|
||||
? buildCitations(fullText)
|
||||
: parsedCitations.map((c) =>
|
||||
|
|
@ -4191,6 +4250,14 @@ export async function runLLMStream(params: {
|
|||
courtlistenerTurnState.casesByClusterId,
|
||||
),
|
||||
);
|
||||
devLog("[chat/stream] final citation annotations", {
|
||||
hasCitationsBlock: citationDiagnostics.hasBlock,
|
||||
citationsBlockLength: citationDiagnostics.rawLength,
|
||||
parseError: citationDiagnostics.error,
|
||||
parsedCitationCount: parsedCitations.length,
|
||||
emittedAnnotationCount: citations.length,
|
||||
usedCustomCitationBuilder: !!buildCitations,
|
||||
});
|
||||
write(
|
||||
`data: ${JSON.stringify({ type: "citations", status: "final", citations })}\n\n`,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ export async function loadActiveVersion(
|
|||
"id, document_id, storage_path, pdf_storage_path, version_number, filename, source, file_type, size_bytes, page_count",
|
||||
)
|
||||
.eq("id", targetVersionId)
|
||||
.is("deleted_at", null)
|
||||
.single();
|
||||
if (!v || v.document_id !== documentId || !v.storage_path) return null;
|
||||
return {
|
||||
|
|
@ -111,7 +112,8 @@ export async function attachActiveVersionPaths<T extends VersionPathRow>(
|
|||
.select(
|
||||
"id, storage_path, pdf_storage_path, version_number, filename, file_type, size_bytes, page_count",
|
||||
)
|
||||
.in("id", versionIds);
|
||||
.in("id", versionIds)
|
||||
.is("deleted_at", null);
|
||||
const byId = new Map<
|
||||
string,
|
||||
{
|
||||
|
|
@ -174,6 +176,7 @@ export async function attachLatestVersionNumbers<T extends DocRow>(
|
|||
.select("document_id, version_number")
|
||||
.in("document_id", ids)
|
||||
.eq("source", "assistant_edit")
|
||||
.is("deleted_at", null)
|
||||
.not("version_number", "is", null);
|
||||
|
||||
const latestByDoc = new Map<string, number>();
|
||||
|
|
|
|||
|
|
@ -69,19 +69,25 @@ export const COURTLISTENER_TOOL_NAMES = {
|
|||
verifyCitations: "courtlistener_verify_citations",
|
||||
} as const;
|
||||
|
||||
export const COURTLISTENER_SYSTEM_PROMPT = `LEGAL RESEARCH QUERIES:
|
||||
- When a user asks a question on US law, you are required to cite relevant case law in your answer. Always verify US case citations using the courtlistener_verify_citations tool.
|
||||
- The courtlistener_verify_citations tool accepts only a citations array of clean reporter citations. Do not pass case names to this tool. Correct: {"citations":["467 U.S. 837","323 U.S. 134"]}. Incorrect: {"citations":["Chevron U.S.A. v. NRDC","Skidmore v. Swift"]}. If you only have case names and no reporter citations, do not call courtlistener_verify_citations for those names.
|
||||
- If any CourtListener tool call reports that a CourtListener rate limit was exceeded, or returns a 429/throttled/rate-limit error, do not make any further CourtListener API/search calls in that turn. Do not retry, verify more citations, fetch more cases, or run additional CourtListener searches; answer with the information already available and briefly state that CourtListener is rate limiting requests.
|
||||
- For cases you may cite or materially rely on, follow this sequence when reporter citations are available: first use courtlistener_verify_citations with clean reporter citations, then use courtlistener_get_cases to fetch/cache the relevant case clusters, then use courtlistener_find_in_case to search targeted keywords in the cached opinions, and only if those keyword snippets are insufficient use courtlistener_read_case to read selected opinion text.
|
||||
- Only cite cases whose underlying opinion text, or at least the specific relevant opinion passages, has been supplied to you in this turn. courtlistener_get_cases only fetches and caches opinions; it does NOT place full opinion text in your context. It returns text-free opinion metadata so you can choose which opinion(s) matter. After courtlistener_get_cases, use courtlistener_find_in_case for targeted keyword or phrase lookup inside that cached case. If those snippets are not enough, use courtlistener_read_case to read only the specific already-fetched opinion(s) you need. courtlistener_find_in_case and courtlistener_read_case require the case to have been fetched first.
|
||||
- When a fetched case has multiple opinions, do not read all opinions by default. Choose the specific opinion_id or opinion_ids needed from the metadata or search hits. Prefer the lead/majority/controlling opinion when it is sufficient; read concurrences, dissents, or combined opinions only when they are necessary for the user's question.
|
||||
- When using courtlistener_find_in_case, search for terms that are 1-3 words long and actually likely to appear exactly as written in the opinion text. Do not use long sentence-like phrases. Run courtlistener_find_in_case no more than 3 times in a single assistant turn; if those searches are insufficient, read the smallest needed opinion text with courtlistener_read_case or answer with the available information.
|
||||
- Do not cite a case based only on memory, search-result snippets, reporter metadata, citationLinks, or verification results. Those sources may help choose candidates, but final case citations must be grounded in supplied opinion text/passages.
|
||||
- Every case citation in final prose must be rendered as a clickable case-law panel link using the markdown link returned in citationLinks, e.g. [Case Name, Citation](us-case-12345). Do not write plain-text case citations without the link.
|
||||
- Use numbered [N] markers for case citations in the final prose and include each cited case in the final <CITATIONS> block.
|
||||
- Each case entry in the <CITATIONS> block must include quote(s) copied exactly from the supplied opinion text/passages for that case, e.g. {"ref": N, "cluster_id": 123, "quotes": [{"opinion_id": 456, "quote": "exact verbatim opinion text"}]}. Do not include top-level "quote", "doc_id", "page", "case_name", or "citation" for case entries.
|
||||
- If a case is useful but you do not have its opinion text or relevant passages, either fetch the opinions before citing it or say that you could not read the opinion and do not cite or characterize the case beyond basic metadata.`;
|
||||
export const COURTLISTENER_SYSTEM_PROMPT = `US CASE LAW RESEARCH:
|
||||
Use CourtListener when answering US-law questions that require case law.
|
||||
|
||||
Workflow:
|
||||
1. If you have reporter citations, verify them with courtlistener_verify_citations using only clean citations: {"citations":["467 U.S. 837","323 U.S. 134"]}. Never pass case names to this tool.
|
||||
2. Fetch matched clusters with courtlistener_get_cases.
|
||||
3. Get cite-worthy text from the fetched cases with courtlistener_find_in_case. Use short 1-3 word searches, maximum 3 searches per assistant turn.
|
||||
4. If snippets are not enough, read only the necessary opinion(s) with courtlistener_read_case. For multi-opinion cases, choose the specific opinion_id/opinionIds needed; do not read all opinions by default.
|
||||
|
||||
Citation rules:
|
||||
- Final case citations must be based on opinion text or passage snippets supplied in this turn. Do not cite cases based only on memory, metadata, search results, citationLinks, or verification results.
|
||||
- If you mention a CourtListener case as legal support in the final answer, cite it with both: (a) the clickable markdown link returned in citationLinks, and (b) an inline [N] marker. Include the clickable case link only the first time you cite that case; later references to the same case should use the existing inline [N] marker without repeating the link unless clarity requires it.
|
||||
- Assign new annotation refs in first-use order as much as possible: [1], then [2], then [3]. Reuse an existing ref when citing the same case/passage again, even if that means a later sentence cites [3] and then [1] again.
|
||||
- The final <CITATIONS> block must include one matching case entry for each [N] case marker: {"ref": N, "cluster_id": 123, "quotes": [{"opinion_id": 456, "quote": "exact verbatim opinion text"}]}.
|
||||
- Do not use doc_id, page, top-level quote, case_name, or citation fields in case entries.
|
||||
- If you have not obtained opinion text or snippets for a useful case, fetch/read it before citing it, or say you could not read it and do not rely on it.
|
||||
|
||||
Limits:
|
||||
- If any CourtListener call returns a rate-limit/throttling/429 error, stop all CourtListener calls for that turn and answer using only information already available.`;
|
||||
|
||||
export const COURTLISTENER_TOOLS = [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,274 +1,294 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type { Tool } from "@anthropic-ai/sdk/resources/messages/messages";
|
||||
import type {
|
||||
StreamChatParams,
|
||||
StreamChatResult,
|
||||
NormalizedToolCall,
|
||||
NormalizedToolResult,
|
||||
StreamChatParams,
|
||||
StreamChatResult,
|
||||
NormalizedToolCall,
|
||||
NormalizedToolResult,
|
||||
} from "./types";
|
||||
import { toClaudeTools } from "./tools";
|
||||
import { logRawLlmStream } from "./rawStreamLog";
|
||||
import { createRawLlmStreamRecorder, logRawLlmStream } from "./rawStreamLog";
|
||||
|
||||
type ContentBlock =
|
||||
| { type: "text"; text: string }
|
||||
| { type: "tool_use"; id: string; name: string; input: unknown }
|
||||
| { type: string; [key: string]: unknown };
|
||||
| { type: "text"; text: string }
|
||||
| { type: "tool_use"; id: string; name: string; input: unknown }
|
||||
| { type: string; [key: string]: unknown };
|
||||
|
||||
type NativeMessage = {
|
||||
role: "user" | "assistant";
|
||||
content: string | ContentBlock[];
|
||||
role: "user" | "assistant";
|
||||
content: string | ContentBlock[];
|
||||
};
|
||||
|
||||
const MAX_TOKENS = 16384;
|
||||
|
||||
function apiKey(override?: string | null): string {
|
||||
const key = override?.trim() || process.env.ANTHROPIC_API_KEY?.trim() || "";
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
"Anthropic API key is not configured. Set ANTHROPIC_API_KEY or add a user Anthropic key.",
|
||||
);
|
||||
}
|
||||
return key;
|
||||
const key = override?.trim() || process.env.ANTHROPIC_API_KEY?.trim() || "";
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
"Anthropic API key is not configured. Set ANTHROPIC_API_KEY or add a user Anthropic key.",
|
||||
);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
function client(override?: string | null): Anthropic {
|
||||
const apiKeyValue = apiKey(override);
|
||||
return new Anthropic({ apiKey: apiKeyValue });
|
||||
const apiKeyValue = apiKey(override);
|
||||
return new Anthropic({ apiKey: apiKeyValue });
|
||||
}
|
||||
|
||||
function toNativeMessages(
|
||||
messages: StreamChatParams["messages"],
|
||||
messages: StreamChatParams["messages"],
|
||||
): NativeMessage[] {
|
||||
return messages.map((m) => ({ role: m.role, content: m.content }));
|
||||
return messages.map((m) => ({ role: m.role, content: m.content }));
|
||||
}
|
||||
|
||||
function claudeErrorMessage(error: unknown): string {
|
||||
const parsedObject = claudeStreamFailureMessage(error);
|
||||
if (parsedObject) return parsedObject;
|
||||
if (error instanceof Error && error.message) {
|
||||
const parsed = parseClaudeErrorPayload(error.message);
|
||||
if (parsed) return parsed;
|
||||
return error.message.startsWith("Claude error:")
|
||||
? error.message
|
||||
: `Claude error: ${error.message}`;
|
||||
}
|
||||
const parsed = parseClaudeErrorPayload(String(error));
|
||||
const parsedObject = claudeStreamFailureMessage(error);
|
||||
if (parsedObject) return parsedObject;
|
||||
if (error instanceof Error && error.message) {
|
||||
const parsed = parseClaudeErrorPayload(error.message);
|
||||
if (parsed) return parsed;
|
||||
return `Claude error: ${String(error)}`;
|
||||
return error.message.startsWith("Claude error:")
|
||||
? error.message
|
||||
: `Claude error: ${error.message}`;
|
||||
}
|
||||
const parsed = parseClaudeErrorPayload(String(error));
|
||||
if (parsed) return parsed;
|
||||
return `Claude error: ${String(error)}`;
|
||||
}
|
||||
|
||||
function parseClaudeErrorPayload(value: string): string | null {
|
||||
const trimmed = value.trim();
|
||||
const jsonStart = trimmed.indexOf("{");
|
||||
if (jsonStart < 0) return null;
|
||||
const jsonEnd = trimmed.lastIndexOf("}");
|
||||
if (jsonEnd <= jsonStart) return null;
|
||||
const payload = trimmed.slice(jsonStart, jsonEnd + 1);
|
||||
try {
|
||||
const parsed = JSON.parse(payload) as unknown;
|
||||
return claudeStreamFailureMessage(parsed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
const trimmed = value.trim();
|
||||
const jsonStart = trimmed.indexOf("{");
|
||||
if (jsonStart < 0) return null;
|
||||
const jsonEnd = trimmed.lastIndexOf("}");
|
||||
if (jsonEnd <= jsonStart) return null;
|
||||
const payload = trimmed.slice(jsonStart, jsonEnd + 1);
|
||||
try {
|
||||
const parsed = JSON.parse(payload) as unknown;
|
||||
return claudeStreamFailureMessage(parsed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function claudeStreamFailureMessage(event: unknown): string | null {
|
||||
if (!event || typeof event !== "object") return null;
|
||||
const record = event as Record<string, unknown>;
|
||||
const error = record.error;
|
||||
if (record.type !== "error" || !error || typeof error !== "object") {
|
||||
return null;
|
||||
}
|
||||
const err = error as Record<string, unknown>;
|
||||
const type =
|
||||
typeof err.type === "string" && err.type.trim()
|
||||
? err.type.trim()
|
||||
: null;
|
||||
const message =
|
||||
typeof err.message === "string" && err.message.trim()
|
||||
? err.message.trim()
|
||||
: "Claude stream failed.";
|
||||
return type ? `Claude error (${type}): ${message}` : `Claude error: ${message}`;
|
||||
if (!event || typeof event !== "object") return null;
|
||||
const record = event as Record<string, unknown>;
|
||||
const error = record.error;
|
||||
if (record.type !== "error" || !error || typeof error !== "object") {
|
||||
return null;
|
||||
}
|
||||
const err = error as Record<string, unknown>;
|
||||
const type =
|
||||
typeof err.type === "string" && err.type.trim() ? err.type.trim() : null;
|
||||
const message =
|
||||
typeof err.message === "string" && err.message.trim()
|
||||
? err.message.trim()
|
||||
: "Claude stream failed.";
|
||||
return type
|
||||
? `Claude error (${type}): ${message}`
|
||||
: `Claude error: ${message}`;
|
||||
}
|
||||
|
||||
function abortError(): Error {
|
||||
const err = new Error("Stream aborted.");
|
||||
err.name = "AbortError";
|
||||
return err;
|
||||
const err = new Error("Stream aborted.");
|
||||
err.name = "AbortError";
|
||||
return err;
|
||||
}
|
||||
|
||||
function throwIfAborted(signal?: AbortSignal) {
|
||||
if (signal?.aborted) throw abortError();
|
||||
if (signal?.aborted) throw abortError();
|
||||
}
|
||||
|
||||
export async function streamClaude(
|
||||
params: StreamChatParams,
|
||||
params: StreamChatParams,
|
||||
): Promise<StreamChatResult> {
|
||||
const {
|
||||
model,
|
||||
systemPrompt,
|
||||
tools = [],
|
||||
callbacks = {},
|
||||
runTools,
|
||||
apiKeys,
|
||||
enableThinking,
|
||||
} = params;
|
||||
const maxIter = params.maxIterations ?? 10;
|
||||
const anthropic = client(apiKeys?.claude);
|
||||
const claudeTools = toClaudeTools(tools);
|
||||
const {
|
||||
model,
|
||||
systemPrompt,
|
||||
tools = [],
|
||||
callbacks = {},
|
||||
runTools,
|
||||
apiKeys,
|
||||
enableThinking,
|
||||
} = params;
|
||||
const maxIter = params.maxIterations ?? 10;
|
||||
const anthropic = client(apiKeys?.claude);
|
||||
const claudeTools = toClaudeTools(tools);
|
||||
|
||||
const messages: NativeMessage[] = toNativeMessages(params.messages);
|
||||
let fullText = "";
|
||||
const messages: NativeMessage[] = toNativeMessages(params.messages);
|
||||
let fullText = "";
|
||||
const rawStreamRecorder = createRawLlmStreamRecorder({
|
||||
provider: "claude",
|
||||
model,
|
||||
});
|
||||
|
||||
try {
|
||||
for (let iter = 0; iter < maxIter; iter++) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
const stream = anthropic.messages.stream({
|
||||
model,
|
||||
system: systemPrompt,
|
||||
messages: messages as Anthropic.MessageParam[],
|
||||
tools: claudeTools.length
|
||||
? (claudeTools as unknown as Tool[])
|
||||
: undefined,
|
||||
max_tokens: MAX_TOKENS,
|
||||
// Claude 4.x models require `thinking.type: "adaptive"` and
|
||||
// drive effort via `output_config.effort` rather than a fixed
|
||||
// token budget. We only opt in when the caller requested it.
|
||||
...(enableThinking
|
||||
? ({
|
||||
thinking: { type: "adaptive" },
|
||||
output_config: { effort: "high" },
|
||||
} as unknown as Record<string, unknown>)
|
||||
: {}),
|
||||
// Extended thinking requires temperature to be default (omitted).
|
||||
});
|
||||
throwIfAborted(params.abortSignal);
|
||||
const stream = anthropic.messages.stream({
|
||||
model,
|
||||
system: systemPrompt,
|
||||
messages: messages as Anthropic.MessageParam[],
|
||||
tools: claudeTools.length
|
||||
? (claudeTools as unknown as Tool[])
|
||||
: undefined,
|
||||
max_tokens: MAX_TOKENS,
|
||||
// Claude 4.x models require `thinking.type: "adaptive"` and
|
||||
// drive effort via `output_config.effort` rather than a fixed
|
||||
// token budget. We only opt in when the caller requested it.
|
||||
...(enableThinking
|
||||
? ({
|
||||
thinking: { type: "adaptive" },
|
||||
output_config: { effort: "high" },
|
||||
} as unknown as Record<string, unknown>)
|
||||
: {}),
|
||||
// Extended thinking requires temperature to be default (omitted).
|
||||
});
|
||||
|
||||
let sawThinking = false;
|
||||
let streamFailureMessage: string | null = null;
|
||||
const abortStream = () => stream.abort();
|
||||
params.abortSignal?.addEventListener("abort", abortStream, {
|
||||
once: true,
|
||||
});
|
||||
let sawThinking = false;
|
||||
let streamFailureMessage: string | null = null;
|
||||
const abortStream = () => stream.abort();
|
||||
params.abortSignal?.addEventListener("abort", abortStream, {
|
||||
once: true,
|
||||
});
|
||||
|
||||
stream.on("streamEvent", (event) => {
|
||||
logRawLlmStream({
|
||||
provider: "claude",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "streamEvent",
|
||||
payload: event,
|
||||
});
|
||||
const failureMessage = claudeStreamFailureMessage(event);
|
||||
if (failureMessage) {
|
||||
streamFailureMessage = failureMessage;
|
||||
stream.abort();
|
||||
}
|
||||
stream.on("streamEvent", (event) => {
|
||||
logRawLlmStream({
|
||||
provider: "claude",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "streamEvent",
|
||||
payload: event,
|
||||
});
|
||||
stream.on("error", (error) => {
|
||||
logRawLlmStream({
|
||||
provider: "claude",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "error",
|
||||
payload: error,
|
||||
});
|
||||
rawStreamRecorder?.record({
|
||||
iteration: iter,
|
||||
label: "streamEvent",
|
||||
payload: event,
|
||||
});
|
||||
|
||||
stream.on("text", (delta) => {
|
||||
callbacks.onContentDelta?.(delta);
|
||||
});
|
||||
if (enableThinking) {
|
||||
stream.on("thinking", (delta) => {
|
||||
sawThinking = true;
|
||||
callbacks.onReasoningDelta?.(delta);
|
||||
});
|
||||
const failureMessage = claudeStreamFailureMessage(event);
|
||||
if (failureMessage) {
|
||||
streamFailureMessage = failureMessage;
|
||||
stream.abort();
|
||||
}
|
||||
|
||||
let final: Awaited<ReturnType<typeof stream.finalMessage>>;
|
||||
try {
|
||||
final = await stream.finalMessage();
|
||||
} catch (error) {
|
||||
if (params.abortSignal?.aborted) throw abortError();
|
||||
if (streamFailureMessage) throw new Error(streamFailureMessage);
|
||||
throw new Error(claudeErrorMessage(error));
|
||||
} finally {
|
||||
params.abortSignal?.removeEventListener("abort", abortStream);
|
||||
}
|
||||
if (sawThinking) callbacks.onReasoningBlockEnd?.();
|
||||
throwIfAborted(params.abortSignal);
|
||||
const stopReason = final.stop_reason;
|
||||
const assistantBlocks = final.content as ContentBlock[];
|
||||
|
||||
// Extract text content and tool_use calls from the final assistant
|
||||
// message so we can accumulate text and drive the tool-call loop.
|
||||
const toolCalls: NormalizedToolCall[] = [];
|
||||
for (const block of assistantBlocks) {
|
||||
if (block.type === "text") {
|
||||
const txt = (block as { text: string }).text;
|
||||
if (typeof txt === "string") fullText += txt;
|
||||
} else if (block.type === "tool_use") {
|
||||
const tu = block as {
|
||||
id: string;
|
||||
name: string;
|
||||
input: unknown;
|
||||
};
|
||||
const call: NormalizedToolCall = {
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
input: (tu.input as Record<string, unknown>) ?? {},
|
||||
};
|
||||
callbacks.onToolCallStart?.(call);
|
||||
toolCalls.push(call);
|
||||
}
|
||||
}
|
||||
|
||||
if (stopReason !== "tool_use" || !toolCalls.length || !runTools) {
|
||||
break;
|
||||
}
|
||||
|
||||
const results = await runTools(toolCalls);
|
||||
throwIfAborted(params.abortSignal);
|
||||
|
||||
// Record the assistant turn (preserving the original content blocks,
|
||||
// which Claude requires on the follow-up) and the user turn that
|
||||
// carries the tool_result blocks.
|
||||
messages.push({ role: "assistant", content: assistantBlocks });
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: results.map((r) => ({
|
||||
type: "tool_result",
|
||||
tool_use_id: r.tool_use_id,
|
||||
content: r.content,
|
||||
})),
|
||||
});
|
||||
stream.on("error", (error) => {
|
||||
logRawLlmStream({
|
||||
provider: "claude",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "error",
|
||||
payload: error,
|
||||
});
|
||||
rawStreamRecorder?.record({
|
||||
iteration: iter,
|
||||
label: "error",
|
||||
payload: error,
|
||||
});
|
||||
});
|
||||
|
||||
stream.on("text", (delta) => {
|
||||
callbacks.onContentDelta?.(delta);
|
||||
});
|
||||
if (enableThinking) {
|
||||
stream.on("thinking", (delta) => {
|
||||
sawThinking = true;
|
||||
callbacks.onReasoningDelta?.(delta);
|
||||
});
|
||||
}
|
||||
|
||||
let final: Awaited<ReturnType<typeof stream.finalMessage>>;
|
||||
try {
|
||||
final = await stream.finalMessage();
|
||||
} catch (error) {
|
||||
if (params.abortSignal?.aborted) throw abortError();
|
||||
if (streamFailureMessage) throw new Error(streamFailureMessage);
|
||||
throw new Error(claudeErrorMessage(error));
|
||||
} finally {
|
||||
params.abortSignal?.removeEventListener("abort", abortStream);
|
||||
}
|
||||
if (sawThinking) callbacks.onReasoningBlockEnd?.();
|
||||
throwIfAborted(params.abortSignal);
|
||||
const stopReason = final.stop_reason;
|
||||
const assistantBlocks = final.content as ContentBlock[];
|
||||
|
||||
// Extract text content and tool_use calls from the final assistant
|
||||
// message so we can accumulate text and drive the tool-call loop.
|
||||
const toolCalls: NormalizedToolCall[] = [];
|
||||
for (const block of assistantBlocks) {
|
||||
if (block.type === "text") {
|
||||
const txt = (block as { text: string }).text;
|
||||
if (typeof txt === "string") fullText += txt;
|
||||
} else if (block.type === "tool_use") {
|
||||
const tu = block as {
|
||||
id: string;
|
||||
name: string;
|
||||
input: unknown;
|
||||
};
|
||||
const call: NormalizedToolCall = {
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
input: (tu.input as Record<string, unknown>) ?? {},
|
||||
};
|
||||
callbacks.onToolCallStart?.(call);
|
||||
toolCalls.push(call);
|
||||
}
|
||||
}
|
||||
|
||||
if (stopReason !== "tool_use" || !toolCalls.length || !runTools) {
|
||||
break;
|
||||
}
|
||||
|
||||
const results = await runTools(toolCalls);
|
||||
throwIfAborted(params.abortSignal);
|
||||
|
||||
// Record the assistant turn (preserving the original content blocks,
|
||||
// which Claude requires on the follow-up) and the user turn that
|
||||
// carries the tool_result blocks.
|
||||
messages.push({ role: "assistant", content: assistantBlocks });
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: results.map((r) => ({
|
||||
type: "tool_result",
|
||||
tool_use_id: r.tool_use_id,
|
||||
content: r.content,
|
||||
})),
|
||||
});
|
||||
}
|
||||
|
||||
await rawStreamRecorder?.flush("completed");
|
||||
return { fullText };
|
||||
} catch (error) {
|
||||
await rawStreamRecorder?.flush("error", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function completeClaudeText(params: {
|
||||
model: string;
|
||||
systemPrompt?: string;
|
||||
user: string;
|
||||
maxTokens?: number;
|
||||
apiKeys?: { claude?: string | null };
|
||||
model: string;
|
||||
systemPrompt?: string;
|
||||
user: string;
|
||||
maxTokens?: number;
|
||||
apiKeys?: { claude?: string | null };
|
||||
}): Promise<string> {
|
||||
const anthropic = client(params.apiKeys?.claude);
|
||||
let resp: Awaited<ReturnType<typeof anthropic.messages.create>>;
|
||||
try {
|
||||
resp = await anthropic.messages.create({
|
||||
model: params.model,
|
||||
max_tokens: params.maxTokens ?? 512,
|
||||
system: params.systemPrompt,
|
||||
messages: [{ role: "user", content: params.user }],
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(claudeErrorMessage(error));
|
||||
}
|
||||
const text = resp.content
|
||||
.filter((b): b is Anthropic.TextBlock => b.type === "text")
|
||||
.map((b) => b.text)
|
||||
.join("");
|
||||
return text;
|
||||
const anthropic = client(params.apiKeys?.claude);
|
||||
let resp: Awaited<ReturnType<typeof anthropic.messages.create>>;
|
||||
try {
|
||||
resp = await anthropic.messages.create({
|
||||
model: params.model,
|
||||
max_tokens: params.maxTokens ?? 512,
|
||||
system: params.systemPrompt,
|
||||
messages: [{ role: "user", content: params.user }],
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(claudeErrorMessage(error));
|
||||
}
|
||||
const text = resp.content
|
||||
.filter((b): b is Anthropic.TextBlock => b.type === "text")
|
||||
.map((b) => b.text)
|
||||
.join("");
|
||||
return text;
|
||||
}
|
||||
|
||||
// Helper re-export for callers wanting to hand normalized results back in.
|
||||
|
|
|
|||
|
|
@ -1,326 +1,351 @@
|
|||
import { GoogleGenAI } from "@google/genai";
|
||||
import type {
|
||||
StreamChatParams,
|
||||
StreamChatResult,
|
||||
NormalizedToolCall,
|
||||
StreamChatParams,
|
||||
StreamChatResult,
|
||||
NormalizedToolCall,
|
||||
} from "./types";
|
||||
import { toGeminiTools } from "./tools";
|
||||
import { logRawLlmStream } from "./rawStreamLog";
|
||||
import { createRawLlmStreamRecorder, logRawLlmStream } from "./rawStreamLog";
|
||||
|
||||
type GeminiPart = {
|
||||
text?: string;
|
||||
// Set by Gemini when the text content is a thought summary rather than
|
||||
// final-answer prose. Requires `thinkingConfig.includeThoughts: true`.
|
||||
thought?: boolean;
|
||||
functionCall?: { id?: string; name: string; args?: Record<string, unknown> };
|
||||
functionResponse?: {
|
||||
id?: string;
|
||||
name: string;
|
||||
response: Record<string, unknown>;
|
||||
};
|
||||
// Gemini 3 returns a thoughtSignature on parts that contain reasoning or
|
||||
// a functionCall. It must be echoed back verbatim on the same part when
|
||||
// we replay the model's turn, or the API rejects the next call.
|
||||
thoughtSignature?: string;
|
||||
text?: string;
|
||||
// Set by Gemini when the text content is a thought summary rather than
|
||||
// final-answer prose. Requires `thinkingConfig.includeThoughts: true`.
|
||||
thought?: boolean;
|
||||
functionCall?: { id?: string; name: string; args?: Record<string, unknown> };
|
||||
functionResponse?: {
|
||||
id?: string;
|
||||
name: string;
|
||||
response: Record<string, unknown>;
|
||||
};
|
||||
// Gemini 3 returns a thoughtSignature on parts that contain reasoning or
|
||||
// a functionCall. It must be echoed back verbatim on the same part when
|
||||
// we replay the model's turn, or the API rejects the next call.
|
||||
thoughtSignature?: string;
|
||||
};
|
||||
|
||||
type GeminiContent = {
|
||||
role: "user" | "model";
|
||||
parts: GeminiPart[];
|
||||
role: "user" | "model";
|
||||
parts: GeminiPart[];
|
||||
};
|
||||
|
||||
function apiKey(override?: string | null): string {
|
||||
const key = override?.trim() || process.env.GEMINI_API_KEY?.trim() || "";
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
"Gemini API key is not configured. Set GEMINI_API_KEY or add a user Gemini key.",
|
||||
);
|
||||
}
|
||||
return key;
|
||||
const key = override?.trim() || process.env.GEMINI_API_KEY?.trim() || "";
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
"Gemini API key is not configured. Set GEMINI_API_KEY or add a user Gemini key.",
|
||||
);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
function client(override?: string | null): GoogleGenAI {
|
||||
return new GoogleGenAI({ apiKey: apiKey(override) });
|
||||
return new GoogleGenAI({ apiKey: apiKey(override) });
|
||||
}
|
||||
|
||||
function toNativeContents(messages: StreamChatParams["messages"]): GeminiContent[] {
|
||||
return messages.map((m) => ({
|
||||
role: m.role === "assistant" ? "model" : "user",
|
||||
parts: [{ text: m.content }],
|
||||
}));
|
||||
function toNativeContents(
|
||||
messages: StreamChatParams["messages"],
|
||||
): GeminiContent[] {
|
||||
return messages.map((m) => ({
|
||||
role: m.role === "assistant" ? "model" : "user",
|
||||
parts: [{ text: m.content }],
|
||||
}));
|
||||
}
|
||||
|
||||
function geminiErrorMessage(error: unknown): string {
|
||||
const parsedObject = geminiStreamFailureMessage(error);
|
||||
if (parsedObject) return parsedObject;
|
||||
if (typeof error === "string") {
|
||||
const parsed = parseGeminiErrorPayload(error);
|
||||
if (parsed) return parsed;
|
||||
return error.startsWith("Gemini error:")
|
||||
? error
|
||||
: `Gemini error: ${error}`;
|
||||
}
|
||||
if (error instanceof Error && error.message) {
|
||||
const parsed = parseGeminiErrorPayload(error.message);
|
||||
if (parsed) return parsed;
|
||||
return error.message.startsWith("Gemini error:")
|
||||
? error.message
|
||||
: `Gemini error: ${error.message}`;
|
||||
}
|
||||
return `Gemini error: ${String(error)}`;
|
||||
const parsedObject = geminiStreamFailureMessage(error);
|
||||
if (parsedObject) return parsedObject;
|
||||
if (typeof error === "string") {
|
||||
const parsed = parseGeminiErrorPayload(error);
|
||||
if (parsed) return parsed;
|
||||
return error.startsWith("Gemini error:") ? error : `Gemini error: ${error}`;
|
||||
}
|
||||
if (error instanceof Error && error.message) {
|
||||
const parsed = parseGeminiErrorPayload(error.message);
|
||||
if (parsed) return parsed;
|
||||
return error.message.startsWith("Gemini error:")
|
||||
? error.message
|
||||
: `Gemini error: ${error.message}`;
|
||||
}
|
||||
return `Gemini error: ${String(error)}`;
|
||||
}
|
||||
|
||||
function parseGeminiErrorPayload(value: string): string | null {
|
||||
const trimmed = value.trim();
|
||||
if (!trimmed.startsWith("{")) return null;
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed) as unknown;
|
||||
return geminiStreamFailureMessage(parsed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
const trimmed = value.trim();
|
||||
if (!trimmed.startsWith("{")) return null;
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed) as unknown;
|
||||
return geminiStreamFailureMessage(parsed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function geminiStreamFailureMessage(chunk: unknown): string | null {
|
||||
if (!chunk || typeof chunk !== "object") return null;
|
||||
const record = chunk as Record<string, unknown>;
|
||||
const error = record.error;
|
||||
if (error && typeof error === "object") {
|
||||
const err = error as Record<string, unknown>;
|
||||
const nested =
|
||||
typeof err.message === "string"
|
||||
? parseGeminiErrorPayload(err.message)
|
||||
: null;
|
||||
if (nested) return nested;
|
||||
const message =
|
||||
typeof err.message === "string" && err.message.trim()
|
||||
? err.message.trim()
|
||||
: "Gemini stream failed.";
|
||||
const code =
|
||||
typeof err.code === "string" && err.code.trim()
|
||||
? err.code.trim()
|
||||
: typeof err.code === "number" && Number.isFinite(err.code)
|
||||
? String(err.code)
|
||||
: typeof err.status === "string" && err.status.trim()
|
||||
? err.status.trim()
|
||||
: null;
|
||||
return code ? `Gemini error (${code}): ${message}` : `Gemini error: ${message}`;
|
||||
}
|
||||
|
||||
const promptFeedback = record.promptFeedback;
|
||||
if (promptFeedback && typeof promptFeedback === "object") {
|
||||
const feedback = promptFeedback as Record<string, unknown>;
|
||||
const blockReason =
|
||||
typeof feedback.blockReason === "string"
|
||||
? feedback.blockReason
|
||||
: null;
|
||||
if (blockReason) {
|
||||
const detail =
|
||||
typeof feedback.blockReasonMessage === "string" &&
|
||||
feedback.blockReasonMessage.trim()
|
||||
? feedback.blockReasonMessage.trim()
|
||||
: "The Gemini response was blocked.";
|
||||
return `Gemini error (${blockReason}): ${detail}`;
|
||||
}
|
||||
}
|
||||
|
||||
const candidates = Array.isArray(record.candidates)
|
||||
? (record.candidates as Record<string, unknown>[])
|
||||
: [];
|
||||
const finishReason =
|
||||
typeof candidates[0]?.finishReason === "string"
|
||||
? candidates[0].finishReason
|
||||
if (!chunk || typeof chunk !== "object") return null;
|
||||
const record = chunk as Record<string, unknown>;
|
||||
const error = record.error;
|
||||
if (error && typeof error === "object") {
|
||||
const err = error as Record<string, unknown>;
|
||||
const nested =
|
||||
typeof err.message === "string"
|
||||
? parseGeminiErrorPayload(err.message)
|
||||
: null;
|
||||
if (nested) return nested;
|
||||
const message =
|
||||
typeof err.message === "string" && err.message.trim()
|
||||
? err.message.trim()
|
||||
: "Gemini stream failed.";
|
||||
const code =
|
||||
typeof err.code === "string" && err.code.trim()
|
||||
? err.code.trim()
|
||||
: typeof err.code === "number" && Number.isFinite(err.code)
|
||||
? String(err.code)
|
||||
: typeof err.status === "string" && err.status.trim()
|
||||
? err.status.trim()
|
||||
: null;
|
||||
const errorFinishReasons = new Set([
|
||||
"SAFETY",
|
||||
"RECITATION",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"OTHER",
|
||||
]);
|
||||
if (finishReason && errorFinishReasons.has(finishReason)) {
|
||||
return `Gemini error (${finishReason}): The Gemini stream ended with an error finish reason.`;
|
||||
}
|
||||
return code
|
||||
? `Gemini error (${code}): ${message}`
|
||||
: `Gemini error: ${message}`;
|
||||
}
|
||||
|
||||
return null;
|
||||
const promptFeedback = record.promptFeedback;
|
||||
if (promptFeedback && typeof promptFeedback === "object") {
|
||||
const feedback = promptFeedback as Record<string, unknown>;
|
||||
const blockReason =
|
||||
typeof feedback.blockReason === "string" ? feedback.blockReason : null;
|
||||
if (blockReason) {
|
||||
const detail =
|
||||
typeof feedback.blockReasonMessage === "string" &&
|
||||
feedback.blockReasonMessage.trim()
|
||||
? feedback.blockReasonMessage.trim()
|
||||
: "The Gemini response was blocked.";
|
||||
return `Gemini error (${blockReason}): ${detail}`;
|
||||
}
|
||||
}
|
||||
|
||||
const candidates = Array.isArray(record.candidates)
|
||||
? (record.candidates as Record<string, unknown>[])
|
||||
: [];
|
||||
const finishReason =
|
||||
typeof candidates[0]?.finishReason === "string"
|
||||
? candidates[0].finishReason
|
||||
: null;
|
||||
const errorFinishReasons = new Set([
|
||||
"SAFETY",
|
||||
"RECITATION",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"OTHER",
|
||||
]);
|
||||
if (finishReason && errorFinishReasons.has(finishReason)) {
|
||||
return `Gemini error (${finishReason}): The Gemini stream ended with an error finish reason.`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function abortError(): Error {
|
||||
const err = new Error("Stream aborted.");
|
||||
err.name = "AbortError";
|
||||
return err;
|
||||
const err = new Error("Stream aborted.");
|
||||
err.name = "AbortError";
|
||||
return err;
|
||||
}
|
||||
|
||||
function throwIfAborted(signal?: AbortSignal) {
|
||||
if (signal?.aborted) throw abortError();
|
||||
if (signal?.aborted) throw abortError();
|
||||
}
|
||||
|
||||
export async function streamGemini(
|
||||
params: StreamChatParams,
|
||||
params: StreamChatParams,
|
||||
): Promise<StreamChatResult> {
|
||||
const { model, systemPrompt, tools = [], callbacks = {}, runTools, apiKeys, enableThinking } = params;
|
||||
const maxIter = params.maxIterations ?? 10;
|
||||
const ai = client(apiKeys?.gemini);
|
||||
const functionDeclarations = toGeminiTools(tools);
|
||||
const {
|
||||
model,
|
||||
systemPrompt,
|
||||
tools = [],
|
||||
callbacks = {},
|
||||
runTools,
|
||||
apiKeys,
|
||||
enableThinking,
|
||||
} = params;
|
||||
const maxIter = params.maxIterations ?? 10;
|
||||
const ai = client(apiKeys?.gemini);
|
||||
const functionDeclarations = toGeminiTools(tools);
|
||||
|
||||
const contents: GeminiContent[] = toNativeContents(params.messages);
|
||||
let fullText = "";
|
||||
const contents: GeminiContent[] = toNativeContents(params.messages);
|
||||
let fullText = "";
|
||||
const rawStreamRecorder = createRawLlmStreamRecorder({
|
||||
provider: "gemini",
|
||||
model,
|
||||
});
|
||||
|
||||
try {
|
||||
for (let iter = 0; iter < maxIter; iter++) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
let stream: AsyncIterable<unknown>;
|
||||
try {
|
||||
stream = await ai.models.generateContentStream({
|
||||
model,
|
||||
contents: contents as never,
|
||||
config: {
|
||||
systemInstruction: systemPrompt,
|
||||
tools: functionDeclarations.length
|
||||
? [{ functionDeclarations } as never]
|
||||
: undefined,
|
||||
// When enabled, ask Gemini to surface thought summaries.
|
||||
// When disabled, explicitly zero the thinking budget so the
|
||||
// model skips thinking entirely (saves tokens and latency
|
||||
// for bulk extraction jobs).
|
||||
thinkingConfig: enableThinking
|
||||
? { includeThoughts: true }
|
||||
: { thinkingBudget: 0 },
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(geminiErrorMessage(error));
|
||||
}
|
||||
|
||||
// Per-iteration accumulators.
|
||||
const textParts: string[] = [];
|
||||
const callParts: GeminiPart[] = [];
|
||||
const toolCalls: NormalizedToolCall[] = [];
|
||||
let sawThinking = false;
|
||||
const iterator = stream[Symbol.asyncIterator]();
|
||||
let rejectAbort: ((reason?: unknown) => void) | null = null;
|
||||
const abortPromise = new Promise<never>((_, reject) => {
|
||||
rejectAbort = reject;
|
||||
});
|
||||
const onAbort = () => rejectAbort?.(abortError());
|
||||
params.abortSignal?.addEventListener("abort", onAbort, {
|
||||
once: true,
|
||||
throwIfAborted(params.abortSignal);
|
||||
let stream: AsyncIterable<unknown>;
|
||||
try {
|
||||
stream = await ai.models.generateContentStream({
|
||||
model,
|
||||
contents: contents as never,
|
||||
config: {
|
||||
systemInstruction: systemPrompt,
|
||||
tools: functionDeclarations.length
|
||||
? [{ functionDeclarations } as never]
|
||||
: undefined,
|
||||
// When enabled, ask Gemini to surface thought summaries.
|
||||
// When disabled, explicitly zero the thinking budget so the
|
||||
// model skips thinking entirely (saves tokens and latency
|
||||
// for bulk extraction jobs).
|
||||
thinkingConfig: enableThinking
|
||||
? { includeThoughts: true }
|
||||
: { thinkingBudget: 0 },
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(geminiErrorMessage(error));
|
||||
}
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
const { value: chunk, done } = await Promise.race([
|
||||
iterator.next(),
|
||||
abortPromise,
|
||||
]);
|
||||
if (done) break;
|
||||
logRawLlmStream({
|
||||
provider: "gemini",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "chunk",
|
||||
payload: chunk,
|
||||
});
|
||||
const failureMessage = geminiStreamFailureMessage(chunk);
|
||||
if (failureMessage) throw new Error(failureMessage);
|
||||
// Per-iteration accumulators.
|
||||
const textParts: string[] = [];
|
||||
const callParts: GeminiPart[] = [];
|
||||
const toolCalls: NormalizedToolCall[] = [];
|
||||
let sawThinking = false;
|
||||
const iterator = stream[Symbol.asyncIterator]();
|
||||
let rejectAbort: ((reason?: unknown) => void) | null = null;
|
||||
const abortPromise = new Promise<never>((_, reject) => {
|
||||
rejectAbort = reject;
|
||||
});
|
||||
const onAbort = () => rejectAbort?.(abortError());
|
||||
params.abortSignal?.addEventListener("abort", onAbort, {
|
||||
once: true,
|
||||
});
|
||||
|
||||
const parts =
|
||||
(chunk as { candidates?: { content?: { parts?: GeminiPart[] } }[] })
|
||||
.candidates?.[0]?.content?.parts ?? [];
|
||||
try {
|
||||
while (true) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
const { value: chunk, done } = await Promise.race([
|
||||
iterator.next(),
|
||||
abortPromise,
|
||||
]);
|
||||
if (done) break;
|
||||
logRawLlmStream({
|
||||
provider: "gemini",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "chunk",
|
||||
payload: chunk,
|
||||
});
|
||||
rawStreamRecorder?.record({
|
||||
iteration: iter,
|
||||
label: "chunk",
|
||||
payload: chunk,
|
||||
});
|
||||
const failureMessage = geminiStreamFailureMessage(chunk);
|
||||
if (failureMessage) throw new Error(failureMessage);
|
||||
|
||||
for (const part of parts) {
|
||||
if (part.text) {
|
||||
if (part.thought) {
|
||||
sawThinking = true;
|
||||
callbacks.onReasoningDelta?.(part.text);
|
||||
} else {
|
||||
textParts.push(part.text);
|
||||
callbacks.onContentDelta?.(part.text);
|
||||
}
|
||||
}
|
||||
if (part.functionCall) {
|
||||
// Preserve the whole part (including thoughtSignature)
|
||||
// so it can be echoed verbatim in the replay turn.
|
||||
callParts.push(part);
|
||||
const call: NormalizedToolCall = {
|
||||
id: part.functionCall.id ?? `${part.functionCall.name}-${toolCalls.length}`,
|
||||
name: part.functionCall.name,
|
||||
input: part.functionCall.args ?? {},
|
||||
};
|
||||
callbacks.onToolCallStart?.(call);
|
||||
toolCalls.push(call);
|
||||
}
|
||||
}
|
||||
const parts =
|
||||
(chunk as { candidates?: { content?: { parts?: GeminiPart[] } }[] })
|
||||
.candidates?.[0]?.content?.parts ?? [];
|
||||
|
||||
for (const part of parts) {
|
||||
if (part.text) {
|
||||
if (part.thought) {
|
||||
sawThinking = true;
|
||||
callbacks.onReasoningDelta?.(part.text);
|
||||
} else {
|
||||
textParts.push(part.text);
|
||||
callbacks.onContentDelta?.(part.text);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (params.abortSignal?.aborted) throw abortError();
|
||||
throw new Error(geminiErrorMessage(error));
|
||||
} finally {
|
||||
params.abortSignal?.removeEventListener("abort", onAbort);
|
||||
if (params.abortSignal?.aborted) {
|
||||
await iterator.return?.();
|
||||
if (part.functionCall) {
|
||||
// Preserve the whole part (including thoughtSignature)
|
||||
// so it can be echoed verbatim in the replay turn.
|
||||
callParts.push(part);
|
||||
const call: NormalizedToolCall = {
|
||||
id:
|
||||
part.functionCall.id ??
|
||||
`${part.functionCall.name}-${toolCalls.length}`,
|
||||
name: part.functionCall.name,
|
||||
input: part.functionCall.args ?? {},
|
||||
};
|
||||
callbacks.onToolCallStart?.(call);
|
||||
toolCalls.push(call);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sawThinking) callbacks.onReasoningBlockEnd?.();
|
||||
throwIfAborted(params.abortSignal);
|
||||
|
||||
fullText += textParts.join("");
|
||||
|
||||
if (!toolCalls.length || !runTools) {
|
||||
break;
|
||||
} catch (error) {
|
||||
if (params.abortSignal?.aborted) throw abortError();
|
||||
throw new Error(geminiErrorMessage(error));
|
||||
} finally {
|
||||
params.abortSignal?.removeEventListener("abort", onAbort);
|
||||
if (params.abortSignal?.aborted) {
|
||||
await iterator.return?.();
|
||||
}
|
||||
}
|
||||
|
||||
const results = await runTools(toolCalls);
|
||||
throwIfAborted(params.abortSignal);
|
||||
if (sawThinking) callbacks.onReasoningBlockEnd?.();
|
||||
throwIfAborted(params.abortSignal);
|
||||
|
||||
// Append the model's turn (text + functionCall parts, in that order)
|
||||
// and the matching functionResponse turn.
|
||||
const modelParts: GeminiPart[] = [];
|
||||
if (textParts.length) modelParts.push({ text: textParts.join("") });
|
||||
for (const cp of callParts) modelParts.push(cp);
|
||||
contents.push({ role: "model", parts: modelParts });
|
||||
fullText += textParts.join("");
|
||||
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: results.map((r) => {
|
||||
const match = toolCalls.find((c) => c.id === r.tool_use_id);
|
||||
return {
|
||||
functionResponse: {
|
||||
...(r.tool_use_id && !r.tool_use_id.startsWith(match?.name ?? "")
|
||||
? { id: r.tool_use_id }
|
||||
: {}),
|
||||
name: match?.name ?? "tool",
|
||||
response: { output: r.content },
|
||||
},
|
||||
};
|
||||
}),
|
||||
});
|
||||
if (!toolCalls.length || !runTools) {
|
||||
break;
|
||||
}
|
||||
|
||||
const results = await runTools(toolCalls);
|
||||
throwIfAborted(params.abortSignal);
|
||||
|
||||
// Append the model's turn (text + functionCall parts, in that order)
|
||||
// and the matching functionResponse turn.
|
||||
const modelParts: GeminiPart[] = [];
|
||||
if (textParts.length) modelParts.push({ text: textParts.join("") });
|
||||
for (const cp of callParts) modelParts.push(cp);
|
||||
contents.push({ role: "model", parts: modelParts });
|
||||
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: results.map((r) => {
|
||||
const match = toolCalls.find((c) => c.id === r.tool_use_id);
|
||||
return {
|
||||
functionResponse: {
|
||||
...(r.tool_use_id && !r.tool_use_id.startsWith(match?.name ?? "")
|
||||
? { id: r.tool_use_id }
|
||||
: {}),
|
||||
name: match?.name ?? "tool",
|
||||
response: { output: r.content },
|
||||
},
|
||||
};
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
await rawStreamRecorder?.flush("completed");
|
||||
return { fullText };
|
||||
} catch (error) {
|
||||
await rawStreamRecorder?.flush("error", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function completeGeminiText(params: {
|
||||
model: string;
|
||||
systemPrompt?: string;
|
||||
user: string;
|
||||
apiKeys?: { gemini?: string | null };
|
||||
model: string;
|
||||
systemPrompt?: string;
|
||||
user: string;
|
||||
apiKeys?: { gemini?: string | null };
|
||||
}): Promise<string> {
|
||||
const ai = client(params.apiKeys?.gemini);
|
||||
let resp: Awaited<ReturnType<typeof ai.models.generateContent>>;
|
||||
try {
|
||||
resp = await ai.models.generateContent({
|
||||
model: params.model,
|
||||
contents: [{ role: "user", parts: [{ text: params.user }] }],
|
||||
config: params.systemPrompt
|
||||
? { systemInstruction: params.systemPrompt }
|
||||
: undefined,
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(geminiErrorMessage(error));
|
||||
}
|
||||
return resp.text ?? "";
|
||||
const ai = client(params.apiKeys?.gemini);
|
||||
let resp: Awaited<ReturnType<typeof ai.models.generateContent>>;
|
||||
try {
|
||||
resp = await ai.models.generateContent({
|
||||
model: params.model,
|
||||
contents: [{ role: "user", parts: [{ text: params.user }] }],
|
||||
config: params.systemPrompt
|
||||
? { systemInstruction: params.systemPrompt }
|
||||
: undefined,
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(geminiErrorMessage(error));
|
||||
}
|
||||
return resp.text ?? "";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,14 @@ import type { Provider } from "./types";
|
|||
// Canonical model IDs
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main-chat tier (top-end) — user picks one of these per message.
|
||||
export const CLAUDE_MAIN_MODELS = ["claude-opus-4-7", "claude-sonnet-4-6"] as const;
|
||||
export const CLAUDE_MAIN_MODELS = [
|
||||
"claude-fable-5",
|
||||
"claude-opus-4-8",
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6",
|
||||
] as const;
|
||||
export const GEMINI_MAIN_MODELS = [
|
||||
"gemini-3.5-flash",
|
||||
"gemini-3.1-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
] as const;
|
||||
|
|
@ -13,7 +19,7 @@ export const OPENAI_MAIN_MODELS = ["gpt-5.5", "gpt-5.4"] as const;
|
|||
|
||||
// Mid-tier (used for tabular review) — user picks one in account settings.
|
||||
export const CLAUDE_MID_MODELS = ["claude-sonnet-4-6"] as const;
|
||||
export const GEMINI_MID_MODELS = ["gemini-3-flash-preview"] as const;
|
||||
export const GEMINI_MID_MODELS = ["gemini-3.5-flash", "gemini-3-flash-preview"] as const;
|
||||
export const OPENAI_MID_MODELS = ["gpt-5.4"] as const;
|
||||
|
||||
// Low-tier (used for title generation, lightweight extractions) — user picks
|
||||
|
|
|
|||
|
|
@ -1,363 +1,400 @@
|
|||
import type {
|
||||
LlmMessage,
|
||||
NormalizedToolCall,
|
||||
NormalizedToolResult,
|
||||
OpenAIToolSchema,
|
||||
StreamChatParams,
|
||||
StreamChatResult,
|
||||
LlmMessage,
|
||||
NormalizedToolCall,
|
||||
NormalizedToolResult,
|
||||
OpenAIToolSchema,
|
||||
StreamChatParams,
|
||||
StreamChatResult,
|
||||
} from "./types";
|
||||
import { logRawLlmStream } from "./rawStreamLog";
|
||||
import { createRawLlmStreamRecorder, logRawLlmStream } from "./rawStreamLog";
|
||||
|
||||
const OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses";
|
||||
const MAX_OUTPUT_TOKENS = 16384;
|
||||
const COURTLISTENER_CITATION_REMINDER_TOOL_NAMES = new Set([
|
||||
"courtlistener_find_in_case",
|
||||
"courtlistener_read_case",
|
||||
]);
|
||||
const COURTLISTENER_CITATION_REMINDER = `COURTLISTENER CITATION REMINDER:
|
||||
If your final answer relies on any CourtListener case, every such case reference must have BOTH a clickable markdown case link and an inline [N] marker.
|
||||
Include the clickable case link only the first time you cite that case; later references to the same case should reuse the existing inline [N] marker without repeating the link unless clarity requires it.
|
||||
Assign new refs in first-use order as much as possible: [1], then [2], then [3]. Reuse an existing ref when citing the same case/passage again, even if that means a later sentence cites [3] and then [1] again.
|
||||
End the response with a <CITATIONS> block containing one matching case entry per [N] marker:
|
||||
{"ref": N, "cluster_id": 123, "quotes": [{"opinion_id": 456, "quote": "exact verbatim opinion text"}]}.
|
||||
Do not use doc_id, page, top-level quote, case_name, or citation fields for CourtListener case entries.`;
|
||||
|
||||
type ResponseInputItem =
|
||||
| { role: "user" | "assistant"; content: string }
|
||||
| { type: "function_call_output"; call_id: string; output: string };
|
||||
| { role: "user" | "assistant"; content: string }
|
||||
| { type: "function_call_output"; call_id: string; output: string };
|
||||
|
||||
type ResponseFunctionTool = {
|
||||
type: "function";
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: Record<string, unknown>;
|
||||
type: "function";
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type ResponseFunctionCallItem = {
|
||||
type: "function_call";
|
||||
call_id?: string;
|
||||
name?: string;
|
||||
arguments?: string;
|
||||
type: "function_call";
|
||||
call_id?: string;
|
||||
name?: string;
|
||||
arguments?: string;
|
||||
};
|
||||
|
||||
type ResponseStreamEvent = {
|
||||
type?: string;
|
||||
delta?: string;
|
||||
response?: {
|
||||
id?: string;
|
||||
output_text?: string;
|
||||
status?: string;
|
||||
error?: { code?: string; message?: string } | null;
|
||||
};
|
||||
type?: string;
|
||||
delta?: string;
|
||||
response?: {
|
||||
id?: string;
|
||||
output_text?: string;
|
||||
status?: string;
|
||||
error?: { code?: string; message?: string } | null;
|
||||
item?: ResponseFunctionCallItem;
|
||||
};
|
||||
error?: { code?: string; message?: string } | null;
|
||||
item?: ResponseFunctionCallItem;
|
||||
};
|
||||
|
||||
function apiKey(override?: string | null): string {
|
||||
const key = override?.trim() || process.env.OPENAI_API_KEY?.trim() || "";
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
"OpenAI API key is not configured. Set OPENAI_API_KEY or add a user OpenAI key.",
|
||||
);
|
||||
}
|
||||
return key;
|
||||
const key = override?.trim() || process.env.OPENAI_API_KEY?.trim() || "";
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
"OpenAI API key is not configured. Set OPENAI_API_KEY or add a user OpenAI key.",
|
||||
);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
function toResponseTools(tools: OpenAIToolSchema[]): ResponseFunctionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
parameters: tool.function.parameters,
|
||||
}));
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
parameters: tool.function.parameters,
|
||||
}));
|
||||
}
|
||||
|
||||
function toResponseInput(messages: LlmMessage[]): ResponseInputItem[] {
|
||||
return messages.map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
}));
|
||||
return messages.map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
}));
|
||||
}
|
||||
|
||||
function extractSseJson(buffer: string): { events: unknown[]; rest: string } {
|
||||
const events: unknown[] = [];
|
||||
const chunks = buffer.split(/\n\n/);
|
||||
const rest = chunks.pop() ?? "";
|
||||
const events: unknown[] = [];
|
||||
const chunks = buffer.split(/\n\n/);
|
||||
const rest = chunks.pop() ?? "";
|
||||
|
||||
for (const chunk of chunks) {
|
||||
const dataLines = chunk
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter((line) => line.startsWith("data:"))
|
||||
.map((line) => line.slice(5).trim());
|
||||
for (const chunk of chunks) {
|
||||
const dataLines = chunk
|
||||
.split("\n")
|
||||
.map((line) => line.trim())
|
||||
.filter((line) => line.startsWith("data:"))
|
||||
.map((line) => line.slice(5).trim());
|
||||
|
||||
for (const data of dataLines) {
|
||||
if (!data || data === "[DONE]") continue;
|
||||
try {
|
||||
events.push(JSON.parse(data));
|
||||
} catch {
|
||||
// Incomplete events stay buffered until the next read.
|
||||
}
|
||||
}
|
||||
for (const data of dataLines) {
|
||||
if (!data || data === "[DONE]") continue;
|
||||
try {
|
||||
events.push(JSON.parse(data));
|
||||
} catch {
|
||||
// Incomplete events stay buffered until the next read.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { events, rest };
|
||||
return { events, rest };
|
||||
}
|
||||
|
||||
function parseFunctionCall(item: ResponseFunctionCallItem): NormalizedToolCall {
|
||||
let input: Record<string, unknown> = {};
|
||||
try {
|
||||
const parsed = JSON.parse(item.arguments || "{}");
|
||||
if (parsed && typeof parsed === "object" && !Array.isArray(parsed)) {
|
||||
input = parsed as Record<string, unknown>;
|
||||
}
|
||||
} catch {
|
||||
input = {};
|
||||
let input: Record<string, unknown> = {};
|
||||
try {
|
||||
const parsed = JSON.parse(item.arguments || "{}");
|
||||
if (parsed && typeof parsed === "object" && !Array.isArray(parsed)) {
|
||||
input = parsed as Record<string, unknown>;
|
||||
}
|
||||
} catch {
|
||||
input = {};
|
||||
}
|
||||
|
||||
return {
|
||||
id: item.call_id ?? item.name ?? "function_call",
|
||||
name: item.name ?? "",
|
||||
input,
|
||||
};
|
||||
return {
|
||||
id: item.call_id ?? item.name ?? "function_call",
|
||||
name: item.name ?? "",
|
||||
input,
|
||||
};
|
||||
}
|
||||
|
||||
function openAIStreamFailureMessage(event: ResponseStreamEvent): string | null {
|
||||
const error = event.response?.error ?? event.error ?? null;
|
||||
const failed =
|
||||
event.type === "response.failed" ||
|
||||
event.response?.status === "failed" ||
|
||||
!!error;
|
||||
if (!failed) return null;
|
||||
const error = event.response?.error ?? event.error ?? null;
|
||||
const failed =
|
||||
event.type === "response.failed" ||
|
||||
event.response?.status === "failed" ||
|
||||
!!error;
|
||||
if (!failed) return null;
|
||||
|
||||
const message =
|
||||
typeof error?.message === "string" && error.message.trim()
|
||||
? error.message.trim()
|
||||
: "OpenAI response failed.";
|
||||
const code =
|
||||
typeof error?.code === "string" && error.code.trim()
|
||||
? error.code.trim()
|
||||
: null;
|
||||
return code ? `OpenAI error (${code}): ${message}` : message;
|
||||
const message =
|
||||
typeof error?.message === "string" && error.message.trim()
|
||||
? error.message.trim()
|
||||
: "OpenAI response failed.";
|
||||
const code =
|
||||
typeof error?.code === "string" && error.code.trim()
|
||||
? error.code.trim()
|
||||
: null;
|
||||
return code ? `OpenAI error (${code}): ${message}` : message;
|
||||
}
|
||||
|
||||
function abortError(): Error {
|
||||
const err = new Error("Stream aborted.");
|
||||
err.name = "AbortError";
|
||||
return err;
|
||||
const err = new Error("Stream aborted.");
|
||||
err.name = "AbortError";
|
||||
return err;
|
||||
}
|
||||
|
||||
function throwIfAborted(signal?: AbortSignal) {
|
||||
if (signal?.aborted) throw abortError();
|
||||
if (signal?.aborted) throw abortError();
|
||||
}
|
||||
|
||||
function responseInstructions(systemPrompt: string, includeReminder: boolean) {
|
||||
return includeReminder
|
||||
? `${systemPrompt}\n\n${COURTLISTENER_CITATION_REMINDER}`
|
||||
: systemPrompt;
|
||||
}
|
||||
|
||||
function shouldAppendCourtlistenerCitationReminder(call: NormalizedToolCall) {
|
||||
return COURTLISTENER_CITATION_REMINDER_TOOL_NAMES.has(call.name);
|
||||
}
|
||||
|
||||
async function createResponse(params: {
|
||||
model: string;
|
||||
input: ResponseInputItem[];
|
||||
instructions?: string;
|
||||
tools?: ResponseFunctionTool[];
|
||||
stream?: boolean;
|
||||
maxTokens?: number;
|
||||
previousResponseId?: string;
|
||||
reasoningSummary?: boolean;
|
||||
apiKey: string;
|
||||
signal?: AbortSignal;
|
||||
model: string;
|
||||
input: ResponseInputItem[];
|
||||
instructions?: string;
|
||||
tools?: ResponseFunctionTool[];
|
||||
stream?: boolean;
|
||||
maxTokens?: number;
|
||||
previousResponseId?: string;
|
||||
reasoningSummary?: boolean;
|
||||
apiKey: string;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<Response> {
|
||||
const response = await fetch(OPENAI_RESPONSES_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${params.apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: params.model,
|
||||
instructions: params.instructions || undefined,
|
||||
input: params.input,
|
||||
tools: params.tools?.length ? params.tools : undefined,
|
||||
stream: params.stream,
|
||||
max_output_tokens: params.maxTokens ?? MAX_OUTPUT_TOKENS,
|
||||
previous_response_id: params.previousResponseId,
|
||||
reasoning: params.reasoningSummary
|
||||
? { summary: "auto" }
|
||||
: undefined,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
const response = await fetch(OPENAI_RESPONSES_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${params.apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: params.model,
|
||||
instructions: params.instructions || undefined,
|
||||
input: params.input,
|
||||
tools: params.tools?.length ? params.tools : undefined,
|
||||
stream: params.stream,
|
||||
max_output_tokens: params.maxTokens ?? MAX_OUTPUT_TOKENS,
|
||||
previous_response_id: params.previousResponseId,
|
||||
reasoning: params.reasoningSummary ? { summary: "auto" } : undefined,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
const err = new Error(
|
||||
`OpenAI request failed (${response.status}): ${text || response.statusText}`,
|
||||
);
|
||||
(err as { status?: number }).status = response.status;
|
||||
throw err;
|
||||
}
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
const err = new Error(
|
||||
`OpenAI request failed (${response.status}): ${text || response.statusText}`,
|
||||
);
|
||||
(err as { status?: number }).status = response.status;
|
||||
throw err;
|
||||
}
|
||||
|
||||
return response;
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function streamOpenAI(
|
||||
params: StreamChatParams,
|
||||
params: StreamChatParams,
|
||||
): Promise<StreamChatResult> {
|
||||
const {
|
||||
model,
|
||||
systemPrompt,
|
||||
tools = [],
|
||||
callbacks = {},
|
||||
runTools,
|
||||
apiKeys,
|
||||
enableThinking,
|
||||
} = params;
|
||||
const maxIter = params.maxIterations ?? 10;
|
||||
const key = apiKey(apiKeys?.openai);
|
||||
const responseTools = toResponseTools(tools);
|
||||
let input = toResponseInput(params.messages);
|
||||
let previousResponseId: string | undefined;
|
||||
let fullText = "";
|
||||
const hasTools = responseTools.length > 0;
|
||||
const {
|
||||
model,
|
||||
systemPrompt,
|
||||
tools = [],
|
||||
callbacks = {},
|
||||
runTools,
|
||||
apiKeys,
|
||||
enableThinking,
|
||||
} = params;
|
||||
const maxIter = params.maxIterations ?? 10;
|
||||
const key = apiKey(apiKeys?.openai);
|
||||
const responseTools = toResponseTools(tools);
|
||||
let input = toResponseInput(params.messages);
|
||||
let previousResponseId: string | undefined;
|
||||
let fullText = "";
|
||||
let needsCourtlistenerCitationReminder = false;
|
||||
const rawStreamRecorder = createRawLlmStreamRecorder({
|
||||
provider: "openai",
|
||||
model,
|
||||
});
|
||||
|
||||
try {
|
||||
for (let iter = 0; iter < maxIter; iter++) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
const response = await createResponse({
|
||||
model,
|
||||
instructions: responseInstructions(
|
||||
systemPrompt,
|
||||
needsCourtlistenerCitationReminder,
|
||||
),
|
||||
input,
|
||||
tools: responseTools,
|
||||
stream: true,
|
||||
previousResponseId,
|
||||
reasoningSummary: !!enableThinking,
|
||||
apiKey: key,
|
||||
signal: params.abortSignal,
|
||||
});
|
||||
if (!response.body) throw new Error("OpenAI response had no body");
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
const toolCalls: NormalizedToolCall[] = [];
|
||||
const startedToolCallIds = new Set<string>();
|
||||
let buffer = "";
|
||||
let sawReasoning = false;
|
||||
|
||||
while (true) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
const response = await createResponse({
|
||||
model,
|
||||
instructions: iter === 0 ? systemPrompt : undefined,
|
||||
input,
|
||||
tools: responseTools,
|
||||
stream: true,
|
||||
previousResponseId,
|
||||
reasoningSummary: !!enableThinking,
|
||||
apiKey: key,
|
||||
signal: params.abortSignal,
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const decoded = decoder.decode(value, { stream: true });
|
||||
logRawLlmStream({
|
||||
provider: "openai",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "sse_chunk",
|
||||
payload: decoded,
|
||||
});
|
||||
if (!response.body) throw new Error("OpenAI response had no body");
|
||||
rawStreamRecorder?.record({
|
||||
iteration: iter,
|
||||
label: "sse_chunk",
|
||||
payload: decoded,
|
||||
});
|
||||
buffer += decoded;
|
||||
const extracted = extractSseJson(buffer);
|
||||
buffer = extracted.rest;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
const toolCalls: NormalizedToolCall[] = [];
|
||||
const startedToolCallIds = new Set<string>();
|
||||
let buffer = "";
|
||||
let pendingText = "";
|
||||
let sawReasoning = false;
|
||||
for (const event of extracted.events as ResponseStreamEvent[]) {
|
||||
logRawLlmStream({
|
||||
provider: "openai",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "sse_event",
|
||||
payload: event,
|
||||
});
|
||||
rawStreamRecorder?.record({
|
||||
iteration: iter,
|
||||
label: "sse_event",
|
||||
payload: event,
|
||||
});
|
||||
|
||||
while (true) {
|
||||
throwIfAborted(params.abortSignal);
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
const failureMessage = openAIStreamFailureMessage(event);
|
||||
if (failureMessage) {
|
||||
throw new Error(failureMessage);
|
||||
}
|
||||
|
||||
const decoded = decoder.decode(value, { stream: true });
|
||||
logRawLlmStream({
|
||||
provider: "openai",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "sse_chunk",
|
||||
payload: decoded,
|
||||
});
|
||||
buffer += decoded;
|
||||
const extracted = extractSseJson(buffer);
|
||||
buffer = extracted.rest;
|
||||
if (event.response?.id) {
|
||||
previousResponseId = event.response.id;
|
||||
}
|
||||
|
||||
for (const event of extracted.events as ResponseStreamEvent[]) {
|
||||
logRawLlmStream({
|
||||
provider: "openai",
|
||||
model,
|
||||
iteration: iter,
|
||||
label: "sse_event",
|
||||
payload: event,
|
||||
});
|
||||
if (
|
||||
event.type === "response.reasoning_summary_text.delta" &&
|
||||
typeof event.delta === "string"
|
||||
) {
|
||||
sawReasoning = true;
|
||||
callbacks.onReasoningDelta?.(event.delta);
|
||||
}
|
||||
|
||||
const failureMessage = openAIStreamFailureMessage(event);
|
||||
if (failureMessage) {
|
||||
throw new Error(failureMessage);
|
||||
}
|
||||
if (
|
||||
event.type === "response.output_text.delta" &&
|
||||
typeof event.delta === "string"
|
||||
) {
|
||||
fullText += event.delta;
|
||||
callbacks.onContentDelta?.(event.delta);
|
||||
}
|
||||
|
||||
if (event.response?.id) {
|
||||
previousResponseId = event.response.id;
|
||||
}
|
||||
if (
|
||||
event.type === "response.output_item.added" &&
|
||||
event.item?.type === "function_call"
|
||||
) {
|
||||
const call = parseFunctionCall(event.item);
|
||||
startedToolCallIds.add(call.id);
|
||||
callbacks.onToolCallStart?.(call);
|
||||
}
|
||||
|
||||
if (
|
||||
event.type === "response.reasoning_summary_text.delta" &&
|
||||
typeof event.delta === "string"
|
||||
) {
|
||||
sawReasoning = true;
|
||||
callbacks.onReasoningDelta?.(event.delta);
|
||||
}
|
||||
|
||||
if (
|
||||
event.type === "response.output_text.delta" &&
|
||||
typeof event.delta === "string"
|
||||
) {
|
||||
if (hasTools) {
|
||||
pendingText += event.delta;
|
||||
} else {
|
||||
fullText += event.delta;
|
||||
callbacks.onContentDelta?.(event.delta);
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
event.type === "response.output_item.added" &&
|
||||
event.item?.type === "function_call"
|
||||
) {
|
||||
const call = parseFunctionCall(event.item);
|
||||
startedToolCallIds.add(call.id);
|
||||
callbacks.onToolCallStart?.(call);
|
||||
}
|
||||
|
||||
if (
|
||||
event.type === "response.output_item.done" &&
|
||||
event.item?.type === "function_call"
|
||||
) {
|
||||
const call = parseFunctionCall(event.item);
|
||||
if (!startedToolCallIds.has(call.id)) {
|
||||
callbacks.onToolCallStart?.(call);
|
||||
}
|
||||
toolCalls.push(call);
|
||||
}
|
||||
if (
|
||||
event.type === "response.output_item.done" &&
|
||||
event.item?.type === "function_call"
|
||||
) {
|
||||
const call = parseFunctionCall(event.item);
|
||||
if (!startedToolCallIds.has(call.id)) {
|
||||
callbacks.onToolCallStart?.(call);
|
||||
}
|
||||
toolCalls.push(call);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sawReasoning) callbacks.onReasoningBlockEnd?.();
|
||||
throwIfAborted(params.abortSignal);
|
||||
if (sawReasoning) callbacks.onReasoningBlockEnd?.();
|
||||
throwIfAborted(params.abortSignal);
|
||||
|
||||
if (!toolCalls.length || !runTools) {
|
||||
if (pendingText) {
|
||||
fullText += pendingText;
|
||||
callbacks.onContentDelta?.(pendingText);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (!toolCalls.length || !runTools) {
|
||||
break;
|
||||
}
|
||||
|
||||
const results = await runTools(toolCalls);
|
||||
throwIfAborted(params.abortSignal);
|
||||
input = results.map((result) => ({
|
||||
type: "function_call_output",
|
||||
call_id: result.tool_use_id,
|
||||
output: result.content,
|
||||
}));
|
||||
if (toolCalls.some(shouldAppendCourtlistenerCitationReminder)) {
|
||||
needsCourtlistenerCitationReminder = true;
|
||||
}
|
||||
|
||||
const results = await runTools(toolCalls);
|
||||
throwIfAborted(params.abortSignal);
|
||||
input = results.map((result) => ({
|
||||
type: "function_call_output",
|
||||
call_id: result.tool_use_id,
|
||||
output: result.content,
|
||||
}));
|
||||
}
|
||||
|
||||
await rawStreamRecorder?.flush("completed");
|
||||
return { fullText };
|
||||
} catch (error) {
|
||||
await rawStreamRecorder?.flush("error", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function completeOpenAIText(params: {
|
||||
model: string;
|
||||
systemPrompt?: string;
|
||||
user: string;
|
||||
maxTokens?: number;
|
||||
apiKeys?: { openai?: string | null };
|
||||
model: string;
|
||||
systemPrompt?: string;
|
||||
user: string;
|
||||
maxTokens?: number;
|
||||
apiKeys?: { openai?: string | null };
|
||||
}): Promise<string> {
|
||||
const response = await createResponse({
|
||||
model: params.model,
|
||||
instructions: params.systemPrompt,
|
||||
input: [{ role: "user", content: params.user }],
|
||||
maxTokens: params.maxTokens ?? 512,
|
||||
apiKey: apiKey(params.apiKeys?.openai),
|
||||
});
|
||||
const json = (await response.json()) as {
|
||||
output_text?: string;
|
||||
output?: {
|
||||
content?: { type?: string; text?: string }[];
|
||||
}[];
|
||||
};
|
||||
const response = await createResponse({
|
||||
model: params.model,
|
||||
instructions: params.systemPrompt,
|
||||
input: [{ role: "user", content: params.user }],
|
||||
maxTokens: params.maxTokens ?? 512,
|
||||
apiKey: apiKey(params.apiKeys?.openai),
|
||||
});
|
||||
const json = (await response.json()) as {
|
||||
output_text?: string;
|
||||
output?: {
|
||||
content?: { type?: string; text?: string }[];
|
||||
}[];
|
||||
};
|
||||
|
||||
if (typeof json.output_text === "string") return json.output_text;
|
||||
if (typeof json.output_text === "string") return json.output_text;
|
||||
|
||||
return (
|
||||
json.output
|
||||
?.flatMap((item) => item.content ?? [])
|
||||
.filter((content) => content.type === "output_text")
|
||||
.map((content) => content.text ?? "")
|
||||
.join("") ?? ""
|
||||
);
|
||||
return (
|
||||
json.output
|
||||
?.flatMap((item) => item.content ?? [])
|
||||
.filter((content) => content.type === "output_text")
|
||||
.map((content) => content.text ?? "")
|
||||
.join("") ?? ""
|
||||
);
|
||||
}
|
||||
|
||||
export type { NormalizedToolResult };
|
||||
|
|
|
|||
|
|
@ -1,14 +1,170 @@
|
|||
export function logRawLlmStream(args: {
|
||||
provider: string;
|
||||
model: string;
|
||||
iteration: number;
|
||||
label: string;
|
||||
payload: unknown;
|
||||
}) {
|
||||
if (process.env.LOG_RAW_LLM_STREAM !== "true") return;
|
||||
import { randomUUID } from "crypto";
|
||||
import { mkdir, open } from "fs/promises";
|
||||
import type { FileHandle } from "fs/promises";
|
||||
import path from "path";
|
||||
|
||||
console.log(
|
||||
`[raw-llm-stream:${args.provider}:${args.model}:iter-${args.iteration}] ${args.label}`,
|
||||
);
|
||||
console.dir(args.payload, { depth: null, maxArrayLength: null });
|
||||
type RawStreamEntry = {
|
||||
timestamp: string;
|
||||
iteration: number;
|
||||
label: string;
|
||||
payload: unknown;
|
||||
};
|
||||
|
||||
function rawStreamLogDir(): string | null {
|
||||
return process.env.RAW_LLM_STREAM_LOG_DIR?.trim() || null;
|
||||
}
|
||||
|
||||
function safeFilePart(value: string) {
|
||||
return value.replace(/[^a-zA-Z0-9._-]+/g, "-").replace(/^-+|-+$/g, "");
|
||||
}
|
||||
|
||||
function stringifyJson(value: unknown) {
|
||||
const seen = new WeakSet<object>();
|
||||
return JSON.stringify(value, (_key, innerValue: unknown) => {
|
||||
if (typeof innerValue === "bigint") return innerValue.toString();
|
||||
if (innerValue instanceof Error) {
|
||||
return {
|
||||
name: innerValue.name,
|
||||
message: innerValue.message,
|
||||
stack: innerValue.stack,
|
||||
};
|
||||
}
|
||||
if (innerValue && typeof innerValue === "object") {
|
||||
if (seen.has(innerValue)) return "[Circular]";
|
||||
seen.add(innerValue);
|
||||
}
|
||||
return innerValue;
|
||||
});
|
||||
}
|
||||
|
||||
export function logRawLlmStream(args: {
|
||||
provider: string;
|
||||
model: string;
|
||||
iteration: number;
|
||||
label: string;
|
||||
payload: unknown;
|
||||
}) {
|
||||
if (process.env.LOG_RAW_LLM_STREAM !== "true") return;
|
||||
|
||||
console.log(
|
||||
`[raw-llm-stream:${args.provider}:${args.model}:iter-${args.iteration}] ${args.label}`,
|
||||
);
|
||||
console.dir(args.payload, { depth: null, maxArrayLength: null });
|
||||
}
|
||||
|
||||
export function createRawLlmStreamRecorder(args: {
|
||||
provider: string;
|
||||
model: string;
|
||||
}) {
|
||||
const dir = rawStreamLogDir();
|
||||
if (!dir) return null;
|
||||
const logDir = dir;
|
||||
|
||||
const startedAt = new Date();
|
||||
const id = randomUUID();
|
||||
const filename = [
|
||||
safeFilePart(args.provider),
|
||||
safeFilePart(args.model),
|
||||
startedAt.toISOString().replace(/[:.]/g, "-"),
|
||||
id,
|
||||
].join("-");
|
||||
const filePath = path.join(logDir, `${filename}.raw-llm-stream.json`);
|
||||
let fileHandle: FileHandle | null = null;
|
||||
let writeChain: Promise<void> = Promise.resolve();
|
||||
let writeError: unknown = null;
|
||||
let wroteEntry = false;
|
||||
let finalized = false;
|
||||
|
||||
async function ensureOpen() {
|
||||
if (fileHandle) return fileHandle;
|
||||
await mkdir(logDir, { recursive: true });
|
||||
fileHandle = await open(filePath, "w");
|
||||
const header = {
|
||||
id,
|
||||
provider: args.provider,
|
||||
model: args.model,
|
||||
startedAt: startedAt.toISOString(),
|
||||
};
|
||||
await fileHandle.write(`${stringifyJson(header)?.slice(0, -1)},"entries":[`);
|
||||
return fileHandle;
|
||||
}
|
||||
|
||||
function queueWrite(action: () => Promise<void>) {
|
||||
writeChain = writeChain
|
||||
.then(action)
|
||||
.catch((error) => {
|
||||
writeError = error;
|
||||
console.error("[raw-llm-stream] failed to write log file", {
|
||||
filePath,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
record(entry: Omit<RawStreamEntry, "timestamp">) {
|
||||
if (finalized) return;
|
||||
const rawEntry = {
|
||||
timestamp: new Date().toISOString(),
|
||||
...entry,
|
||||
};
|
||||
queueWrite(async () => {
|
||||
const handle = await ensureOpen();
|
||||
const serialized =
|
||||
stringifyJson(rawEntry) ??
|
||||
stringifyJson({
|
||||
timestamp: rawEntry.timestamp,
|
||||
iteration: rawEntry.iteration,
|
||||
label: rawEntry.label,
|
||||
payload: "[Unserializable payload]",
|
||||
});
|
||||
await handle.write(`${wroteEntry ? "," : ""}${serialized}`);
|
||||
wroteEntry = true;
|
||||
});
|
||||
},
|
||||
async flush(status: "completed" | "error", error?: unknown) {
|
||||
if (finalized) return;
|
||||
finalized = true;
|
||||
const errorPayload =
|
||||
error instanceof Error
|
||||
? {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack,
|
||||
}
|
||||
: error
|
||||
? { message: String(error) }
|
||||
: undefined;
|
||||
|
||||
const footer = {
|
||||
finishedAt: new Date().toISOString(),
|
||||
status,
|
||||
error: errorPayload,
|
||||
};
|
||||
|
||||
try {
|
||||
await writeChain;
|
||||
const handle = await ensureOpen();
|
||||
await handle.write(`],${stringifyJson(footer)?.slice(1)}\n`);
|
||||
} catch (writeError) {
|
||||
console.error("[raw-llm-stream] failed to write log file", {
|
||||
filePath,
|
||||
error:
|
||||
writeError instanceof Error
|
||||
? writeError.message
|
||||
: String(writeError),
|
||||
});
|
||||
} finally {
|
||||
if (fileHandle) {
|
||||
await fileHandle.close().catch(() => {});
|
||||
fileHandle = null;
|
||||
}
|
||||
if (writeError) {
|
||||
console.error("[raw-llm-stream] log file may be incomplete", {
|
||||
filePath,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,3 +51,31 @@ export async function getUserApiKeys(
|
|||
const client = db ?? createServerSupabase();
|
||||
return getStoredUserApiKeys(userId, client);
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the user has US legal research (CourtListener) tools enabled in
|
||||
* chat. Controlled by the Features > Legal Research > Jurisdiction > US
|
||||
* toggle in account settings. Defaults to enabled — both when the user has
|
||||
* no profile row yet and when the column is missing (migration not applied),
|
||||
* so existing behaviour is preserved on partially-migrated deployments.
|
||||
*/
|
||||
export async function getLegalResearchUsEnabled(
|
||||
userId: string,
|
||||
db?: ReturnType<typeof createServerSupabase>,
|
||||
): Promise<boolean> {
|
||||
const client = db ?? createServerSupabase();
|
||||
try {
|
||||
const { data, error } = await client
|
||||
.from("user_profiles")
|
||||
.select("legal_research_us")
|
||||
.eq("user_id", userId)
|
||||
.maybeSingle();
|
||||
if (error || !data) return true;
|
||||
return (
|
||||
(data as { legal_research_us?: boolean | null })
|
||||
.legal_research_us !== false
|
||||
);
|
||||
} catch {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@ import {
|
|||
type ChatMessage,
|
||||
} from "../lib/chatTools";
|
||||
import { completeText } from "../lib/llm";
|
||||
import { getUserApiKeys, getUserModelSettings } from "../lib/userSettings";
|
||||
import {
|
||||
getLegalResearchUsEnabled,
|
||||
getUserApiKeys,
|
||||
getUserModelSettings,
|
||||
} from "../lib/userSettings";
|
||||
import { checkProjectAccess } from "../lib/access";
|
||||
import { safeErrorLog, safeErrorMessage } from "../lib/safeError";
|
||||
|
||||
|
|
@ -552,7 +556,14 @@ chatRouter.post("/", requireAuth, async (req, res) => {
|
|||
db,
|
||||
docIndex,
|
||||
);
|
||||
const apiMessages = buildMessages(enrichedMessages, docAvailability);
|
||||
const legalResearchUs = await getLegalResearchUsEnabled(userId, db);
|
||||
const apiMessages = buildMessages(
|
||||
enrichedMessages,
|
||||
docAvailability,
|
||||
undefined,
|
||||
undefined,
|
||||
legalResearchUs,
|
||||
);
|
||||
|
||||
const workflowStore = await buildWorkflowStore(userId, userEmail, db);
|
||||
|
||||
|
|
@ -588,6 +599,7 @@ chatRouter.post("/", requireAuth, async (req, res) => {
|
|||
db,
|
||||
write,
|
||||
workflowStore,
|
||||
includeResearchTools: legalResearchUs,
|
||||
model,
|
||||
apiKeys,
|
||||
signal: streamAbort.signal,
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ documentsRouter.get("/:documentId/versions", requireAuth, async (req, res) => {
|
|||
const { data: rows } = await db
|
||||
.from("document_versions")
|
||||
.select(
|
||||
"id, version_number, source, created_at, filename, file_type, size_bytes, page_count",
|
||||
"id, version_number, source, created_at, filename, file_type, size_bytes, page_count, deleted_at, deleted_by",
|
||||
)
|
||||
.eq("document_id", documentId)
|
||||
.order("created_at", { ascending: true });
|
||||
|
|
@ -433,19 +433,12 @@ documentsRouter.post(
|
|||
});
|
||||
}
|
||||
|
||||
const targetActive = await loadActiveVersion(documentId, db);
|
||||
const targetType = targetActive?.file_type ?? "";
|
||||
const active = await loadActiveVersion(sourceDocumentId, db);
|
||||
if (!active)
|
||||
return void res
|
||||
.status(404)
|
||||
.json({ detail: "Source document has no active version." });
|
||||
const sourceType = active.file_type ?? "";
|
||||
if (targetType && sourceType && targetType !== sourceType) {
|
||||
return void res.status(400).json({
|
||||
detail: `Source document type (${sourceType}) does not match document type (${targetType}).`,
|
||||
});
|
||||
}
|
||||
|
||||
const bytes = await downloadFile(active.storage_path);
|
||||
if (!bytes)
|
||||
|
|
@ -603,8 +596,6 @@ documentsRouter.post(
|
|||
if (!access.ok)
|
||||
return void res.status(404).json({ detail: "Document not found" });
|
||||
|
||||
// Reject if the uploaded file's extension doesn't match the document's
|
||||
// declared type — otherwise every downstream viewer/extractor breaks.
|
||||
const suffix = file.originalname.includes(".")
|
||||
? file.originalname.split(".").pop()!.toLowerCase()
|
||||
: "";
|
||||
|
|
@ -614,14 +605,6 @@ documentsRouter.post(
|
|||
});
|
||||
}
|
||||
|
||||
const currentActive = await loadActiveVersion(documentId, db);
|
||||
const expectedType = currentActive?.file_type ?? "";
|
||||
if (expectedType && expectedType !== suffix) {
|
||||
return void res.status(400).json({
|
||||
detail: `Uploaded file type (${suffix}) does not match document type (${expectedType}).`,
|
||||
});
|
||||
}
|
||||
|
||||
// Peg the new version into a predictable /versions/:id path under the
|
||||
// existing document folder so ops can spot the history in storage.
|
||||
const versionSlug = crypto.randomUUID().replace(/-/g, "");
|
||||
|
|
@ -777,6 +760,7 @@ documentsRouter.patch(
|
|||
.update({ filename })
|
||||
.eq("id", versionId)
|
||||
.eq("document_id", documentId)
|
||||
.is("deleted_at", null)
|
||||
.select(
|
||||
"id, version_number, source, created_at, filename, file_type, size_bytes, page_count",
|
||||
)
|
||||
|
|
@ -788,6 +772,160 @@ documentsRouter.patch(
|
|||
},
|
||||
);
|
||||
|
||||
// PUT /single-documents/:documentId/versions/:versionId/file
|
||||
// Replace the file bytes and metadata for an existing version while keeping
|
||||
// its version number and id. This is destructive and owner-only.
|
||||
documentsRouter.put(
|
||||
"/:documentId/versions/:versionId/file",
|
||||
requireAuth,
|
||||
singleFileUpload("file"),
|
||||
async (req, res) => {
|
||||
const userId = res.locals.userId as string;
|
||||
const userEmail = res.locals.userEmail as string | undefined;
|
||||
const { documentId, versionId } = req.params;
|
||||
const db = createServerSupabase();
|
||||
|
||||
const file = req.file;
|
||||
if (!file)
|
||||
return void res.status(400).json({ detail: "file is required" });
|
||||
|
||||
const { data: doc } = await db
|
||||
.from("documents")
|
||||
.select("id, user_id, project_id")
|
||||
.eq("id", documentId)
|
||||
.single();
|
||||
if (!doc)
|
||||
return void res.status(404).json({ detail: "Document not found" });
|
||||
const access = await ensureDocAccess(doc, userId, userEmail, db);
|
||||
if (!access.ok || !access.isOwner)
|
||||
return void res.status(404).json({ detail: "Document not found" });
|
||||
|
||||
const { data: target, error: targetErr } = await db
|
||||
.from("document_versions")
|
||||
.select("id, storage_path, pdf_storage_path, file_type, deleted_at")
|
||||
.eq("id", versionId)
|
||||
.eq("document_id", documentId)
|
||||
.single();
|
||||
if (targetErr || !target)
|
||||
return void res.status(404).json({ detail: "Version not found" });
|
||||
if (target.deleted_at)
|
||||
return void res.status(400).json({ detail: "Version is deleted." });
|
||||
|
||||
const suffix = file.originalname.includes(".")
|
||||
? file.originalname.split(".").pop()!.toLowerCase()
|
||||
: "";
|
||||
if (!ALLOWED_TYPES.has(suffix)) {
|
||||
return void res.status(400).json({
|
||||
detail: `Unsupported file type: ${suffix}. Allowed: pdf, docx, doc`,
|
||||
});
|
||||
}
|
||||
if (target.file_type && target.file_type !== suffix) {
|
||||
return void res.status(400).json({
|
||||
detail: `Uploaded file type (${suffix}) does not match version type (${target.file_type}).`,
|
||||
});
|
||||
}
|
||||
|
||||
const versionSlug = crypto.randomUUID().replace(/-/g, "");
|
||||
const key = versionStorageKey(
|
||||
userId,
|
||||
documentId,
|
||||
versionSlug,
|
||||
file.originalname,
|
||||
);
|
||||
const contentType =
|
||||
suffix === "pdf"
|
||||
? "application/pdf"
|
||||
: "application/vnd.openxmlformats-officedocument.wordprocessingml.document";
|
||||
|
||||
try {
|
||||
await uploadFile(
|
||||
key,
|
||||
file.buffer.buffer.slice(
|
||||
file.buffer.byteOffset,
|
||||
file.buffer.byteOffset + file.buffer.byteLength,
|
||||
) as ArrayBuffer,
|
||||
contentType,
|
||||
);
|
||||
} catch (e) {
|
||||
console.error("[versions/replace] storage write failed", e);
|
||||
return void res
|
||||
.status(500)
|
||||
.json({ detail: "Failed to upload replacement version." });
|
||||
}
|
||||
|
||||
let pdfStoragePath: string | null = null;
|
||||
if (suffix === "docx" || suffix === "doc") {
|
||||
try {
|
||||
const pdfBuf = await docxToPdf(file.buffer);
|
||||
const pdfKey = `converted-pdfs/${userId}/${documentId}/${versionSlug}.pdf`;
|
||||
await uploadFile(
|
||||
pdfKey,
|
||||
pdfBuf.buffer.slice(
|
||||
pdfBuf.byteOffset,
|
||||
pdfBuf.byteOffset + pdfBuf.byteLength,
|
||||
) as ArrayBuffer,
|
||||
"application/pdf",
|
||||
);
|
||||
pdfStoragePath = pdfKey;
|
||||
} catch (err) {
|
||||
console.error(
|
||||
`[versions/replace] DOCX→PDF conversion failed for ${file.originalname}:`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
} else if (suffix === "pdf") {
|
||||
pdfStoragePath = key;
|
||||
}
|
||||
|
||||
const rawBuf = file.buffer.buffer.slice(
|
||||
file.buffer.byteOffset,
|
||||
file.buffer.byteOffset + file.buffer.byteLength,
|
||||
) as ArrayBuffer;
|
||||
const pageCount = suffix === "pdf" ? await countPdfPages(rawBuf) : null;
|
||||
const requestedFilename =
|
||||
typeof req.body?.filename === "string" && req.body.filename.trim()
|
||||
? req.body.filename.trim().slice(0, 200)
|
||||
: file.originalname;
|
||||
const uploadedAt = new Date().toISOString();
|
||||
|
||||
const { data: updated, error: updateErr } = await db
|
||||
.from("document_versions")
|
||||
.update({
|
||||
storage_path: key,
|
||||
pdf_storage_path: pdfStoragePath,
|
||||
filename: requestedFilename,
|
||||
file_type: suffix,
|
||||
size_bytes: file.buffer.byteLength,
|
||||
page_count: pageCount,
|
||||
created_at: uploadedAt,
|
||||
})
|
||||
.eq("id", versionId)
|
||||
.eq("document_id", documentId)
|
||||
.select(
|
||||
"id, version_number, source, created_at, filename, file_type, size_bytes, page_count",
|
||||
)
|
||||
.single();
|
||||
if (updateErr || !updated) {
|
||||
await Promise.all(
|
||||
[key, pdfStoragePath]
|
||||
.filter((path): path is string => !!path)
|
||||
.map((path) => deleteFile(path).catch(() => {})),
|
||||
);
|
||||
return void res.status(500).json({
|
||||
detail: updateErr?.message ?? "Failed to replace version.",
|
||||
});
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
[target.storage_path, target.pdf_storage_path]
|
||||
.filter((path): path is string => !!path)
|
||||
.map((path) => deleteFile(path).catch(() => {})),
|
||||
);
|
||||
|
||||
res.json(updated);
|
||||
},
|
||||
);
|
||||
|
||||
// DELETE /single-documents/:documentId/versions/:versionId
|
||||
// Delete one version. The last remaining version cannot be deleted; if the
|
||||
// deleted version is current, the newest remaining version becomes current.
|
||||
|
|
@ -813,8 +951,11 @@ documentsRouter.delete(
|
|||
|
||||
const { data: versions, error: versionsErr } = await db
|
||||
.from("document_versions")
|
||||
.select("id, storage_path, pdf_storage_path, version_number, created_at")
|
||||
.eq("document_id", documentId);
|
||||
.select(
|
||||
"id, storage_path, pdf_storage_path, version_number, created_at, deleted_at",
|
||||
)
|
||||
.eq("document_id", documentId)
|
||||
.is("deleted_at", null);
|
||||
if (versionsErr) {
|
||||
return void res.status(500).json({ detail: versionsErr.message });
|
||||
}
|
||||
|
|
@ -825,6 +966,7 @@ documentsRouter.delete(
|
|||
pdf_storage_path: string | null;
|
||||
version_number: number | null;
|
||||
created_at: string | null;
|
||||
deleted_at?: string | null;
|
||||
}[];
|
||||
const target = rows.find((row) => row.id === versionId);
|
||||
if (!target)
|
||||
|
|
@ -850,6 +992,7 @@ documentsRouter.delete(
|
|||
doc.current_version_id === versionId
|
||||
? (remaining[0]?.id ?? null)
|
||||
: doc.current_version_id;
|
||||
const deletedAt = new Date().toISOString();
|
||||
|
||||
if (doc.current_version_id === versionId) {
|
||||
const { error: updateErr } = await db
|
||||
|
|
@ -866,9 +1009,15 @@ documentsRouter.delete(
|
|||
|
||||
const { error: deleteErr } = await db
|
||||
.from("document_versions")
|
||||
.delete()
|
||||
.update({
|
||||
storage_path: null,
|
||||
pdf_storage_path: null,
|
||||
deleted_at: deletedAt,
|
||||
deleted_by: userId,
|
||||
})
|
||||
.eq("id", versionId)
|
||||
.eq("document_id", documentId);
|
||||
.eq("document_id", documentId)
|
||||
.is("deleted_at", null);
|
||||
if (deleteErr) {
|
||||
return void res.status(500).json({ detail: deleteErr.message });
|
||||
}
|
||||
|
|
@ -882,6 +1031,7 @@ documentsRouter.delete(
|
|||
res.json({
|
||||
deleted_version_id: versionId,
|
||||
current_version_id: nextCurrentVersionId,
|
||||
deleted_at: deletedAt,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ downloadsRouter.get("/:token", requireAuth, async (req, res) => {
|
|||
.from("document_versions")
|
||||
.select("id, document_id")
|
||||
.eq("storage_path", info.path)
|
||||
.is("deleted_at", null)
|
||||
.maybeSingle();
|
||||
if (byStoragePath) {
|
||||
version = byStoragePath as { id: string; document_id: string };
|
||||
|
|
|
|||
|
|
@ -15,7 +15,10 @@ import {
|
|||
PROJECT_EXTRA_TOOLS,
|
||||
type ChatMessage,
|
||||
} from "../lib/chatTools";
|
||||
import { getUserApiKeys } from "../lib/userSettings";
|
||||
import {
|
||||
getLegalResearchUsEnabled,
|
||||
getUserApiKeys,
|
||||
} from "../lib/userSettings";
|
||||
import { checkProjectAccess } from "../lib/access";
|
||||
import { safeErrorLog, safeErrorMessage } from "../lib/safeError";
|
||||
|
||||
|
|
@ -141,10 +144,13 @@ projectChatRouter.post("/", requireAuth, async (req, res) => {
|
|||
systemPromptExtra += `\n\nUSER-ATTACHED DOCUMENTS FOR THIS TURN:\nThe user has attached the following document(s) directly to their latest message. Treat these as the primary focus of the request unless their message clearly says otherwise.\n${lines.join("\n")}`;
|
||||
}
|
||||
|
||||
const legalResearchUs = await getLegalResearchUsEnabled(userId, db);
|
||||
const apiMessages = buildMessages(
|
||||
messagesForLLM,
|
||||
docAvailability,
|
||||
systemPromptExtra,
|
||||
undefined,
|
||||
legalResearchUs,
|
||||
);
|
||||
|
||||
const workflowStore = await buildWorkflowStore(userId, userEmail, db);
|
||||
|
|
@ -176,6 +182,7 @@ projectChatRouter.post("/", requireAuth, async (req, res) => {
|
|||
write,
|
||||
extraTools: PROJECT_EXTRA_TOOLS,
|
||||
workflowStore,
|
||||
includeResearchTools: legalResearchUs,
|
||||
model,
|
||||
apiKeys,
|
||||
signal: streamAbort.signal,
|
||||
|
|
|
|||
|
|
@ -60,6 +60,59 @@ async function deleteProjectDocumentsAndVersionFiles(
|
|||
return error ?? null;
|
||||
}
|
||||
|
||||
async function attachDocumentOwnerLabels(
|
||||
db: ReturnType<typeof createServerSupabase>,
|
||||
docs: { user_id?: string | null }[],
|
||||
) {
|
||||
const ownerIds = docs
|
||||
.map((doc) => doc.user_id)
|
||||
.filter((id): id is string => typeof id === "string" && id.length > 0)
|
||||
.filter((id, index, arr) => arr.indexOf(id) === index);
|
||||
if (ownerIds.length === 0) return;
|
||||
|
||||
const emailByUserId = new Map<string, string>();
|
||||
const userResults = await Promise.allSettled(
|
||||
ownerIds.map(async (id) => {
|
||||
const { data, error } = await db.auth.admin.getUserById(id);
|
||||
if (error) throw error;
|
||||
return { id, email: data.user?.email ?? null };
|
||||
}),
|
||||
);
|
||||
for (const result of userResults) {
|
||||
if (result.status === "fulfilled" && result.value.email) {
|
||||
emailByUserId.set(result.value.id, result.value.email);
|
||||
}
|
||||
}
|
||||
|
||||
const displayNameByUserId = new Map<string, string>();
|
||||
const { data: profiles, error: profilesError } = await db
|
||||
.from("user_profiles")
|
||||
.select("user_id, display_name")
|
||||
.in("user_id", ownerIds);
|
||||
if (profilesError) {
|
||||
console.warn("[projects] failed to load document owner profiles", profilesError);
|
||||
}
|
||||
for (const profile of profiles ?? []) {
|
||||
const displayName =
|
||||
typeof profile.display_name === "string"
|
||||
? profile.display_name.trim()
|
||||
: "";
|
||||
if (displayName) {
|
||||
displayNameByUserId.set(profile.user_id as string, displayName);
|
||||
}
|
||||
}
|
||||
|
||||
for (const doc of docs as ({
|
||||
user_id?: string | null;
|
||||
owner_email?: string | null;
|
||||
owner_display_name?: string | null;
|
||||
})[]) {
|
||||
if (!doc.user_id) continue;
|
||||
doc.owner_email = emailByUserId.get(doc.user_id) ?? null;
|
||||
doc.owner_display_name = displayNameByUserId.get(doc.user_id) ?? null;
|
||||
}
|
||||
}
|
||||
|
||||
// GET /projects
|
||||
projectsRouter.get("/", requireAuth, async (req, res) => {
|
||||
const userId = res.locals.userId as string;
|
||||
|
|
@ -190,10 +243,12 @@ projectsRouter.get("/:projectId", requireAuth, async (req, res) => {
|
|||
]);
|
||||
const docsTyped = (docs ?? []) as unknown as {
|
||||
id: string;
|
||||
user_id?: string | null;
|
||||
current_version_id?: string | null;
|
||||
}[];
|
||||
await attachLatestVersionNumbers(db, docsTyped);
|
||||
await attachActiveVersionPaths(db, docsTyped);
|
||||
await attachDocumentOwnerLabels(db, docsTyped);
|
||||
res.json({
|
||||
...project,
|
||||
is_owner: project.user_id === userId,
|
||||
|
|
@ -335,9 +390,11 @@ projectsRouter.patch("/:projectId", requireAuth, async (req, res) => {
|
|||
]);
|
||||
const docsTyped = (docs ?? []) as unknown as {
|
||||
id: string;
|
||||
user_id?: string | null;
|
||||
current_version_id?: string | null;
|
||||
}[];
|
||||
await attachActiveVersionPaths(db, docsTyped);
|
||||
await attachDocumentOwnerLabels(db, docsTyped);
|
||||
res.json({ ...data, documents: docsTyped, folders: folderData ?? [] });
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ type UserProfileRow = {
|
|||
title_model: string | null;
|
||||
tabular_model: string;
|
||||
mfa_on_login: boolean | null;
|
||||
legal_research_us: boolean | null;
|
||||
};
|
||||
|
||||
function errorMessage(error: unknown): string {
|
||||
|
|
@ -65,6 +66,8 @@ function errorMessage(error: unknown): string {
|
|||
}
|
||||
|
||||
const PROFILE_SELECT =
|
||||
"display_name, organisation, message_credits_used, credits_reset_date, tier, title_model, tabular_model, mfa_on_login, legal_research_us";
|
||||
const PROFILE_SELECT_NO_LEGAL =
|
||||
"display_name, organisation, message_credits_used, credits_reset_date, tier, title_model, tabular_model, mfa_on_login";
|
||||
const LEGACY_PROFILE_SELECT =
|
||||
"display_name, organisation, message_credits_used, credits_reset_date, tier, tabular_model";
|
||||
|
|
@ -80,14 +83,43 @@ function isMissingProfileColumn(error: unknown, column: string): boolean {
|
|||
return record.code === "42703" && message.includes(column);
|
||||
}
|
||||
|
||||
// Loads a profile while tolerating older databases that lack the
|
||||
// legal_research_us column. Tries the full select first, then falls back to
|
||||
// the legacy cascade (which also handles missing title_model / mfa_on_login)
|
||||
// and defaults the feature flag to enabled.
|
||||
async function selectProfile(
|
||||
db: ReturnType<typeof createServerSupabase>,
|
||||
userId: string,
|
||||
mode: "maybe" | "single",
|
||||
) {
|
||||
const fullQuery = db
|
||||
.from("user_profiles")
|
||||
.select(PROFILE_SELECT)
|
||||
.eq("user_id", userId);
|
||||
const full =
|
||||
mode === "single"
|
||||
? await fullQuery.single()
|
||||
: await fullQuery.maybeSingle();
|
||||
if (!full.error) return full;
|
||||
|
||||
const legacy = await selectProfileLegacy(db, userId, mode);
|
||||
if (legacy.data && typeof legacy.data === "object") {
|
||||
const row = legacy.data as Record<string, unknown>;
|
||||
if (!("legal_research_us" in row)) {
|
||||
Object.assign(row, { legal_research_us: true });
|
||||
}
|
||||
}
|
||||
return legacy;
|
||||
}
|
||||
|
||||
async function selectProfileLegacy(
|
||||
db: ReturnType<typeof createServerSupabase>,
|
||||
userId: string,
|
||||
mode: "maybe" | "single",
|
||||
) {
|
||||
const query = db
|
||||
.from("user_profiles")
|
||||
.select(PROFILE_SELECT)
|
||||
.select(PROFILE_SELECT_NO_LEGAL)
|
||||
.eq("user_id", userId);
|
||||
const result =
|
||||
mode === "single" ? await query.single() : await query.maybeSingle();
|
||||
|
|
@ -166,6 +198,7 @@ function serializeProfile(row: UserProfileRow, apiKeyStatus?: ApiKeyStatus) {
|
|||
titleModel: resolveModel(row.title_model, titleFallback),
|
||||
tabularModel: resolveModel(row.tabular_model, DEFAULT_TABULAR_MODEL),
|
||||
mfaOnLogin: row.mfa_on_login === true,
|
||||
legalResearchUs: row.legal_research_us !== false,
|
||||
...(apiKeyStatus ? { apiKeyStatus } : {}),
|
||||
};
|
||||
}
|
||||
|
|
@ -178,6 +211,7 @@ function validateProfilePayload(body: unknown):
|
|||
organisation?: string | null;
|
||||
title_model?: string;
|
||||
tabular_model?: string;
|
||||
legal_research_us?: boolean;
|
||||
updated_at: string;
|
||||
};
|
||||
}
|
||||
|
|
@ -192,6 +226,7 @@ function validateProfilePayload(body: unknown):
|
|||
"organisation",
|
||||
"titleModel",
|
||||
"tabularModel",
|
||||
"legalResearchUs",
|
||||
]);
|
||||
const invalidField = Object.keys(raw).find(
|
||||
(key) => !allowedFields.has(key),
|
||||
|
|
@ -208,6 +243,7 @@ function validateProfilePayload(body: unknown):
|
|||
organisation?: string | null;
|
||||
title_model?: string;
|
||||
tabular_model?: string;
|
||||
legal_research_us?: boolean;
|
||||
updated_at: string;
|
||||
} = { updated_at: new Date().toISOString() };
|
||||
|
||||
|
|
@ -253,6 +289,16 @@ function validateProfilePayload(body: unknown):
|
|||
update.title_model = resolved;
|
||||
}
|
||||
|
||||
if ("legalResearchUs" in raw) {
|
||||
if (typeof raw.legalResearchUs !== "boolean") {
|
||||
return {
|
||||
ok: false,
|
||||
detail: "legalResearchUs must be a boolean",
|
||||
};
|
||||
}
|
||||
update.legal_research_us = raw.legalResearchUs;
|
||||
}
|
||||
|
||||
return { ok: true, update };
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue