fix: enforce codex local step budget

This commit is contained in:
Andrey Avtomonov 2026-06-01 18:08:02 +02:00
parent f27fc9c9a5
commit 5966a09c49
4 changed files with 113 additions and 13 deletions

View file

@ -34,12 +34,47 @@ function promptWithSystem(system: string | undefined, prompt: string): string {
return [system, prompt].filter(Boolean).join('\n\n');
}
async function collectEvents(events: AsyncIterable<unknown>): Promise<unknown[]> {
interface CollectCodexEventsOptions {
stepBudget?: number;
abortController?: AbortController;
}
interface CollectCodexEventsResult {
events: unknown[];
budgetExceeded: boolean;
}
function eventRecord(value: unknown): Record<string, unknown> | undefined {
return value && typeof value === 'object' ? (value as Record<string, unknown>) : undefined;
}
function isCompletedMcpToolCall(event: unknown): boolean {
const record = eventRecord(event);
const item = eventRecord(record?.item);
return record?.type === 'item.completed' && item?.type === 'mcp_tool_call';
}
async function collectEvents(
events: AsyncIterable<unknown>,
options: CollectCodexEventsOptions = {},
): Promise<CollectCodexEventsResult> {
const collected: unknown[] = [];
let completedToolSteps = 0;
let budgetExceeded = false;
for await (const event of events) {
collected.push(event);
if (options.stepBudget !== undefined && isCompletedMcpToolCall(event)) {
completedToolSteps += 1;
if (completedToolSteps >= options.stepBudget) {
budgetExceeded = true;
options.abortController?.abort();
break;
}
}
}
return collected;
return { events: collected, budgetExceeded };
}
function metrics(summary: CodexExecEventSummary, startedAt: number): { totalMs: number; usage: LlmTokenUsage } {
@ -125,7 +160,7 @@ export class CodexKtxLlmRuntime implements KtxLlmRuntimePort {
}
: {}),
});
const events = await collectEvents(
const collected = await collectEvents(
await this.runner.runStreamed({
projectDir: this.deps.projectDir,
model,
@ -134,7 +169,7 @@ export class CodexKtxLlmRuntime implements KtxLlmRuntimePort {
env: config.env,
}),
);
const summary = summarizeCodexExecEvents(events, { startedAt });
const summary = summarizeCodexExecEvents(collected.events, { startedAt });
input.onMetrics?.(metrics(summary, startedAt));
return assertSuccessfulText(summary);
} finally {
@ -166,7 +201,7 @@ export class CodexKtxLlmRuntime implements KtxLlmRuntimePort {
}
: {}),
});
const events = await collectEvents(
const collected = await collectEvents(
await this.runner.runStreamed({
projectDir: this.deps.projectDir,
model,
@ -176,7 +211,7 @@ export class CodexKtxLlmRuntime implements KtxLlmRuntimePort {
outputSchema: z.toJSONSchema(input.schema, { target: 'draft-7' }) as Record<string, unknown>,
}),
);
const summary = summarizeCodexExecEvents(events, { startedAt });
const summary = summarizeCodexExecEvents(collected.events, { startedAt });
input.onMetrics?.(metrics(summary, startedAt));
return parseStructuredOutput(input.schema, assertSuccessfulText(summary));
} finally {
@ -207,16 +242,19 @@ export class CodexKtxLlmRuntime implements KtxLlmRuntimePort {
}
: {}),
});
const events = await collectEvents(
const abortController = new AbortController();
const collected = await collectEvents(
await this.runner.runStreamed({
projectDir: this.deps.projectDir,
model,
prompt: promptWithSystem(params.systemPrompt, params.userPrompt),
configOverrides: config.configOverrides,
env: config.env,
signal: abortController.signal,
}),
{ stepBudget: params.stepBudget, abortController },
);
const summary = summarizeCodexExecEvents(events, { startedAt });
const summary = summarizeCodexExecEvents(collected.events, { startedAt });
for (let index = 1; index <= summary.stepCount; index += 1) {
try {
await params.onStepFinish?.({ stepIndex: index, stepBudget: params.stepBudget });
@ -227,7 +265,7 @@ export class CodexKtxLlmRuntime implements KtxLlmRuntimePort {
}
}
const error = summaryError(summary);
const stopReason = error ? 'error' : summary.stopReason;
const stopReason = collected.budgetExceeded ? 'budget' : error ? 'error' : summary.stopReason;
return {
stopReason,
...(stopReason === 'error' && error ? { error } : {}),

View file

@ -1,4 +1,4 @@
import { Codex, type CodexOptions, type ThreadOptions } from '@openai/codex-sdk';
import { Codex, type CodexOptions, type ThreadOptions, type TurnOptions } from '@openai/codex-sdk';
export interface CodexSdkRunnerInput {
projectDir: string;
@ -7,6 +7,7 @@ export interface CodexSdkRunnerInput {
configOverrides?: Record<string, unknown>;
env?: Record<string, string>;
outputSchema?: Record<string, unknown>;
signal?: AbortSignal;
}
export interface CodexSdkRunner {
@ -19,7 +20,7 @@ export interface CodexSdkCliRunnerOptions {
}
type CodexThread = {
runStreamed(input: string, turnOptions?: { outputSchema?: Record<string, unknown> }): Promise<{ events: AsyncIterable<unknown> }>;
runStreamed(input: string, turnOptions?: TurnOptions): Promise<{ events: AsyncIterable<unknown> }>;
};
type CodexClient = {
@ -82,9 +83,13 @@ export class CodexSdkCliRunner implements CodexSdkRunner {
webSearchMode: 'disabled',
approvalPolicy: 'never',
});
const turnOptions: TurnOptions = {
...(input.outputSchema ? { outputSchema: input.outputSchema } : {}),
...(input.signal ? { signal: input.signal } : {}),
};
const streamed = await thread.runStreamed(
input.prompt,
input.outputSchema ? { outputSchema: input.outputSchema } : undefined,
Object.keys(turnOptions).length > 0 ? turnOptions : undefined,
);
return streamed.events;
}

View file

@ -17,6 +17,24 @@ function runner(items: unknown[]) {
};
}
function budgetRunner() {
let observedSignal: AbortSignal | undefined;
return {
observedSignal: () => observedSignal,
runStreamed: vi.fn(async (input: { signal?: AbortSignal }) => {
observedSignal = input.signal;
return events([
{ type: 'turn.started' },
{ type: 'item.started', item: { type: 'mcp_tool_call', server: 'ktx', tool: 'first', status: 'in_progress' } },
{ type: 'item.completed', item: { type: 'mcp_tool_call', server: 'ktx', tool: 'first', status: 'completed' } },
{ type: 'item.started', item: { type: 'mcp_tool_call', server: 'ktx', tool: 'second', status: 'in_progress' } },
{ type: 'item.completed', item: { type: 'mcp_tool_call', server: 'ktx', tool: 'second', status: 'completed' } },
{ type: 'turn.completed', usage: { input_tokens: 1, output_tokens: 1 } },
]);
}),
};
}
describe('CodexKtxLlmRuntime', () => {
it('generates text with the role-selected model and metrics', async () => {
const onMetrics = vi.fn();
@ -207,6 +225,40 @@ describe('CodexKtxLlmRuntime', () => {
});
});
it('returns budget and aborts the Codex stream when local MCP step budget is reached', async () => {
const fakeRunner = budgetRunner();
const runtime = new CodexKtxLlmRuntime({
projectDir: '/tmp/project',
modelSlots: { default: 'codex' },
runner: fakeRunner,
});
const onStepFinish = vi.fn();
const result = await runtime.runAgentLoop({
modelRole: 'default',
systemPrompt: 'system',
userPrompt: 'user',
stepBudget: 1,
telemetryTags: {},
onStepFinish,
toolSet: {
first: {
name: 'first',
description: 'First tool',
inputSchema: z.object({}),
execute: vi.fn(),
},
},
});
expect(result.stopReason).toBe('budget');
expect(result.error).toBeUndefined();
expect(result.metrics).toMatchObject({ stepCount: 1 });
expect(onStepFinish).toHaveBeenCalledTimes(1);
expect(onStepFinish).toHaveBeenCalledWith({ stepIndex: 1, stepBudget: 1 });
expect(fakeRunner.observedSignal()?.aborted).toBe(true);
});
it('probes Codex authentication through a minimal non-interactive turn', async () => {
const fakeRunner = runner([
{ type: 'turn.started' },

View file

@ -45,6 +45,7 @@ describe('CodexSdkCliRunner', () => {
};
try {
const controller = new AbortController();
const events = await runner.runStreamed({
projectDir: '/tmp/ktx-project',
model: 'gpt-5.3-codex',
@ -54,6 +55,7 @@ describe('CodexSdkCliRunner', () => {
},
env: { KTX_CODEX_RUNTIME_MCP_TOKEN: 'run-token' },
outputSchema,
signal: controller.signal,
});
expect(sdkMock.Codex).toHaveBeenCalledWith({
@ -77,7 +79,10 @@ describe('CodexSdkCliRunner', () => {
webSearchMode: 'disabled',
approvalPolicy: 'never',
});
expect(sdkMock.runStreamed).toHaveBeenCalledWith('Return JSON.', { outputSchema });
expect(sdkMock.runStreamed).toHaveBeenCalledWith('Return JSON.', {
outputSchema,
signal: controller.signal,
});
await expect(collectAsync(events)).resolves.toEqual([
{ type: 'turn.completed', usage: { input_tokens: 1, output_tokens: 2 } },
]);