From a7171f05f6826e090e55b3a35eb7c68edfe3b001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E6=A9=98?= Date: Sat, 9 May 2026 02:15:38 +0000 Subject: [PATCH 1/2] feat(workflow): add ThreadReactor generic ReAct loop + migrate extract (Phase 1) - New src/reactor/ module: createThreadReactor, createLlmFn, types - Two-stage API: config (llm, systemPrompt, tools, toolHandler) + per-call (thread, input, schema) - All tool failures are recoverable (returned to LLM as error message) - Rewrite createExtract to use createThreadReactor - Delete reactExtract old implementation - Fix template test imports (START/END from runtime, validateWorkflowDescriptor from engine) 268 tests passing. Refs #139, relates #140 --- .../src/components/thread-detail.tsx | 4 +- .../__tests__/develop-template.test.ts | 9 +- .../__tests__/solve-issue-template.test.ts | 10 +- packages/workflow/README.md | 2 +- ...extract.test.ts => thread-reactor.test.ts} | 85 ++++- packages/workflow/src/extract/extract-fn.ts | 77 +++- packages/workflow/src/extract/index.ts | 9 +- .../workflow/src/extract/react-extract.ts | 343 ------------------ packages/workflow/src/extract/types.ts | 9 +- packages/workflow/src/index.ts | 15 +- packages/workflow/src/reactor/index.ts | 12 + packages/workflow/src/reactor/llm-fn.ts | 48 +++ .../workflow/src/reactor/thread-reactor.ts | 317 ++++++++++++++++ packages/workflow/src/reactor/types.ts | 62 ++++ 14 files changed, 604 insertions(+), 398 deletions(-) rename packages/workflow/__tests__/{react-extract.test.ts => thread-reactor.test.ts} (70%) delete mode 100644 packages/workflow/src/extract/react-extract.ts create mode 100644 packages/workflow/src/reactor/index.ts create mode 100644 packages/workflow/src/reactor/llm-fn.ts create mode 100644 packages/workflow/src/reactor/thread-reactor.ts create mode 100644 packages/workflow/src/reactor/types.ts diff --git a/packages/dashboard/src/components/thread-detail.tsx b/packages/dashboard/src/components/thread-detail.tsx index 27ddd0c..08d62e9 100644 --- a/packages/dashboard/src/components/thread-detail.tsx +++ b/packages/dashboard/src/components/thread-detail.tsx @@ -101,9 +101,9 @@ export function ThreadDetail({ threadId, onBack }: Props) { )} {(status === "ok" || liveActive || records.length > 0) && (
- {records.map((r, i) => ( + {records.map((r) => (
diff --git a/packages/workflow-template-develop/__tests__/develop-template.test.ts b/packages/workflow-template-develop/__tests__/develop-template.test.ts index 0b9ff36..f67adba 100644 --- a/packages/workflow-template-develop/__tests__/develop-template.test.ts +++ b/packages/workflow-template-develop/__tests__/develop-template.test.ts @@ -1,11 +1,6 @@ import { describe, expect, test } from "bun:test"; -import { - END, - type ModeratorContext, - type RoleStep, - START, - validateWorkflowDescriptor, -} from "@uncaged/workflow-runtime"; +import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; +import { validateWorkflowDescriptor } from "@uncaged/workflow"; import { buildDevelopDescriptor } from "../src/descriptor.js"; import { developModerator } from "../src/index.js"; import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js"; diff --git a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts index b10ff6f..6fb6e7f 100644 --- a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts +++ b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts @@ -2,14 +2,8 @@ import { afterEach, describe, expect, test } from "bun:test"; import { mkdtemp, rm } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { createCasStore, createExtract, createWorkflow } from "@uncaged/workflow"; -import { - END, - type ModeratorContext, - type RoleStep, - START, - validateWorkflowDescriptor, -} from "@uncaged/workflow-runtime"; +import { createCasStore, createExtract, createWorkflow, validateWorkflowDescriptor } from "@uncaged/workflow"; +import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; import { buildSolveIssueDescriptor } from "../src/descriptor.js"; import type { DeveloperMeta } from "../src/developer.js"; import { solveIssueModerator, solveIssueWorkflowDefinition } from "../src/index.js"; diff --git a/packages/workflow/README.md b/packages/workflow/README.md index fe84528..db319e3 100644 --- a/packages/workflow/README.md +++ b/packages/workflow/README.md @@ -29,7 +29,7 @@ import { createWorkflow, readWorkflowRegistry, executeThread } from "@uncaged/wo | **Registry** | `readWorkflowRegistry`, `writeWorkflowRegistry`, `registerWorkflowVersion`, `workflowRegistryPath`, YAML helpers | | **CAS** | `createCasStore`, Merkle helpers (`putStepMerkleNode`, `getContentMerklePayload`, …), `hashWorkflowBundleBytes` | | **Engine** | `createWorkflow`, `executeThread`, `parseThreadDataJsonl`, fork helpers, `garbageCollectCas` | -| **Extract / LLM tools** | `llmExtract`, `reactExtract`, `createExtract`, `getExtractProvider` | +| **Extract / LLM tools** | `llmExtract`, `createExtract`, `createThreadReactor`, `createLlmFn`, `getExtractProvider` | | **Agent bridge** | `workflowAsAgent` — expose a registered workflow as an agent-backed role | | **Utilities** | `createLogger`, ULID / Crockford Base32 codecs, `getDefaultWorkflowStorageRoot`, paths | diff --git a/packages/workflow/__tests__/react-extract.test.ts b/packages/workflow/__tests__/thread-reactor.test.ts similarity index 70% rename from packages/workflow/__tests__/react-extract.test.ts rename to packages/workflow/__tests__/thread-reactor.test.ts index ed3faa1..c880df1 100644 --- a/packages/workflow/__tests__/react-extract.test.ts +++ b/packages/workflow/__tests__/thread-reactor.test.ts @@ -6,7 +6,8 @@ import type { LlmProvider } from "@uncaged/workflow-runtime"; import * as z from "zod/v4"; import { createCasStore } from "../src/cas/cas.js"; import { createContentMerkleNode, serializeMerkleNode } from "../src/cas/merkle.js"; -import { reactExtract } from "../src/extract/react-extract.js"; +import { extractFunctionToolFromZodSchema } from "../src/extract/llm-extract.js"; +import { createLlmFn, createThreadReactor } from "../src/reactor/index.js"; const metaSchema = z.object({ seen: z.string() }); @@ -16,7 +17,57 @@ const provider: LlmProvider = { model: "test", }; -describe("reactExtract", () => { +const CAS_GET_TOOL_DEFINITION = { + type: "function" as const, + function: { + name: "cas_get", + description: "Read CAS node", + parameters: { + type: "object", + properties: { + hash: { type: "string", description: "hash" }, + }, + required: ["hash"], + }, + }, +}; + +type ThreadCtx = { cas: ReturnType }; + +function createTestReactor() { + const llm = createLlmFn(provider); + return createThreadReactor({ + llm, + maxRounds: 10, + staticTools: [CAS_GET_TOOL_DEFINITION], + structuredToolFromSchema: (schema) => { + const t = extractFunctionToolFromZodSchema(schema); + return { + name: t.name, + tool: { + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }, + }; + }, + systemPromptForStructuredTool: (structuredToolName) => + `Extract metadata. Use cas_get when needed. Call ${structuredToolName} with JSON args matching the schema, or reply with plain JSON.`, + toolHandler: async (call, thread) => { + if (call.function.name !== "cas_get") { + return `unexpected tool ${call.function.name}`; + } + const ta = JSON.parse(call.function.arguments) as { hash: string }; + const blob = await thread.cas.get(ta.hash); + return blob === null ? "null" : blob; + }, + }); +} + +describe("createThreadReactor (extract-shaped)", () => { let restoreFetch: (() => void) | null = null; afterEach(() => { @@ -25,7 +76,7 @@ describe("reactExtract", () => { }); test("cas_get rounds then extract tool yields validated meta", async () => { - const casDir = await mkdtemp(join(tmpdir(), "react-extract-")); + const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-")); const cas = createCasStore(casDir); try { const blob = serializeMerkleNode(createContentMerkleNode("needle")); @@ -87,12 +138,12 @@ describe("reactExtract", () => { { preconnect: origFetch.preconnect.bind(origFetch) }, ) as typeof fetch; + const reactor = createTestReactor(); const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`; - const result = await reactExtract({ - text, + const result = await reactor({ + thread: { cas }, + input: text, schema: metaSchema, - provider, - cas, }); expect(result.ok).toBe(true); @@ -107,7 +158,7 @@ describe("reactExtract", () => { }); test("stops after max tool rounds when model keeps calling cas_get", async () => { - const casDir = await mkdtemp(join(tmpdir(), "react-extract-max-")); + const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-max-")); const cas = createCasStore(casDir); try { const blob = serializeMerkleNode(createContentMerkleNode("x")); @@ -146,11 +197,11 @@ describe("reactExtract", () => { { preconnect: origFetch.preconnect.bind(origFetch) }, ) as typeof fetch; - const result = await reactExtract({ - text: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.", + const reactor = createTestReactor(); + const result = await reactor({ + thread: { cas }, + input: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.", schema: metaSchema, - provider, - cas, }); expect(result.ok).toBe(false); @@ -165,7 +216,7 @@ describe("reactExtract", () => { }); test("passthrough JSON assistant message without tool calls", async () => { - const casDir = await mkdtemp(join(tmpdir(), "react-extract-pass-")); + const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-pass-")); const cas = createCasStore(casDir); try { const origFetch = globalThis.fetch; @@ -189,11 +240,11 @@ describe("reactExtract", () => { { preconnect: origFetch.preconnect.bind(origFetch) }, ) as typeof fetch; - const result = await reactExtract({ - text: "## Agent Output\nok\n## Extraction Instruction\nExtract.", + const reactor = createTestReactor(); + const result = await reactor({ + thread: { cas }, + input: "## Agent Output\nok\n## Extraction Instruction\nExtract.", schema: metaSchema, - provider, - cas, }); expect(result.ok).toBe(true); diff --git a/packages/workflow/src/extract/extract-fn.ts b/packages/workflow/src/extract/extract-fn.ts index 5f79380..b278562 100644 --- a/packages/workflow/src/extract/extract-fn.ts +++ b/packages/workflow/src/extract/extract-fn.ts @@ -1,12 +1,39 @@ import type { ExtractContext, ExtractFn, LlmProvider } from "@uncaged/workflow-runtime"; import type * as z from "zod/v4"; import { type CasStore, getContentMerklePayload } from "../cas/index.js"; -import { reactExtract } from "./react-extract.js"; +import { createLlmFn, createThreadReactor } from "../reactor/index.js"; +import { extractFunctionToolFromZodSchema } from "./llm-extract.js"; export type ExtractDeps = { cas: CasStore; }; +const MAX_REACT_ROUNDS = 10; + +const CAS_GET_TOOL_DEFINITION = { + type: "function" as const, + function: { + name: "cas_get", + description: + "Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.", + parameters: { + type: "object", + properties: { + hash: { type: "string", description: "The CAS hash to retrieve" }, + }, + required: ["hash"], + }, + }, +}; + +export type ExtractThreadContext = { + cas: CasStore; +}; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + /** Builds the user-side extraction prompt (thread + agent output + instruction). */ export async function buildExtractUserContent( ctx: ExtractContext, @@ -46,17 +73,61 @@ export async function buildExtractUserContent( * Create an ExtractFn backed by an LLM provider. * * Internally runs a multi-turn ReAct loop with two tools (`cas_get` for traversing the - * Merkle DAG and a schema-shaped `extract` tool); the loop also accepts a plain-JSON + * Merkle DAG and a schema-shaped extract tool); the loop also accepts a plain-JSON * assistant reply as a short-circuit, which covers the legacy "single" extraction path. */ export function createExtract(provider: LlmProvider, deps: ExtractDeps): ExtractFn { + const llm = createLlmFn(provider); + const reactor = createThreadReactor({ + llm, + maxRounds: MAX_REACT_ROUNDS, + staticTools: [CAS_GET_TOOL_DEFINITION], + structuredToolFromSchema: (schema) => { + const t = extractFunctionToolFromZodSchema(schema); + return { + name: t.name, + tool: { + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }, + }; + }, + systemPromptForStructuredTool: (structuredToolName) => + `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${structuredToolName} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`, + toolHandler: async (call, thread) => { + if (call.function.name !== "cas_get") { + return `Unexpected tool routed to handler: ${call.function.name}`; + } + let hash: string; + try { + const ta = JSON.parse(call.function.arguments) as unknown; + if (!isRecord(ta) || typeof ta.hash !== "string") { + return 'cas_get requires a JSON object with a string "hash" field.'; + } + hash = ta.hash; + } catch { + return 'cas_get arguments were not valid JSON. Provide {"hash": ""}.'; + } + const blob = await thread.cas.get(hash); + return blob === null ? "null" : blob; + }, + }); + return async >( schema: z.ZodType, prompt: string, ctx: ExtractContext, ): Promise => { const text = await buildExtractUserContent(ctx, prompt, deps); - const result = await reactExtract({ text, schema, provider, cas: deps.cas }); + const result = await reactor({ + thread: { cas: deps.cas }, + input: text, + schema, + }); if (!result.ok) { throw new Error(`extract failed: ${result.error}`); } diff --git a/packages/workflow/src/extract/index.ts b/packages/workflow/src/extract/index.ts index cf502d8..e1069af 100644 --- a/packages/workflow/src/extract/index.ts +++ b/packages/workflow/src/extract/index.ts @@ -1,16 +1,11 @@ export { buildExtractUserContent, createExtract, + type ExtractThreadContext, } from "./extract-fn.js"; export { extractFunctionToolFromZodSchema, llmErrorToCause, llmExtract, } from "./llm-extract.js"; -export { reactExtract } from "./react-extract.js"; -export type { - ExtractFn, - LlmError, - LlmExtractArgs, - ReactExtractArgs, -} from "./types.js"; +export type { ExtractFn, LlmError, LlmExtractArgs } from "./types.js"; diff --git a/packages/workflow/src/extract/react-extract.ts b/packages/workflow/src/extract/react-extract.ts deleted file mode 100644 index 5c0c456..0000000 --- a/packages/workflow/src/extract/react-extract.ts +++ /dev/null @@ -1,343 +0,0 @@ -import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime"; -import type * as z from "zod/v4"; -import { err, ok, type Result } from "../util/index.js"; - -import { extractFunctionToolFromZodSchema } from "./llm-extract.js"; -import type { ReactExtractArgs } from "./types.js"; - -const MAX_REACT_ROUNDS = 10; - -const CAS_GET_TOOL_DEFINITION = { - type: "function" as const, - function: { - name: "cas_get", - description: - "Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.", - parameters: { - type: "object", - properties: { - hash: { type: "string", description: "The CAS hash to retrieve" }, - }, - required: ["hash"], - }, - }, -}; - -function chatCompletionsUrl(baseUrl: string): string { - const trimmed = baseUrl.replace(/\/+$/, ""); - return `${trimmed}/chat/completions`; -} - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null && !Array.isArray(value); -} - -function tryParseJsonContent(content: string): unknown | null { - const trimmed = content.trim(); - const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed); - const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed; - try { - return JSON.parse(payload) as unknown; - } catch { - return null; - } -} - -type ToolCall = { - id: string; - type: "function"; - function: { name: string; arguments: string }; -}; - -type ChatMessage = - | { role: "system"; content: string } - | { role: "user"; content: string } - | { - role: "assistant"; - content: string | null; - tool_calls: ToolCall[]; - } - | { role: "assistant"; content: string } - | { role: "tool"; tool_call_id: string; content: string }; - -type AssistantTurn = - | { kind: "plain_json"; value: T } - | { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null }; - -function firstAssistantMessage(responseText: string): Result, string> { - let parsed: unknown; - try { - parsed = JSON.parse(responseText) as unknown; - } catch (cause) { - const message = cause instanceof Error ? cause.message : String(cause); - return err(`invalid_response_json:${message}`); - } - if (!isRecord(parsed)) { - return err("invalid_response_top_level"); - } - const choices = parsed.choices; - if (!Array.isArray(choices) || choices.length === 0) { - return err("no_choices_in_response"); - } - const firstChoice = choices[0]; - if (!isRecord(firstChoice)) { - return err("invalid_choice"); - } - const messageObj = firstChoice.message; - if (!isRecord(messageObj)) { - return err("invalid_message"); - } - return ok(messageObj); -} - -function normalizeToolCalls(toolCallsRaw: unknown[]): Result { - const toolCalls: ToolCall[] = []; - for (const tc of toolCallsRaw) { - if (!isRecord(tc)) { - return err("invalid_tool_call"); - } - const id = tc.id; - const tcType = tc.type; - const fn = tc.function; - if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) { - return err("invalid_tool_call_shape"); - } - const name = fn.name; - const argumentsStr = fn.arguments; - if (typeof name !== "string" || typeof argumentsStr !== "string") { - return err("invalid_tool_call_function"); - } - toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } }); - } - return ok(toolCalls); -} - -type AssistantTurnOrCorrection> = - | AssistantTurn - | { kind: "plain_json_invalid"; rawContent: string; correction: string }; - -function classifyAssistantTurn>( - messageObj: Record, - schema: z.ZodType, -): Result, string> { - const toolCallsRaw = messageObj.tool_calls; - if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) { - const content = messageObj.content; - if (typeof content !== "string") { - return err("no_tool_calls_and_no_string_content"); - } - const jsonParsed = tryParseJsonContent(content); - if (jsonParsed === null) { - return ok({ - kind: "plain_json_invalid", - rawContent: content, - correction: - "Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the extract tool with the structured arguments.", - }); - } - const validated = schema.safeParse(jsonParsed); - if (!validated.success) { - return ok({ - kind: "plain_json_invalid", - rawContent: content, - correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the extract tool with the structured arguments.`, - }); - } - return ok({ kind: "plain_json", value: validated.data }); - } - const callsResult = normalizeToolCalls(toolCallsRaw); - if (!callsResult.ok) { - return err(callsResult.error); - } - const assistantContent = messageObj.content; - return ok({ - kind: "tool_calls", - calls: callsResult.value, - assistantContent: typeof assistantContent === "string" ? assistantContent : null, - }); -} - -async function appendCasGetToolResult( - tc: ToolCall, - cas: CasStore, - messages: ChatMessage[], -): Promise> { - let hash: string; - try { - const ta = JSON.parse(tc.function.arguments) as unknown; - if (!isRecord(ta) || typeof ta.hash !== "string") { - return err("cas_get_invalid_arguments"); - } - hash = ta.hash; - } catch { - return err("cas_get_arguments_not_json"); - } - const blob = await cas.get(hash); - const toolContent = blob === null ? "null" : blob; - messages.push({ - role: "tool", - tool_call_id: tc.id, - content: toolContent, - }); - return ok(null); -} - -async function appendExtractToolResult>( - tc: ToolCall, - schema: z.ZodType, - messages: ChatMessage[], -): Promise> { - let parsedArgs: unknown; - try { - parsedArgs = JSON.parse(tc.function.arguments) as unknown; - } catch { - return err("extract_tool_arguments_not_json"); - } - const validated = schema.safeParse(parsedArgs); - if (!validated.success) { - return err(`schema_validation_failed:${validated.error.message}`); - } - messages.push({ - role: "tool", - tool_call_id: tc.id, - content: '{"ok":true}', - }); - return ok(validated.data); -} - -async function appendToolResults>( - toolCalls: ToolCall[], - extractToolName: string, - schema: z.ZodType, - cas: CasStore, - messages: ChatMessage[], -): Promise> { - let extracted: T | null = null; - for (const tc of toolCalls) { - if (tc.function.name === "cas_get") { - const casRes = await appendCasGetToolResult(tc, cas, messages); - if (!casRes.ok) { - return casRes; - } - continue; - } - if (tc.function.name === extractToolName) { - const exRes = await appendExtractToolResult(tc, schema, messages); - if (!exRes.ok) { - return exRes; - } - extracted = exRes.value; - continue; - } - return err(`unknown_tool:${tc.function.name}`); - } - return ok(extracted); -} - -async function postChatCompletion( - provider: LlmProvider, - messages: ChatMessage[], - tools: readonly Record[], -): Promise> { - try { - const response = await fetch(chatCompletionsUrl(provider.baseUrl), { - method: "POST", - headers: { - Authorization: `Bearer ${provider.apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: provider.model, - messages, - tools, - tool_choice: "auto", - }), - }); - const responseText = await response.text(); - if (!response.ok) { - return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`); - } - return ok(responseText); - } catch (cause) { - const message = cause instanceof Error ? cause.message : String(cause); - return err(`network_error:${message}`); - } -} - -/** - * Multi-turn ReAct extraction with `cas_get` plus a schema-shaped extract tool (OpenAI-compatible). - * Final meta comes from a successful extract tool call or from plain JSON in the assistant message. - */ -export async function reactExtract>( - args: ReactExtractArgs, -): Promise> { - const extractTool = extractFunctionToolFromZodSchema(args.schema); - const tools = [ - CAS_GET_TOOL_DEFINITION, - { - type: "function" as const, - function: { - name: extractTool.name, - description: extractTool.description, - parameters: extractTool.parameters, - }, - }, - ]; - - const systemContent = `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${extractTool.name} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`; - - const messages: ChatMessage[] = [ - { role: "system", content: systemContent }, - { role: "user", content: args.text }, - ]; - - for (let round = 0; round < MAX_REACT_ROUNDS; round++) { - const bodyResult = await postChatCompletion(args.provider, messages, tools); - if (!bodyResult.ok) { - return bodyResult; - } - - const msgResult = firstAssistantMessage(bodyResult.value); - if (!msgResult.ok) { - return msgResult; - } - - const classified = classifyAssistantTurn(msgResult.value, args.schema); - if (!classified.ok) { - return classified; - } - - const turn = classified.value; - if (turn.kind === "plain_json") { - return ok(turn.value); - } - - if (turn.kind === "plain_json_invalid") { - messages.push({ role: "assistant", content: turn.rawContent }); - messages.push({ role: "user", content: turn.correction }); - continue; - } - - messages.push({ - role: "assistant", - content: turn.assistantContent, - tool_calls: turn.calls, - }); - - const toolsRound = await appendToolResults( - turn.calls, - extractTool.name, - args.schema, - args.cas, - messages, - ); - if (!toolsRound.ok) { - return toolsRound; - } - if (toolsRound.value !== null) { - return ok(toolsRound.value); - } - } - - return err("max_react_rounds_exceeded"); -} diff --git a/packages/workflow/src/extract/types.ts b/packages/workflow/src/extract/types.ts index 016ab60..c5bf283 100644 --- a/packages/workflow/src/extract/types.ts +++ b/packages/workflow/src/extract/types.ts @@ -1,15 +1,8 @@ -import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime"; +import type { LlmProvider } from "@uncaged/workflow-runtime"; import type * as z from "zod/v4"; export type { ExtractFn } from "@uncaged/workflow-runtime"; -export type ReactExtractArgs> = { - text: string; - schema: z.ZodType; - provider: LlmProvider; - cas: CasStore; -}; - export type LlmExtractArgs = { text: string; schema: z.ZodType; diff --git a/packages/workflow/src/index.ts b/packages/workflow/src/index.ts index 70897a1..896f865 100644 --- a/packages/workflow/src/index.ts +++ b/packages/workflow/src/index.ts @@ -56,12 +56,23 @@ export { export { createExtract, type ExtractFn, + type ExtractThreadContext, type LlmError, llmErrorToCause, llmExtract, - type ReactExtractArgs, - reactExtract, } from "./extract/index.js"; +export { + type ChatMessage, + createLlmFn, + createThreadReactor, + type LlmFn, + type StructuredToolSpec, + type ThreadReactorConfig, + type ThreadReactorFn, + type ThreadReactorInvokeArgs, + type ToolCall, + type ToolDefinition, +} from "./reactor/index.js"; export { getRegisteredWorkflow, listRegisteredWorkflowNames, diff --git a/packages/workflow/src/reactor/index.ts b/packages/workflow/src/reactor/index.ts new file mode 100644 index 0000000..80f2ab2 --- /dev/null +++ b/packages/workflow/src/reactor/index.ts @@ -0,0 +1,12 @@ +export { createLlmFn } from "./llm-fn.js"; +export { createThreadReactor } from "./thread-reactor.js"; +export type { + ChatMessage, + LlmFn, + StructuredToolSpec, + ThreadReactorConfig, + ThreadReactorFn, + ThreadReactorInvokeArgs, + ToolCall, + ToolDefinition, +} from "./types.js"; diff --git a/packages/workflow/src/reactor/llm-fn.ts b/packages/workflow/src/reactor/llm-fn.ts new file mode 100644 index 0000000..fd5c911 --- /dev/null +++ b/packages/workflow/src/reactor/llm-fn.ts @@ -0,0 +1,48 @@ +import type { LlmProvider } from "@uncaged/workflow-runtime"; + +import { err, ok } from "../util/index.js"; + +import type { ChatMessage, LlmFn, ToolDefinition } from "./types.js"; + +function chatCompletionsUrl(baseUrl: string): string { + const trimmed = baseUrl.replace(/\/+$/, ""); + return `${trimmed}/chat/completions`; +} + +/** + * Wraps provider credentials into an {@link LlmFn}: single POST to chat/completions, + * returns raw JSON body text or a {@link Result} error. Callers parse assistant messages. + */ +export function createLlmFn(provider: LlmProvider): LlmFn { + return async ({ + messages, + tools, + }: { + messages: ChatMessage[]; + tools: readonly ToolDefinition[]; + }) => { + try { + const response = await fetch(chatCompletionsUrl(provider.baseUrl), { + method: "POST", + headers: { + Authorization: `Bearer ${provider.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: provider.model, + messages, + tools, + tool_choice: "auto", + }), + }); + const responseText = await response.text(); + if (!response.ok) { + return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`); + } + return ok(responseText); + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + return err(`network_error:${message}`); + } + }; +} diff --git a/packages/workflow/src/reactor/thread-reactor.ts b/packages/workflow/src/reactor/thread-reactor.ts new file mode 100644 index 0000000..4931bd4 --- /dev/null +++ b/packages/workflow/src/reactor/thread-reactor.ts @@ -0,0 +1,317 @@ +import type * as z from "zod/v4"; + +import { err, ok, type Result } from "../util/index.js"; + +import type { + ChatMessage, + StructuredToolSpec, + ThreadReactorConfig, + ThreadReactorFn, + ToolCall, + ToolDefinition, +} from "./types.js"; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function tryParseJsonContent(content: string): unknown | null { + const trimmed = content.trim(); + const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed); + const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed; + try { + return JSON.parse(payload) as unknown; + } catch { + return null; + } +} + +function firstAssistantMessage(responseText: string): Result, string> { + let parsed: unknown; + try { + parsed = JSON.parse(responseText) as unknown; + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + return err(`invalid_response_json:${message}`); + } + if (!isRecord(parsed)) { + return err("invalid_response_top_level"); + } + const choices = parsed.choices; + if (!Array.isArray(choices) || choices.length === 0) { + return err("no_choices_in_response"); + } + const firstChoice = choices[0]; + if (!isRecord(firstChoice)) { + return err("invalid_choice"); + } + const messageObj = firstChoice.message; + if (!isRecord(messageObj)) { + return err("invalid_message"); + } + return ok(messageObj); +} + +function normalizeToolCalls(toolCallsRaw: unknown[]): Result { + const toolCalls: ToolCall[] = []; + for (const tc of toolCallsRaw) { + if (!isRecord(tc)) { + return err("invalid_tool_call"); + } + const id = tc.id; + const tcType = tc.type; + const fn = tc.function; + if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) { + return err("invalid_tool_call_shape"); + } + const name = fn.name; + const argumentsStr = fn.arguments; + if (typeof name !== "string" || typeof argumentsStr !== "string") { + return err("invalid_tool_call_function"); + } + toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } }); + } + return ok(toolCalls); +} + +type AssistantTurn = + | { kind: "plain_json"; value: T } + | { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null }; + +type AssistantTurnOrCorrection = + | AssistantTurn + | { kind: "plain_json_invalid"; rawContent: string; correction: string }; + +function classifyAssistantTurn( + messageObj: Record, + schema: z.ZodType, + structuredToolName: string, +): Result, string> { + const toolCallsRaw = messageObj.tool_calls; + if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) { + const content = messageObj.content; + if (typeof content !== "string") { + return err("no_tool_calls_and_no_string_content"); + } + const jsonParsed = tryParseJsonContent(content); + if (jsonParsed === null) { + return ok({ + kind: "plain_json_invalid", + rawContent: content, + correction: `Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`, + }); + } + const validated = schema.safeParse(jsonParsed); + if (!validated.success) { + return ok({ + kind: "plain_json_invalid", + rawContent: content, + correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`, + }); + } + return ok({ kind: "plain_json", value: validated.data }); + } + const callsResult = normalizeToolCalls(toolCallsRaw); + if (!callsResult.ok) { + return err(callsResult.error); + } + const assistantContent = messageObj.content; + return ok({ + kind: "tool_calls", + calls: callsResult.value, + assistantContent: typeof assistantContent === "string" ? assistantContent : null, + }); +} + +function toolNamesFromDefinitions(tools: readonly { function: { name: string } }[]): Set { + return new Set(tools.map((t) => t.function.name)); +} + +function appendStructuredToolResult( + tc: ToolCall, + schema: z.ZodType, + messages: ChatMessage[], +): T | null { + let parsedArgs: unknown; + try { + parsedArgs = JSON.parse(tc.function.arguments) as unknown; + } catch { + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: + "Tool arguments were not valid JSON. Provide valid JSON object arguments matching the schema.", + }); + return null; + } + const validated = schema.safeParse(parsedArgs); + if (!validated.success) { + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: `Schema validation failed: ${validated.error.message}. Fix the arguments and call the tool again with a JSON object that matches the schema.`, + }); + return null; + } + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: '{"ok":true}', + }); + return validated.data; +} + +async function dispatchToolCall( + tc: ToolCall, + spec: StructuredToolSpec, + knownNames: Set, + schema: z.ZodType, + thread: TThread, + toolHandler: ThreadReactorConfig["toolHandler"], + messages: ChatMessage[], +): Promise { + if (!knownNames.has(tc.function.name)) { + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: `Unknown tool: ${tc.function.name}. Use one of the declared tools only.`, + }); + return null; + } + if (tc.function.name === spec.name) { + return appendStructuredToolResult(tc, schema, messages); + } + let toolContent: string; + try { + toolContent = await toolHandler(tc, thread); + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + toolContent = `Tool execution failed: ${message}`; + } + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: toolContent, + }); + return null; +} + +async function resolveToolCallRound( + turn: Extract, { kind: "tool_calls" }>, + spec: StructuredToolSpec, + knownNames: Set, + schema: z.ZodType, + thread: TThread, + toolHandler: ThreadReactorConfig["toolHandler"], + messages: ChatMessage[], +): Promise | null> { + messages.push({ + role: "assistant", + content: turn.assistantContent, + tool_calls: turn.calls, + }); + let extractedRound: T | null = null; + for (const tc of turn.calls) { + const extracted = await dispatchToolCall( + tc, + spec, + knownNames, + schema, + thread, + toolHandler, + messages, + ); + if (extracted !== null) { + extractedRound = extracted; + } + } + if (extractedRound !== null) { + return ok(extractedRound); + } + return null; +} + +async function runOneReactRound( + config: ThreadReactorConfig, + args: { thread: TThread; schema: z.ZodType }, + tools: readonly ToolDefinition[], + knownNames: Set, + spec: StructuredToolSpec, + messages: ChatMessage[], +): Promise | null> { + const bodyResult = await config.llm({ messages, tools }); + if (!bodyResult.ok) { + return bodyResult; + } + + const msgResult = firstAssistantMessage(bodyResult.value); + if (!msgResult.ok) { + return msgResult; + } + + const classified = classifyAssistantTurn(msgResult.value, args.schema, spec.name); + if (!classified.ok) { + return classified; + } + + const turn = classified.value; + if (turn.kind === "plain_json") { + return ok(turn.value); + } + + if (turn.kind === "plain_json_invalid") { + messages.push({ role: "assistant", content: turn.rawContent }); + messages.push({ role: "user", content: turn.correction }); + return null; + } + + return resolveToolCallRound( + turn, + spec, + knownNames, + args.schema, + args.thread, + config.toolHandler, + messages, + ); +} + +/** + * Generic ReAct loop: LLM round-trips with tools until structured output validates, + * plain JSON matches schema, or {@link ThreadReactorConfig.maxRounds} is exceeded. + */ +export function createThreadReactor( + config: ThreadReactorConfig, +): ThreadReactorFn { + return async (args: { + thread: TThread; + input: string; + schema: z.ZodType; + }): Promise> => { + const spec = config.structuredToolFromSchema(args.schema); + const tools = [...config.staticTools, spec.tool]; + const knownNames = toolNamesFromDefinitions(tools); + const systemPrompt = config.systemPromptForStructuredTool(spec.name); + + const messages: ChatMessage[] = [ + { role: "system", content: systemPrompt }, + { role: "user", content: args.input }, + ]; + + for (let round = 0; round < config.maxRounds; round++) { + const step = await runOneReactRound( + config, + { thread: args.thread, schema: args.schema }, + tools, + knownNames, + spec, + messages, + ); + if (step !== null) { + return step; + } + } + + return err("max_react_rounds_exceeded"); + }; +} diff --git a/packages/workflow/src/reactor/types.ts b/packages/workflow/src/reactor/types.ts new file mode 100644 index 0000000..5d9d499 --- /dev/null +++ b/packages/workflow/src/reactor/types.ts @@ -0,0 +1,62 @@ +import type * as z from "zod/v4"; + +import type { Result } from "../util/index.js"; + +export type ToolCall = { + id: string; + type: "function"; + function: { name: string; arguments: string }; +}; + +export type ToolDefinition = { + type: "function"; + function: { + name: string; + description: string; + parameters: Record; + }; +}; + +export type ChatMessage = + | { role: "system"; content: string } + | { role: "user"; content: string } + | { + role: "assistant"; + content: string | null; + tool_calls: ToolCall[]; + } + | { role: "assistant"; content: string } + | { role: "tool"; tool_call_id: string; content: string }; + +export type LlmFn = (input: { + messages: ChatMessage[]; + tools: readonly ToolDefinition[]; +}) => Promise>; + +/** Structured tool derived from the per-invocation Zod schema (e.g. extract tool). */ +export type StructuredToolSpec = { + name: string; + tool: ToolDefinition; +}; + +export type ThreadReactorConfig = { + llm: LlmFn; + /** Static tools (e.g. cas_get); structured tool is appended per invocation. */ + staticTools: readonly ToolDefinition[]; + /** Builds the schema-shaped tool and its OpenAI name for this invocation. */ + structuredToolFromSchema: (schema: z.ZodType) => StructuredToolSpec; + /** System prompt for this run; include the structured tool name for cache stability per schema. */ + systemPromptForStructuredTool: (structuredToolName: string) => string; + toolHandler: (call: ToolCall, thread: TThread) => Promise; + maxRounds: number; +}; + +export type ThreadReactorInvokeArgs = { + thread: TThread; + input: string; + schema: z.ZodType; +}; + +export type ThreadReactorFn = ( + args: ThreadReactorInvokeArgs, +) => Promise>; -- 2.43.0 From b8f9ffcb59fb997dc015cb611e7b8abf8beaeb37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E6=A9=98?= Date: Sat, 9 May 2026 02:26:39 +0000 Subject: [PATCH 2/2] feat(workflow): migrate supervisor to ThreadReactor (Phase 2) - Rewrite supervisor to use createThreadReactor + createLlmFn - No direct fetch/HTTP calls in supervisor - All 266 tests passing Refs #139, relates #141 --- .../__tests__/develop-template.test.ts | 2 +- .../__tests__/solve-issue-template.test.ts | 7 +- packages/workflow/__tests__/engine.test.ts | 18 +- .../workflow/__tests__/supervisor.test.ts | 132 ++++++++------ packages/workflow/src/engine/supervisor.ts | 163 ++++++------------ 5 files changed, 149 insertions(+), 173 deletions(-) diff --git a/packages/workflow-template-develop/__tests__/develop-template.test.ts b/packages/workflow-template-develop/__tests__/develop-template.test.ts index f67adba..02c37ae 100644 --- a/packages/workflow-template-develop/__tests__/develop-template.test.ts +++ b/packages/workflow-template-develop/__tests__/develop-template.test.ts @@ -1,6 +1,6 @@ import { describe, expect, test } from "bun:test"; -import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; import { validateWorkflowDescriptor } from "@uncaged/workflow"; +import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; import { buildDevelopDescriptor } from "../src/descriptor.js"; import { developModerator } from "../src/index.js"; import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js"; diff --git a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts index 6fb6e7f..463e7b2 100644 --- a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts +++ b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts @@ -2,7 +2,12 @@ import { afterEach, describe, expect, test } from "bun:test"; import { mkdtemp, rm } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { createCasStore, createExtract, createWorkflow, validateWorkflowDescriptor } from "@uncaged/workflow"; +import { + createCasStore, + createExtract, + createWorkflow, + validateWorkflowDescriptor, +} from "@uncaged/workflow"; import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; import { buildSolveIssueDescriptor } from "../src/descriptor.js"; import type { DeveloperMeta } from "../src/developer.js"; diff --git a/packages/workflow/__tests__/engine.test.ts b/packages/workflow/__tests__/engine.test.ts index 9dbaf6c..cb2d92f 100644 --- a/packages/workflow/__tests__/engine.test.ts +++ b/packages/workflow/__tests__/engine.test.ts @@ -101,10 +101,10 @@ async function writeRegistryYaml(storageRoot: string, yaml: string): Promise>; - supervisorContent: string; + supervisorDecision: "continue" | "stop"; onSupervisorCall?: () => void; }): () => void { const origFetch = globalThis.fetch; @@ -114,9 +114,9 @@ function installMockExtractThenSupervisor(params: { init?: RequestInit, ): Promise => { const body = init?.body ? (JSON.parse(String(init.body)) as Record) : {}; - const tools = body.tools; - const hasTools = Array.isArray(tools) && tools.length > 0; - if (hasTools) { + const model = typeof body.model === "string" ? body.model : ""; + const isSupervisor = model.startsWith("supervisor-"); + if (!isSupervisor) { const args = params.extractArgs[extractI] ?? params.extractArgs[params.extractArgs.length - 1]; if (args === undefined) { @@ -133,7 +133,9 @@ function installMockExtractThenSupervisor(params: { params.onSupervisorCall?.(); return new Response( JSON.stringify({ - choices: [{ message: { content: params.supervisorContent } }], + choices: [ + { message: { content: JSON.stringify({ decision: params.supervisorDecision }) } }, + ], }), { status: 200, headers: { "Content-Type": "application/json" } }, ); @@ -674,7 +676,7 @@ describe("executeThread", () => { test("supervisor stops thread when interval elapses and model returns stop", async () => { restoreFetch = installMockExtractThenSupervisor({ extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }], - supervisorContent: "stop", + supervisorDecision: "stop", }); const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-stop-")); @@ -725,7 +727,7 @@ describe("executeThread", () => { let supervisorCalls = 0; restoreFetch = installMockExtractThenSupervisor({ extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }], - supervisorContent: "stop", + supervisorDecision: "stop", onSupervisorCall: () => { supervisorCalls += 1; }, diff --git a/packages/workflow/__tests__/supervisor.test.ts b/packages/workflow/__tests__/supervisor.test.ts index e2a9186..3591ae8 100644 --- a/packages/workflow/__tests__/supervisor.test.ts +++ b/packages/workflow/__tests__/supervisor.test.ts @@ -1,6 +1,6 @@ import { afterEach, describe, expect, test } from "bun:test"; -import { parseSupervisorDecisionText, runSupervisor } from "../src/engine/supervisor.js"; +import { runSupervisor } from "../src/engine/supervisor.js"; import type { WorkflowConfig } from "../src/registry/index.js"; import type { LogFn } from "../src/util/index.js"; @@ -20,28 +20,23 @@ function supervisorOnlyConfig(): WorkflowConfig { }; } -describe("parseSupervisorDecisionText", () => { - test("reads continue and stop case-insensitively", () => { - expect(parseSupervisorDecisionText("continue")).toBe("continue"); - expect(parseSupervisorDecisionText("CONTINUE")).toBe("continue"); - expect(parseSupervisorDecisionText("stop")).toBe("stop"); - expect(parseSupervisorDecisionText("STOP.")).toBe("stop"); +function jsonResponse(body: Record, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, }); +} - test("finds token inside a sentence", () => { - expect(parseSupervisorDecisionText("Answer: continue")).toBe("continue"); - expect(parseSupervisorDecisionText("I recommend stop now")).toBe("stop"); - }); - - test("when both appear, earlier token wins", () => { - expect(parseSupervisorDecisionText("continue then stop")).toBe("continue"); - expect(parseSupervisorDecisionText("stop then continue")).toBe("stop"); - }); - - test("defaults to continue when unclear", () => { - expect(parseSupervisorDecisionText("maybe later")).toBe("continue"); - }); -}); +function installFetchMock(impl: (init?: RequestInit) => Promise): () => void { + const origFetch = globalThis.fetch; + globalThis.fetch = Object.assign( + async (_input: Parameters[0], init?: RequestInit) => impl(init), + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + return () => { + globalThis.fetch = origFetch; + }; +} describe("runSupervisor", () => { let restoreFetch: (() => void) | null = null; @@ -52,16 +47,9 @@ describe("runSupervisor", () => { }); test("returns continue when supervisor model cannot be resolved (no fetch)", async () => { - const origFetch = globalThis.fetch; - restoreFetch = () => { - globalThis.fetch = origFetch; - }; - globalThis.fetch = Object.assign( - async () => { - throw new Error("fetch should not run when supervisor is not configured"); - }, - { preconnect: origFetch.preconnect.bind(origFetch) }, - ) as typeof fetch; + restoreFetch = installFetchMock(async () => { + throw new Error("fetch should not run when supervisor is not configured"); + }); const config: WorkflowConfig = { maxDepth: 1, @@ -87,21 +75,27 @@ describe("runSupervisor", () => { expect(r.value).toBe("continue"); }); - test("returns stop from chat/completions assistant content", async () => { - const origFetch = globalThis.fetch; - restoreFetch = () => { - globalThis.fetch = origFetch; - }; - globalThis.fetch = Object.assign( - async () => - new Response( - JSON.stringify({ - choices: [{ message: { content: "stop" } }], - }), - { status: 200, headers: { "Content-Type": "application/json" } }, - ), - { preconnect: origFetch.preconnect.bind(origFetch) }, - ) as typeof fetch; + test("returns stop from structured tool call", async () => { + restoreFetch = installFetchMock(async () => + jsonResponse({ + choices: [ + { + message: { + tool_calls: [ + { + id: "t1", + type: "function", + function: { + name: "supervisor_decision", + arguments: JSON.stringify({ decision: "stop" }), + }, + }, + ], + }, + }, + ], + }), + ); const r = await runSupervisor({ config: supervisorOnlyConfig(), @@ -116,14 +110,44 @@ describe("runSupervisor", () => { expect(r.value).toBe("stop"); }); - test("returns err on invalid JSON body", async () => { - const origFetch = globalThis.fetch; - restoreFetch = () => { - globalThis.fetch = origFetch; - }; - globalThis.fetch = Object.assign(async () => new Response("not-json", { status: 200 }), { - preconnect: origFetch.preconnect.bind(origFetch), - }) as typeof fetch; + test("returns continue from plain JSON content (reactor short-circuit)", async () => { + restoreFetch = installFetchMock(async () => + jsonResponse({ + choices: [{ message: { content: '{"decision":"continue"}' } }], + }), + ); + + const r = await runSupervisor({ + config: supervisorOnlyConfig(), + prompt: "do Y", + recentSteps: [], + logger: noopLogger, + }); + expect(r.ok).toBe(true); + if (!r.ok) { + return; + } + expect(r.value).toBe("continue"); + }); + + test("returns err when reactor cannot validate the schema within max rounds", async () => { + restoreFetch = installFetchMock(async () => + jsonResponse({ + choices: [{ message: { content: "not-json" } }], + }), + ); + + const r = await runSupervisor({ + config: supervisorOnlyConfig(), + prompt: "p", + recentSteps: [], + logger: noopLogger, + }); + expect(r.ok).toBe(false); + }); + + test("returns err on HTTP failure", async () => { + restoreFetch = installFetchMock(async () => new Response("boom", { status: 500 })); const r = await runSupervisor({ config: supervisorOnlyConfig(), diff --git a/packages/workflow/src/engine/supervisor.ts b/packages/workflow/src/engine/supervisor.ts index 9b151fa..cef8777 100644 --- a/packages/workflow/src/engine/supervisor.ts +++ b/packages/workflow/src/engine/supervisor.ts @@ -1,67 +1,27 @@ +import * as z from "zod/v4"; + import { resolveModel } from "../config/index.js"; +import { extractFunctionToolFromZodSchema } from "../extract/index.js"; +import { createLlmFn, createThreadReactor } from "../reactor/index.js"; import type { WorkflowConfig } from "../registry/index.js"; import { err, type LogFn, ok, type Result } from "../util/index.js"; import type { SupervisorDecision } from "./types.js"; const SUPERVISOR_RECENT_STEP_LIMIT = 12; +const SUPERVISOR_MAX_REACT_ROUNDS = 4; -function chatCompletionsUrl(baseUrl: string): string { - const trimmed = baseUrl.replace(/\/+$/, ""); - return `${trimmed}/chat/completions`; -} +const supervisorDecisionSchema = z + .object({ + decision: z.enum(["continue", "stop"]), + }) + .meta({ + title: "supervisor_decision", + description: + 'Workflow supervisor decision. "continue" when the thread is making progress; "stop" when done, looping, or stuck.', + }); -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null && !Array.isArray(value); -} - -function readAssistantContent(parsed: unknown): string | null { - if (!isRecord(parsed)) { - return null; - } - const choices = parsed.choices; - if (!Array.isArray(choices) || choices.length === 0) { - return null; - } - const first = choices[0]; - if (!isRecord(first)) { - return null; - } - const messageObj = first.message; - if (!isRecord(messageObj)) { - return null; - } - const content = messageObj.content; - if (typeof content !== "string") { - return null; - } - return content; -} - -/** Lenient: accepts STOP/stop/stop. as prose; prefers {@link SupervisorDecision.stop} when both tokens appear. */ -export function parseSupervisorDecisionText(text: string): SupervisorDecision { - const lower = text.toLowerCase(); - const stopWord = /\bstop\b/.test(lower); - const continueWord = /\bcontinue\b/.test(lower); - if (stopWord && continueWord) { - const si = lower.search(/\bstop\b/); - const ci = lower.search(/\bcontinue\b/); - return si <= ci ? "stop" : "continue"; - } - if (stopWord) { - return "stop"; - } - if (continueWord) { - return "continue"; - } - if (lower.includes("stop")) { - return "stop"; - } - if (lower.includes("continue")) { - return "continue"; - } - return "continue"; -} +type SupervisorThreadContext = Record; type RunSupervisorArgs = { config: WorkflowConfig; @@ -70,7 +30,13 @@ type RunSupervisorArgs = { logger: LogFn; }; -/** Calls the `supervisor` scene LLM; opt-out when {@link resolveModel} fails (returns ok(`continue`)). */ +function buildSupervisorInput(args: RunSupervisorArgs): string { + const recent = args.recentSteps.slice(-SUPERVISOR_RECENT_STEP_LIMIT); + const stepsBlock = recent.map((s, index) => `${index + 1}. [${s.role}] ${s.summary}`).join("\n"); + return `Original task:\n${args.prompt}\n\nRecent steps (oldest first):\n${stepsBlock === "" ? "(none)" : stepsBlock}`; +} + +/** Calls the `supervisor` scene via {@link createThreadReactor}; opt-out when {@link resolveModel} fails (returns ok(`continue`)). */ export async function runSupervisor( args: RunSupervisorArgs, ): Promise> { @@ -78,63 +44,42 @@ export async function runSupervisor( if (!resolved.ok) { return ok("continue"); } - const provider = resolved.value; - const recent = args.recentSteps.slice(-SUPERVISOR_RECENT_STEP_LIMIT); - const stepsBlock = recent.map((s, index) => `${index + 1}. [${s.role}] ${s.summary}`).join("\n"); - const body = { - model: provider.model, - messages: [ - { - role: "system" as const, - content: - 'You supervise a multi-step workflow. Decide if the thread should keep running or halt.\n\nReply with exactly one token: either "continue" (progress toward the goal, not obviously stuck) or "stop" (done, looping, or no progress). Do not add explanation.', - }, - { - role: "user" as const, - content: `Original task:\n${args.prompt}\n\nRecent steps (oldest first):\n${stepsBlock === "" ? "(none)" : stepsBlock}`, - }, - ], - }; + const reactor = createThreadReactor({ + llm: createLlmFn(resolved.value), + maxRounds: SUPERVISOR_MAX_REACT_ROUNDS, + staticTools: [], + structuredToolFromSchema: (schema) => { + const t = extractFunctionToolFromZodSchema(schema); + return { + name: t.name, + tool: { + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }, + }; + }, + systemPromptForStructuredTool: (structuredToolName) => + `You supervise a multi-step workflow. Decide whether the thread should keep running or halt. Reply with "continue" when the thread is making progress toward the task, or "stop" when it is finished, looping, or no longer making progress. Call the ${structuredToolName} tool with JSON arguments matching the schema, or reply with only a JSON object such as {"decision":"stop"}.`, + toolHandler: async (call) => `Unknown tool: ${call.function.name}`, + }); - let response: Response; - try { - response = await fetch(chatCompletionsUrl(provider.baseUrl), { - method: "POST", - headers: { - Authorization: `Bearer ${provider.apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - }); - } catch (cause) { - const message = cause instanceof Error ? cause.message : String(cause); - args.logger("R9CW4PLM", `supervisor request failed: ${message}`); - return err(`supervisor network error: ${message}`); + const result = await reactor({ + thread: {} as SupervisorThreadContext, + input: buildSupervisorInput(args), + schema: supervisorDecisionSchema, + }); + + if (!result.ok) { + args.logger("R9CW4PLM", `supervisor failed: ${result.error}`); + return err(`supervisor: ${result.error}`); } - const responseText = await response.text(); - if (!response.ok) { - args.logger("T3HN8VKQ", `supervisor HTTP ${response.status}: ${responseText.slice(0, 200)}`); - return err(`supervisor HTTP ${response.status}: ${responseText.slice(0, 500)}`); - } - - let parsed: unknown; - try { - parsed = JSON.parse(responseText) as unknown; - } catch (cause) { - const message = cause instanceof Error ? cause.message : String(cause); - args.logger("W7BQ2NXM", `supervisor response is not JSON: ${message}`); - return err(`supervisor invalid JSON: ${message}`); - } - - const content = readAssistantContent(parsed); - if (content === null || content.trim() === "") { - args.logger("Y4JX9PKW", "supervisor returned empty assistant content"); - return err("supervisor empty assistant content"); - } - - const decision = parseSupervisorDecisionText(content); + const decision: SupervisorDecision = result.value.decision; args.logger("Z8KM5QWT", `supervisor says ${decision}`); return ok(decision); } -- 2.43.0