trustgraph/ts/packages/flow/src/model/text-completion/common.ts

366 lines
11 KiB
TypeScript
Raw Normal View History

2026-06-01 23:19:54 -05:00
import {
Llm,
2026-06-01 23:19:54 -05:00
TooManyRequestsError,
errorMessage,
makeLlmServiceShape,
2026-06-01 23:19:54 -05:00
type LlmChunk,
type LlmResult,
type LlmProvider,
2026-06-01 23:19:54 -05:00
} from "@trustgraph/base";
import { Config, Effect, Layer, ManagedRuntime, Ref, Result, Stream } from "effect";
2026-06-01 23:19:54 -05:00
import * as O from "effect/Option";
import * as Predicate from "effect/Predicate";
2026-06-01 23:19:54 -05:00
import * as S from "effect/Schema";
import { AiError, LanguageModel, Prompt, Response } from "effect/unstable/ai";
2026-06-01 23:19:54 -05:00
export class TextCompletionConfigError extends S.TaggedErrorClass<TextCompletionConfigError>()(
"TextCompletionConfigError",
{
message: S.String,
provider: S.String,
key: S.String,
},
) {}
export class TextCompletionProviderError extends S.TaggedErrorClass<TextCompletionProviderError>()(
"TextCompletionProviderError",
{
message: S.String,
provider: S.String,
},
) {}
export type TextCompletionRuntimeError =
| TextCompletionProviderError
| TooManyRequestsError;
export interface LanguageModelProviderRequest {
readonly model: string;
readonly temperature: number;
}
export interface LanguageModelProviderOptions<Requirements> {
readonly provider: string;
readonly defaultModel: string;
readonly defaultTemperature: number;
readonly runtime: ManagedRuntime.ManagedRuntime<Requirements, TextCompletionRuntimeError>;
readonly makeLanguageModel: (
request: LanguageModelProviderRequest,
) => Effect.Effect<LanguageModel.Service, TextCompletionRuntimeError, Requirements>;
}
export const makeTextCompletionLayer = <E, R>(
provider: Effect.Effect<LlmProvider, E, R>,
): Layer.Layer<Llm, E, R> =>
Layer.effect(Llm)(
provider.pipe(
Effect.map((resolvedProvider) =>
Llm.of(makeLlmServiceShape(resolvedProvider))
),
),
);
2026-06-02 04:33:48 -05:00
type StreamingTokenTotals = {
readonly inToken: number;
readonly outToken: number;
};
type LlmStreamPart = {
readonly text: O.Option<string>;
readonly inToken: O.Option<number>;
readonly outToken: O.Option<number>;
};
const initialTokenTotals = {
inToken: 0,
outToken: 0,
} satisfies StreamingTokenTotals;
const updateTokenTotals = (
current: StreamingTokenTotals,
part: LlmStreamPart,
): StreamingTokenTotals => ({
inToken: O.getOrElse(part.inToken, () => current.inToken),
outToken: O.getOrElse(part.outToken, () => current.outToken),
});
const finalChunk = (model: string, totals: StreamingTokenTotals): LlmChunk => ({
text: "",
inToken: totals.inToken,
outToken: totals.outToken,
model,
isFinal: true,
});
const textChunk = (model: string, text: string): LlmChunk => ({
text,
inToken: null,
outToken: null,
model,
isFinal: false,
});
const effectAiProviderError = (
provider: string,
error: unknown,
): TextCompletionRuntimeError => {
if (
AiError.isAiError(error) &&
(error.reason._tag === "RateLimitError" || error.reason._tag === "QuotaExhaustedError")
) {
return TooManyRequestsError.make({ message: "Rate limit exceeded" });
}
return providerRuntimeError(provider, error);
};
const usageInputTokens = (usage: Response.Usage): number =>
usage.inputTokens.total ?? 0;
const usageOutputTokens = (usage: Response.Usage): number =>
usage.outputTokens.total ?? 0;
const languageModelPrompt = (
system: string,
prompt: string,
): Prompt.RawInput => [
{ role: "system", content: system },
{ role: "user", content: [{ type: "text", text: prompt }] },
];
2026-06-02 04:33:48 -05:00
const contentPartText = (part: unknown): O.Option<string> =>
Predicate.isObject(part) &&
Predicate.hasProperty(part, "text") &&
Predicate.isString(part.text)
? O.some(part.text)
: O.none();
export const textFromContent = (content: unknown): string => {
if (Predicate.isString(content)) {
return content;
}
return Array.isArray(content)
? content.flatMap((part) => O.toArray(contentPartText(part))).join("")
: "";
};
export const llmStreamPart = (part: {
readonly text?: string | null | undefined;
readonly inToken?: number | null | undefined;
readonly outToken?: number | null | undefined;
}): LlmStreamPart => ({
text: O.fromNullishOr(part.text),
inToken: O.fromNullishOr(part.inToken),
outToken: O.fromNullishOr(part.outToken),
});
export const streamTextCompletionChunks = <A>(
iterable: AsyncIterable<A>,
options: {
readonly model: string;
readonly mapError: (error: unknown) => TextCompletionRuntimeError;
readonly extract: (chunk: A) => LlmStreamPart;
readonly finalTokens?: Effect.Effect<StreamingTokenTotals, TextCompletionRuntimeError>;
},
): Stream.Stream<LlmChunk, TextCompletionRuntimeError> =>
Stream.unwrap(Effect.gen(function* () {
const totals = yield* Ref.make(initialTokenTotals);
const chunks = Stream.fromAsyncIterable(iterable, options.mapError).pipe(
Stream.mapEffect((chunk) =>
Effect.gen(function* () {
const part = options.extract(chunk);
yield* Ref.update(totals, (current) => updateTokenTotals(current, part));
return O.map(
O.filter(part.text, (text) => text.length > 0),
(text) => textChunk(options.model, text),
);
})
),
Stream.filterMap((chunk) =>
O.match(chunk, {
onNone: () => Result.fail(undefined),
onSome: Result.succeed,
})
),
);
const tokenTotals = options.finalTokens ?? Ref.get(totals);
return chunks.pipe(
Stream.concat(Stream.fromEffect(tokenTotals.pipe(
Effect.map((tokens) => finalChunk(options.model, tokens)),
))),
);
}));
2026-06-01 23:19:54 -05:00
export const optionalStringConfig = Effect.fn("TextCompletion.optionalStringConfig")(function*(
provider: string,
name: string,
) {
const value = yield* Config.string(name).pipe(
Config.option,
Effect.mapError((cause) =>
TextCompletionConfigError.make({
provider,
key: name,
message: errorMessage(cause),
})
),
);
return O.getOrUndefined(value);
});
export const requiredString = (
value: string | undefined,
provider: string,
key: string,
message: string,
) =>
value !== undefined && value.length > 0
? Effect.succeed(value)
: Effect.fail(TextCompletionConfigError.make({ provider, key, message }));
export const providerRuntimeError = (
provider: string,
error: unknown,
): TextCompletionRuntimeError =>
TextCompletionProviderError.make({
provider,
message: errorMessage(error),
});
export const providerStatusError = (
provider: string,
error: unknown,
): TextCompletionRuntimeError => {
const status = Predicate.isObject(error) && Predicate.hasProperty(error, "status")
? error.status
2026-06-01 23:19:54 -05:00
: undefined;
const statusCode = Predicate.isObject(error) && Predicate.hasProperty(error, "statusCode")
? error.statusCode
2026-06-01 23:19:54 -05:00
: undefined;
return status === 429 || statusCode === 429
? TooManyRequestsError.make({ message: "Rate limit exceeded" })
: providerRuntimeError(provider, error);
};
const languageModelResult = (
response: LanguageModel.GenerateTextResponse<{}>,
model: string,
): LlmResult => ({
text: response.text,
inToken: usageInputTokens(response.usage),
outToken: usageOutputTokens(response.usage),
model,
});
const languageModelStreamChunk = (
provider: string,
model: string,
part: Response.StreamPart<{}>,
): Effect.Effect<Result.Result<LlmChunk, undefined>, TextCompletionRuntimeError> => {
switch (part.type) {
case "text-delta":
return Effect.succeed(
part.delta.length > 0
? Result.succeed(textChunk(model, part.delta))
: Result.fail(undefined),
);
case "finish":
return Effect.succeed(
Result.succeed(
finalChunk(model, {
inToken: usageInputTokens(part.usage),
outToken: usageOutputTokens(part.usage),
}),
),
);
case "error":
return Effect.fail(effectAiProviderError(provider, part.error));
default:
return Effect.succeed(Result.fail(undefined));
}
};
const runLanguageModelStream = <RuntimeRequirements, StreamRequirements extends RuntimeRequirements>(
runtime: ManagedRuntime.ManagedRuntime<RuntimeRequirements, TextCompletionRuntimeError>,
stream: Stream.Stream<LlmChunk, TextCompletionRuntimeError, StreamRequirements>,
): AsyncIterable<LlmChunk> => ({
[Symbol.asyncIterator]: () => {
const iterator = runtime.context().then((context) =>
Stream.toAsyncIterableWith(stream, context)[Symbol.asyncIterator]()
);
return {
next: () => iterator.then((current) => current.next()),
};
},
});
export const makeLanguageModelProvider = <Requirements>(
options: LanguageModelProviderOptions<Requirements>,
): LlmProvider => ({
generateContent: (system, prompt, model, temperature) => {
const modelName = model ?? options.defaultModel;
const temp = temperature ?? options.defaultTemperature;
return options.runtime.runPromise(
Effect.gen(function* () {
const languageModel = yield* options.makeLanguageModel({
model: modelName,
temperature: temp,
});
const response = yield* languageModel.generateText({
prompt: languageModelPrompt(system, prompt),
}).pipe(
Effect.mapError((error) => effectAiProviderError(options.provider, error)),
);
return languageModelResult(response, modelName);
}),
);
},
supportsStreaming: () => true,
generateContentStream: (system, prompt, model, temperature) => {
const modelName = model ?? options.defaultModel;
const temp = temperature ?? options.defaultTemperature;
const stream = Stream.unwrap(
Effect.gen(function* () {
const languageModel = yield* options.makeLanguageModel({
model: modelName,
temperature: temp,
});
return languageModel.streamText({
prompt: languageModelPrompt(system, prompt),
}).pipe(
Stream.mapError((error) => effectAiProviderError(options.provider, error)),
Stream.filterMapEffect((part) =>
languageModelStreamChunk(options.provider, modelName, part)
),
);
}),
);
return toAsyncGenerator(runLanguageModelStream(options.runtime, stream), (error) =>
effectAiProviderError(options.provider, error)
);
},
});
2026-06-01 23:19:54 -05:00
export const toAsyncGenerator = (
iterable: AsyncIterable<LlmChunk>,
mapError: (error: unknown) => TextCompletionRuntimeError,
): AsyncGenerator<LlmChunk> => {
const iterator = iterable[Symbol.asyncIterator]();
let generator: AsyncGenerator<LlmChunk>;
generator = {
next: (value?: unknown) => iterator.next(value),
2026-06-01 23:19:54 -05:00
return: (value?: unknown) =>
iterator.return === undefined
? Promise.resolve({ done: true, value })
: iterator.return(value),
2026-06-01 23:19:54 -05:00
throw: (error?: unknown) =>
iterator.throw === undefined
? Promise.reject(mapError(error))
: iterator.throw(error),
2026-06-01 23:19:54 -05:00
[Symbol.asyncIterator]: () => generator,
};
2026-06-01 23:19:54 -05:00
return generator;
};