Merge pull request 'feat: ReAct ExtractFn with tool-use' (#53) from feat/44-react-extract into main
This commit is contained in:
@@ -29,6 +29,7 @@ const greeter: RoleDefinition<Roles["greeter"]> = {
|
||||
extractPrompt: "Extract the greeting string produced for the user.",
|
||||
schema: greeterMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
};
|
||||
|
||||
const extract = createExtract({
|
||||
@@ -48,4 +49,5 @@ export const run = createWorkflow<Roles>(
|
||||
agent: async (ctx) => `Hello, ${ctx.start.content}`,
|
||||
},
|
||||
extract,
|
||||
null,
|
||||
);
|
||||
|
||||
@@ -39,4 +39,5 @@ export const coderRole: RoleDefinition<CoderMeta> = {
|
||||
"Extract completedPhase: the planner phase hash finished this round (exact hash string from the plan). If multiple phases were finished in one round, use the last finished phase hash. Extract filesChanged and a summary of the work.",
|
||||
schema: coderMetaSchema,
|
||||
extractRefs: (meta) => [meta.completedPhase],
|
||||
extractMode: "single",
|
||||
};
|
||||
|
||||
@@ -32,4 +32,5 @@ export const committerRole: RoleDefinition<CommitterMeta> = {
|
||||
"Extract the commit result: committed (with branch and SHA), recoverable failure, or unrecoverable failure. Include error details and log references if applicable.",
|
||||
schema: committerMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
};
|
||||
|
||||
@@ -50,4 +50,5 @@ export const plannerRole: RoleDefinition<PlannerMeta> = {
|
||||
"Extract the implementation phases from the agent's output. Each phase has a hash (the CAS content-hash returned by the cas put command) and a title (one-line summary).",
|
||||
schema: plannerMetaSchema,
|
||||
extractRefs: (meta) => meta.phases.map((p) => p.hash),
|
||||
extractMode: "single",
|
||||
};
|
||||
|
||||
@@ -48,4 +48,5 @@ export const preparerRole: RoleDefinition<PreparerMeta> = {
|
||||
"Extract repoPath (absolute path), defaultBranch, conventions (summary string or null), and toolchain (packageManager, testCommand, lintCommand, buildCommand — each string or null).",
|
||||
schema: preparerMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
};
|
||||
|
||||
@@ -22,4 +22,5 @@ export const reviewerRole: RoleDefinition<ReviewerMeta> = {
|
||||
"Extract the review verdict: approved or rejected. If rejected, list the blocking issues.",
|
||||
schema: reviewerMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
};
|
||||
|
||||
@@ -313,7 +313,7 @@ describe("createSolveIssueRun", () => {
|
||||
casDir = await mkdtemp(join(tmpdir(), "solve-issue-cas-"));
|
||||
const cas = createCasStore(casDir);
|
||||
|
||||
const run = createSolveIssueRun({ agent: async () => "" }, stubExtract);
|
||||
const run = createSolveIssueRun({ agent: async () => "" }, stubExtract, null);
|
||||
const gen = run(
|
||||
{ prompt: "task", steps: [] },
|
||||
{ threadId: "01TEST000000000000000000TR", maxRounds: 20, depth: 0, cas },
|
||||
@@ -374,6 +374,7 @@ describe("createSolveIssueRun", () => {
|
||||
},
|
||||
},
|
||||
stubExtract,
|
||||
null,
|
||||
);
|
||||
const gen = run(
|
||||
{ prompt: "task", steps: [] },
|
||||
|
||||
@@ -2,6 +2,7 @@ import {
|
||||
type AgentBinding,
|
||||
createWorkflow,
|
||||
type ExtractFn,
|
||||
type LlmProvider,
|
||||
type WorkflowDefinition,
|
||||
type WorkflowFn,
|
||||
} from "@uncaged/workflow";
|
||||
@@ -50,6 +51,10 @@ export const solveIssueWorkflowDefinition: WorkflowDefinition<SolveIssueMeta> =
|
||||
moderator: solveIssueModerator,
|
||||
};
|
||||
|
||||
export function createSolveIssueRun(binding: AgentBinding, extract: ExtractFn): WorkflowFn {
|
||||
return createWorkflow(solveIssueWorkflowDefinition, binding, extract);
|
||||
export function createSolveIssueRun(
|
||||
binding: AgentBinding,
|
||||
extract: ExtractFn,
|
||||
llmProvider: LlmProvider | null,
|
||||
): WorkflowFn {
|
||||
return createWorkflow(solveIssueWorkflowDefinition, binding, extract, llmProvider);
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ describe("buildDescriptor", () => {
|
||||
extractPrompt: "Extract title and count from the analysis.",
|
||||
schema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
},
|
||||
},
|
||||
moderator: () => END,
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
parseMerkleNode,
|
||||
serializeMerkleNode,
|
||||
} from "../src/merkle.js";
|
||||
import { END } from "../src/types.js";
|
||||
import { END, type LlmProvider } from "../src/types.js";
|
||||
|
||||
const plannerMetaSchema = z.object({
|
||||
plan: z.string(),
|
||||
@@ -97,6 +97,7 @@ const demoWorkflow = createWorkflow<DemoMeta>(
|
||||
extractPrompt: "Extract plan text and affected files list.",
|
||||
schema: plannerMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
},
|
||||
coder: {
|
||||
description: "Demo coder",
|
||||
@@ -104,6 +105,7 @@ const demoWorkflow = createWorkflow<DemoMeta>(
|
||||
extractPrompt: "Extract the code diff summary.",
|
||||
schema: coderMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
},
|
||||
},
|
||||
moderator: (ctx) => {
|
||||
@@ -124,6 +126,7 @@ const demoWorkflow = createWorkflow<DemoMeta>(
|
||||
},
|
||||
},
|
||||
demoExtract,
|
||||
null,
|
||||
);
|
||||
|
||||
describe("executeThread", () => {
|
||||
@@ -445,4 +448,169 @@ describe("executeThread", () => {
|
||||
await rm(root, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
test("extractMode react traverses CAS DAG via cas_get during extraction", async () => {
|
||||
const dagMetaSchema = z.object({ leafPayload: z.string() });
|
||||
type DagDemoMeta = { walker: z.infer<typeof dagMetaSchema> };
|
||||
|
||||
const origFetch = globalThis.fetch;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
let fetchRound = 0;
|
||||
|
||||
const root = await mkdtemp(join(tmpdir(), "wf-engine-react-"));
|
||||
try {
|
||||
const cas = createCasStore(join(root, "cas"));
|
||||
const leafYaml = serializeMerkleNode(createContentMerkleNode("needle-from-leaf"));
|
||||
const leafHash = await cas.put(leafYaml);
|
||||
const rootYaml = serializeMerkleNode({
|
||||
type: "thread",
|
||||
payload: {
|
||||
workflow: "dag-demo",
|
||||
threadId: "01DAG00000000000000000001",
|
||||
result: { returnCode: 0, summary: "" },
|
||||
},
|
||||
children: [leafHash],
|
||||
});
|
||||
const dagRootHash = await cas.put(rootYaml);
|
||||
|
||||
globalThis.fetch = Object.assign(
|
||||
async (_input: Parameters<typeof fetch>[0], _init?: RequestInit) => {
|
||||
fetchRound += 1;
|
||||
if (fetchRound === 1) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
id: "c1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "cas_get",
|
||||
arguments: JSON.stringify({ hash: dagRootHash }),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
}
|
||||
if (fetchRound === 2) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
id: "c2",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "cas_get",
|
||||
arguments: JSON.stringify({ hash: leafHash }),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
}
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
id: "c3",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "extract",
|
||||
arguments: JSON.stringify({ leafPayload: "needle-from-leaf" }),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
},
|
||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||
) as typeof fetch;
|
||||
|
||||
const llm: LlmProvider = { baseUrl: "http://127.0.0.1:9", apiKey: "test", model: "test" };
|
||||
const extractFn = createExtract(llm);
|
||||
|
||||
const dagWorkflow = createWorkflow<DagDemoMeta>(
|
||||
{
|
||||
roles: {
|
||||
walker: {
|
||||
description: "DAG walker",
|
||||
systemPrompt: "Output only the root CAS hash.",
|
||||
extractPrompt:
|
||||
"Set leafPayload to the string payload of the content Merkle node under the root.",
|
||||
schema: dagMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "react",
|
||||
},
|
||||
},
|
||||
moderator: (ctx) => (ctx.steps.length === 0 ? "walker" : END),
|
||||
},
|
||||
{ agent: async () => dagRootHash },
|
||||
extractFn,
|
||||
llm,
|
||||
);
|
||||
|
||||
const threadId = "01KQXKW18CT8G75T53R8F4G7YG";
|
||||
const hash = "C9NMV6V2TQT81";
|
||||
const dataPath = join(root, "logs", hash, `${threadId}.data.jsonl`);
|
||||
const infoPath = join(root, "logs", hash, `${threadId}.info.jsonl`);
|
||||
await mkdir(join(root, "logs", hash), { recursive: true });
|
||||
|
||||
const logger = createLogger({ sink: { kind: "file", path: infoPath } });
|
||||
const ac = new AbortController();
|
||||
|
||||
const result = await executeThread(
|
||||
dagWorkflow,
|
||||
"dag-demo",
|
||||
{ prompt: "traverse", steps: [] },
|
||||
{
|
||||
maxRounds: 5,
|
||||
depth: 0,
|
||||
signal: ac.signal,
|
||||
awaitAfterEachYield: async () => {},
|
||||
forkSourceThreadId: null,
|
||||
prefilledDiskSteps: null,
|
||||
},
|
||||
{ threadId, hash, dataJsonlPath: dataPath, infoJsonlPath: infoPath, cas },
|
||||
logger,
|
||||
);
|
||||
|
||||
expect(result.returnCode).toBe(0);
|
||||
expect(fetchRound).toBe(3);
|
||||
|
||||
const dataText = await readFile(dataPath, "utf8");
|
||||
const lines = dataText
|
||||
.trim()
|
||||
.split("\n")
|
||||
.filter((l) => l !== "");
|
||||
const roleRec = JSON.parse(lines[1] ?? "{}") as Record<string, unknown>;
|
||||
expect(roleRec.role).toBe("walker");
|
||||
expect(roleRec.meta).toEqual({ leafPayload: "needle-from-leaf" });
|
||||
} finally {
|
||||
globalThis.fetch = origFetch;
|
||||
await rm(root, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
import { afterEach, describe, expect, test } from "bun:test";
|
||||
import { mkdtemp, rm } from "node:fs/promises";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import * as z from "zod/v4";
|
||||
|
||||
import { createCasStore } from "../src/cas.js";
|
||||
import { createContentMerkleNode, serializeMerkleNode } from "../src/merkle.js";
|
||||
import { reactExtract } from "../src/react-extract.js";
|
||||
import type { LlmProvider } from "../src/types.js";
|
||||
|
||||
const metaSchema = z.object({ seen: z.string() });
|
||||
|
||||
const provider: LlmProvider = {
|
||||
baseUrl: "http://127.0.0.1:9",
|
||||
apiKey: "test",
|
||||
model: "test",
|
||||
};
|
||||
|
||||
describe("reactExtract", () => {
|
||||
let restoreFetch: (() => void) | null = null;
|
||||
|
||||
afterEach(() => {
|
||||
restoreFetch?.();
|
||||
restoreFetch = null;
|
||||
});
|
||||
|
||||
test("cas_get rounds then extract tool yields validated meta", async () => {
|
||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-"));
|
||||
const cas = createCasStore(casDir);
|
||||
try {
|
||||
const blob = serializeMerkleNode(createContentMerkleNode("needle"));
|
||||
const h = await cas.put(blob);
|
||||
|
||||
const origFetch = globalThis.fetch;
|
||||
let round = 0;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
globalThis.fetch = Object.assign(
|
||||
async (_input: Parameters<typeof fetch>[0], _init?: RequestInit) => {
|
||||
round += 1;
|
||||
if (round === 1) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
id: "t1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "cas_get",
|
||||
arguments: JSON.stringify({ hash: h }),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
}
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
id: "t2",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "extract",
|
||||
arguments: JSON.stringify({ seen: "needle" }),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
},
|
||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||
) as typeof fetch;
|
||||
|
||||
const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`;
|
||||
const result = await reactExtract({
|
||||
text,
|
||||
schema: metaSchema,
|
||||
provider,
|
||||
cas,
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(true);
|
||||
if (!result.ok) {
|
||||
return;
|
||||
}
|
||||
expect(result.value).toEqual({ seen: "needle" });
|
||||
expect(round).toBe(2);
|
||||
} finally {
|
||||
await rm(casDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
test("stops after max tool rounds when model keeps calling cas_get", async () => {
|
||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-max-"));
|
||||
const cas = createCasStore(casDir);
|
||||
try {
|
||||
const blob = serializeMerkleNode(createContentMerkleNode("x"));
|
||||
const h = await cas.put(blob);
|
||||
|
||||
const origFetch = globalThis.fetch;
|
||||
let round = 0;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
globalThis.fetch = Object.assign(
|
||||
async (_input: Parameters<typeof fetch>[0], _init?: RequestInit) => {
|
||||
round += 1;
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
id: `loop-${round}`,
|
||||
type: "function",
|
||||
function: {
|
||||
name: "cas_get",
|
||||
arguments: JSON.stringify({ hash: h }),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
},
|
||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||
) as typeof fetch;
|
||||
|
||||
const result = await reactExtract({
|
||||
text: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.",
|
||||
schema: metaSchema,
|
||||
provider,
|
||||
cas,
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(false);
|
||||
if (result.ok) {
|
||||
return;
|
||||
}
|
||||
expect(result.error).toBe("max_react_rounds_exceeded");
|
||||
expect(round).toBe(10);
|
||||
} finally {
|
||||
await rm(casDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
test("passthrough JSON assistant message without tool calls", async () => {
|
||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-pass-"));
|
||||
const cas = createCasStore(casDir);
|
||||
try {
|
||||
const origFetch = globalThis.fetch;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
globalThis.fetch = Object.assign(
|
||||
async (_input: Parameters<typeof fetch>[0], _init?: RequestInit) =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: '{"seen":"direct"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
),
|
||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||
) as typeof fetch;
|
||||
|
||||
const result = await reactExtract({
|
||||
text: "## Agent Output\nok\n## Extraction Instruction\nExtract.",
|
||||
schema: metaSchema,
|
||||
provider,
|
||||
cas,
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(true);
|
||||
if (!result.ok) {
|
||||
return;
|
||||
}
|
||||
expect(result.value).toEqual({ seen: "direct" });
|
||||
} finally {
|
||||
await rm(casDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -91,6 +91,7 @@ const refsDemoWorkflow = createWorkflow<RefsDemoMeta>(
|
||||
extractPrompt: "Extract phases with CAS hashes.",
|
||||
schema: plannerMetaSchema,
|
||||
extractRefs: (meta) => meta.phases.map((p) => p.hash),
|
||||
extractMode: "single",
|
||||
},
|
||||
},
|
||||
moderator: (ctx) => (ctx.steps.length === 0 ? "planner" : END),
|
||||
@@ -99,6 +100,7 @@ const refsDemoWorkflow = createWorkflow<RefsDemoMeta>(
|
||||
agent: async () => "plan-output",
|
||||
},
|
||||
refsDemoExtract,
|
||||
null,
|
||||
);
|
||||
|
||||
describe("RoleStep refs tracking", () => {
|
||||
|
||||
@@ -142,12 +142,14 @@ describe("workflowAsAgent integration", () => {
|
||||
extractPrompt: "extract done flag",
|
||||
schema: callerMetaSchema,
|
||||
extractRefs: null,
|
||||
extractMode: "single",
|
||||
},
|
||||
},
|
||||
moderator: (ctx) => (ctx.steps.length === 0 ? "caller" : END),
|
||||
},
|
||||
{ agent: workflowAsAgent("child-wf", { storageRoot: root }) },
|
||||
parentExtract,
|
||||
null,
|
||||
);
|
||||
|
||||
const threadId = "01KQXKW18CT8G75T53R8F4G7YG";
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import type { ExtractFn } from "./extract-fn.js";
|
||||
import type { CasStore } from "./cas.js";
|
||||
import { buildExtractUserContent, type ExtractFn } from "./extract-fn.js";
|
||||
import { putContentMerkleNode } from "./merkle.js";
|
||||
import { reactExtract } from "./react-extract.js";
|
||||
import { mergeRefsWithContentHash } from "./refs-field.js";
|
||||
import {
|
||||
type AgentBinding,
|
||||
type AgentContext,
|
||||
END,
|
||||
type ExtractContext,
|
||||
type LlmProvider,
|
||||
type ModeratorContext,
|
||||
type RoleDefinition,
|
||||
type RoleMeta,
|
||||
@@ -36,14 +39,51 @@ function resolveExtractedRefs(
|
||||
return extractRefsFn(meta as Record<string, unknown>);
|
||||
}
|
||||
|
||||
async function resolveRoleMeta<M extends RoleMeta>(
|
||||
roleDef: RoleDefinition<Record<string, unknown>>,
|
||||
extractCtx: ExtractContext<M>,
|
||||
extract: ExtractFn,
|
||||
llmProvider: LlmProvider | null,
|
||||
cas: CasStore,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (roleDef.extractMode === "react") {
|
||||
if (llmProvider === null) {
|
||||
throw new Error(
|
||||
'createWorkflow: llmProvider is required when a role uses extractMode "react"',
|
||||
);
|
||||
}
|
||||
const text = await buildExtractUserContent(
|
||||
extractCtx as unknown as ExtractContext,
|
||||
roleDef.extractPrompt,
|
||||
);
|
||||
const reactResult = await reactExtract({
|
||||
text,
|
||||
schema: roleDef.schema,
|
||||
provider: llmProvider,
|
||||
cas,
|
||||
});
|
||||
if (!reactResult.ok) {
|
||||
throw new Error(`react extract failed: ${reactResult.error}`);
|
||||
}
|
||||
return reactResult.value as Record<string, unknown>;
|
||||
}
|
||||
return (await extract(
|
||||
roleDef.schema,
|
||||
roleDef.extractPrompt,
|
||||
extractCtx as unknown as ExtractContext,
|
||||
)) as Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds pure role definitions + moderator to runtime agents and structured extraction.
|
||||
* Assign with `export const run = createWorkflow(def, binding, extract)`.
|
||||
* Assign with `export const run = createWorkflow(def, binding, extract, llmProvider)`.
|
||||
* Pass the same {@link LlmProvider} as {@link createExtract} when any role uses `extractMode: "react"`.
|
||||
*/
|
||||
export function createWorkflow<M extends RoleMeta>(
|
||||
def: Pick<WorkflowDefinition<M>, "roles" | "moderator">,
|
||||
binding: AgentBinding,
|
||||
extract: ExtractFn,
|
||||
llmProvider: LlmProvider | null,
|
||||
): WorkflowFn {
|
||||
return async function* workflowLoop(
|
||||
input: ThreadInput,
|
||||
@@ -107,10 +147,12 @@ export function createWorkflow<M extends RoleMeta>(
|
||||
agentContent: raw,
|
||||
};
|
||||
|
||||
const meta = await extract(
|
||||
roleDef.schema,
|
||||
roleDef.extractPrompt,
|
||||
extractCtx as unknown as ExtractContext,
|
||||
const meta = await resolveRoleMeta(
|
||||
roleDef as unknown as RoleDefinition<Record<string, unknown>>,
|
||||
extractCtx,
|
||||
extract,
|
||||
llmProvider,
|
||||
options.cas,
|
||||
);
|
||||
|
||||
const contentHash = await putContentMerkleNode(options.cas, raw);
|
||||
|
||||
@@ -10,6 +10,40 @@ export type ExtractFn = <T extends Record<string, unknown>>(
|
||||
ctx: ExtractContext,
|
||||
) => Promise<T>;
|
||||
|
||||
/** Builds the user-side extraction prompt (thread + agent output + instruction). */
|
||||
export async function buildExtractUserContent(
|
||||
ctx: ExtractContext,
|
||||
prompt: string,
|
||||
): Promise<string> {
|
||||
const lines: string[] = [];
|
||||
lines.push(`## Role: ${ctx.currentRole.name}`);
|
||||
lines.push(ctx.currentRole.systemPrompt);
|
||||
lines.push("");
|
||||
lines.push("## Task");
|
||||
lines.push(ctx.start.content);
|
||||
lines.push("");
|
||||
if (ctx.steps.length > 0) {
|
||||
lines.push("## Thread History");
|
||||
for (const step of ctx.steps) {
|
||||
const body = await getContentMerklePayload(ctx.cas, step.contentHash);
|
||||
if (body === null) {
|
||||
throw new Error(`extract: missing CAS blob for step ${step.role}: ${step.contentHash}`);
|
||||
}
|
||||
lines.push(`### ${step.role}`);
|
||||
lines.push(body);
|
||||
lines.push(`Meta: ${JSON.stringify(step.meta)}`);
|
||||
lines.push("");
|
||||
}
|
||||
}
|
||||
lines.push("## Agent Output");
|
||||
lines.push(ctx.agentContent);
|
||||
lines.push("");
|
||||
lines.push("## Extraction Instruction");
|
||||
lines.push(prompt);
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an ExtractFn backed by an LLM provider.
|
||||
* Builds prompt text from {@link ExtractContext} plus `prompt` and calls structured extraction.
|
||||
@@ -20,33 +54,7 @@ export function createExtract(provider: LlmProvider): ExtractFn {
|
||||
prompt: string,
|
||||
ctx: ExtractContext,
|
||||
): Promise<T> => {
|
||||
const lines: string[] = [];
|
||||
lines.push(`## Role: ${ctx.currentRole.name}`);
|
||||
lines.push(ctx.currentRole.systemPrompt);
|
||||
lines.push("");
|
||||
lines.push("## Task");
|
||||
lines.push(ctx.start.content);
|
||||
lines.push("");
|
||||
if (ctx.steps.length > 0) {
|
||||
lines.push("## Thread History");
|
||||
for (const step of ctx.steps) {
|
||||
const body = await getContentMerklePayload(ctx.cas, step.contentHash);
|
||||
if (body === null) {
|
||||
throw new Error(`extract: missing CAS blob for step ${step.role}: ${step.contentHash}`);
|
||||
}
|
||||
lines.push(`### ${step.role}`);
|
||||
lines.push(body);
|
||||
lines.push(`Meta: ${JSON.stringify(step.meta)}`);
|
||||
lines.push("");
|
||||
}
|
||||
}
|
||||
lines.push("## Agent Output");
|
||||
lines.push(ctx.agentContent);
|
||||
lines.push("");
|
||||
lines.push("## Extraction Instruction");
|
||||
lines.push(prompt);
|
||||
|
||||
const text = lines.join("\n");
|
||||
const text = await buildExtractUserContent(ctx, prompt);
|
||||
const result = await llmExtractWithRetry({ text, schema, provider });
|
||||
if (!result.ok) {
|
||||
throw new Error(`extract failed: ${JSON.stringify(result.error)}`);
|
||||
|
||||
@@ -54,6 +54,7 @@ export {
|
||||
serializeMerkleNode,
|
||||
type ThreadMerklePayload,
|
||||
} from "./merkle.js";
|
||||
export { type ReactExtractArgs, reactExtract } from "./react-extract.js";
|
||||
export {
|
||||
type ExtractProviderConfig,
|
||||
getRegisteredWorkflow,
|
||||
@@ -80,6 +81,7 @@ export {
|
||||
type AgentFn,
|
||||
END,
|
||||
type ExtractContext,
|
||||
type ExtractMode,
|
||||
type LlmProvider,
|
||||
type Moderator,
|
||||
type ModeratorContext,
|
||||
|
||||
@@ -47,6 +47,21 @@ function readToolDescription(parametersSchema: Record<string, unknown>): string
|
||||
return "Extract structured data from the input text.";
|
||||
}
|
||||
|
||||
/** Builds OpenAI function-tool metadata from a Zod meta schema (same naming rules as single-shot extract). */
|
||||
export function extractFunctionToolFromZodSchema(schema: z.ZodType<unknown>): {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: Record<string, unknown>;
|
||||
} {
|
||||
const rawJsonSchema = z.toJSONSchema(schema) as Record<string, unknown>;
|
||||
const parameters = stripJsonSchemaMeta(rawJsonSchema);
|
||||
return {
|
||||
name: readToolName(parameters),
|
||||
description: readToolDescription(parameters),
|
||||
parameters,
|
||||
};
|
||||
}
|
||||
|
||||
function readToolArgumentsJson(parsed: unknown, previewSource: string): Result<string, LlmError> {
|
||||
if (!isRecord(parsed)) {
|
||||
return err({ kind: "invalid_response_json", message: "Top-level JSON is not an object" });
|
||||
@@ -124,10 +139,7 @@ export function llmErrorToCause(error: LlmError): Error {
|
||||
async function performLlmExtract<T>(
|
||||
options: LlmExtractArgs<T> & { userContent: string },
|
||||
): Promise<Result<T, LlmError>> {
|
||||
const rawJsonSchema = z.toJSONSchema(options.schema) as Record<string, unknown>;
|
||||
const parameters = stripJsonSchemaMeta(rawJsonSchema);
|
||||
const toolName = readToolName(parameters);
|
||||
const toolDescription = readToolDescription(parameters);
|
||||
const extractTool = extractFunctionToolFromZodSchema(options.schema);
|
||||
|
||||
const body = {
|
||||
model: options.provider.model,
|
||||
@@ -142,13 +154,13 @@ async function performLlmExtract<T>(
|
||||
{
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: toolName,
|
||||
description: toolDescription,
|
||||
parameters,
|
||||
name: extractTool.name,
|
||||
description: extractTool.description,
|
||||
parameters: extractTool.parameters,
|
||||
},
|
||||
},
|
||||
],
|
||||
tool_choice: { type: "function" as const, function: { name: toolName } },
|
||||
tool_choice: { type: "function" as const, function: { name: extractTool.name } },
|
||||
};
|
||||
|
||||
let response: Response;
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
import type * as z from "zod/v4";
|
||||
|
||||
import type { CasStore } from "./cas.js";
|
||||
import { extractFunctionToolFromZodSchema } from "./llm-extract.js";
|
||||
import { err, ok, type Result } from "./result.js";
|
||||
import type { LlmProvider } from "./types.js";
|
||||
|
||||
export type ReactExtractArgs<T extends Record<string, unknown>> = {
|
||||
text: string;
|
||||
schema: z.ZodType<T>;
|
||||
provider: LlmProvider;
|
||||
cas: CasStore;
|
||||
};
|
||||
|
||||
const MAX_REACT_ROUNDS = 10;
|
||||
|
||||
const CAS_GET_TOOL_DEFINITION = {
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: "cas_get",
|
||||
description:
|
||||
"Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
hash: { type: "string", description: "The CAS hash to retrieve" },
|
||||
},
|
||||
required: ["hash"],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
function chatCompletionsUrl(baseUrl: string): string {
|
||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
||||
return `${trimmed}/chat/completions`;
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function tryParseJsonContent(content: string): unknown | null {
|
||||
const trimmed = content.trim();
|
||||
const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed);
|
||||
const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed;
|
||||
try {
|
||||
return JSON.parse(payload) as unknown;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
type ToolCall = {
|
||||
id: string;
|
||||
type: "function";
|
||||
function: { name: string; arguments: string };
|
||||
};
|
||||
|
||||
type ChatMessage =
|
||||
| { role: "system"; content: string }
|
||||
| { role: "user"; content: string }
|
||||
| {
|
||||
role: "assistant";
|
||||
content: string | null;
|
||||
tool_calls: ToolCall[];
|
||||
}
|
||||
| { role: "tool"; tool_call_id: string; content: string };
|
||||
|
||||
type AssistantTurn<T> =
|
||||
| { kind: "plain_json"; value: T }
|
||||
| { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null };
|
||||
|
||||
function firstAssistantMessage(responseText: string): Result<Record<string, unknown>, string> {
|
||||
let parsed: unknown;
|
||||
try {
|
||||
parsed = JSON.parse(responseText) as unknown;
|
||||
} catch (cause) {
|
||||
const message = cause instanceof Error ? cause.message : String(cause);
|
||||
return err(`invalid_response_json:${message}`);
|
||||
}
|
||||
if (!isRecord(parsed)) {
|
||||
return err("invalid_response_top_level");
|
||||
}
|
||||
const choices = parsed.choices;
|
||||
if (!Array.isArray(choices) || choices.length === 0) {
|
||||
return err("no_choices_in_response");
|
||||
}
|
||||
const firstChoice = choices[0];
|
||||
if (!isRecord(firstChoice)) {
|
||||
return err("invalid_choice");
|
||||
}
|
||||
const messageObj = firstChoice.message;
|
||||
if (!isRecord(messageObj)) {
|
||||
return err("invalid_message");
|
||||
}
|
||||
return ok(messageObj);
|
||||
}
|
||||
|
||||
function normalizeToolCalls(toolCallsRaw: unknown[]): Result<ToolCall[], string> {
|
||||
const toolCalls: ToolCall[] = [];
|
||||
for (const tc of toolCallsRaw) {
|
||||
if (!isRecord(tc)) {
|
||||
return err("invalid_tool_call");
|
||||
}
|
||||
const id = tc.id;
|
||||
const tcType = tc.type;
|
||||
const fn = tc.function;
|
||||
if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) {
|
||||
return err("invalid_tool_call_shape");
|
||||
}
|
||||
const name = fn.name;
|
||||
const argumentsStr = fn.arguments;
|
||||
if (typeof name !== "string" || typeof argumentsStr !== "string") {
|
||||
return err("invalid_tool_call_function");
|
||||
}
|
||||
toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } });
|
||||
}
|
||||
return ok(toolCalls);
|
||||
}
|
||||
|
||||
function classifyAssistantTurn<T extends Record<string, unknown>>(
|
||||
messageObj: Record<string, unknown>,
|
||||
schema: z.ZodType<T>,
|
||||
): Result<AssistantTurn<T>, string> {
|
||||
const toolCallsRaw = messageObj.tool_calls;
|
||||
if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) {
|
||||
const content = messageObj.content;
|
||||
if (typeof content !== "string") {
|
||||
return err("no_tool_calls_and_no_string_content");
|
||||
}
|
||||
const jsonParsed = tryParseJsonContent(content);
|
||||
if (jsonParsed === null) {
|
||||
return err("no_tool_calls_and_content_not_json");
|
||||
}
|
||||
const validated = schema.safeParse(jsonParsed);
|
||||
if (!validated.success) {
|
||||
return err(`schema_validation_failed:${validated.error.message}`);
|
||||
}
|
||||
return ok({ kind: "plain_json", value: validated.data });
|
||||
}
|
||||
const callsResult = normalizeToolCalls(toolCallsRaw);
|
||||
if (!callsResult.ok) {
|
||||
return err(callsResult.error);
|
||||
}
|
||||
const assistantContent = messageObj.content;
|
||||
return ok({
|
||||
kind: "tool_calls",
|
||||
calls: callsResult.value,
|
||||
assistantContent: typeof assistantContent === "string" ? assistantContent : null,
|
||||
});
|
||||
}
|
||||
|
||||
async function appendCasGetToolResult(
|
||||
tc: ToolCall,
|
||||
cas: CasStore,
|
||||
messages: ChatMessage[],
|
||||
): Promise<Result<null, string>> {
|
||||
let hash: string;
|
||||
try {
|
||||
const ta = JSON.parse(tc.function.arguments) as unknown;
|
||||
if (!isRecord(ta) || typeof ta.hash !== "string") {
|
||||
return err("cas_get_invalid_arguments");
|
||||
}
|
||||
hash = ta.hash;
|
||||
} catch {
|
||||
return err("cas_get_arguments_not_json");
|
||||
}
|
||||
const blob = await cas.get(hash);
|
||||
const toolContent = blob === null ? "null" : blob;
|
||||
messages.push({
|
||||
role: "tool",
|
||||
tool_call_id: tc.id,
|
||||
content: toolContent,
|
||||
});
|
||||
return ok(null);
|
||||
}
|
||||
|
||||
async function appendExtractToolResult<T extends Record<string, unknown>>(
|
||||
tc: ToolCall,
|
||||
schema: z.ZodType<T>,
|
||||
messages: ChatMessage[],
|
||||
): Promise<Result<T, string>> {
|
||||
let parsedArgs: unknown;
|
||||
try {
|
||||
parsedArgs = JSON.parse(tc.function.arguments) as unknown;
|
||||
} catch {
|
||||
return err("extract_tool_arguments_not_json");
|
||||
}
|
||||
const validated = schema.safeParse(parsedArgs);
|
||||
if (!validated.success) {
|
||||
return err(`schema_validation_failed:${validated.error.message}`);
|
||||
}
|
||||
messages.push({
|
||||
role: "tool",
|
||||
tool_call_id: tc.id,
|
||||
content: '{"ok":true}',
|
||||
});
|
||||
return ok(validated.data);
|
||||
}
|
||||
|
||||
async function appendToolResults<T extends Record<string, unknown>>(
|
||||
toolCalls: ToolCall[],
|
||||
extractToolName: string,
|
||||
schema: z.ZodType<T>,
|
||||
cas: CasStore,
|
||||
messages: ChatMessage[],
|
||||
): Promise<Result<T | null, string>> {
|
||||
let extracted: T | null = null;
|
||||
for (const tc of toolCalls) {
|
||||
if (tc.function.name === "cas_get") {
|
||||
const casRes = await appendCasGetToolResult(tc, cas, messages);
|
||||
if (!casRes.ok) {
|
||||
return casRes;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (tc.function.name === extractToolName) {
|
||||
const exRes = await appendExtractToolResult(tc, schema, messages);
|
||||
if (!exRes.ok) {
|
||||
return exRes;
|
||||
}
|
||||
extracted = exRes.value;
|
||||
continue;
|
||||
}
|
||||
return err(`unknown_tool:${tc.function.name}`);
|
||||
}
|
||||
return ok(extracted);
|
||||
}
|
||||
|
||||
async function postChatCompletion(
|
||||
provider: LlmProvider,
|
||||
messages: ChatMessage[],
|
||||
tools: readonly Record<string, unknown>[],
|
||||
): Promise<Result<string, string>> {
|
||||
try {
|
||||
const response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${provider.apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: provider.model,
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: "auto",
|
||||
}),
|
||||
});
|
||||
const responseText = await response.text();
|
||||
if (!response.ok) {
|
||||
return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`);
|
||||
}
|
||||
return ok(responseText);
|
||||
} catch (cause) {
|
||||
const message = cause instanceof Error ? cause.message : String(cause);
|
||||
return err(`network_error:${message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-turn ReAct extraction with `cas_get` plus a schema-shaped extract tool (OpenAI-compatible).
|
||||
* Final meta comes from a successful extract tool call or from plain JSON in the assistant message.
|
||||
*/
|
||||
export async function reactExtract<T extends Record<string, unknown>>(
|
||||
args: ReactExtractArgs<T>,
|
||||
): Promise<Result<T, string>> {
|
||||
const extractTool = extractFunctionToolFromZodSchema(args.schema);
|
||||
const tools = [
|
||||
CAS_GET_TOOL_DEFINITION,
|
||||
{
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: extractTool.name,
|
||||
description: extractTool.description,
|
||||
parameters: extractTool.parameters,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const systemContent = `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${extractTool.name} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`;
|
||||
|
||||
const messages: ChatMessage[] = [
|
||||
{ role: "system", content: systemContent },
|
||||
{ role: "user", content: args.text },
|
||||
];
|
||||
|
||||
for (let round = 0; round < MAX_REACT_ROUNDS; round++) {
|
||||
const bodyResult = await postChatCompletion(args.provider, messages, tools);
|
||||
if (!bodyResult.ok) {
|
||||
return bodyResult;
|
||||
}
|
||||
|
||||
const msgResult = firstAssistantMessage(bodyResult.value);
|
||||
if (!msgResult.ok) {
|
||||
return msgResult;
|
||||
}
|
||||
|
||||
const classified = classifyAssistantTurn(msgResult.value, args.schema);
|
||||
if (!classified.ok) {
|
||||
return classified;
|
||||
}
|
||||
|
||||
const turn = classified.value;
|
||||
if (turn.kind === "plain_json") {
|
||||
return ok(turn.value);
|
||||
}
|
||||
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: turn.assistantContent,
|
||||
tool_calls: turn.calls,
|
||||
});
|
||||
|
||||
const toolsRound = await appendToolResults(
|
||||
turn.calls,
|
||||
extractTool.name,
|
||||
args.schema,
|
||||
args.cas,
|
||||
messages,
|
||||
);
|
||||
if (!toolsRound.ok) {
|
||||
return toolsRound;
|
||||
}
|
||||
if (toolsRound.value !== null) {
|
||||
return ok(toolsRound.value);
|
||||
}
|
||||
}
|
||||
|
||||
return err("max_react_rounds_exceeded");
|
||||
}
|
||||
@@ -16,6 +16,9 @@ export type LlmProvider = {
|
||||
model: string;
|
||||
};
|
||||
|
||||
/** How the engine runs meta extraction for a role after the agent phase. */
|
||||
export type ExtractMode = "single" | "react";
|
||||
|
||||
/** What each generator yield produces — one role's output (engine adds `timestamp` when persisting). */
|
||||
export type RoleOutput = {
|
||||
role: string;
|
||||
@@ -121,6 +124,7 @@ export type RoleDefinition<Meta extends Record<string, unknown>> = {
|
||||
schema: z.ZodType<Meta>;
|
||||
/** When non-null, produces CAS hashes to persist on this role's steps (see `RoleOutput.refs`). */
|
||||
extractRefs: ((meta: Meta) => string[]) | null;
|
||||
extractMode: ExtractMode;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user