Merge pull request 'feat(workflow): ThreadReactor — generic ReAct loop + extract/supervisor migration' (#142) from feat/139-thread-reactor into main
This commit is contained in:
@@ -101,9 +101,9 @@ export function ThreadDetail({ threadId, onBack }: Props) {
|
|||||||
)}
|
)}
|
||||||
{(status === "ok" || liveActive || records.length > 0) && (
|
{(status === "ok" || liveActive || records.length > 0) && (
|
||||||
<div className="space-y-3">
|
<div className="space-y-3">
|
||||||
{records.map((r, i) => (
|
{records.map((r) => (
|
||||||
<div
|
<div
|
||||||
key={i}
|
key={`${threadId}-${r.type}-${String(r.timestamp)}-${r.role ?? ""}-${r.content ?? ""}`}
|
||||||
className="p-3 rounded border text-sm"
|
className="p-3 rounded border text-sm"
|
||||||
style={{ background: "var(--color-surface)", borderColor: "var(--color-border)" }}
|
style={{ background: "var(--color-surface)", borderColor: "var(--color-border)" }}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
import { describe, expect, test } from "bun:test";
|
import { describe, expect, test } from "bun:test";
|
||||||
import {
|
import { validateWorkflowDescriptor } from "@uncaged/workflow";
|
||||||
END,
|
import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime";
|
||||||
type ModeratorContext,
|
|
||||||
type RoleStep,
|
|
||||||
START,
|
|
||||||
validateWorkflowDescriptor,
|
|
||||||
} from "@uncaged/workflow-runtime";
|
|
||||||
import { buildDevelopDescriptor } from "../src/descriptor.js";
|
import { buildDevelopDescriptor } from "../src/descriptor.js";
|
||||||
import { developModerator } from "../src/index.js";
|
import { developModerator } from "../src/index.js";
|
||||||
import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js";
|
import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js";
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ import { afterEach, describe, expect, test } from "bun:test";
|
|||||||
import { mkdtemp, rm } from "node:fs/promises";
|
import { mkdtemp, rm } from "node:fs/promises";
|
||||||
import { tmpdir } from "node:os";
|
import { tmpdir } from "node:os";
|
||||||
import { join } from "node:path";
|
import { join } from "node:path";
|
||||||
import { createCasStore, createExtract, createWorkflow } from "@uncaged/workflow";
|
|
||||||
import {
|
import {
|
||||||
END,
|
createCasStore,
|
||||||
type ModeratorContext,
|
createExtract,
|
||||||
type RoleStep,
|
createWorkflow,
|
||||||
START,
|
|
||||||
validateWorkflowDescriptor,
|
validateWorkflowDescriptor,
|
||||||
} from "@uncaged/workflow-runtime";
|
} from "@uncaged/workflow";
|
||||||
|
import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime";
|
||||||
import { buildSolveIssueDescriptor } from "../src/descriptor.js";
|
import { buildSolveIssueDescriptor } from "../src/descriptor.js";
|
||||||
import type { DeveloperMeta } from "../src/developer.js";
|
import type { DeveloperMeta } from "../src/developer.js";
|
||||||
import { solveIssueModerator, solveIssueWorkflowDefinition } from "../src/index.js";
|
import { solveIssueModerator, solveIssueWorkflowDefinition } from "../src/index.js";
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import { createWorkflow, readWorkflowRegistry, executeThread } from "@uncaged/wo
|
|||||||
| **Registry** | `readWorkflowRegistry`, `writeWorkflowRegistry`, `registerWorkflowVersion`, `workflowRegistryPath`, YAML helpers |
|
| **Registry** | `readWorkflowRegistry`, `writeWorkflowRegistry`, `registerWorkflowVersion`, `workflowRegistryPath`, YAML helpers |
|
||||||
| **CAS** | `createCasStore`, Merkle helpers (`putStepMerkleNode`, `getContentMerklePayload`, …), `hashWorkflowBundleBytes` |
|
| **CAS** | `createCasStore`, Merkle helpers (`putStepMerkleNode`, `getContentMerklePayload`, …), `hashWorkflowBundleBytes` |
|
||||||
| **Engine** | `createWorkflow`, `executeThread`, `parseThreadDataJsonl`, fork helpers, `garbageCollectCas` |
|
| **Engine** | `createWorkflow`, `executeThread`, `parseThreadDataJsonl`, fork helpers, `garbageCollectCas` |
|
||||||
| **Extract / LLM tools** | `llmExtract`, `reactExtract`, `createExtract`, `getExtractProvider` |
|
| **Extract / LLM tools** | `llmExtract`, `createExtract`, `createThreadReactor`, `createLlmFn`, `getExtractProvider` |
|
||||||
| **Agent bridge** | `workflowAsAgent` — expose a registered workflow as an agent-backed role |
|
| **Agent bridge** | `workflowAsAgent` — expose a registered workflow as an agent-backed role |
|
||||||
| **Utilities** | `createLogger`, ULID / Crockford Base32 codecs, `getDefaultWorkflowStorageRoot`, paths |
|
| **Utilities** | `createLogger`, ULID / Crockford Base32 codecs, `getDefaultWorkflowStorageRoot`, paths |
|
||||||
|
|
||||||
|
|||||||
@@ -101,10 +101,10 @@ async function writeRegistryYaml(storageRoot: string, yaml: string): Promise<voi
|
|||||||
await writeFile(join(storageRoot, "workflow.yaml"), yaml, "utf8");
|
await writeFile(join(storageRoot, "workflow.yaml"), yaml, "utf8");
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Extract rounds reply with schema-shaped JSON in `content`; supervisor uses plain `content` (no tools advertised). */
|
/** Extract and supervisor both run via {@link createThreadReactor}; differentiate by `body.model`. */
|
||||||
function installMockExtractThenSupervisor(params: {
|
function installMockExtractThenSupervisor(params: {
|
||||||
extractArgs: ReadonlyArray<Record<string, unknown>>;
|
extractArgs: ReadonlyArray<Record<string, unknown>>;
|
||||||
supervisorContent: string;
|
supervisorDecision: "continue" | "stop";
|
||||||
onSupervisorCall?: () => void;
|
onSupervisorCall?: () => void;
|
||||||
}): () => void {
|
}): () => void {
|
||||||
const origFetch = globalThis.fetch;
|
const origFetch = globalThis.fetch;
|
||||||
@@ -114,9 +114,9 @@ function installMockExtractThenSupervisor(params: {
|
|||||||
init?: RequestInit,
|
init?: RequestInit,
|
||||||
): Promise<Response> => {
|
): Promise<Response> => {
|
||||||
const body = init?.body ? (JSON.parse(String(init.body)) as Record<string, unknown>) : {};
|
const body = init?.body ? (JSON.parse(String(init.body)) as Record<string, unknown>) : {};
|
||||||
const tools = body.tools;
|
const model = typeof body.model === "string" ? body.model : "";
|
||||||
const hasTools = Array.isArray(tools) && tools.length > 0;
|
const isSupervisor = model.startsWith("supervisor-");
|
||||||
if (hasTools) {
|
if (!isSupervisor) {
|
||||||
const args =
|
const args =
|
||||||
params.extractArgs[extractI] ?? params.extractArgs[params.extractArgs.length - 1];
|
params.extractArgs[extractI] ?? params.extractArgs[params.extractArgs.length - 1];
|
||||||
if (args === undefined) {
|
if (args === undefined) {
|
||||||
@@ -133,7 +133,9 @@ function installMockExtractThenSupervisor(params: {
|
|||||||
params.onSupervisorCall?.();
|
params.onSupervisorCall?.();
|
||||||
return new Response(
|
return new Response(
|
||||||
JSON.stringify({
|
JSON.stringify({
|
||||||
choices: [{ message: { content: params.supervisorContent } }],
|
choices: [
|
||||||
|
{ message: { content: JSON.stringify({ decision: params.supervisorDecision }) } },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||||
);
|
);
|
||||||
@@ -674,7 +676,7 @@ describe("executeThread", () => {
|
|||||||
test("supervisor stops thread when interval elapses and model returns stop", async () => {
|
test("supervisor stops thread when interval elapses and model returns stop", async () => {
|
||||||
restoreFetch = installMockExtractThenSupervisor({
|
restoreFetch = installMockExtractThenSupervisor({
|
||||||
extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }],
|
extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }],
|
||||||
supervisorContent: "stop",
|
supervisorDecision: "stop",
|
||||||
});
|
});
|
||||||
|
|
||||||
const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-stop-"));
|
const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-stop-"));
|
||||||
@@ -725,7 +727,7 @@ describe("executeThread", () => {
|
|||||||
let supervisorCalls = 0;
|
let supervisorCalls = 0;
|
||||||
restoreFetch = installMockExtractThenSupervisor({
|
restoreFetch = installMockExtractThenSupervisor({
|
||||||
extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }],
|
extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }],
|
||||||
supervisorContent: "stop",
|
supervisorDecision: "stop",
|
||||||
onSupervisorCall: () => {
|
onSupervisorCall: () => {
|
||||||
supervisorCalls += 1;
|
supervisorCalls += 1;
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { afterEach, describe, expect, test } from "bun:test";
|
import { afterEach, describe, expect, test } from "bun:test";
|
||||||
|
|
||||||
import { parseSupervisorDecisionText, runSupervisor } from "../src/engine/supervisor.js";
|
import { runSupervisor } from "../src/engine/supervisor.js";
|
||||||
import type { WorkflowConfig } from "../src/registry/index.js";
|
import type { WorkflowConfig } from "../src/registry/index.js";
|
||||||
import type { LogFn } from "../src/util/index.js";
|
import type { LogFn } from "../src/util/index.js";
|
||||||
|
|
||||||
@@ -20,28 +20,23 @@ function supervisorOnlyConfig(): WorkflowConfig {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
describe("parseSupervisorDecisionText", () => {
|
function jsonResponse(body: Record<string, unknown>, status = 200): Response {
|
||||||
test("reads continue and stop case-insensitively", () => {
|
return new Response(JSON.stringify(body), {
|
||||||
expect(parseSupervisorDecisionText("continue")).toBe("continue");
|
status,
|
||||||
expect(parseSupervisorDecisionText("CONTINUE")).toBe("continue");
|
headers: { "Content-Type": "application/json" },
|
||||||
expect(parseSupervisorDecisionText("stop")).toBe("stop");
|
|
||||||
expect(parseSupervisorDecisionText("STOP.")).toBe("stop");
|
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
test("finds token inside a sentence", () => {
|
function installFetchMock(impl: (init?: RequestInit) => Promise<Response>): () => void {
|
||||||
expect(parseSupervisorDecisionText("Answer: continue")).toBe("continue");
|
const origFetch = globalThis.fetch;
|
||||||
expect(parseSupervisorDecisionText("I recommend stop now")).toBe("stop");
|
globalThis.fetch = Object.assign(
|
||||||
});
|
async (_input: Parameters<typeof fetch>[0], init?: RequestInit) => impl(init),
|
||||||
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
test("when both appear, earlier token wins", () => {
|
) as typeof fetch;
|
||||||
expect(parseSupervisorDecisionText("continue then stop")).toBe("continue");
|
return () => {
|
||||||
expect(parseSupervisorDecisionText("stop then continue")).toBe("stop");
|
globalThis.fetch = origFetch;
|
||||||
});
|
};
|
||||||
|
}
|
||||||
test("defaults to continue when unclear", () => {
|
|
||||||
expect(parseSupervisorDecisionText("maybe later")).toBe("continue");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("runSupervisor", () => {
|
describe("runSupervisor", () => {
|
||||||
let restoreFetch: (() => void) | null = null;
|
let restoreFetch: (() => void) | null = null;
|
||||||
@@ -52,16 +47,9 @@ describe("runSupervisor", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("returns continue when supervisor model cannot be resolved (no fetch)", async () => {
|
test("returns continue when supervisor model cannot be resolved (no fetch)", async () => {
|
||||||
const origFetch = globalThis.fetch;
|
restoreFetch = installFetchMock(async () => {
|
||||||
restoreFetch = () => {
|
throw new Error("fetch should not run when supervisor is not configured");
|
||||||
globalThis.fetch = origFetch;
|
});
|
||||||
};
|
|
||||||
globalThis.fetch = Object.assign(
|
|
||||||
async () => {
|
|
||||||
throw new Error("fetch should not run when supervisor is not configured");
|
|
||||||
},
|
|
||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
|
||||||
) as typeof fetch;
|
|
||||||
|
|
||||||
const config: WorkflowConfig = {
|
const config: WorkflowConfig = {
|
||||||
maxDepth: 1,
|
maxDepth: 1,
|
||||||
@@ -87,21 +75,27 @@ describe("runSupervisor", () => {
|
|||||||
expect(r.value).toBe("continue");
|
expect(r.value).toBe("continue");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("returns stop from chat/completions assistant content", async () => {
|
test("returns stop from structured tool call", async () => {
|
||||||
const origFetch = globalThis.fetch;
|
restoreFetch = installFetchMock(async () =>
|
||||||
restoreFetch = () => {
|
jsonResponse({
|
||||||
globalThis.fetch = origFetch;
|
choices: [
|
||||||
};
|
{
|
||||||
globalThis.fetch = Object.assign(
|
message: {
|
||||||
async () =>
|
tool_calls: [
|
||||||
new Response(
|
{
|
||||||
JSON.stringify({
|
id: "t1",
|
||||||
choices: [{ message: { content: "stop" } }],
|
type: "function",
|
||||||
}),
|
function: {
|
||||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
name: "supervisor_decision",
|
||||||
),
|
arguments: JSON.stringify({ decision: "stop" }),
|
||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
},
|
||||||
) as typeof fetch;
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
const r = await runSupervisor({
|
const r = await runSupervisor({
|
||||||
config: supervisorOnlyConfig(),
|
config: supervisorOnlyConfig(),
|
||||||
@@ -116,14 +110,44 @@ describe("runSupervisor", () => {
|
|||||||
expect(r.value).toBe("stop");
|
expect(r.value).toBe("stop");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("returns err on invalid JSON body", async () => {
|
test("returns continue from plain JSON content (reactor short-circuit)", async () => {
|
||||||
const origFetch = globalThis.fetch;
|
restoreFetch = installFetchMock(async () =>
|
||||||
restoreFetch = () => {
|
jsonResponse({
|
||||||
globalThis.fetch = origFetch;
|
choices: [{ message: { content: '{"decision":"continue"}' } }],
|
||||||
};
|
}),
|
||||||
globalThis.fetch = Object.assign(async () => new Response("not-json", { status: 200 }), {
|
);
|
||||||
preconnect: origFetch.preconnect.bind(origFetch),
|
|
||||||
}) as typeof fetch;
|
const r = await runSupervisor({
|
||||||
|
config: supervisorOnlyConfig(),
|
||||||
|
prompt: "do Y",
|
||||||
|
recentSteps: [],
|
||||||
|
logger: noopLogger,
|
||||||
|
});
|
||||||
|
expect(r.ok).toBe(true);
|
||||||
|
if (!r.ok) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
expect(r.value).toBe("continue");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("returns err when reactor cannot validate the schema within max rounds", async () => {
|
||||||
|
restoreFetch = installFetchMock(async () =>
|
||||||
|
jsonResponse({
|
||||||
|
choices: [{ message: { content: "not-json" } }],
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const r = await runSupervisor({
|
||||||
|
config: supervisorOnlyConfig(),
|
||||||
|
prompt: "p",
|
||||||
|
recentSteps: [],
|
||||||
|
logger: noopLogger,
|
||||||
|
});
|
||||||
|
expect(r.ok).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("returns err on HTTP failure", async () => {
|
||||||
|
restoreFetch = installFetchMock(async () => new Response("boom", { status: 500 }));
|
||||||
|
|
||||||
const r = await runSupervisor({
|
const r = await runSupervisor({
|
||||||
config: supervisorOnlyConfig(),
|
config: supervisorOnlyConfig(),
|
||||||
|
|||||||
+68
-17
@@ -6,7 +6,8 @@ import type { LlmProvider } from "@uncaged/workflow-runtime";
|
|||||||
import * as z from "zod/v4";
|
import * as z from "zod/v4";
|
||||||
import { createCasStore } from "../src/cas/cas.js";
|
import { createCasStore } from "../src/cas/cas.js";
|
||||||
import { createContentMerkleNode, serializeMerkleNode } from "../src/cas/merkle.js";
|
import { createContentMerkleNode, serializeMerkleNode } from "../src/cas/merkle.js";
|
||||||
import { reactExtract } from "../src/extract/react-extract.js";
|
import { extractFunctionToolFromZodSchema } from "../src/extract/llm-extract.js";
|
||||||
|
import { createLlmFn, createThreadReactor } from "../src/reactor/index.js";
|
||||||
|
|
||||||
const metaSchema = z.object({ seen: z.string() });
|
const metaSchema = z.object({ seen: z.string() });
|
||||||
|
|
||||||
@@ -16,7 +17,57 @@ const provider: LlmProvider = {
|
|||||||
model: "test",
|
model: "test",
|
||||||
};
|
};
|
||||||
|
|
||||||
describe("reactExtract", () => {
|
const CAS_GET_TOOL_DEFINITION = {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: "cas_get",
|
||||||
|
description: "Read CAS node",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
hash: { type: "string", description: "hash" },
|
||||||
|
},
|
||||||
|
required: ["hash"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
type ThreadCtx = { cas: ReturnType<typeof createCasStore> };
|
||||||
|
|
||||||
|
function createTestReactor() {
|
||||||
|
const llm = createLlmFn(provider);
|
||||||
|
return createThreadReactor<ThreadCtx>({
|
||||||
|
llm,
|
||||||
|
maxRounds: 10,
|
||||||
|
staticTools: [CAS_GET_TOOL_DEFINITION],
|
||||||
|
structuredToolFromSchema: (schema) => {
|
||||||
|
const t = extractFunctionToolFromZodSchema(schema);
|
||||||
|
return {
|
||||||
|
name: t.name,
|
||||||
|
tool: {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
systemPromptForStructuredTool: (structuredToolName) =>
|
||||||
|
`Extract metadata. Use cas_get when needed. Call ${structuredToolName} with JSON args matching the schema, or reply with plain JSON.`,
|
||||||
|
toolHandler: async (call, thread) => {
|
||||||
|
if (call.function.name !== "cas_get") {
|
||||||
|
return `unexpected tool ${call.function.name}`;
|
||||||
|
}
|
||||||
|
const ta = JSON.parse(call.function.arguments) as { hash: string };
|
||||||
|
const blob = await thread.cas.get(ta.hash);
|
||||||
|
return blob === null ? "null" : blob;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("createThreadReactor (extract-shaped)", () => {
|
||||||
let restoreFetch: (() => void) | null = null;
|
let restoreFetch: (() => void) | null = null;
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@@ -25,7 +76,7 @@ describe("reactExtract", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("cas_get rounds then extract tool yields validated meta", async () => {
|
test("cas_get rounds then extract tool yields validated meta", async () => {
|
||||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-"));
|
const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-"));
|
||||||
const cas = createCasStore(casDir);
|
const cas = createCasStore(casDir);
|
||||||
try {
|
try {
|
||||||
const blob = serializeMerkleNode(createContentMerkleNode("needle"));
|
const blob = serializeMerkleNode(createContentMerkleNode("needle"));
|
||||||
@@ -87,12 +138,12 @@ describe("reactExtract", () => {
|
|||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
) as typeof fetch;
|
) as typeof fetch;
|
||||||
|
|
||||||
|
const reactor = createTestReactor();
|
||||||
const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`;
|
const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`;
|
||||||
const result = await reactExtract({
|
const result = await reactor({
|
||||||
text,
|
thread: { cas },
|
||||||
|
input: text,
|
||||||
schema: metaSchema,
|
schema: metaSchema,
|
||||||
provider,
|
|
||||||
cas,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.ok).toBe(true);
|
expect(result.ok).toBe(true);
|
||||||
@@ -107,7 +158,7 @@ describe("reactExtract", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("stops after max tool rounds when model keeps calling cas_get", async () => {
|
test("stops after max tool rounds when model keeps calling cas_get", async () => {
|
||||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-max-"));
|
const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-max-"));
|
||||||
const cas = createCasStore(casDir);
|
const cas = createCasStore(casDir);
|
||||||
try {
|
try {
|
||||||
const blob = serializeMerkleNode(createContentMerkleNode("x"));
|
const blob = serializeMerkleNode(createContentMerkleNode("x"));
|
||||||
@@ -146,11 +197,11 @@ describe("reactExtract", () => {
|
|||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
) as typeof fetch;
|
) as typeof fetch;
|
||||||
|
|
||||||
const result = await reactExtract({
|
const reactor = createTestReactor();
|
||||||
text: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.",
|
const result = await reactor({
|
||||||
|
thread: { cas },
|
||||||
|
input: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.",
|
||||||
schema: metaSchema,
|
schema: metaSchema,
|
||||||
provider,
|
|
||||||
cas,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.ok).toBe(false);
|
expect(result.ok).toBe(false);
|
||||||
@@ -165,7 +216,7 @@ describe("reactExtract", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("passthrough JSON assistant message without tool calls", async () => {
|
test("passthrough JSON assistant message without tool calls", async () => {
|
||||||
const casDir = await mkdtemp(join(tmpdir(), "react-extract-pass-"));
|
const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-pass-"));
|
||||||
const cas = createCasStore(casDir);
|
const cas = createCasStore(casDir);
|
||||||
try {
|
try {
|
||||||
const origFetch = globalThis.fetch;
|
const origFetch = globalThis.fetch;
|
||||||
@@ -189,11 +240,11 @@ describe("reactExtract", () => {
|
|||||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||||
) as typeof fetch;
|
) as typeof fetch;
|
||||||
|
|
||||||
const result = await reactExtract({
|
const reactor = createTestReactor();
|
||||||
text: "## Agent Output\nok\n## Extraction Instruction\nExtract.",
|
const result = await reactor({
|
||||||
|
thread: { cas },
|
||||||
|
input: "## Agent Output\nok\n## Extraction Instruction\nExtract.",
|
||||||
schema: metaSchema,
|
schema: metaSchema,
|
||||||
provider,
|
|
||||||
cas,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.ok).toBe(true);
|
expect(result.ok).toBe(true);
|
||||||
@@ -1,67 +1,27 @@
|
|||||||
|
import * as z from "zod/v4";
|
||||||
|
|
||||||
import { resolveModel } from "../config/index.js";
|
import { resolveModel } from "../config/index.js";
|
||||||
|
import { extractFunctionToolFromZodSchema } from "../extract/index.js";
|
||||||
|
import { createLlmFn, createThreadReactor } from "../reactor/index.js";
|
||||||
import type { WorkflowConfig } from "../registry/index.js";
|
import type { WorkflowConfig } from "../registry/index.js";
|
||||||
import { err, type LogFn, ok, type Result } from "../util/index.js";
|
import { err, type LogFn, ok, type Result } from "../util/index.js";
|
||||||
|
|
||||||
import type { SupervisorDecision } from "./types.js";
|
import type { SupervisorDecision } from "./types.js";
|
||||||
|
|
||||||
const SUPERVISOR_RECENT_STEP_LIMIT = 12;
|
const SUPERVISOR_RECENT_STEP_LIMIT = 12;
|
||||||
|
const SUPERVISOR_MAX_REACT_ROUNDS = 4;
|
||||||
|
|
||||||
function chatCompletionsUrl(baseUrl: string): string {
|
const supervisorDecisionSchema = z
|
||||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
.object({
|
||||||
return `${trimmed}/chat/completions`;
|
decision: z.enum(["continue", "stop"]),
|
||||||
}
|
})
|
||||||
|
.meta({
|
||||||
|
title: "supervisor_decision",
|
||||||
|
description:
|
||||||
|
'Workflow supervisor decision. "continue" when the thread is making progress; "stop" when done, looping, or stuck.',
|
||||||
|
});
|
||||||
|
|
||||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
type SupervisorThreadContext = Record<string, never>;
|
||||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
function readAssistantContent(parsed: unknown): string | null {
|
|
||||||
if (!isRecord(parsed)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const choices = parsed.choices;
|
|
||||||
if (!Array.isArray(choices) || choices.length === 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const first = choices[0];
|
|
||||||
if (!isRecord(first)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const messageObj = first.message;
|
|
||||||
if (!isRecord(messageObj)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
const content = messageObj.content;
|
|
||||||
if (typeof content !== "string") {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return content;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Lenient: accepts STOP/stop/stop. as prose; prefers {@link SupervisorDecision.stop} when both tokens appear. */
|
|
||||||
export function parseSupervisorDecisionText(text: string): SupervisorDecision {
|
|
||||||
const lower = text.toLowerCase();
|
|
||||||
const stopWord = /\bstop\b/.test(lower);
|
|
||||||
const continueWord = /\bcontinue\b/.test(lower);
|
|
||||||
if (stopWord && continueWord) {
|
|
||||||
const si = lower.search(/\bstop\b/);
|
|
||||||
const ci = lower.search(/\bcontinue\b/);
|
|
||||||
return si <= ci ? "stop" : "continue";
|
|
||||||
}
|
|
||||||
if (stopWord) {
|
|
||||||
return "stop";
|
|
||||||
}
|
|
||||||
if (continueWord) {
|
|
||||||
return "continue";
|
|
||||||
}
|
|
||||||
if (lower.includes("stop")) {
|
|
||||||
return "stop";
|
|
||||||
}
|
|
||||||
if (lower.includes("continue")) {
|
|
||||||
return "continue";
|
|
||||||
}
|
|
||||||
return "continue";
|
|
||||||
}
|
|
||||||
|
|
||||||
type RunSupervisorArgs = {
|
type RunSupervisorArgs = {
|
||||||
config: WorkflowConfig;
|
config: WorkflowConfig;
|
||||||
@@ -70,7 +30,13 @@ type RunSupervisorArgs = {
|
|||||||
logger: LogFn;
|
logger: LogFn;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Calls the `supervisor` scene LLM; opt-out when {@link resolveModel} fails (returns ok(`continue`)). */
|
function buildSupervisorInput(args: RunSupervisorArgs): string {
|
||||||
|
const recent = args.recentSteps.slice(-SUPERVISOR_RECENT_STEP_LIMIT);
|
||||||
|
const stepsBlock = recent.map((s, index) => `${index + 1}. [${s.role}] ${s.summary}`).join("\n");
|
||||||
|
return `Original task:\n${args.prompt}\n\nRecent steps (oldest first):\n${stepsBlock === "" ? "(none)" : stepsBlock}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Calls the `supervisor` scene via {@link createThreadReactor}; opt-out when {@link resolveModel} fails (returns ok(`continue`)). */
|
||||||
export async function runSupervisor(
|
export async function runSupervisor(
|
||||||
args: RunSupervisorArgs,
|
args: RunSupervisorArgs,
|
||||||
): Promise<Result<SupervisorDecision, string>> {
|
): Promise<Result<SupervisorDecision, string>> {
|
||||||
@@ -78,63 +44,42 @@ export async function runSupervisor(
|
|||||||
if (!resolved.ok) {
|
if (!resolved.ok) {
|
||||||
return ok("continue");
|
return ok("continue");
|
||||||
}
|
}
|
||||||
const provider = resolved.value;
|
|
||||||
const recent = args.recentSteps.slice(-SUPERVISOR_RECENT_STEP_LIMIT);
|
|
||||||
const stepsBlock = recent.map((s, index) => `${index + 1}. [${s.role}] ${s.summary}`).join("\n");
|
|
||||||
|
|
||||||
const body = {
|
const reactor = createThreadReactor<SupervisorThreadContext>({
|
||||||
model: provider.model,
|
llm: createLlmFn(resolved.value),
|
||||||
messages: [
|
maxRounds: SUPERVISOR_MAX_REACT_ROUNDS,
|
||||||
{
|
staticTools: [],
|
||||||
role: "system" as const,
|
structuredToolFromSchema: (schema) => {
|
||||||
content:
|
const t = extractFunctionToolFromZodSchema(schema);
|
||||||
'You supervise a multi-step workflow. Decide if the thread should keep running or halt.\n\nReply with exactly one token: either "continue" (progress toward the goal, not obviously stuck) or "stop" (done, looping, or no progress). Do not add explanation.',
|
return {
|
||||||
},
|
name: t.name,
|
||||||
{
|
tool: {
|
||||||
role: "user" as const,
|
type: "function" as const,
|
||||||
content: `Original task:\n${args.prompt}\n\nRecent steps (oldest first):\n${stepsBlock === "" ? "(none)" : stepsBlock}`,
|
function: {
|
||||||
},
|
name: t.name,
|
||||||
],
|
description: t.description,
|
||||||
};
|
parameters: t.parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
systemPromptForStructuredTool: (structuredToolName) =>
|
||||||
|
`You supervise a multi-step workflow. Decide whether the thread should keep running or halt. Reply with "continue" when the thread is making progress toward the task, or "stop" when it is finished, looping, or no longer making progress. Call the ${structuredToolName} tool with JSON arguments matching the schema, or reply with only a JSON object such as {"decision":"stop"}.`,
|
||||||
|
toolHandler: async (call) => `Unknown tool: ${call.function.name}`,
|
||||||
|
});
|
||||||
|
|
||||||
let response: Response;
|
const result = await reactor({
|
||||||
try {
|
thread: {} as SupervisorThreadContext,
|
||||||
response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
input: buildSupervisorInput(args),
|
||||||
method: "POST",
|
schema: supervisorDecisionSchema,
|
||||||
headers: {
|
});
|
||||||
Authorization: `Bearer ${provider.apiKey}`,
|
|
||||||
"Content-Type": "application/json",
|
if (!result.ok) {
|
||||||
},
|
args.logger("R9CW4PLM", `supervisor failed: ${result.error}`);
|
||||||
body: JSON.stringify(body),
|
return err(`supervisor: ${result.error}`);
|
||||||
});
|
|
||||||
} catch (cause) {
|
|
||||||
const message = cause instanceof Error ? cause.message : String(cause);
|
|
||||||
args.logger("R9CW4PLM", `supervisor request failed: ${message}`);
|
|
||||||
return err(`supervisor network error: ${message}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const responseText = await response.text();
|
const decision: SupervisorDecision = result.value.decision;
|
||||||
if (!response.ok) {
|
|
||||||
args.logger("T3HN8VKQ", `supervisor HTTP ${response.status}: ${responseText.slice(0, 200)}`);
|
|
||||||
return err(`supervisor HTTP ${response.status}: ${responseText.slice(0, 500)}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: unknown;
|
|
||||||
try {
|
|
||||||
parsed = JSON.parse(responseText) as unknown;
|
|
||||||
} catch (cause) {
|
|
||||||
const message = cause instanceof Error ? cause.message : String(cause);
|
|
||||||
args.logger("W7BQ2NXM", `supervisor response is not JSON: ${message}`);
|
|
||||||
return err(`supervisor invalid JSON: ${message}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const content = readAssistantContent(parsed);
|
|
||||||
if (content === null || content.trim() === "") {
|
|
||||||
args.logger("Y4JX9PKW", "supervisor returned empty assistant content");
|
|
||||||
return err("supervisor empty assistant content");
|
|
||||||
}
|
|
||||||
|
|
||||||
const decision = parseSupervisorDecisionText(content);
|
|
||||||
args.logger("Z8KM5QWT", `supervisor says ${decision}`);
|
args.logger("Z8KM5QWT", `supervisor says ${decision}`);
|
||||||
return ok(decision);
|
return ok(decision);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,39 @@
|
|||||||
import type { ExtractContext, ExtractFn, LlmProvider } from "@uncaged/workflow-runtime";
|
import type { ExtractContext, ExtractFn, LlmProvider } from "@uncaged/workflow-runtime";
|
||||||
import type * as z from "zod/v4";
|
import type * as z from "zod/v4";
|
||||||
import { type CasStore, getContentMerklePayload } from "../cas/index.js";
|
import { type CasStore, getContentMerklePayload } from "../cas/index.js";
|
||||||
import { reactExtract } from "./react-extract.js";
|
import { createLlmFn, createThreadReactor } from "../reactor/index.js";
|
||||||
|
import { extractFunctionToolFromZodSchema } from "./llm-extract.js";
|
||||||
|
|
||||||
export type ExtractDeps = {
|
export type ExtractDeps = {
|
||||||
cas: CasStore;
|
cas: CasStore;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const MAX_REACT_ROUNDS = 10;
|
||||||
|
|
||||||
|
const CAS_GET_TOOL_DEFINITION = {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: "cas_get",
|
||||||
|
description:
|
||||||
|
"Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
hash: { type: "string", description: "The CAS hash to retrieve" },
|
||||||
|
},
|
||||||
|
required: ["hash"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ExtractThreadContext = {
|
||||||
|
cas: CasStore;
|
||||||
|
};
|
||||||
|
|
||||||
|
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||||
|
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||||
|
}
|
||||||
|
|
||||||
/** Builds the user-side extraction prompt (thread + agent output + instruction). */
|
/** Builds the user-side extraction prompt (thread + agent output + instruction). */
|
||||||
export async function buildExtractUserContent(
|
export async function buildExtractUserContent(
|
||||||
ctx: ExtractContext,
|
ctx: ExtractContext,
|
||||||
@@ -46,17 +73,61 @@ export async function buildExtractUserContent(
|
|||||||
* Create an ExtractFn backed by an LLM provider.
|
* Create an ExtractFn backed by an LLM provider.
|
||||||
*
|
*
|
||||||
* Internally runs a multi-turn ReAct loop with two tools (`cas_get` for traversing the
|
* Internally runs a multi-turn ReAct loop with two tools (`cas_get` for traversing the
|
||||||
* Merkle DAG and a schema-shaped `extract` tool); the loop also accepts a plain-JSON
|
* Merkle DAG and a schema-shaped extract tool); the loop also accepts a plain-JSON
|
||||||
* assistant reply as a short-circuit, which covers the legacy "single" extraction path.
|
* assistant reply as a short-circuit, which covers the legacy "single" extraction path.
|
||||||
*/
|
*/
|
||||||
export function createExtract(provider: LlmProvider, deps: ExtractDeps): ExtractFn {
|
export function createExtract(provider: LlmProvider, deps: ExtractDeps): ExtractFn {
|
||||||
|
const llm = createLlmFn(provider);
|
||||||
|
const reactor = createThreadReactor<ExtractThreadContext>({
|
||||||
|
llm,
|
||||||
|
maxRounds: MAX_REACT_ROUNDS,
|
||||||
|
staticTools: [CAS_GET_TOOL_DEFINITION],
|
||||||
|
structuredToolFromSchema: (schema) => {
|
||||||
|
const t = extractFunctionToolFromZodSchema(schema);
|
||||||
|
return {
|
||||||
|
name: t.name,
|
||||||
|
tool: {
|
||||||
|
type: "function" as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
systemPromptForStructuredTool: (structuredToolName) =>
|
||||||
|
`You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${structuredToolName} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`,
|
||||||
|
toolHandler: async (call, thread) => {
|
||||||
|
if (call.function.name !== "cas_get") {
|
||||||
|
return `Unexpected tool routed to handler: ${call.function.name}`;
|
||||||
|
}
|
||||||
|
let hash: string;
|
||||||
|
try {
|
||||||
|
const ta = JSON.parse(call.function.arguments) as unknown;
|
||||||
|
if (!isRecord(ta) || typeof ta.hash !== "string") {
|
||||||
|
return 'cas_get requires a JSON object with a string "hash" field.';
|
||||||
|
}
|
||||||
|
hash = ta.hash;
|
||||||
|
} catch {
|
||||||
|
return 'cas_get arguments were not valid JSON. Provide {"hash": "<cas-hash>"}.';
|
||||||
|
}
|
||||||
|
const blob = await thread.cas.get(hash);
|
||||||
|
return blob === null ? "null" : blob;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return async <T extends Record<string, unknown>>(
|
return async <T extends Record<string, unknown>>(
|
||||||
schema: z.ZodType<T>,
|
schema: z.ZodType<T>,
|
||||||
prompt: string,
|
prompt: string,
|
||||||
ctx: ExtractContext,
|
ctx: ExtractContext,
|
||||||
): Promise<T> => {
|
): Promise<T> => {
|
||||||
const text = await buildExtractUserContent(ctx, prompt, deps);
|
const text = await buildExtractUserContent(ctx, prompt, deps);
|
||||||
const result = await reactExtract({ text, schema, provider, cas: deps.cas });
|
const result = await reactor({
|
||||||
|
thread: { cas: deps.cas },
|
||||||
|
input: text,
|
||||||
|
schema,
|
||||||
|
});
|
||||||
if (!result.ok) {
|
if (!result.ok) {
|
||||||
throw new Error(`extract failed: ${result.error}`);
|
throw new Error(`extract failed: ${result.error}`);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
export {
|
export {
|
||||||
buildExtractUserContent,
|
buildExtractUserContent,
|
||||||
createExtract,
|
createExtract,
|
||||||
|
type ExtractThreadContext,
|
||||||
} from "./extract-fn.js";
|
} from "./extract-fn.js";
|
||||||
export {
|
export {
|
||||||
extractFunctionToolFromZodSchema,
|
extractFunctionToolFromZodSchema,
|
||||||
llmErrorToCause,
|
llmErrorToCause,
|
||||||
llmExtract,
|
llmExtract,
|
||||||
} from "./llm-extract.js";
|
} from "./llm-extract.js";
|
||||||
export { reactExtract } from "./react-extract.js";
|
export type { ExtractFn, LlmError, LlmExtractArgs } from "./types.js";
|
||||||
export type {
|
|
||||||
ExtractFn,
|
|
||||||
LlmError,
|
|
||||||
LlmExtractArgs,
|
|
||||||
ReactExtractArgs,
|
|
||||||
} from "./types.js";
|
|
||||||
|
|||||||
@@ -1,343 +0,0 @@
|
|||||||
import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime";
|
|
||||||
import type * as z from "zod/v4";
|
|
||||||
import { err, ok, type Result } from "../util/index.js";
|
|
||||||
|
|
||||||
import { extractFunctionToolFromZodSchema } from "./llm-extract.js";
|
|
||||||
import type { ReactExtractArgs } from "./types.js";
|
|
||||||
|
|
||||||
const MAX_REACT_ROUNDS = 10;
|
|
||||||
|
|
||||||
const CAS_GET_TOOL_DEFINITION = {
|
|
||||||
type: "function" as const,
|
|
||||||
function: {
|
|
||||||
name: "cas_get",
|
|
||||||
description:
|
|
||||||
"Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.",
|
|
||||||
parameters: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
hash: { type: "string", description: "The CAS hash to retrieve" },
|
|
||||||
},
|
|
||||||
required: ["hash"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
function chatCompletionsUrl(baseUrl: string): string {
|
|
||||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
|
||||||
return `${trimmed}/chat/completions`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
|
||||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
function tryParseJsonContent(content: string): unknown | null {
|
|
||||||
const trimmed = content.trim();
|
|
||||||
const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed);
|
|
||||||
const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed;
|
|
||||||
try {
|
|
||||||
return JSON.parse(payload) as unknown;
|
|
||||||
} catch {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall = {
|
|
||||||
id: string;
|
|
||||||
type: "function";
|
|
||||||
function: { name: string; arguments: string };
|
|
||||||
};
|
|
||||||
|
|
||||||
type ChatMessage =
|
|
||||||
| { role: "system"; content: string }
|
|
||||||
| { role: "user"; content: string }
|
|
||||||
| {
|
|
||||||
role: "assistant";
|
|
||||||
content: string | null;
|
|
||||||
tool_calls: ToolCall[];
|
|
||||||
}
|
|
||||||
| { role: "assistant"; content: string }
|
|
||||||
| { role: "tool"; tool_call_id: string; content: string };
|
|
||||||
|
|
||||||
type AssistantTurn<T> =
|
|
||||||
| { kind: "plain_json"; value: T }
|
|
||||||
| { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null };
|
|
||||||
|
|
||||||
function firstAssistantMessage(responseText: string): Result<Record<string, unknown>, string> {
|
|
||||||
let parsed: unknown;
|
|
||||||
try {
|
|
||||||
parsed = JSON.parse(responseText) as unknown;
|
|
||||||
} catch (cause) {
|
|
||||||
const message = cause instanceof Error ? cause.message : String(cause);
|
|
||||||
return err(`invalid_response_json:${message}`);
|
|
||||||
}
|
|
||||||
if (!isRecord(parsed)) {
|
|
||||||
return err("invalid_response_top_level");
|
|
||||||
}
|
|
||||||
const choices = parsed.choices;
|
|
||||||
if (!Array.isArray(choices) || choices.length === 0) {
|
|
||||||
return err("no_choices_in_response");
|
|
||||||
}
|
|
||||||
const firstChoice = choices[0];
|
|
||||||
if (!isRecord(firstChoice)) {
|
|
||||||
return err("invalid_choice");
|
|
||||||
}
|
|
||||||
const messageObj = firstChoice.message;
|
|
||||||
if (!isRecord(messageObj)) {
|
|
||||||
return err("invalid_message");
|
|
||||||
}
|
|
||||||
return ok(messageObj);
|
|
||||||
}
|
|
||||||
|
|
||||||
function normalizeToolCalls(toolCallsRaw: unknown[]): Result<ToolCall[], string> {
|
|
||||||
const toolCalls: ToolCall[] = [];
|
|
||||||
for (const tc of toolCallsRaw) {
|
|
||||||
if (!isRecord(tc)) {
|
|
||||||
return err("invalid_tool_call");
|
|
||||||
}
|
|
||||||
const id = tc.id;
|
|
||||||
const tcType = tc.type;
|
|
||||||
const fn = tc.function;
|
|
||||||
if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) {
|
|
||||||
return err("invalid_tool_call_shape");
|
|
||||||
}
|
|
||||||
const name = fn.name;
|
|
||||||
const argumentsStr = fn.arguments;
|
|
||||||
if (typeof name !== "string" || typeof argumentsStr !== "string") {
|
|
||||||
return err("invalid_tool_call_function");
|
|
||||||
}
|
|
||||||
toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } });
|
|
||||||
}
|
|
||||||
return ok(toolCalls);
|
|
||||||
}
|
|
||||||
|
|
||||||
type AssistantTurnOrCorrection<T extends Record<string, unknown>> =
|
|
||||||
| AssistantTurn<T>
|
|
||||||
| { kind: "plain_json_invalid"; rawContent: string; correction: string };
|
|
||||||
|
|
||||||
function classifyAssistantTurn<T extends Record<string, unknown>>(
|
|
||||||
messageObj: Record<string, unknown>,
|
|
||||||
schema: z.ZodType<T>,
|
|
||||||
): Result<AssistantTurnOrCorrection<T>, string> {
|
|
||||||
const toolCallsRaw = messageObj.tool_calls;
|
|
||||||
if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) {
|
|
||||||
const content = messageObj.content;
|
|
||||||
if (typeof content !== "string") {
|
|
||||||
return err("no_tool_calls_and_no_string_content");
|
|
||||||
}
|
|
||||||
const jsonParsed = tryParseJsonContent(content);
|
|
||||||
if (jsonParsed === null) {
|
|
||||||
return ok({
|
|
||||||
kind: "plain_json_invalid",
|
|
||||||
rawContent: content,
|
|
||||||
correction:
|
|
||||||
"Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the extract tool with the structured arguments.",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
const validated = schema.safeParse(jsonParsed);
|
|
||||||
if (!validated.success) {
|
|
||||||
return ok({
|
|
||||||
kind: "plain_json_invalid",
|
|
||||||
rawContent: content,
|
|
||||||
correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the extract tool with the structured arguments.`,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return ok({ kind: "plain_json", value: validated.data });
|
|
||||||
}
|
|
||||||
const callsResult = normalizeToolCalls(toolCallsRaw);
|
|
||||||
if (!callsResult.ok) {
|
|
||||||
return err(callsResult.error);
|
|
||||||
}
|
|
||||||
const assistantContent = messageObj.content;
|
|
||||||
return ok({
|
|
||||||
kind: "tool_calls",
|
|
||||||
calls: callsResult.value,
|
|
||||||
assistantContent: typeof assistantContent === "string" ? assistantContent : null,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function appendCasGetToolResult(
|
|
||||||
tc: ToolCall,
|
|
||||||
cas: CasStore,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
): Promise<Result<null, string>> {
|
|
||||||
let hash: string;
|
|
||||||
try {
|
|
||||||
const ta = JSON.parse(tc.function.arguments) as unknown;
|
|
||||||
if (!isRecord(ta) || typeof ta.hash !== "string") {
|
|
||||||
return err("cas_get_invalid_arguments");
|
|
||||||
}
|
|
||||||
hash = ta.hash;
|
|
||||||
} catch {
|
|
||||||
return err("cas_get_arguments_not_json");
|
|
||||||
}
|
|
||||||
const blob = await cas.get(hash);
|
|
||||||
const toolContent = blob === null ? "null" : blob;
|
|
||||||
messages.push({
|
|
||||||
role: "tool",
|
|
||||||
tool_call_id: tc.id,
|
|
||||||
content: toolContent,
|
|
||||||
});
|
|
||||||
return ok(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function appendExtractToolResult<T extends Record<string, unknown>>(
|
|
||||||
tc: ToolCall,
|
|
||||||
schema: z.ZodType<T>,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
): Promise<Result<T, string>> {
|
|
||||||
let parsedArgs: unknown;
|
|
||||||
try {
|
|
||||||
parsedArgs = JSON.parse(tc.function.arguments) as unknown;
|
|
||||||
} catch {
|
|
||||||
return err("extract_tool_arguments_not_json");
|
|
||||||
}
|
|
||||||
const validated = schema.safeParse(parsedArgs);
|
|
||||||
if (!validated.success) {
|
|
||||||
return err(`schema_validation_failed:${validated.error.message}`);
|
|
||||||
}
|
|
||||||
messages.push({
|
|
||||||
role: "tool",
|
|
||||||
tool_call_id: tc.id,
|
|
||||||
content: '{"ok":true}',
|
|
||||||
});
|
|
||||||
return ok(validated.data);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function appendToolResults<T extends Record<string, unknown>>(
|
|
||||||
toolCalls: ToolCall[],
|
|
||||||
extractToolName: string,
|
|
||||||
schema: z.ZodType<T>,
|
|
||||||
cas: CasStore,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
): Promise<Result<T | null, string>> {
|
|
||||||
let extracted: T | null = null;
|
|
||||||
for (const tc of toolCalls) {
|
|
||||||
if (tc.function.name === "cas_get") {
|
|
||||||
const casRes = await appendCasGetToolResult(tc, cas, messages);
|
|
||||||
if (!casRes.ok) {
|
|
||||||
return casRes;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (tc.function.name === extractToolName) {
|
|
||||||
const exRes = await appendExtractToolResult(tc, schema, messages);
|
|
||||||
if (!exRes.ok) {
|
|
||||||
return exRes;
|
|
||||||
}
|
|
||||||
extracted = exRes.value;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return err(`unknown_tool:${tc.function.name}`);
|
|
||||||
}
|
|
||||||
return ok(extracted);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function postChatCompletion(
|
|
||||||
provider: LlmProvider,
|
|
||||||
messages: ChatMessage[],
|
|
||||||
tools: readonly Record<string, unknown>[],
|
|
||||||
): Promise<Result<string, string>> {
|
|
||||||
try {
|
|
||||||
const response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${provider.apiKey}`,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
model: provider.model,
|
|
||||||
messages,
|
|
||||||
tools,
|
|
||||||
tool_choice: "auto",
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
const responseText = await response.text();
|
|
||||||
if (!response.ok) {
|
|
||||||
return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`);
|
|
||||||
}
|
|
||||||
return ok(responseText);
|
|
||||||
} catch (cause) {
|
|
||||||
const message = cause instanceof Error ? cause.message : String(cause);
|
|
||||||
return err(`network_error:${message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Multi-turn ReAct extraction with `cas_get` plus a schema-shaped extract tool (OpenAI-compatible).
|
|
||||||
* Final meta comes from a successful extract tool call or from plain JSON in the assistant message.
|
|
||||||
*/
|
|
||||||
export async function reactExtract<T extends Record<string, unknown>>(
|
|
||||||
args: ReactExtractArgs<T>,
|
|
||||||
): Promise<Result<T, string>> {
|
|
||||||
const extractTool = extractFunctionToolFromZodSchema(args.schema);
|
|
||||||
const tools = [
|
|
||||||
CAS_GET_TOOL_DEFINITION,
|
|
||||||
{
|
|
||||||
type: "function" as const,
|
|
||||||
function: {
|
|
||||||
name: extractTool.name,
|
|
||||||
description: extractTool.description,
|
|
||||||
parameters: extractTool.parameters,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
const systemContent = `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${extractTool.name} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`;
|
|
||||||
|
|
||||||
const messages: ChatMessage[] = [
|
|
||||||
{ role: "system", content: systemContent },
|
|
||||||
{ role: "user", content: args.text },
|
|
||||||
];
|
|
||||||
|
|
||||||
for (let round = 0; round < MAX_REACT_ROUNDS; round++) {
|
|
||||||
const bodyResult = await postChatCompletion(args.provider, messages, tools);
|
|
||||||
if (!bodyResult.ok) {
|
|
||||||
return bodyResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgResult = firstAssistantMessage(bodyResult.value);
|
|
||||||
if (!msgResult.ok) {
|
|
||||||
return msgResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
const classified = classifyAssistantTurn(msgResult.value, args.schema);
|
|
||||||
if (!classified.ok) {
|
|
||||||
return classified;
|
|
||||||
}
|
|
||||||
|
|
||||||
const turn = classified.value;
|
|
||||||
if (turn.kind === "plain_json") {
|
|
||||||
return ok(turn.value);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (turn.kind === "plain_json_invalid") {
|
|
||||||
messages.push({ role: "assistant", content: turn.rawContent });
|
|
||||||
messages.push({ role: "user", content: turn.correction });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.push({
|
|
||||||
role: "assistant",
|
|
||||||
content: turn.assistantContent,
|
|
||||||
tool_calls: turn.calls,
|
|
||||||
});
|
|
||||||
|
|
||||||
const toolsRound = await appendToolResults(
|
|
||||||
turn.calls,
|
|
||||||
extractTool.name,
|
|
||||||
args.schema,
|
|
||||||
args.cas,
|
|
||||||
messages,
|
|
||||||
);
|
|
||||||
if (!toolsRound.ok) {
|
|
||||||
return toolsRound;
|
|
||||||
}
|
|
||||||
if (toolsRound.value !== null) {
|
|
||||||
return ok(toolsRound.value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err("max_react_rounds_exceeded");
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,8 @@
|
|||||||
import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime";
|
import type { LlmProvider } from "@uncaged/workflow-runtime";
|
||||||
import type * as z from "zod/v4";
|
import type * as z from "zod/v4";
|
||||||
|
|
||||||
export type { ExtractFn } from "@uncaged/workflow-runtime";
|
export type { ExtractFn } from "@uncaged/workflow-runtime";
|
||||||
|
|
||||||
export type ReactExtractArgs<T extends Record<string, unknown>> = {
|
|
||||||
text: string;
|
|
||||||
schema: z.ZodType<T>;
|
|
||||||
provider: LlmProvider;
|
|
||||||
cas: CasStore;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type LlmExtractArgs<T> = {
|
export type LlmExtractArgs<T> = {
|
||||||
text: string;
|
text: string;
|
||||||
schema: z.ZodType<T>;
|
schema: z.ZodType<T>;
|
||||||
|
|||||||
@@ -56,12 +56,23 @@ export {
|
|||||||
export {
|
export {
|
||||||
createExtract,
|
createExtract,
|
||||||
type ExtractFn,
|
type ExtractFn,
|
||||||
|
type ExtractThreadContext,
|
||||||
type LlmError,
|
type LlmError,
|
||||||
llmErrorToCause,
|
llmErrorToCause,
|
||||||
llmExtract,
|
llmExtract,
|
||||||
type ReactExtractArgs,
|
|
||||||
reactExtract,
|
|
||||||
} from "./extract/index.js";
|
} from "./extract/index.js";
|
||||||
|
export {
|
||||||
|
type ChatMessage,
|
||||||
|
createLlmFn,
|
||||||
|
createThreadReactor,
|
||||||
|
type LlmFn,
|
||||||
|
type StructuredToolSpec,
|
||||||
|
type ThreadReactorConfig,
|
||||||
|
type ThreadReactorFn,
|
||||||
|
type ThreadReactorInvokeArgs,
|
||||||
|
type ToolCall,
|
||||||
|
type ToolDefinition,
|
||||||
|
} from "./reactor/index.js";
|
||||||
export {
|
export {
|
||||||
getRegisteredWorkflow,
|
getRegisteredWorkflow,
|
||||||
listRegisteredWorkflowNames,
|
listRegisteredWorkflowNames,
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
export { createLlmFn } from "./llm-fn.js";
|
||||||
|
export { createThreadReactor } from "./thread-reactor.js";
|
||||||
|
export type {
|
||||||
|
ChatMessage,
|
||||||
|
LlmFn,
|
||||||
|
StructuredToolSpec,
|
||||||
|
ThreadReactorConfig,
|
||||||
|
ThreadReactorFn,
|
||||||
|
ThreadReactorInvokeArgs,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
} from "./types.js";
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
import type { LlmProvider } from "@uncaged/workflow-runtime";
|
||||||
|
|
||||||
|
import { err, ok } from "../util/index.js";
|
||||||
|
|
||||||
|
import type { ChatMessage, LlmFn, ToolDefinition } from "./types.js";
|
||||||
|
|
||||||
|
function chatCompletionsUrl(baseUrl: string): string {
|
||||||
|
const trimmed = baseUrl.replace(/\/+$/, "");
|
||||||
|
return `${trimmed}/chat/completions`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps provider credentials into an {@link LlmFn}: single POST to chat/completions,
|
||||||
|
* returns raw JSON body text or a {@link Result} error. Callers parse assistant messages.
|
||||||
|
*/
|
||||||
|
export function createLlmFn(provider: LlmProvider): LlmFn {
|
||||||
|
return async ({
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
}: {
|
||||||
|
messages: ChatMessage[];
|
||||||
|
tools: readonly ToolDefinition[];
|
||||||
|
}) => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${provider.apiKey}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: provider.model,
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
tool_choice: "auto",
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
const responseText = await response.text();
|
||||||
|
if (!response.ok) {
|
||||||
|
return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`);
|
||||||
|
}
|
||||||
|
return ok(responseText);
|
||||||
|
} catch (cause) {
|
||||||
|
const message = cause instanceof Error ? cause.message : String(cause);
|
||||||
|
return err(`network_error:${message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,317 @@
|
|||||||
|
import type * as z from "zod/v4";
|
||||||
|
|
||||||
|
import { err, ok, type Result } from "../util/index.js";
|
||||||
|
|
||||||
|
import type {
|
||||||
|
ChatMessage,
|
||||||
|
StructuredToolSpec,
|
||||||
|
ThreadReactorConfig,
|
||||||
|
ThreadReactorFn,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
} from "./types.js";
|
||||||
|
|
||||||
|
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||||
|
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
function tryParseJsonContent(content: string): unknown | null {
|
||||||
|
const trimmed = content.trim();
|
||||||
|
const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed);
|
||||||
|
const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed;
|
||||||
|
try {
|
||||||
|
return JSON.parse(payload) as unknown;
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function firstAssistantMessage(responseText: string): Result<Record<string, unknown>, string> {
|
||||||
|
let parsed: unknown;
|
||||||
|
try {
|
||||||
|
parsed = JSON.parse(responseText) as unknown;
|
||||||
|
} catch (cause) {
|
||||||
|
const message = cause instanceof Error ? cause.message : String(cause);
|
||||||
|
return err(`invalid_response_json:${message}`);
|
||||||
|
}
|
||||||
|
if (!isRecord(parsed)) {
|
||||||
|
return err("invalid_response_top_level");
|
||||||
|
}
|
||||||
|
const choices = parsed.choices;
|
||||||
|
if (!Array.isArray(choices) || choices.length === 0) {
|
||||||
|
return err("no_choices_in_response");
|
||||||
|
}
|
||||||
|
const firstChoice = choices[0];
|
||||||
|
if (!isRecord(firstChoice)) {
|
||||||
|
return err("invalid_choice");
|
||||||
|
}
|
||||||
|
const messageObj = firstChoice.message;
|
||||||
|
if (!isRecord(messageObj)) {
|
||||||
|
return err("invalid_message");
|
||||||
|
}
|
||||||
|
return ok(messageObj);
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeToolCalls(toolCallsRaw: unknown[]): Result<ToolCall[], string> {
|
||||||
|
const toolCalls: ToolCall[] = [];
|
||||||
|
for (const tc of toolCallsRaw) {
|
||||||
|
if (!isRecord(tc)) {
|
||||||
|
return err("invalid_tool_call");
|
||||||
|
}
|
||||||
|
const id = tc.id;
|
||||||
|
const tcType = tc.type;
|
||||||
|
const fn = tc.function;
|
||||||
|
if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) {
|
||||||
|
return err("invalid_tool_call_shape");
|
||||||
|
}
|
||||||
|
const name = fn.name;
|
||||||
|
const argumentsStr = fn.arguments;
|
||||||
|
if (typeof name !== "string" || typeof argumentsStr !== "string") {
|
||||||
|
return err("invalid_tool_call_function");
|
||||||
|
}
|
||||||
|
toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } });
|
||||||
|
}
|
||||||
|
return ok(toolCalls);
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantTurn<T> =
|
||||||
|
| { kind: "plain_json"; value: T }
|
||||||
|
| { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null };
|
||||||
|
|
||||||
|
type AssistantTurnOrCorrection<T> =
|
||||||
|
| AssistantTurn<T>
|
||||||
|
| { kind: "plain_json_invalid"; rawContent: string; correction: string };
|
||||||
|
|
||||||
|
function classifyAssistantTurn<T>(
|
||||||
|
messageObj: Record<string, unknown>,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
structuredToolName: string,
|
||||||
|
): Result<AssistantTurnOrCorrection<T>, string> {
|
||||||
|
const toolCallsRaw = messageObj.tool_calls;
|
||||||
|
if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) {
|
||||||
|
const content = messageObj.content;
|
||||||
|
if (typeof content !== "string") {
|
||||||
|
return err("no_tool_calls_and_no_string_content");
|
||||||
|
}
|
||||||
|
const jsonParsed = tryParseJsonContent(content);
|
||||||
|
if (jsonParsed === null) {
|
||||||
|
return ok({
|
||||||
|
kind: "plain_json_invalid",
|
||||||
|
rawContent: content,
|
||||||
|
correction: `Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const validated = schema.safeParse(jsonParsed);
|
||||||
|
if (!validated.success) {
|
||||||
|
return ok({
|
||||||
|
kind: "plain_json_invalid",
|
||||||
|
rawContent: content,
|
||||||
|
correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return ok({ kind: "plain_json", value: validated.data });
|
||||||
|
}
|
||||||
|
const callsResult = normalizeToolCalls(toolCallsRaw);
|
||||||
|
if (!callsResult.ok) {
|
||||||
|
return err(callsResult.error);
|
||||||
|
}
|
||||||
|
const assistantContent = messageObj.content;
|
||||||
|
return ok({
|
||||||
|
kind: "tool_calls",
|
||||||
|
calls: callsResult.value,
|
||||||
|
assistantContent: typeof assistantContent === "string" ? assistantContent : null,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function toolNamesFromDefinitions(tools: readonly { function: { name: string } }[]): Set<string> {
|
||||||
|
return new Set(tools.map((t) => t.function.name));
|
||||||
|
}
|
||||||
|
|
||||||
|
function appendStructuredToolResult<T>(
|
||||||
|
tc: ToolCall,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): T | null {
|
||||||
|
let parsedArgs: unknown;
|
||||||
|
try {
|
||||||
|
parsedArgs = JSON.parse(tc.function.arguments) as unknown;
|
||||||
|
} catch {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content:
|
||||||
|
"Tool arguments were not valid JSON. Provide valid JSON object arguments matching the schema.",
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const validated = schema.safeParse(parsedArgs);
|
||||||
|
if (!validated.success) {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: `Schema validation failed: ${validated.error.message}. Fix the arguments and call the tool again with a JSON object that matches the schema.`,
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: '{"ok":true}',
|
||||||
|
});
|
||||||
|
return validated.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function dispatchToolCall<T, TThread>(
|
||||||
|
tc: ToolCall,
|
||||||
|
spec: StructuredToolSpec,
|
||||||
|
knownNames: Set<string>,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
thread: TThread,
|
||||||
|
toolHandler: ThreadReactorConfig<TThread>["toolHandler"],
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): Promise<T | null> {
|
||||||
|
if (!knownNames.has(tc.function.name)) {
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: `Unknown tool: ${tc.function.name}. Use one of the declared tools only.`,
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (tc.function.name === spec.name) {
|
||||||
|
return appendStructuredToolResult(tc, schema, messages);
|
||||||
|
}
|
||||||
|
let toolContent: string;
|
||||||
|
try {
|
||||||
|
toolContent = await toolHandler(tc, thread);
|
||||||
|
} catch (cause) {
|
||||||
|
const message = cause instanceof Error ? cause.message : String(cause);
|
||||||
|
toolContent = `Tool execution failed: ${message}`;
|
||||||
|
}
|
||||||
|
messages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: tc.id,
|
||||||
|
content: toolContent,
|
||||||
|
});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function resolveToolCallRound<T, TThread>(
|
||||||
|
turn: Extract<AssistantTurn<T>, { kind: "tool_calls" }>,
|
||||||
|
spec: StructuredToolSpec,
|
||||||
|
knownNames: Set<string>,
|
||||||
|
schema: z.ZodType<T>,
|
||||||
|
thread: TThread,
|
||||||
|
toolHandler: ThreadReactorConfig<TThread>["toolHandler"],
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): Promise<Result<T, string> | null> {
|
||||||
|
messages.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: turn.assistantContent,
|
||||||
|
tool_calls: turn.calls,
|
||||||
|
});
|
||||||
|
let extractedRound: T | null = null;
|
||||||
|
for (const tc of turn.calls) {
|
||||||
|
const extracted = await dispatchToolCall(
|
||||||
|
tc,
|
||||||
|
spec,
|
||||||
|
knownNames,
|
||||||
|
schema,
|
||||||
|
thread,
|
||||||
|
toolHandler,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
if (extracted !== null) {
|
||||||
|
extractedRound = extracted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (extractedRound !== null) {
|
||||||
|
return ok(extractedRound);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function runOneReactRound<T, TThread>(
|
||||||
|
config: ThreadReactorConfig<TThread>,
|
||||||
|
args: { thread: TThread; schema: z.ZodType<T> },
|
||||||
|
tools: readonly ToolDefinition[],
|
||||||
|
knownNames: Set<string>,
|
||||||
|
spec: StructuredToolSpec,
|
||||||
|
messages: ChatMessage[],
|
||||||
|
): Promise<Result<T, string> | null> {
|
||||||
|
const bodyResult = await config.llm({ messages, tools });
|
||||||
|
if (!bodyResult.ok) {
|
||||||
|
return bodyResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
const msgResult = firstAssistantMessage(bodyResult.value);
|
||||||
|
if (!msgResult.ok) {
|
||||||
|
return msgResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
const classified = classifyAssistantTurn(msgResult.value, args.schema, spec.name);
|
||||||
|
if (!classified.ok) {
|
||||||
|
return classified;
|
||||||
|
}
|
||||||
|
|
||||||
|
const turn = classified.value;
|
||||||
|
if (turn.kind === "plain_json") {
|
||||||
|
return ok(turn.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (turn.kind === "plain_json_invalid") {
|
||||||
|
messages.push({ role: "assistant", content: turn.rawContent });
|
||||||
|
messages.push({ role: "user", content: turn.correction });
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolveToolCallRound(
|
||||||
|
turn,
|
||||||
|
spec,
|
||||||
|
knownNames,
|
||||||
|
args.schema,
|
||||||
|
args.thread,
|
||||||
|
config.toolHandler,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generic ReAct loop: LLM round-trips with tools until structured output validates,
|
||||||
|
* plain JSON matches schema, or {@link ThreadReactorConfig.maxRounds} is exceeded.
|
||||||
|
*/
|
||||||
|
export function createThreadReactor<TThread>(
|
||||||
|
config: ThreadReactorConfig<TThread>,
|
||||||
|
): ThreadReactorFn<TThread> {
|
||||||
|
return async <T>(args: {
|
||||||
|
thread: TThread;
|
||||||
|
input: string;
|
||||||
|
schema: z.ZodType<T>;
|
||||||
|
}): Promise<Result<T, string>> => {
|
||||||
|
const spec = config.structuredToolFromSchema(args.schema);
|
||||||
|
const tools = [...config.staticTools, spec.tool];
|
||||||
|
const knownNames = toolNamesFromDefinitions(tools);
|
||||||
|
const systemPrompt = config.systemPromptForStructuredTool(spec.name);
|
||||||
|
|
||||||
|
const messages: ChatMessage[] = [
|
||||||
|
{ role: "system", content: systemPrompt },
|
||||||
|
{ role: "user", content: args.input },
|
||||||
|
];
|
||||||
|
|
||||||
|
for (let round = 0; round < config.maxRounds; round++) {
|
||||||
|
const step = await runOneReactRound(
|
||||||
|
config,
|
||||||
|
{ thread: args.thread, schema: args.schema },
|
||||||
|
tools,
|
||||||
|
knownNames,
|
||||||
|
spec,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
if (step !== null) {
|
||||||
|
return step;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err("max_react_rounds_exceeded");
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
import type * as z from "zod/v4";
|
||||||
|
|
||||||
|
import type { Result } from "../util/index.js";
|
||||||
|
|
||||||
|
export type ToolCall = {
|
||||||
|
id: string;
|
||||||
|
type: "function";
|
||||||
|
function: { name: string; arguments: string };
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ToolDefinition = {
|
||||||
|
type: "function";
|
||||||
|
function: {
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
parameters: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ChatMessage =
|
||||||
|
| { role: "system"; content: string }
|
||||||
|
| { role: "user"; content: string }
|
||||||
|
| {
|
||||||
|
role: "assistant";
|
||||||
|
content: string | null;
|
||||||
|
tool_calls: ToolCall[];
|
||||||
|
}
|
||||||
|
| { role: "assistant"; content: string }
|
||||||
|
| { role: "tool"; tool_call_id: string; content: string };
|
||||||
|
|
||||||
|
export type LlmFn = (input: {
|
||||||
|
messages: ChatMessage[];
|
||||||
|
tools: readonly ToolDefinition[];
|
||||||
|
}) => Promise<Result<string, string>>;
|
||||||
|
|
||||||
|
/** Structured tool derived from the per-invocation Zod schema (e.g. extract tool). */
|
||||||
|
export type StructuredToolSpec = {
|
||||||
|
name: string;
|
||||||
|
tool: ToolDefinition;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadReactorConfig<TThread> = {
|
||||||
|
llm: LlmFn;
|
||||||
|
/** Static tools (e.g. cas_get); structured tool is appended per invocation. */
|
||||||
|
staticTools: readonly ToolDefinition[];
|
||||||
|
/** Builds the schema-shaped tool and its OpenAI name for this invocation. */
|
||||||
|
structuredToolFromSchema: (schema: z.ZodType<unknown>) => StructuredToolSpec;
|
||||||
|
/** System prompt for this run; include the structured tool name for cache stability per schema. */
|
||||||
|
systemPromptForStructuredTool: (structuredToolName: string) => string;
|
||||||
|
toolHandler: (call: ToolCall, thread: TThread) => Promise<string>;
|
||||||
|
maxRounds: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadReactorInvokeArgs<TThread, T> = {
|
||||||
|
thread: TThread;
|
||||||
|
input: string;
|
||||||
|
schema: z.ZodType<T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadReactorFn<TThread> = <T>(
|
||||||
|
args: ThreadReactorInvokeArgs<TThread, T>,
|
||||||
|
) => Promise<Result<T, string>>;
|
||||||
Reference in New Issue
Block a user