From 43a6600378407096f6dcce5fb815d1eae49ae9ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E6=A9=98?= Date: Thu, 7 May 2026 13:28:00 +0000 Subject: [PATCH] feat: ReAct ExtractFn with tool-use - RoleDefinition.extractMode: "single" | "react" - reactExtract: multi-turn LLM with cas_get tool for DAG traversal - Max 10 tool-call rounds, schema validation on final output - create-workflow routes to reactExtract when extractMode is "react" - All existing roles set to "single" (no behavior change) - 162 tests passing Fixes #44 --- examples/hello-world.ts | 2 + packages/workflow-role-coder/src/coder.ts | 1 + .../workflow-role-committer/src/committer.ts | 1 + packages/workflow-role-planner/src/planner.ts | 1 + .../workflow-role-preparer/src/preparer.ts | 1 + .../workflow-role-reviewer/src/reviewer.ts | 1 + .../__tests__/solve-issue-template.test.ts | 3 +- .../src/index.ts | 9 +- .../__tests__/build-descriptor.test.ts | 1 + packages/workflow/__tests__/engine.test.ts | 170 ++++++++- .../workflow/__tests__/react-extract.test.ts | 209 +++++++++++ .../workflow/__tests__/refs-tracking.test.ts | 2 + .../workflow-as-agent-integration.test.ts | 2 + packages/workflow/src/create-workflow.ts | 54 ++- packages/workflow/src/extract-fn.ts | 62 ++-- packages/workflow/src/index.ts | 2 + packages/workflow/src/llm-extract.ts | 28 +- packages/workflow/src/react-extract.ts | 330 ++++++++++++++++++ packages/workflow/src/types.ts | 4 + 19 files changed, 838 insertions(+), 45 deletions(-) create mode 100644 packages/workflow/__tests__/react-extract.test.ts create mode 100644 packages/workflow/src/react-extract.ts diff --git a/examples/hello-world.ts b/examples/hello-world.ts index 590b963..5991056 100644 --- a/examples/hello-world.ts +++ b/examples/hello-world.ts @@ -29,6 +29,7 @@ const greeter: RoleDefinition = { extractPrompt: "Extract the greeting string produced for the user.", schema: greeterMetaSchema, extractRefs: null, + extractMode: "single", }; const extract = createExtract({ @@ -48,4 +49,5 @@ export const run = createWorkflow( agent: async (ctx) => `Hello, ${ctx.start.content}`, }, extract, + null, ); diff --git a/packages/workflow-role-coder/src/coder.ts b/packages/workflow-role-coder/src/coder.ts index b5f6bac..4a40897 100644 --- a/packages/workflow-role-coder/src/coder.ts +++ b/packages/workflow-role-coder/src/coder.ts @@ -39,4 +39,5 @@ export const coderRole: RoleDefinition = { "Extract completedPhase: the planner phase hash finished this round (exact hash string from the plan). If multiple phases were finished in one round, use the last finished phase hash. Extract filesChanged and a summary of the work.", schema: coderMetaSchema, extractRefs: (meta) => [meta.completedPhase], + extractMode: "single", }; diff --git a/packages/workflow-role-committer/src/committer.ts b/packages/workflow-role-committer/src/committer.ts index c1d5562..061fa9e 100644 --- a/packages/workflow-role-committer/src/committer.ts +++ b/packages/workflow-role-committer/src/committer.ts @@ -32,4 +32,5 @@ export const committerRole: RoleDefinition = { "Extract the commit result: committed (with branch and SHA), recoverable failure, or unrecoverable failure. Include error details and log references if applicable.", schema: committerMetaSchema, extractRefs: null, + extractMode: "single", }; diff --git a/packages/workflow-role-planner/src/planner.ts b/packages/workflow-role-planner/src/planner.ts index eb5101b..f77faf1 100644 --- a/packages/workflow-role-planner/src/planner.ts +++ b/packages/workflow-role-planner/src/planner.ts @@ -50,4 +50,5 @@ export const plannerRole: RoleDefinition = { "Extract the implementation phases from the agent's output. Each phase has a hash (the CAS content-hash returned by the cas put command) and a title (one-line summary).", schema: plannerMetaSchema, extractRefs: (meta) => meta.phases.map((p) => p.hash), + extractMode: "single", }; diff --git a/packages/workflow-role-preparer/src/preparer.ts b/packages/workflow-role-preparer/src/preparer.ts index d79f45e..5b4a1fe 100644 --- a/packages/workflow-role-preparer/src/preparer.ts +++ b/packages/workflow-role-preparer/src/preparer.ts @@ -48,4 +48,5 @@ export const preparerRole: RoleDefinition = { "Extract repoPath (absolute path), defaultBranch, conventions (summary string or null), and toolchain (packageManager, testCommand, lintCommand, buildCommand — each string or null).", schema: preparerMetaSchema, extractRefs: null, + extractMode: "single", }; diff --git a/packages/workflow-role-reviewer/src/reviewer.ts b/packages/workflow-role-reviewer/src/reviewer.ts index 695d771..e8513d1 100644 --- a/packages/workflow-role-reviewer/src/reviewer.ts +++ b/packages/workflow-role-reviewer/src/reviewer.ts @@ -22,4 +22,5 @@ export const reviewerRole: RoleDefinition = { "Extract the review verdict: approved or rejected. If rejected, list the blocking issues.", schema: reviewerMetaSchema, extractRefs: null, + extractMode: "single", }; diff --git a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts index 492771e..c85aa30 100644 --- a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts +++ b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts @@ -313,7 +313,7 @@ describe("createSolveIssueRun", () => { casDir = await mkdtemp(join(tmpdir(), "solve-issue-cas-")); const cas = createCasStore(casDir); - const run = createSolveIssueRun({ agent: async () => "" }, stubExtract); + const run = createSolveIssueRun({ agent: async () => "" }, stubExtract, null); const gen = run( { prompt: "task", steps: [] }, { threadId: "01TEST000000000000000000TR", maxRounds: 20, depth: 0, cas }, @@ -374,6 +374,7 @@ describe("createSolveIssueRun", () => { }, }, stubExtract, + null, ); const gen = run( { prompt: "task", steps: [] }, diff --git a/packages/workflow-template-solve-issue/src/index.ts b/packages/workflow-template-solve-issue/src/index.ts index 640cd18..317414d 100644 --- a/packages/workflow-template-solve-issue/src/index.ts +++ b/packages/workflow-template-solve-issue/src/index.ts @@ -2,6 +2,7 @@ import { type AgentBinding, createWorkflow, type ExtractFn, + type LlmProvider, type WorkflowDefinition, type WorkflowFn, } from "@uncaged/workflow"; @@ -50,6 +51,10 @@ export const solveIssueWorkflowDefinition: WorkflowDefinition = moderator: solveIssueModerator, }; -export function createSolveIssueRun(binding: AgentBinding, extract: ExtractFn): WorkflowFn { - return createWorkflow(solveIssueWorkflowDefinition, binding, extract); +export function createSolveIssueRun( + binding: AgentBinding, + extract: ExtractFn, + llmProvider: LlmProvider | null, +): WorkflowFn { + return createWorkflow(solveIssueWorkflowDefinition, binding, extract, llmProvider); } diff --git a/packages/workflow/__tests__/build-descriptor.test.ts b/packages/workflow/__tests__/build-descriptor.test.ts index 40b6244..46fd71a 100644 --- a/packages/workflow/__tests__/build-descriptor.test.ts +++ b/packages/workflow/__tests__/build-descriptor.test.ts @@ -23,6 +23,7 @@ describe("buildDescriptor", () => { extractPrompt: "Extract title and count from the analysis.", schema, extractRefs: null, + extractMode: "single", }, }, moderator: () => END, diff --git a/packages/workflow/__tests__/engine.test.ts b/packages/workflow/__tests__/engine.test.ts index cf16281..5e4de10 100644 --- a/packages/workflow/__tests__/engine.test.ts +++ b/packages/workflow/__tests__/engine.test.ts @@ -15,7 +15,7 @@ import { parseMerkleNode, serializeMerkleNode, } from "../src/merkle.js"; -import { END } from "../src/types.js"; +import { END, type LlmProvider } from "../src/types.js"; const plannerMetaSchema = z.object({ plan: z.string(), @@ -97,6 +97,7 @@ const demoWorkflow = createWorkflow( extractPrompt: "Extract plan text and affected files list.", schema: plannerMetaSchema, extractRefs: null, + extractMode: "single", }, coder: { description: "Demo coder", @@ -104,6 +105,7 @@ const demoWorkflow = createWorkflow( extractPrompt: "Extract the code diff summary.", schema: coderMetaSchema, extractRefs: null, + extractMode: "single", }, }, moderator: (ctx) => { @@ -124,6 +126,7 @@ const demoWorkflow = createWorkflow( }, }, demoExtract, + null, ); describe("executeThread", () => { @@ -445,4 +448,169 @@ describe("executeThread", () => { await rm(root, { recursive: true, force: true }); } }); + + test("extractMode react traverses CAS DAG via cas_get during extraction", async () => { + const dagMetaSchema = z.object({ leafPayload: z.string() }); + type DagDemoMeta = { walker: z.infer }; + + const origFetch = globalThis.fetch; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + let fetchRound = 0; + + const root = await mkdtemp(join(tmpdir(), "wf-engine-react-")); + try { + const cas = createCasStore(join(root, "cas")); + const leafYaml = serializeMerkleNode(createContentMerkleNode("needle-from-leaf")); + const leafHash = await cas.put(leafYaml); + const rootYaml = serializeMerkleNode({ + type: "thread", + payload: { + workflow: "dag-demo", + threadId: "01DAG00000000000000000001", + result: { returnCode: 0, summary: "" }, + }, + children: [leafHash], + }); + const dagRootHash = await cas.put(rootYaml); + + globalThis.fetch = Object.assign( + async (_input: Parameters[0], _init?: RequestInit) => { + fetchRound += 1; + if (fetchRound === 1) { + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + id: "c1", + type: "function", + function: { + name: "cas_get", + arguments: JSON.stringify({ hash: dagRootHash }), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + } + if (fetchRound === 2) { + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + id: "c2", + type: "function", + function: { + name: "cas_get", + arguments: JSON.stringify({ hash: leafHash }), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + } + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + id: "c3", + type: "function", + function: { + name: "extract", + arguments: JSON.stringify({ leafPayload: "needle-from-leaf" }), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + }, + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + + const llm: LlmProvider = { baseUrl: "http://127.0.0.1:9", apiKey: "test", model: "test" }; + const extractFn = createExtract(llm); + + const dagWorkflow = createWorkflow( + { + roles: { + walker: { + description: "DAG walker", + systemPrompt: "Output only the root CAS hash.", + extractPrompt: + "Set leafPayload to the string payload of the content Merkle node under the root.", + schema: dagMetaSchema, + extractRefs: null, + extractMode: "react", + }, + }, + moderator: (ctx) => (ctx.steps.length === 0 ? "walker" : END), + }, + { agent: async () => dagRootHash }, + extractFn, + llm, + ); + + const threadId = "01KQXKW18CT8G75T53R8F4G7YG"; + const hash = "C9NMV6V2TQT81"; + const dataPath = join(root, "logs", hash, `${threadId}.data.jsonl`); + const infoPath = join(root, "logs", hash, `${threadId}.info.jsonl`); + await mkdir(join(root, "logs", hash), { recursive: true }); + + const logger = createLogger({ sink: { kind: "file", path: infoPath } }); + const ac = new AbortController(); + + const result = await executeThread( + dagWorkflow, + "dag-demo", + { prompt: "traverse", steps: [] }, + { + maxRounds: 5, + depth: 0, + signal: ac.signal, + awaitAfterEachYield: async () => {}, + forkSourceThreadId: null, + prefilledDiskSteps: null, + }, + { threadId, hash, dataJsonlPath: dataPath, infoJsonlPath: infoPath, cas }, + logger, + ); + + expect(result.returnCode).toBe(0); + expect(fetchRound).toBe(3); + + const dataText = await readFile(dataPath, "utf8"); + const lines = dataText + .trim() + .split("\n") + .filter((l) => l !== ""); + const roleRec = JSON.parse(lines[1] ?? "{}") as Record; + expect(roleRec.role).toBe("walker"); + expect(roleRec.meta).toEqual({ leafPayload: "needle-from-leaf" }); + } finally { + globalThis.fetch = origFetch; + await rm(root, { recursive: true, force: true }); + } + }); }); diff --git a/packages/workflow/__tests__/react-extract.test.ts b/packages/workflow/__tests__/react-extract.test.ts new file mode 100644 index 0000000..9555fa5 --- /dev/null +++ b/packages/workflow/__tests__/react-extract.test.ts @@ -0,0 +1,209 @@ +import { afterEach, describe, expect, test } from "bun:test"; +import { mkdtemp, rm } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import * as z from "zod/v4"; + +import { createCasStore } from "../src/cas.js"; +import { createContentMerkleNode, serializeMerkleNode } from "../src/merkle.js"; +import { reactExtract } from "../src/react-extract.js"; +import type { LlmProvider } from "../src/types.js"; + +const metaSchema = z.object({ seen: z.string() }); + +const provider: LlmProvider = { + baseUrl: "http://127.0.0.1:9", + apiKey: "test", + model: "test", +}; + +describe("reactExtract", () => { + let restoreFetch: (() => void) | null = null; + + afterEach(() => { + restoreFetch?.(); + restoreFetch = null; + }); + + test("cas_get rounds then extract tool yields validated meta", async () => { + const casDir = await mkdtemp(join(tmpdir(), "react-extract-")); + const cas = createCasStore(casDir); + try { + const blob = serializeMerkleNode(createContentMerkleNode("needle")); + const h = await cas.put(blob); + + const origFetch = globalThis.fetch; + let round = 0; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + globalThis.fetch = Object.assign( + async (_input: Parameters[0], _init?: RequestInit) => { + round += 1; + if (round === 1) { + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + id: "t1", + type: "function", + function: { + name: "cas_get", + arguments: JSON.stringify({ hash: h }), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + } + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + id: "t2", + type: "function", + function: { + name: "extract", + arguments: JSON.stringify({ seen: "needle" }), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + }, + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + + const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`; + const result = await reactExtract({ + text, + schema: metaSchema, + provider, + cas, + }); + + expect(result.ok).toBe(true); + if (!result.ok) { + return; + } + expect(result.value).toEqual({ seen: "needle" }); + expect(round).toBe(2); + } finally { + await rm(casDir, { recursive: true, force: true }); + } + }); + + test("stops after max tool rounds when model keeps calling cas_get", async () => { + const casDir = await mkdtemp(join(tmpdir(), "react-extract-max-")); + const cas = createCasStore(casDir); + try { + const blob = serializeMerkleNode(createContentMerkleNode("x")); + const h = await cas.put(blob); + + const origFetch = globalThis.fetch; + let round = 0; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + globalThis.fetch = Object.assign( + async (_input: Parameters[0], _init?: RequestInit) => { + round += 1; + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + id: `loop-${round}`, + type: "function", + function: { + name: "cas_get", + arguments: JSON.stringify({ hash: h }), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + }, + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + + const result = await reactExtract({ + text: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.", + schema: metaSchema, + provider, + cas, + }); + + expect(result.ok).toBe(false); + if (result.ok) { + return; + } + expect(result.error).toBe("max_react_rounds_exceeded"); + expect(round).toBe(10); + } finally { + await rm(casDir, { recursive: true, force: true }); + } + }); + + test("passthrough JSON assistant message without tool calls", async () => { + const casDir = await mkdtemp(join(tmpdir(), "react-extract-pass-")); + const cas = createCasStore(casDir); + try { + const origFetch = globalThis.fetch; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + globalThis.fetch = Object.assign( + async (_input: Parameters[0], _init?: RequestInit) => + new Response( + JSON.stringify({ + choices: [ + { + message: { + content: '{"seen":"direct"}', + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ), + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + + const result = await reactExtract({ + text: "## Agent Output\nok\n## Extraction Instruction\nExtract.", + schema: metaSchema, + provider, + cas, + }); + + expect(result.ok).toBe(true); + if (!result.ok) { + return; + } + expect(result.value).toEqual({ seen: "direct" }); + } finally { + await rm(casDir, { recursive: true, force: true }); + } + }); +}); diff --git a/packages/workflow/__tests__/refs-tracking.test.ts b/packages/workflow/__tests__/refs-tracking.test.ts index d705650..521b2f2 100644 --- a/packages/workflow/__tests__/refs-tracking.test.ts +++ b/packages/workflow/__tests__/refs-tracking.test.ts @@ -91,6 +91,7 @@ const refsDemoWorkflow = createWorkflow( extractPrompt: "Extract phases with CAS hashes.", schema: plannerMetaSchema, extractRefs: (meta) => meta.phases.map((p) => p.hash), + extractMode: "single", }, }, moderator: (ctx) => (ctx.steps.length === 0 ? "planner" : END), @@ -99,6 +100,7 @@ const refsDemoWorkflow = createWorkflow( agent: async () => "plan-output", }, refsDemoExtract, + null, ); describe("RoleStep refs tracking", () => { diff --git a/packages/workflow/__tests__/workflow-as-agent-integration.test.ts b/packages/workflow/__tests__/workflow-as-agent-integration.test.ts index 73dc18f..fdb643e 100644 --- a/packages/workflow/__tests__/workflow-as-agent-integration.test.ts +++ b/packages/workflow/__tests__/workflow-as-agent-integration.test.ts @@ -142,12 +142,14 @@ describe("workflowAsAgent integration", () => { extractPrompt: "extract done flag", schema: callerMetaSchema, extractRefs: null, + extractMode: "single", }, }, moderator: (ctx) => (ctx.steps.length === 0 ? "caller" : END), }, { agent: workflowAsAgent("child-wf", { storageRoot: root }) }, parentExtract, + null, ); const threadId = "01KQXKW18CT8G75T53R8F4G7YG"; diff --git a/packages/workflow/src/create-workflow.ts b/packages/workflow/src/create-workflow.ts index ee0539e..8d0090e 100644 --- a/packages/workflow/src/create-workflow.ts +++ b/packages/workflow/src/create-workflow.ts @@ -1,11 +1,14 @@ -import type { ExtractFn } from "./extract-fn.js"; +import type { CasStore } from "./cas.js"; +import { buildExtractUserContent, type ExtractFn } from "./extract-fn.js"; import { putContentMerkleNode } from "./merkle.js"; +import { reactExtract } from "./react-extract.js"; import { mergeRefsWithContentHash } from "./refs-field.js"; import { type AgentBinding, type AgentContext, END, type ExtractContext, + type LlmProvider, type ModeratorContext, type RoleDefinition, type RoleMeta, @@ -36,14 +39,51 @@ function resolveExtractedRefs( return extractRefsFn(meta as Record); } +async function resolveRoleMeta( + roleDef: RoleDefinition>, + extractCtx: ExtractContext, + extract: ExtractFn, + llmProvider: LlmProvider | null, + cas: CasStore, +): Promise> { + if (roleDef.extractMode === "react") { + if (llmProvider === null) { + throw new Error( + 'createWorkflow: llmProvider is required when a role uses extractMode "react"', + ); + } + const text = await buildExtractUserContent( + extractCtx as unknown as ExtractContext, + roleDef.extractPrompt, + ); + const reactResult = await reactExtract({ + text, + schema: roleDef.schema, + provider: llmProvider, + cas, + }); + if (!reactResult.ok) { + throw new Error(`react extract failed: ${reactResult.error}`); + } + return reactResult.value as Record; + } + return (await extract( + roleDef.schema, + roleDef.extractPrompt, + extractCtx as unknown as ExtractContext, + )) as Record; +} + /** * Binds pure role definitions + moderator to runtime agents and structured extraction. - * Assign with `export const run = createWorkflow(def, binding, extract)`. + * Assign with `export const run = createWorkflow(def, binding, extract, llmProvider)`. + * Pass the same {@link LlmProvider} as {@link createExtract} when any role uses `extractMode: "react"`. */ export function createWorkflow( def: Pick, "roles" | "moderator">, binding: AgentBinding, extract: ExtractFn, + llmProvider: LlmProvider | null, ): WorkflowFn { return async function* workflowLoop( input: ThreadInput, @@ -107,10 +147,12 @@ export function createWorkflow( agentContent: raw, }; - const meta = await extract( - roleDef.schema, - roleDef.extractPrompt, - extractCtx as unknown as ExtractContext, + const meta = await resolveRoleMeta( + roleDef as unknown as RoleDefinition>, + extractCtx, + extract, + llmProvider, + options.cas, ); const contentHash = await putContentMerkleNode(options.cas, raw); diff --git a/packages/workflow/src/extract-fn.ts b/packages/workflow/src/extract-fn.ts index 4c46a2b..bb0baf4 100644 --- a/packages/workflow/src/extract-fn.ts +++ b/packages/workflow/src/extract-fn.ts @@ -10,6 +10,40 @@ export type ExtractFn = >( ctx: ExtractContext, ) => Promise; +/** Builds the user-side extraction prompt (thread + agent output + instruction). */ +export async function buildExtractUserContent( + ctx: ExtractContext, + prompt: string, +): Promise { + const lines: string[] = []; + lines.push(`## Role: ${ctx.currentRole.name}`); + lines.push(ctx.currentRole.systemPrompt); + lines.push(""); + lines.push("## Task"); + lines.push(ctx.start.content); + lines.push(""); + if (ctx.steps.length > 0) { + lines.push("## Thread History"); + for (const step of ctx.steps) { + const body = await getContentMerklePayload(ctx.cas, step.contentHash); + if (body === null) { + throw new Error(`extract: missing CAS blob for step ${step.role}: ${step.contentHash}`); + } + lines.push(`### ${step.role}`); + lines.push(body); + lines.push(`Meta: ${JSON.stringify(step.meta)}`); + lines.push(""); + } + } + lines.push("## Agent Output"); + lines.push(ctx.agentContent); + lines.push(""); + lines.push("## Extraction Instruction"); + lines.push(prompt); + + return lines.join("\n"); +} + /** * Create an ExtractFn backed by an LLM provider. * Builds prompt text from {@link ExtractContext} plus `prompt` and calls structured extraction. @@ -20,33 +54,7 @@ export function createExtract(provider: LlmProvider): ExtractFn { prompt: string, ctx: ExtractContext, ): Promise => { - const lines: string[] = []; - lines.push(`## Role: ${ctx.currentRole.name}`); - lines.push(ctx.currentRole.systemPrompt); - lines.push(""); - lines.push("## Task"); - lines.push(ctx.start.content); - lines.push(""); - if (ctx.steps.length > 0) { - lines.push("## Thread History"); - for (const step of ctx.steps) { - const body = await getContentMerklePayload(ctx.cas, step.contentHash); - if (body === null) { - throw new Error(`extract: missing CAS blob for step ${step.role}: ${step.contentHash}`); - } - lines.push(`### ${step.role}`); - lines.push(body); - lines.push(`Meta: ${JSON.stringify(step.meta)}`); - lines.push(""); - } - } - lines.push("## Agent Output"); - lines.push(ctx.agentContent); - lines.push(""); - lines.push("## Extraction Instruction"); - lines.push(prompt); - - const text = lines.join("\n"); + const text = await buildExtractUserContent(ctx, prompt); const result = await llmExtractWithRetry({ text, schema, provider }); if (!result.ok) { throw new Error(`extract failed: ${JSON.stringify(result.error)}`); diff --git a/packages/workflow/src/index.ts b/packages/workflow/src/index.ts index 6db5579..aa201d6 100644 --- a/packages/workflow/src/index.ts +++ b/packages/workflow/src/index.ts @@ -54,6 +54,7 @@ export { serializeMerkleNode, type ThreadMerklePayload, } from "./merkle.js"; +export { type ReactExtractArgs, reactExtract } from "./react-extract.js"; export { type ExtractProviderConfig, getRegisteredWorkflow, @@ -80,6 +81,7 @@ export { type AgentFn, END, type ExtractContext, + type ExtractMode, type LlmProvider, type Moderator, type ModeratorContext, diff --git a/packages/workflow/src/llm-extract.ts b/packages/workflow/src/llm-extract.ts index 35e8ac9..85f25c6 100644 --- a/packages/workflow/src/llm-extract.ts +++ b/packages/workflow/src/llm-extract.ts @@ -47,6 +47,21 @@ function readToolDescription(parametersSchema: Record): string return "Extract structured data from the input text."; } +/** Builds OpenAI function-tool metadata from a Zod meta schema (same naming rules as single-shot extract). */ +export function extractFunctionToolFromZodSchema(schema: z.ZodType): { + name: string; + description: string; + parameters: Record; +} { + const rawJsonSchema = z.toJSONSchema(schema) as Record; + const parameters = stripJsonSchemaMeta(rawJsonSchema); + return { + name: readToolName(parameters), + description: readToolDescription(parameters), + parameters, + }; +} + function readToolArgumentsJson(parsed: unknown, previewSource: string): Result { if (!isRecord(parsed)) { return err({ kind: "invalid_response_json", message: "Top-level JSON is not an object" }); @@ -124,10 +139,7 @@ export function llmErrorToCause(error: LlmError): Error { async function performLlmExtract( options: LlmExtractArgs & { userContent: string }, ): Promise> { - const rawJsonSchema = z.toJSONSchema(options.schema) as Record; - const parameters = stripJsonSchemaMeta(rawJsonSchema); - const toolName = readToolName(parameters); - const toolDescription = readToolDescription(parameters); + const extractTool = extractFunctionToolFromZodSchema(options.schema); const body = { model: options.provider.model, @@ -142,13 +154,13 @@ async function performLlmExtract( { type: "function" as const, function: { - name: toolName, - description: toolDescription, - parameters, + name: extractTool.name, + description: extractTool.description, + parameters: extractTool.parameters, }, }, ], - tool_choice: { type: "function" as const, function: { name: toolName } }, + tool_choice: { type: "function" as const, function: { name: extractTool.name } }, }; let response: Response; diff --git a/packages/workflow/src/react-extract.ts b/packages/workflow/src/react-extract.ts new file mode 100644 index 0000000..6c038a5 --- /dev/null +++ b/packages/workflow/src/react-extract.ts @@ -0,0 +1,330 @@ +import type * as z from "zod/v4"; + +import type { CasStore } from "./cas.js"; +import { extractFunctionToolFromZodSchema } from "./llm-extract.js"; +import { err, ok, type Result } from "./result.js"; +import type { LlmProvider } from "./types.js"; + +export type ReactExtractArgs> = { + text: string; + schema: z.ZodType; + provider: LlmProvider; + cas: CasStore; +}; + +const MAX_REACT_ROUNDS = 10; + +const CAS_GET_TOOL_DEFINITION = { + type: "function" as const, + function: { + name: "cas_get", + description: + "Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.", + parameters: { + type: "object", + properties: { + hash: { type: "string", description: "The CAS hash to retrieve" }, + }, + required: ["hash"], + }, + }, +}; + +function chatCompletionsUrl(baseUrl: string): string { + const trimmed = baseUrl.replace(/\/+$/, ""); + return `${trimmed}/chat/completions`; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function tryParseJsonContent(content: string): unknown | null { + const trimmed = content.trim(); + const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed); + const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed; + try { + return JSON.parse(payload) as unknown; + } catch { + return null; + } +} + +type ToolCall = { + id: string; + type: "function"; + function: { name: string; arguments: string }; +}; + +type ChatMessage = + | { role: "system"; content: string } + | { role: "user"; content: string } + | { + role: "assistant"; + content: string | null; + tool_calls: ToolCall[]; + } + | { role: "tool"; tool_call_id: string; content: string }; + +type AssistantTurn = + | { kind: "plain_json"; value: T } + | { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null }; + +function firstAssistantMessage(responseText: string): Result, string> { + let parsed: unknown; + try { + parsed = JSON.parse(responseText) as unknown; + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + return err(`invalid_response_json:${message}`); + } + if (!isRecord(parsed)) { + return err("invalid_response_top_level"); + } + const choices = parsed.choices; + if (!Array.isArray(choices) || choices.length === 0) { + return err("no_choices_in_response"); + } + const firstChoice = choices[0]; + if (!isRecord(firstChoice)) { + return err("invalid_choice"); + } + const messageObj = firstChoice.message; + if (!isRecord(messageObj)) { + return err("invalid_message"); + } + return ok(messageObj); +} + +function normalizeToolCalls(toolCallsRaw: unknown[]): Result { + const toolCalls: ToolCall[] = []; + for (const tc of toolCallsRaw) { + if (!isRecord(tc)) { + return err("invalid_tool_call"); + } + const id = tc.id; + const tcType = tc.type; + const fn = tc.function; + if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) { + return err("invalid_tool_call_shape"); + } + const name = fn.name; + const argumentsStr = fn.arguments; + if (typeof name !== "string" || typeof argumentsStr !== "string") { + return err("invalid_tool_call_function"); + } + toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } }); + } + return ok(toolCalls); +} + +function classifyAssistantTurn>( + messageObj: Record, + schema: z.ZodType, +): Result, string> { + const toolCallsRaw = messageObj.tool_calls; + if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) { + const content = messageObj.content; + if (typeof content !== "string") { + return err("no_tool_calls_and_no_string_content"); + } + const jsonParsed = tryParseJsonContent(content); + if (jsonParsed === null) { + return err("no_tool_calls_and_content_not_json"); + } + const validated = schema.safeParse(jsonParsed); + if (!validated.success) { + return err(`schema_validation_failed:${validated.error.message}`); + } + return ok({ kind: "plain_json", value: validated.data }); + } + const callsResult = normalizeToolCalls(toolCallsRaw); + if (!callsResult.ok) { + return err(callsResult.error); + } + const assistantContent = messageObj.content; + return ok({ + kind: "tool_calls", + calls: callsResult.value, + assistantContent: typeof assistantContent === "string" ? assistantContent : null, + }); +} + +async function appendCasGetToolResult( + tc: ToolCall, + cas: CasStore, + messages: ChatMessage[], +): Promise> { + let hash: string; + try { + const ta = JSON.parse(tc.function.arguments) as unknown; + if (!isRecord(ta) || typeof ta.hash !== "string") { + return err("cas_get_invalid_arguments"); + } + hash = ta.hash; + } catch { + return err("cas_get_arguments_not_json"); + } + const blob = await cas.get(hash); + const toolContent = blob === null ? "null" : blob; + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: toolContent, + }); + return ok(null); +} + +async function appendExtractToolResult>( + tc: ToolCall, + schema: z.ZodType, + messages: ChatMessage[], +): Promise> { + let parsedArgs: unknown; + try { + parsedArgs = JSON.parse(tc.function.arguments) as unknown; + } catch { + return err("extract_tool_arguments_not_json"); + } + const validated = schema.safeParse(parsedArgs); + if (!validated.success) { + return err(`schema_validation_failed:${validated.error.message}`); + } + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: '{"ok":true}', + }); + return ok(validated.data); +} + +async function appendToolResults>( + toolCalls: ToolCall[], + extractToolName: string, + schema: z.ZodType, + cas: CasStore, + messages: ChatMessage[], +): Promise> { + let extracted: T | null = null; + for (const tc of toolCalls) { + if (tc.function.name === "cas_get") { + const casRes = await appendCasGetToolResult(tc, cas, messages); + if (!casRes.ok) { + return casRes; + } + continue; + } + if (tc.function.name === extractToolName) { + const exRes = await appendExtractToolResult(tc, schema, messages); + if (!exRes.ok) { + return exRes; + } + extracted = exRes.value; + continue; + } + return err(`unknown_tool:${tc.function.name}`); + } + return ok(extracted); +} + +async function postChatCompletion( + provider: LlmProvider, + messages: ChatMessage[], + tools: readonly Record[], +): Promise> { + try { + const response = await fetch(chatCompletionsUrl(provider.baseUrl), { + method: "POST", + headers: { + Authorization: `Bearer ${provider.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: provider.model, + messages, + tools, + tool_choice: "auto", + }), + }); + const responseText = await response.text(); + if (!response.ok) { + return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`); + } + return ok(responseText); + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + return err(`network_error:${message}`); + } +} + +/** + * Multi-turn ReAct extraction with `cas_get` plus a schema-shaped extract tool (OpenAI-compatible). + * Final meta comes from a successful extract tool call or from plain JSON in the assistant message. + */ +export async function reactExtract>( + args: ReactExtractArgs, +): Promise> { + const extractTool = extractFunctionToolFromZodSchema(args.schema); + const tools = [ + CAS_GET_TOOL_DEFINITION, + { + type: "function" as const, + function: { + name: extractTool.name, + description: extractTool.description, + parameters: extractTool.parameters, + }, + }, + ]; + + const systemContent = `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${extractTool.name} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`; + + const messages: ChatMessage[] = [ + { role: "system", content: systemContent }, + { role: "user", content: args.text }, + ]; + + for (let round = 0; round < MAX_REACT_ROUNDS; round++) { + const bodyResult = await postChatCompletion(args.provider, messages, tools); + if (!bodyResult.ok) { + return bodyResult; + } + + const msgResult = firstAssistantMessage(bodyResult.value); + if (!msgResult.ok) { + return msgResult; + } + + const classified = classifyAssistantTurn(msgResult.value, args.schema); + if (!classified.ok) { + return classified; + } + + const turn = classified.value; + if (turn.kind === "plain_json") { + return ok(turn.value); + } + + messages.push({ + role: "assistant", + content: turn.assistantContent, + tool_calls: turn.calls, + }); + + const toolsRound = await appendToolResults( + turn.calls, + extractTool.name, + args.schema, + args.cas, + messages, + ); + if (!toolsRound.ok) { + return toolsRound; + } + if (toolsRound.value !== null) { + return ok(toolsRound.value); + } + } + + return err("max_react_rounds_exceeded"); +} diff --git a/packages/workflow/src/types.ts b/packages/workflow/src/types.ts index 42ad51f..faa5dfd 100644 --- a/packages/workflow/src/types.ts +++ b/packages/workflow/src/types.ts @@ -16,6 +16,9 @@ export type LlmProvider = { model: string; }; +/** How the engine runs meta extraction for a role after the agent phase. */ +export type ExtractMode = "single" | "react"; + /** What each generator yield produces — one role's output (engine adds `timestamp` when persisting). */ export type RoleOutput = { role: string; @@ -121,6 +124,7 @@ export type RoleDefinition> = { schema: z.ZodType; /** When non-null, produces CAS hashes to persist on this role's steps (see `RoleOutput.refs`). */ extractRefs: ((meta: Meta) => string[]) | null; + extractMode: ExtractMode; }; /**