feat(workflow): add ThreadReactor generic ReAct loop + migrate extract (Phase 1)
- New src/reactor/ module: createThreadReactor, createLlmFn, types - Two-stage API: config (llm, systemPrompt, tools, toolHandler) + per-call (thread, input, schema) - All tool failures are recoverable (returned to LLM as error message) - Rewrite createExtract to use createThreadReactor - Delete reactExtract old implementation - Fix template test imports (START/END from runtime, validateWorkflowDescriptor from engine) 268 tests passing. Refs #139, relates #140
This commit is contained in:
@@ -101,9 +101,9 @@ export function ThreadDetail({ threadId, onBack }: Props) {
|
|||||||
)}
|
)}
|
||||||
{(status === "ok" || liveActive || records.length > 0) && (
|
{(status === "ok" || liveActive || records.length > 0) && (
|
||||||
<div className="space-y-3">
|
<div className="space-y-3">
|
||||||
{records.map((r, i) => (
|
{records.map((r) => (
|
||||||
<div
|
<div
|
||||||
key={i}
|
key={`${threadId}-${r.type}-${String(r.timestamp)}-${r.role ?? ""}-${r.content ?? ""}`}
|
||||||
className="p-3 rounded border text-sm"
|
className="p-3 rounded border text-sm"
|
||||||
style={{ background: "var(--color-surface)", borderColor: "var(--color-border)" }}
|
style={{ background: "var(--color-surface)", borderColor: "var(--color-border)" }}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
import { describe, expect, test } from "bun:test";
|
import { describe, expect, test } from "bun:test";
|
||||||
import {
|
import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime";
|
||||||
END,
|
import { validateWorkflowDescriptor } from "@uncaged/workflow";
|
||||||
type ModeratorContext,
|
|
||||||
type RoleStep,
|
|
||||||
START,
|
|
||||||
validateWorkflowDescriptor,
|
|
||||||
} from "@uncaged/workflow-runtime";
|
|
||||||
import { buildDevelopDescriptor } from "../src/descriptor.js";
|
import { buildDevelopDescriptor } from "../src/descriptor.js";
|
||||||
import { developModerator } from "../src/index.js";
|
import { developModerator } from "../src/index.js";
|
||||||
import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js";
|
import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js";
|
||||||
|
|||||||
@@ -2,14 +2,8 @@ import { afterEach, describe, expect, test } from "bun:test";
|
|||||||
import { mkdtemp, rm } from "node:fs/promises";
|
import { mkdtemp, rm } from "node:fs/promises";
|
||||||
import { tmpdir } from "node:os";
|
import { tmpdir } from "node:os";
|
||||||
import { join } from "node:path";
|
import { join } from "node:path";
|
||||||
import { createCasStore, createExtract, createWorkflow } from "@uncaged/workflow";
|
import { createCasStore, createExtract, createWorkflow, validateWorkflowDescriptor } from "@uncaged/workflow";
|
||||||
import {
|
import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime";
|
||||||
END,
|
|
||||||
type ModeratorContext,
|
|
||||||
type RoleStep,
|
|
||||||
START,
|
|
||||||
validateWorkflowDescriptor,
|
|
||||||
} from "@uncaged/workflow-runtime";
|
|
||||||
import { buildSolveIssueDescriptor } from "../src/descriptor.js";
|
import { buildSolveIssueDescriptor } from "../src/descriptor.js";
|
||||||
import type { DeveloperMeta } from "../src/developer.js";
|
import type { DeveloperMeta } from "../src/developer.js";
|
||||||
import { solveIssueModerator, solveIssueWorkflowDefinition } from "../src/index.js";
|
import { solveIssueModerator, solveIssueWorkflowDefinition } from "../src/index.js";
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import { createWorkflow, readWorkflowRegistry, executeThread } from "@uncaged/wo
|
|||||||
| **Registry** | `readWorkflowRegistry`, `writeWorkflowRegistry`, `registerWorkflowVersion`, `workflowRegistryPath`, YAML helpers |
|
| **Registry** | `readWorkflowRegistry`, `writeWorkflowRegistry`, `registerWorkflowVersion`, `workflowRegistryPath`, YAML helpers |
|
||||||
| **CAS** | `createCasStore`, Merkle helpers (`putStepMerkleNode`, `getContentMerklePayload`, …), `hashWorkflowBundleBytes` |
|
| **CAS** | `createCasStore`, Merkle helpers (`putStepMerkleNode`, `getContentMerklePayload`, …), `hashWorkflowBundleBytes` |
|
||||||
| **Engine** | `createWorkflow`, `executeThread`, `parseThreadDataJsonl`, fork helpers, `garbageCollectCas` |
|
| **Engine** | `createWorkflow`, `executeThread`, `parseThreadDataJsonl`, fork helpers, `garbageCollectCas` |
|
||||||
| **Extract / LLM tools** | `llmExtract`, `reactExtract`, `createExtract`, `getExtractProvider` |
|
| **Extract / LLM tools** | `llmExtract`, `createExtract`, `createThreadReactor`, `createLlmFn`, `getExtractProvider` |
|
||||||
| **Agent bridge** | `workflowAsAgent` — expose a registered workflow as an agent-backed role |
|
| **Agent bridge** | `workflowAsAgent` — expose a registered workflow as an agent-backed role |
|
||||||
| **Utilities** | `createLogger`, ULID / Crockford Base32 codecs, `getDefaultWorkflowStorageRoot`, paths |
|
| **Utilities** | `createLogger`, ULID / Crockford Base32 codecs, `getDefaultWorkflowStorageRoot`, paths |
|
||||||
|
|
||||||
|
|||||||
+68
-17
@@ -6,7 +6,8 @@ import type { LlmProvider } from "@uncaged/workflow-runtime";
|
|||||||
import * as z from "zod/v4";
|
import * as z from "zod/v4";
|
||||||
import { createCasStore } from "../src/cas/cas.js";
|
import { createCasStore } from "../src/cas/cas.js";
|
||||||
import { createContentMerkleNode, serializeMerkleNode } from "../src/cas/merkle.js";
|
import { createContentMerkleNode, serializeMerkleNode } from "../src/cas/merkle.js";
|
||||||
import { reactExtract } from "../src/extract/react-extract.js";
|
import { extractFunctionToolFromZodSchema } from "../src/extract/llm-extract.js";
|
||||||
|
import { createLlmFn, createThreadReactor } from "../src/reactor/index.js";
|
||||||
|
|
||||||
const metaSchema = z.object({ seen: z.string() });
|
const metaSchema = z.object({ seen: z.string() });
|
||||||
|
|
||||||
@@ -16,7 +17,57 @@ const provider: LlmProvider = {
|
|||||||
model: "test",
|
model: "test",
|
||||||
};
|
};
|
||||||
|
|
||||||
describe("reactExtract", () => {
|
const CAS_GET_TOOL_DEFINITION = {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: "cas_get",
|
||||||
|
description: "Read CAS node",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
hash: { type: "string", description: "hash" },
|
||||||
|
},
|
||||||
|
required: ["hash"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
type ThreadCtx = { cas: ReturnType<typeof createCasStore> };
|
||||||
|
|
||||||
|
function createTestReactor() {
|
||||||
|
const llm = createLlmFn(provider);
|
||||||
|
return createThreadReactor<ThreadCtx>({
|
||||||
|
llm,
|
||||||
|
maxRounds: 10,
|
||||||
|
staticTools: [CAS_GET_TOOL_DEFINITION],
|
||||||
|
structuredToolFromSchema: (schema) => {
|
||||||
|
const t = extractFunctionToolFromZodSchema(schema);
|
||||||
|
return {
|
||||||
|
name: t.name,
|
||||||
|
tool: {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
systemPromptForStructuredTool: (structuredToolName) =>
|
||||||
|
`Extract metadata. Use cas_get when needed. Call ${structuredToolName} with JSON args matching the schema, or reply with plain JSON.`,
|
||||||
|
toolHandler: async (call, thread) => {
|
||||||
|
if (call.function.name !== "cas_get") {
|
||||||
|
return `unexpected tool ${call.function.name}`;
|
||||||
|
}
|
||||||
|
const ta = JSON.parse(call.function.arguments) as { hash: string };
|
||||||
|
const blob = await thread.cas.get(ta.hash);
|
||||||
|
return blob === null ? "null" : blob;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("createThreadReactor (extract-shaped)", () => {
|
||||||
let restoreFetch: (() => void) | null = null;
|
let restoreFetch: (() => void) | null = null;
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@@ -25,7 +76,7 @@ describe("reactExtract", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("cas_get rounds then extract tool yields validated meta", async () => {
|
test("cas_get rounds then extract tool yields validated meta", async () => {
|
||||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-"));
|
const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-"));
|
||||||
const cas = createCasStore(casDir);
|
const cas = createCasStore(casDir);
|
||||||
try {
|
try {
|
||||||
const blob = serializeMerkleNode(createContentMerkleNode("needle"));
|
const blob = serializeMerkleNode(createContentMerkleNode("needle"));
|
||||||
@@ -87,12 +138,12 @@ describe("reactExtract", () => {
|
|||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
) as typeof fetch;
|
) as typeof fetch;
|
||||||
|
|
||||||
|
const reactor = createTestReactor();
|
||||||
const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`;
|
const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`;
|
||||||
const result = await reactExtract({
|
const result = await reactor({
|
||||||
text,
|
thread: { cas },
|
||||||
|
input: text,
|
||||||
schema: metaSchema,
|
schema: metaSchema,
|
||||||
provider,
|
|
||||||
cas,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.ok).toBe(true);
|
expect(result.ok).toBe(true);
|
||||||
@@ -107,7 +158,7 @@ describe("reactExtract", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("stops after max tool rounds when model keeps calling cas_get", async () => {
|
test("stops after max tool rounds when model keeps calling cas_get", async () => {
|
||||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-max-"));
|
const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-max-"));
|
||||||
const cas = createCasStore(casDir);
|
const cas = createCasStore(casDir);
|
||||||
try {
|
try {
|
||||||
const blob = serializeMerkleNode(createContentMerkleNode("x"));
|
const blob = serializeMerkleNode(createContentMerkleNode("x"));
|
||||||
@@ -146,11 +197,11 @@ describe("reactExtract", () => {
|
|||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
) as typeof fetch;
|
) as typeof fetch;
|
||||||
|
|
||||||
const result = await reactExtract({
|
const reactor = createTestReactor();
|
||||||
text: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.",
|
const result = await reactor({
|
||||||
|
thread: { cas },
|
||||||
|
input: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.",
|
||||||
schema: metaSchema,
|
schema: metaSchema,
|
||||||
provider,
|
|
||||||
cas,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.ok).toBe(false);
|
expect(result.ok).toBe(false);
|
||||||
@@ -165,7 +216,7 @@ describe("reactExtract", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("passthrough JSON assistant message without tool calls", async () => {
|
test("passthrough JSON assistant message without tool calls", async () => {
|
||||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-pass-"));
|
const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-pass-"));
|
||||||
const cas = createCasStore(casDir);
|
const cas = createCasStore(casDir);
|
||||||
try {
|
try {
|
||||||
const origFetch = globalThis.fetch;
|
const origFetch = globalThis.fetch;
|
||||||
@@ -189,11 +240,11 @@ describe("reactExtract", () => {
|
|||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
) as typeof fetch;
|
) as typeof fetch;
|
||||||
|
|
||||||
const result = await reactExtract({
|
const reactor = createTestReactor();
|
||||||
text: "## Agent Output\nok\n## Extraction Instruction\nExtract.",
|
const result = await reactor({
|
||||||
|
thread: { cas },
|
||||||
|
input: "## Agent Output\nok\n## Extraction Instruction\nExtract.",
|
||||||
schema: metaSchema,
|
schema: metaSchema,
|
||||||
provider,
|
|
||||||
cas,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.ok).toBe(true);
|
expect(result.ok).toBe(true);
|
||||||
@@ -1,12 +1,39 @@
|
|||||||
import type { ExtractContext, ExtractFn, LlmProvider } from "@uncaged/workflow-runtime";
|
import type { ExtractContext, ExtractFn, LlmProvider } from "@uncaged/workflow-runtime";
|
||||||
import type * as z from "zod/v4";
|
import type * as z from "zod/v4";
|
||||||
import { type CasStore, getContentMerklePayload } from "../cas/index.js";
|
import { type CasStore, getContentMerklePayload } from "../cas/index.js";
|
||||||
import { reactExtract } from "./react-extract.js";
|
import { createLlmFn, createThreadReactor } from "../reactor/index.js";
|
||||||
|
import { extractFunctionToolFromZodSchema } from "./llm-extract.js";
|
||||||
|
|
||||||
export type ExtractDeps = {
|
export type ExtractDeps = {
|
||||||
cas: CasStore;
|
cas: CasStore;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const MAX_REACT_ROUNDS = 10;
|
||||||
|
|
||||||
|
const CAS_GET_TOOL_DEFINITION = {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: "cas_get",
|
||||||
|
description:
|
||||||
|
"Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
hash: { type: "string", description: "The CAS hash to retrieve" },
|
||||||
|
},
|
||||||
|
required: ["hash"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ExtractThreadContext = {
|
||||||
|
cas: CasStore;
|
||||||
|
};
|
||||||
|
|
||||||
|
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||||
|
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||||
|
}
|
||||||
|
|
||||||
/** Builds the user-side extraction prompt (thread + agent output + instruction). */
|
/** Builds the user-side extraction prompt (thread + agent output + instruction). */
|
||||||
export async function buildExtractUserContent(
|
export async function buildExtractUserContent(
|
||||||
ctx: ExtractContext,
|
ctx: ExtractContext,
|
||||||
@@ -46,17 +73,61 @@ export async function buildExtractUserContent(
|
|||||||
* Create an ExtractFn backed by an LLM provider.
|
* Create an ExtractFn backed by an LLM provider.
|
||||||
*
|
*
|
||||||
* Internally runs a multi-turn ReAct loop with two tools (`cas_get` for traversing the
|
* Internally runs a multi-turn ReAct loop with two tools (`cas_get` for traversing the
|
||||||
* Merkle DAG and a schema-shaped `extract` tool); the loop also accepts a plain-JSON
|
* Merkle DAG and a schema-shaped extract tool); the loop also accepts a plain-JSON
|
||||||
* assistant reply as a short-circuit, which covers the legacy "single" extraction path.
|
* assistant reply as a short-circuit, which covers the legacy "single" extraction path.
|
||||||
*/
|
*/
|
||||||
export function createExtract(provider: LlmProvider, deps: ExtractDeps): ExtractFn {
|
export function createExtract(provider: LlmProvider, deps: ExtractDeps): ExtractFn {
|
||||||
|
const llm = createLlmFn(provider);
|
||||||
|
const reactor = createThreadReactor<ExtractThreadContext>({
|
||||||
|
llm,
|
||||||
|
maxRounds: MAX_REACT_ROUNDS,
|
||||||
|
staticTools: [CAS_GET_TOOL_DEFINITION],
|
||||||
|
structuredToolFromSchema: (schema) => {
|
||||||
|
const t = extractFunctionToolFromZodSchema(schema);
|
||||||
|
return {
|
||||||
|
name: t.name,
|
||||||
|
tool: {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
systemPromptForStructuredTool: (structuredToolName) =>
|
||||||
|
`You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${structuredToolName} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`,
|
||||||
|
toolHandler: async (call, thread) => {
|
||||||
|
if (call.function.name !== "cas_get") {
|
||||||
|
return `Unexpected tool routed to handler: ${call.function.name}`;
|
||||||
|
}
|
||||||
|
let hash: string;
|
||||||
|
try {
|
||||||
|
const ta = JSON.parse(call.function.arguments) as unknown;
|
||||||
|
if (!isRecord(ta) || typeof ta.hash !== "string") {
|
||||||
|
return 'cas_get requires a JSON object with a string "hash" field.';
|
||||||
|
}
|
||||||
|
hash = ta.hash;
|
||||||
|
} catch {
|
||||||
|
return 'cas_get arguments were not valid JSON. Provide {"hash": "<cas-hash>"}.';
|
||||||
|
}
|
||||||
|
const blob = await thread.cas.get(hash);
|
||||||
|
return blob === null ? "null" : blob;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return async <T extends Record<string, unknown>>(
|
return async <T extends Record<string, unknown>>(
|
||||||
schema: z.ZodType<T>,
|
schema: z.ZodType<T>,
|
||||||
prompt: string,
|
prompt: string,
|
||||||
ctx: ExtractContext,
|
ctx: ExtractContext,
|
||||||
): Promise<T> => {
|
): Promise<T> => {
|
||||||
const text = await buildExtractUserContent(ctx, prompt, deps);
|
const text = await buildExtractUserContent(ctx, prompt, deps);
|
||||||
const result = await reactExtract({ text, schema, provider, cas: deps.cas });
|
const result = await reactor({
|
||||||
|
thread: { cas: deps.cas },
|
||||||
|
input: text,
|
||||||
|
schema,
|
||||||
|
});
|
||||||
if (!result.ok) {
|
if (!result.ok) {
|
||||||
throw new Error(`extract failed: ${result.error}`);
|
throw new Error(`extract failed: ${result.error}`);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
export {
|
export {
|
||||||
buildExtractUserContent,
|
buildExtractUserContent,
|
||||||
createExtract,
|
createExtract,
|
||||||
|
type ExtractThreadContext,
|
||||||
} from "./extract-fn.js";
|
} from "./extract-fn.js";
|
||||||
export {
|
export {
|
||||||
extractFunctionToolFromZodSchema,
|
extractFunctionToolFromZodSchema,
|
||||||
llmErrorToCause,
|
llmErrorToCause,
|
||||||
llmExtract,
|
llmExtract,
|
||||||
} from "./llm-extract.js";
|
} from "./llm-extract.js";
|
||||||
export { reactExtract } from "./react-extract.js";
|
export type { ExtractFn, LlmError, LlmExtractArgs } from "./types.js";
|
||||||
export type {
|
|
||||||
ExtractFn,
|
|
||||||
LlmError,
|
|
||||||
LlmExtractArgs,
|
|
||||||
ReactExtractArgs,
|
|
||||||
} from "./types.js";
|
|
||||||
|
|||||||
@@ -1,343 +0,0 @@
|
|||||||
import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime";
|
|
||||||
import type * as z from "zod/v4";
|
|
||||||
import { err, ok, type Result } from "../util/index.js";
|
|
||||||
|
|
||||||
import { extractFunctionToolFromZodSchema } from "./llm-extract.js";
|
|
||||||
import type { ReactExtractArgs } from "./types.js";
|
|
||||||
|
|
||||||
const MAX_REACT_ROUNDS = 10;
|
|
||||||
|
|
||||||
const CAS_GET_TOOL_DEFINITION = {
|
|
||||||
type: "function" as const,
|
|
||||||
function: {
|
|
||||||
name: "cas_get",
|
|
||||||
description:
|
|
||||||
"Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
hash: { type: "string", description: "The CAS hash to retrieve" },
|
|
||||||
},
|
|
||||||
required: ["hash"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
function chatCompletionsUrl(baseUrl: string): string {
|
|
||||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
|
||||||
return `${trimmed}/chat/completions`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
|
||||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
function tryParseJsonContent(content: string): unknown | null {
|
|
||||||
const trimmed = content.trim();
|
|
||||||
const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed);
|
|
||||||
const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed;
|
|
||||||
try {
|
|
||||||
return JSON.parse(payload) as unknown;
|
|
||||||
} catch {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall = {
|
|
||||||
id: string;
|
|
||||||
type: "function";
|
|
||||||
function: { name: string; arguments: string };
|
|
||||||
};
|
|
||||||
|
|
||||||
type ChatMessage =
|
|
||||||
| { role: "system"; content: string }
|
|
||||||
| { role: "user"; content: string }
|
|
||||||
| {
|
|
||||||
role: "assistant";
|
|
||||||
content: string | null;
|
|
||||||
tool_calls: ToolCall[];
|
|
||||||
}
|
|
||||||
| { role: "assistant"; content: string }
|
|
||||||
| { role: "tool"; tool_call_id: string; content: string };
|
|
||||||
|
|
||||||
type AssistantTurn<T> =
|
|
||||||
| { kind: "plain_json"; value: T }
|
|
||||||
| { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null };
|
|
||||||
|
|
||||||
function firstAssistantMessage(responseText: string): Result<Record<string, unknown>, string> {
|
|
||||||
let parsed: unknown;
|
|
||||||
try {
|
|
||||||
parsed = JSON.parse(responseText) as unknown;
|
|
||||||
} catch (cause) {
|
|
||||||
const message = cause instanceof Error ? cause.message : String(cause);
|
|
||||||
return err(`invalid_response_json:${message}`);
|
|
||||||
}
|
|
||||||
if (!isRecord(parsed)) {
|
|
||||||
return err("invalid_response_top_level");
|
|
||||||
}
|
|
||||||
const choices = parsed.choices;
|
|
||||||
if (!Array.isArray(choices) || choices.length === 0) {
|
|
||||||
return err("no_choices_in_response");
|
|
||||||
}
|
|
||||||
const firstChoice = choices[0];
|
|
||||||
if (!isRecord(firstChoice)) {
|
|
||||||
return err("invalid_choice");
|
|
||||||
}
|
|
||||||
const messageObj = firstChoice.message;
|
|
||||||
if (!isRecord(messageObj)) {
|
|
||||||
return err("invalid_message");
|
|
||||||
}
|
|
||||||
return ok(messageObj);
|
|
||||||
}
|
|
||||||
|
|
||||||
function normalizeToolCalls(toolCallsRaw: unknown[]): Result<ToolCall[], string> {
|
|
||||||
const toolCalls: ToolCall[] = [];
|
|
||||||
for (const tc of toolCallsRaw) {
|
|
||||||
if (!isRecord(tc)) {
|
|
||||||
return err("invalid_tool_call");
|
|
||||||
}
|
|
||||||
const id = tc.id;
|
|
||||||
const tcType = tc.type;
|
|
||||||
const fn = tc.function;
|
|
||||||
if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) {
|
|
||||||
return err("invalid_tool_call_shape");
|
|
||||||
}
|
|
||||||
const name = fn.name;
|
|
||||||
const argumentsStr = fn.arguments;
|
|
||||||
if (typeof name !== "string" || typeof argumentsStr !== "string") {
|
|
||||||
return err("invalid_tool_call_function");
|
|
||||||
}
|
|
||||||
toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } });
|
|
||||||
}
|
|
||||||
return ok(toolCalls);
|
|
||||||
}
|
|
||||||
|
|
||||||
type AssistantTurnOrCorrection<T extends Record<string, unknown>> =
|
|
||||||
| AssistantTurn<T>
|
|
||||||
| { kind: "plain_json_invalid"; rawContent: string; correction: string };
|
|
||||||
|
|
||||||
function classifyAssistantTurn<T extends Record<string, unknown>>(
|
|
||||||
messageObj: Record<string, unknown>,
|
|
||||||
schema: z.ZodType<T>,
|
|
||||||
): Result<AssistantTurnOrCorrection<T>, string> {
|
|
||||||
const toolCallsRaw = messageObj.tool_calls;
|
|
||||||
if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) {
|
|
||||||
const content = messageObj.content;
|
|
||||||
if (typeof content !== "string") {
|
|
||||||
return err("no_tool_calls_and_no_string_content");
|
|
||||||
}
|
|
||||||
const jsonParsed = tryParseJsonContent(content);
|
|
||||||
if (jsonParsed === null) {
|
|
||||||
return ok({
|
|
||||||
kind: "plain_json_invalid",
|
|
||||||
rawContent: content,
|
|
||||||
correction:
|
|
||||||
"Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the extract tool with the structured arguments.",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
const validated = schema.safeParse(jsonParsed);
|
|
||||||
if (!validated.success) {
|
|
||||||
return ok({
|
|
||||||
kind: "plain_json_invalid",
|
|
||||||
rawContent: content,
|
|
||||||
correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the extract tool with the structured arguments.`,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return ok({ kind: "plain_json", value: validated.data });
|
|
||||||
}
|
|
||||||
const callsResult = normalizeToolCalls(toolCallsRaw);
|
|
||||||
if (!callsResult.ok) {
|
|
||||||
return err(callsResult.error);
|
|
||||||
}
|
|
||||||
const assistantContent = messageObj.content;
|
|
||||||
return ok({
|
|
||||||
kind: "tool_calls",
|
|
||||||
calls: callsResult.value,
|
|
||||||
assistantContent: typeof assistantContent === "string" ? assistantContent : null,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function appendCasGetToolResult(
|
|
||||||
tc: ToolCall,
|
|
||||||
cas: CasStore,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
): Promise<Result<null, string>> {
|
|
||||||
let hash: string;
|
|
||||||
try {
|
|
||||||
const ta = JSON.parse(tc.function.arguments) as unknown;
|
|
||||||
if (!isRecord(ta) || typeof ta.hash !== "string") {
|
|
||||||
return err("cas_get_invalid_arguments");
|
|
||||||
}
|
|
||||||
hash = ta.hash;
|
|
||||||
} catch {
|
|
||||||
return err("cas_get_arguments_not_json");
|
|
||||||
}
|
|
||||||
const blob = await cas.get(hash);
|
|
||||||
const toolContent = blob === null ? "null" : blob;
|
|
||||||
messages.push({
|
|
||||||
role: "tool",
|
|
||||||
tool_call_id: tc.id,
|
|
||||||
content: toolContent,
|
|
||||||
});
|
|
||||||
return ok(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function appendExtractToolResult<T extends Record<string, unknown>>(
|
|
||||||
tc: ToolCall,
|
|
||||||
schema: z.ZodType<T>,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
): Promise<Result<T, string>> {
|
|
||||||
let parsedArgs: unknown;
|
|
||||||
try {
|
|
||||||
parsedArgs = JSON.parse(tc.function.arguments) as unknown;
|
|
||||||
} catch {
|
|
||||||
return err("extract_tool_arguments_not_json");
|
|
||||||
}
|
|
||||||
const validated = schema.safeParse(parsedArgs);
|
|
||||||
if (!validated.success) {
|
|
||||||
return err(`schema_validation_failed:${validated.error.message}`);
|
|
||||||
}
|
|
||||||
messages.push({
|
|
||||||
role: "tool",
|
|
||||||
tool_call_id: tc.id,
|
|
||||||
content: '{"ok":true}',
|
|
||||||
});
|
|
||||||
return ok(validated.data);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function appendToolResults<T extends Record<string, unknown>>(
|
|
||||||
toolCalls: ToolCall[],
|
|
||||||
extractToolName: string,
|
|
||||||
schema: z.ZodType<T>,
|
|
||||||
cas: CasStore,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
): Promise<Result<T | null, string>> {
|
|
||||||
let extracted: T | null = null;
|
|
||||||
for (const tc of toolCalls) {
|
|
||||||
if (tc.function.name === "cas_get") {
|
|
||||||
const casRes = await appendCasGetToolResult(tc, cas, messages);
|
|
||||||
if (!casRes.ok) {
|
|
||||||
return casRes;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (tc.function.name === extractToolName) {
|
|
||||||
const exRes = await appendExtractToolResult(tc, schema, messages);
|
|
||||||
if (!exRes.ok) {
|
|
||||||
return exRes;
|
|
||||||
}
|
|
||||||
extracted = exRes.value;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return err(`unknown_tool:${tc.function.name}`);
|
|
||||||
}
|
|
||||||
return ok(extracted);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function postChatCompletion(
|
|
||||||
provider: LlmProvider,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
tools: readonly Record<string, unknown>[],
|
|
||||||
): Promise<Result<string, string>> {
|
|
||||||
try {
|
|
||||||
const response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${provider.apiKey}`,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
model: provider.model,
|
|
||||||
messages,
|
|
||||||
tools,
|
|
||||||
tool_choice: "auto",
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
const responseText = await response.text();
|
|
||||||
if (!response.ok) {
|
|
||||||
return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`);
|
|
||||||
}
|
|
||||||
return ok(responseText);
|
|
||||||
} catch (cause) {
|
|
||||||
const message = cause instanceof Error ? cause.message : String(cause);
|
|
||||||
return err(`network_error:${message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Multi-turn ReAct extraction with `cas_get` plus a schema-shaped extract tool (OpenAI-compatible).
|
|
||||||
* Final meta comes from a successful extract tool call or from plain JSON in the assistant message.
|
|
||||||
*/
|
|
||||||
export async function reactExtract<T extends Record<string, unknown>>(
|
|
||||||
args: ReactExtractArgs<T>,
|
|
||||||
): Promise<Result<T, string>> {
|
|
||||||
const extractTool = extractFunctionToolFromZodSchema(args.schema);
|
|
||||||
const tools = [
|
|
||||||
CAS_GET_TOOL_DEFINITION,
|
|
||||||
{
|
|
||||||
type: "function" as const,
|
|
||||||
function: {
|
|
||||||
name: extractTool.name,
|
|
||||||
description: extractTool.description,
|
|
||||||
parameters: extractTool.parameters,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
const systemContent = `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${extractTool.name} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`;
|
|
||||||
|
|
||||||
const messages: ChatMessage[] = [
|
|
||||||
{ role: "system", content: systemContent },
|
|
||||||
{ role: "user", content: args.text },
|
|
||||||
];
|
|
||||||
|
|
||||||
for (let round = 0; round < MAX_REACT_ROUNDS; round++) {
|
|
||||||
const bodyResult = await postChatCompletion(args.provider, messages, tools);
|
|
||||||
if (!bodyResult.ok) {
|
|
||||||
return bodyResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgResult = firstAssistantMessage(bodyResult.value);
|
|
||||||
if (!msgResult.ok) {
|
|
||||||
return msgResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
const classified = classifyAssistantTurn(msgResult.value, args.schema);
|
|
||||||
if (!classified.ok) {
|
|
||||||
return classified;
|
|
||||||
}
|
|
||||||
|
|
||||||
const turn = classified.value;
|
|
||||||
if (turn.kind === "plain_json") {
|
|
||||||
return ok(turn.value);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (turn.kind === "plain_json_invalid") {
|
|
||||||
messages.push({ role: "assistant", content: turn.rawContent });
|
|
||||||
messages.push({ role: "user", content: turn.correction });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.push({
|
|
||||||
role: "assistant",
|
|
||||||
content: turn.assistantContent,
|
|
||||||
tool_calls: turn.calls,
|
|
||||||
});
|
|
||||||
|
|
||||||
const toolsRound = await appendToolResults(
|
|
||||||
turn.calls,
|
|
||||||
extractTool.name,
|
|
||||||
args.schema,
|
|
||||||
args.cas,
|
|
||||||
messages,
|
|
||||||
);
|
|
||||||
if (!toolsRound.ok) {
|
|
||||||
return toolsRound;
|
|
||||||
}
|
|
||||||
if (toolsRound.value !== null) {
|
|
||||||
return ok(toolsRound.value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err("max_react_rounds_exceeded");
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,8 @@
|
|||||||
import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime";
|
import type { LlmProvider } from "@uncaged/workflow-runtime";
|
||||||
import type * as z from "zod/v4";
|
import type * as z from "zod/v4";
|
||||||
|
|
||||||
export type { ExtractFn } from "@uncaged/workflow-runtime";
|
export type { ExtractFn } from "@uncaged/workflow-runtime";
|
||||||
|
|
||||||
export type ReactExtractArgs<T extends Record<string, unknown>> = {
|
|
||||||
text: string;
|
|
||||||
schema: z.ZodType<T>;
|
|
||||||
provider: LlmProvider;
|
|
||||||
cas: CasStore;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type LlmExtractArgs<T> = {
|
export type LlmExtractArgs<T> = {
|
||||||
text: string;
|
text: string;
|
||||||
schema: z.ZodType<T>;
|
schema: z.ZodType<T>;
|
||||||
|
|||||||
@@ -56,12 +56,23 @@ export {
|
|||||||
export {
|
export {
|
||||||
createExtract,
|
createExtract,
|
||||||
type ExtractFn,
|
type ExtractFn,
|
||||||
|
type ExtractThreadContext,
|
||||||
type LlmError,
|
type LlmError,
|
||||||
llmErrorToCause,
|
llmErrorToCause,
|
||||||
llmExtract,
|
llmExtract,
|
||||||
type ReactExtractArgs,
|
|
||||||
reactExtract,
|
|
||||||
} from "./extract/index.js";
|
} from "./extract/index.js";
|
||||||
|
export {
|
||||||
|
type ChatMessage,
|
||||||
|
createLlmFn,
|
||||||
|
createThreadReactor,
|
||||||
|
type LlmFn,
|
||||||
|
type StructuredToolSpec,
|
||||||
|
type ThreadReactorConfig,
|
||||||
|
type ThreadReactorFn,
|
||||||
|
type ThreadReactorInvokeArgs,
|
||||||
|
type ToolCall,
|
||||||
|
type ToolDefinition,
|
||||||
|
} from "./reactor/index.js";
|
||||||
export {
|
export {
|
||||||
getRegisteredWorkflow,
|
getRegisteredWorkflow,
|
||||||
listRegisteredWorkflowNames,
|
listRegisteredWorkflowNames,
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
export { createLlmFn } from "./llm-fn.js";
|
||||||
|
export { createThreadReactor } from "./thread-reactor.js";
|
||||||
|
export type {
|
||||||
|
ChatMessage,
|
||||||
|
LlmFn,
|
||||||
|
StructuredToolSpec,
|
||||||
|
ThreadReactorConfig,
|
||||||
|
ThreadReactorFn,
|
||||||
|
ThreadReactorInvokeArgs,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
} from "./types.js";
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
import type { LlmProvider } from "@uncaged/workflow-runtime";
|
||||||
|
|
||||||
|
import { err, ok } from "../util/index.js";
|
||||||
|
|
||||||
|
import type { ChatMessage, LlmFn, ToolDefinition } from "./types.js";
|
||||||
|
|
||||||
|
function chatCompletionsUrl(baseUrl: string): string {
|
||||||
|
const trimmed = baseUrl.replace(/\/+$/, "");
|
||||||
|
return `${trimmed}/chat/completions`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps provider credentials into an {@link LlmFn}: single POST to chat/completions,
|
||||||
|
* returns raw JSON body text or a {@link Result} error. Callers parse assistant messages.
|
||||||
|
*/
|
||||||
|
export function createLlmFn(provider: LlmProvider): LlmFn {
|
||||||
|
return async ({
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
}: {
|
||||||
|
messages: ChatMessage[];
|
||||||
|
tools: readonly ToolDefinition[];
|
||||||
|
}) => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${provider.apiKey}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: provider.model,
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
tool_choice: "auto",
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
const responseText = await response.text();
|
||||||
|
if (!response.ok) {
|
||||||
|
return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`);
|
||||||
|
}
|
||||||
|
return ok(responseText);
|
||||||
|
} catch (cause) {
|
||||||
|
const message = cause instanceof Error ? cause.message : String(cause);
|
||||||
|
return err(`network_error:${message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,317 @@
|
|||||||
|
import type * as z from "zod/v4";
|
||||||
|
|
||||||
|
import { err, ok, type Result } from "../util/index.js";
|
||||||
|
|
||||||
|
import type {
|
||||||
|
ChatMessage,
|
||||||
|
StructuredToolSpec,
|
||||||
|
ThreadReactorConfig,
|
||||||
|
ThreadReactorFn,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
} from "./types.js";
|
||||||
|
|
||||||
|
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||||
|
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
function tryParseJsonContent(content: string): unknown | null {
|
||||||
|
const trimmed = content.trim();
|
||||||
|
const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed);
|
||||||
|
const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed;
|
||||||
|
try {
|
||||||
|
return JSON.parse(payload) as unknown;
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function firstAssistantMessage(responseText: string): Result<Record<string, unknown>, string> {
|
||||||
|
let parsed: unknown;
|
||||||
|
try {
|
||||||
|
parsed = JSON.parse(responseText) as unknown;
|
||||||
|
} catch (cause) {
|
||||||
|
const message = cause instanceof Error ? cause.message : String(cause);
|
||||||
|
return err(`invalid_response_json:${message}`);
|
||||||
|
}
|
||||||
|
if (!isRecord(parsed)) {
|
||||||
|
return err("invalid_response_top_level");
|
||||||
|
}
|
||||||
|
const choices = parsed.choices;
|
||||||
|
if (!Array.isArray(choices) || choices.length === 0) {
|
||||||
|
return err("no_choices_in_response");
|
||||||
|
}
|
||||||
|
const firstChoice = choices[0];
|
||||||
|
if (!isRecord(firstChoice)) {
|
||||||
|
return err("invalid_choice");
|
||||||
|
}
|
||||||
|
const messageObj = firstChoice.message;
|
||||||
|
if (!isRecord(messageObj)) {
|
||||||
|
return err("invalid_message");
|
||||||
|
}
|
||||||
|
return ok(messageObj);
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeToolCalls(toolCallsRaw: unknown[]): Result<ToolCall[], string> {
|
||||||
|
const toolCalls: ToolCall[] = [];
|
||||||
|
for (const tc of toolCallsRaw) {
|
||||||
|
if (!isRecord(tc)) {
|
||||||
|
return err("invalid_tool_call");
|
||||||
|
}
|
||||||
|
const id = tc.id;
|
||||||
|
const tcType = tc.type;
|
||||||
|
const fn = tc.function;
|
||||||
|
if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) {
|
||||||
|
return err("invalid_tool_call_shape");
|
||||||
|
}
|
||||||
|
const name = fn.name;
|
||||||
|
const argumentsStr = fn.arguments;
|
||||||
|
if (typeof name !== "string" || typeof argumentsStr !== "string") {
|
||||||
|
return err("invalid_tool_call_function");
|
||||||
|
}
|
||||||
|
toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } });
|
||||||
|
}
|
||||||
|
return ok(toolCalls);
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantTurn<T> =
|
||||||
|
| { kind: "plain_json"; value: T }
|
||||||
|
| { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null };
|
||||||
|
|
||||||
|
type AssistantTurnOrCorrection<T> =
|
||||||
|
| AssistantTurn<T>
|
||||||
|
| { kind: "plain_json_invalid"; rawContent: string; correction: string };
|
||||||
|
|
||||||
|
function classifyAssistantTurn<T>(
|
||||||
|
messageObj: Record<string, unknown>,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
structuredToolName: string,
|
||||||
|
): Result<AssistantTurnOrCorrection<T>, string> {
|
||||||
|
const toolCallsRaw = messageObj.tool_calls;
|
||||||
|
if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) {
|
||||||
|
const content = messageObj.content;
|
||||||
|
if (typeof content !== "string") {
|
||||||
|
return err("no_tool_calls_and_no_string_content");
|
||||||
|
}
|
||||||
|
const jsonParsed = tryParseJsonContent(content);
|
||||||
|
if (jsonParsed === null) {
|
||||||
|
return ok({
|
||||||
|
kind: "plain_json_invalid",
|
||||||
|
rawContent: content,
|
||||||
|
correction: `Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const validated = schema.safeParse(jsonParsed);
|
||||||
|
if (!validated.success) {
|
||||||
|
return ok({
|
||||||
|
kind: "plain_json_invalid",
|
||||||
|
rawContent: content,
|
||||||
|
correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return ok({ kind: "plain_json", value: validated.data });
|
||||||
|
}
|
||||||
|
const callsResult = normalizeToolCalls(toolCallsRaw);
|
||||||
|
if (!callsResult.ok) {
|
||||||
|
return err(callsResult.error);
|
||||||
|
}
|
||||||
|
const assistantContent = messageObj.content;
|
||||||
|
return ok({
|
||||||
|
kind: "tool_calls",
|
||||||
|
calls: callsResult.value,
|
||||||
|
assistantContent: typeof assistantContent === "string" ? assistantContent : null,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function toolNamesFromDefinitions(tools: readonly { function: { name: string } }[]): Set<string> {
|
||||||
|
return new Set(tools.map((t) => t.function.name));
|
||||||
|
}
|
||||||
|
|
||||||
|
function appendStructuredToolResult<T>(
|
||||||
|
tc: ToolCall,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): T | null {
|
||||||
|
let parsedArgs: unknown;
|
||||||
|
try {
|
||||||
|
parsedArgs = JSON.parse(tc.function.arguments) as unknown;
|
||||||
|
} catch {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content:
|
||||||
|
"Tool arguments were not valid JSON. Provide valid JSON object arguments matching the schema.",
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const validated = schema.safeParse(parsedArgs);
|
||||||
|
if (!validated.success) {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: `Schema validation failed: ${validated.error.message}. Fix the arguments and call the tool again with a JSON object that matches the schema.`,
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: '{"ok":true}',
|
||||||
|
});
|
||||||
|
return validated.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function dispatchToolCall<T, TThread>(
|
||||||
|
tc: ToolCall,
|
||||||
|
spec: StructuredToolSpec,
|
||||||
|
knownNames: Set<string>,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
thread: TThread,
|
||||||
|
toolHandler: ThreadReactorConfig<TThread>["toolHandler"],
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): Promise<T | null> {
|
||||||
|
if (!knownNames.has(tc.function.name)) {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: `Unknown tool: ${tc.function.name}. Use one of the declared tools only.`,
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (tc.function.name === spec.name) {
|
||||||
|
return appendStructuredToolResult(tc, schema, messages);
|
||||||
|
}
|
||||||
|
let toolContent: string;
|
||||||
|
try {
|
||||||
|
toolContent = await toolHandler(tc, thread);
|
||||||
|
} catch (cause) {
|
||||||
|
const message = cause instanceof Error ? cause.message : String(cause);
|
||||||
|
toolContent = `Tool execution failed: ${message}`;
|
||||||
|
}
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: toolContent,
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function resolveToolCallRound<T, TThread>(
|
||||||
|
turn: Extract<AssistantTurn<T>, { kind: "tool_calls" }>,
|
||||||
|
spec: StructuredToolSpec,
|
||||||
|
knownNames: Set<string>,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
thread: TThread,
|
||||||
|
toolHandler: ThreadReactorConfig<TThread>["toolHandler"],
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): Promise<Result<T, string> | null> {
|
||||||
|
messages.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: turn.assistantContent,
|
||||||
|
tool_calls: turn.calls,
|
||||||
|
});
|
||||||
|
let extractedRound: T | null = null;
|
||||||
|
for (const tc of turn.calls) {
|
||||||
|
const extracted = await dispatchToolCall(
|
||||||
|
tc,
|
||||||
|
spec,
|
||||||
|
knownNames,
|
||||||
|
schema,
|
||||||
|
thread,
|
||||||
|
toolHandler,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
if (extracted !== null) {
|
||||||
|
extractedRound = extracted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (extractedRound !== null) {
|
||||||
|
return ok(extractedRound);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function runOneReactRound<T, TThread>(
|
||||||
|
config: ThreadReactorConfig<TThread>,
|
||||||
|
args: { thread: TThread; schema: z.ZodType<T> },
|
||||||
|
tools: readonly ToolDefinition[],
|
||||||
|
knownNames: Set<string>,
|
||||||
|
spec: StructuredToolSpec,
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): Promise<Result<T, string> | null> {
|
||||||
|
const bodyResult = await config.llm({ messages, tools });
|
||||||
|
if (!bodyResult.ok) {
|
||||||
|
return bodyResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
const msgResult = firstAssistantMessage(bodyResult.value);
|
||||||
|
if (!msgResult.ok) {
|
||||||
|
return msgResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
const classified = classifyAssistantTurn(msgResult.value, args.schema, spec.name);
|
||||||
|
if (!classified.ok) {
|
||||||
|
return classified;
|
||||||
|
}
|
||||||
|
|
||||||
|
const turn = classified.value;
|
||||||
|
if (turn.kind === "plain_json") {
|
||||||
|
return ok(turn.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (turn.kind === "plain_json_invalid") {
|
||||||
|
messages.push({ role: "assistant", content: turn.rawContent });
|
||||||
|
messages.push({ role: "user", content: turn.correction });
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolveToolCallRound(
|
||||||
|
turn,
|
||||||
|
spec,
|
||||||
|
knownNames,
|
||||||
|
args.schema,
|
||||||
|
args.thread,
|
||||||
|
config.toolHandler,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generic ReAct loop: LLM round-trips with tools until structured output validates,
|
||||||
|
* plain JSON matches schema, or {@link ThreadReactorConfig.maxRounds} is exceeded.
|
||||||
|
*/
|
||||||
|
export function createThreadReactor<TThread>(
|
||||||
|
config: ThreadReactorConfig<TThread>,
|
||||||
|
): ThreadReactorFn<TThread> {
|
||||||
|
return async <T>(args: {
|
||||||
|
thread: TThread;
|
||||||
|
input: string;
|
||||||
|
schema: z.ZodType<T>;
|
||||||
|
}): Promise<Result<T, string>> => {
|
||||||
|
const spec = config.structuredToolFromSchema(args.schema);
|
||||||
|
const tools = [...config.staticTools, spec.tool];
|
||||||
|
const knownNames = toolNamesFromDefinitions(tools);
|
||||||
|
const systemPrompt = config.systemPromptForStructuredTool(spec.name);
|
||||||
|
|
||||||
|
const messages: ChatMessage[] = [
|
||||||
|
{ role: "system", content: systemPrompt },
|
||||||
|
{ role: "user", content: args.input },
|
||||||
|
];
|
||||||
|
|
||||||
|
for (let round = 0; round < config.maxRounds; round++) {
|
||||||
|
const step = await runOneReactRound(
|
||||||
|
config,
|
||||||
|
{ thread: args.thread, schema: args.schema },
|
||||||
|
tools,
|
||||||
|
knownNames,
|
||||||
|
spec,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
if (step !== null) {
|
||||||
|
return step;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err("max_react_rounds_exceeded");
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
import type * as z from "zod/v4";
|
||||||
|
|
||||||
|
import type { Result } from "../util/index.js";
|
||||||
|
|
||||||
|
export type ToolCall = {
|
||||||
|
id: string;
|
||||||
|
type: "function";
|
||||||
|
function: { name: string; arguments: string };
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ToolDefinition = {
|
||||||
|
type: "function";
|
||||||
|
function: {
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
parameters: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ChatMessage =
|
||||||
|
| { role: "system"; content: string }
|
||||||
|
| { role: "user"; content: string }
|
||||||
|
| {
|
||||||
|
role: "assistant";
|
||||||
|
content: string | null;
|
||||||
|
tool_calls: ToolCall[];
|
||||||
|
}
|
||||||
|
| { role: "assistant"; content: string }
|
||||||
|
| { role: "tool"; tool_call_id: string; content: string };
|
||||||
|
|
||||||
|
export type LlmFn = (input: {
|
||||||
|
messages: ChatMessage[];
|
||||||
|
tools: readonly ToolDefinition[];
|
||||||
|
}) => Promise<Result<string, string>>;
|
||||||
|
|
||||||
|
/** Structured tool derived from the per-invocation Zod schema (e.g. extract tool). */
|
||||||
|
export type StructuredToolSpec = {
|
||||||
|
name: string;
|
||||||
|
tool: ToolDefinition;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadReactorConfig<TThread> = {
|
||||||
|
llm: LlmFn;
|
||||||
|
/** Static tools (e.g. cas_get); structured tool is appended per invocation. */
|
||||||
|
staticTools: readonly ToolDefinition[];
|
||||||
|
/** Builds the schema-shaped tool and its OpenAI name for this invocation. */
|
||||||
|
structuredToolFromSchema: (schema: z.ZodType<unknown>) => StructuredToolSpec;
|
||||||
|
/** System prompt for this run; include the structured tool name for cache stability per schema. */
|
||||||
|
systemPromptForStructuredTool: (structuredToolName: string) => string;
|
||||||
|
toolHandler: (call: ToolCall, thread: TThread) => Promise<string>;
|
||||||
|
maxRounds: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadReactorInvokeArgs<TThread, T> = {
|
||||||
|
thread: TThread;
|
||||||
|
input: string;
|
||||||
|
schema: z.ZodType<T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadReactorFn<TThread> = <T>(
|
||||||
|
args: ThreadReactorInvokeArgs<TThread, T>,
|
||||||
|
) => Promise<Result<T, string>>;
|
||||||
Reference in New Issue
Block a user