From 014c442ed23f27bad0765e24044088943064b10b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E6=A9=98?= Date: Fri, 8 May 2026 02:38:54 +0000 Subject: [PATCH] =?UTF-8?q?feat(engine):=20add=20supervisor=20scene=20?= =?UTF-8?q?=E2=80=94=20opt-in=20LLM-based=20thread=20stop=20(Phase=203)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Supervisor replaces maxRounds as primary stop mechanism. Every N rounds (configurable via supervisorInterval, default 3), the engine calls a cheap LLM to evaluate thread progress and decide continue/stop. - New engine/supervisor.ts: runSupervisor + parseSupervisorDecisionText - Supervisor is opt-in: no models.supervisor configured = always continue - WorkflowConfig gains supervisorInterval (default 3, 0 to disable) - Engine calls supervisor after each supervisorInterval rounds - 256 tests pass, 14 new tests for supervisor logic Refs #110 --- packages/workflow/__tests__/engine.test.ts | 190 ++++++++++++++++++ packages/workflow/__tests__/registry.test.ts | 59 ++++++ .../workflow/__tests__/resolve-model.test.ts | 4 + .../workflow/__tests__/supervisor.test.ts | 136 +++++++++++++ .../__tests__/workflow-as-agent.test.ts | 1 + packages/workflow/src/engine/engine.ts | 123 ++++++++++-- packages/workflow/src/engine/index.ts | 1 + packages/workflow/src/engine/supervisor.ts | 140 +++++++++++++ packages/workflow/src/engine/types.ts | 2 + packages/workflow/src/index.ts | 1 + .../src/registry/registry-normalize.ts | 13 ++ packages/workflow/src/registry/types.ts | 2 + 12 files changed, 655 insertions(+), 17 deletions(-) create mode 100644 packages/workflow/__tests__/supervisor.test.ts create mode 100644 packages/workflow/src/engine/supervisor.ts diff --git a/packages/workflow/__tests__/engine.test.ts b/packages/workflow/__tests__/engine.test.ts index 78f04b1..1e668ed 100644 --- a/packages/workflow/__tests__/engine.test.ts +++ b/packages/workflow/__tests__/engine.test.ts @@ -96,6 +96,98 @@ async function writeExtractRegistryConfig(storageRoot: string): Promise { await writeFile(join(storageRoot, "workflow.yaml"), EXTRACT_REGISTRY_YAML, "utf8"); } +const SUPERVISOR_INTERVAL_REGISTRY_YAML = `config: + maxDepth: 3 + supervisorInterval: 2 + providers: + stub: + baseUrl: http://127.0.0.1:9 + apiKey: test + models: + extract: stub/model + supervisor: stub/supervisor-cheap +workflows: {} +`; + +const SUPERVISOR_LONG_INTERVAL_REGISTRY_YAML = `config: + maxDepth: 3 + supervisorInterval: 10 + providers: + stub: + baseUrl: http://127.0.0.1:9 + apiKey: test + models: + extract: stub/model + supervisor: stub/supervisor-cheap +workflows: {} +`; + +async function writeRegistryYaml(storageRoot: string, yaml: string): Promise { + await writeFile(join(storageRoot, "workflow.yaml"), yaml, "utf8"); +} + +/** Extract rounds use tool_calls; supervisor uses plain `content` (no tools). */ +function installMockExtractThenSupervisor(params: { + extractArgs: ReadonlyArray>; + supervisorContent: string; + onSupervisorCall?: () => void; +}): () => void { + const origFetch = globalThis.fetch; + let extractI = 0; + const mockFetch = async ( + _input: Parameters[0], + init?: RequestInit, + ): Promise => { + const body = init?.body ? (JSON.parse(String(init.body)) as Record) : {}; + const tools = body.tools; + const hasTools = Array.isArray(tools) && tools.length > 0; + if (hasTools) { + const args = + params.extractArgs[extractI] ?? params.extractArgs[params.extractArgs.length - 1]; + if (args === undefined) { + throw new Error("installMockExtractThenSupervisor: empty extractArgs"); + } + extractI += 1; + const firstTool = tools[0] as Record; + const fn = firstTool.function as Record | undefined; + const toolName = typeof fn?.name === "string" ? fn.name : "extract"; + return new Response( + JSON.stringify({ + choices: [ + { + message: { + tool_calls: [ + { + type: "function", + function: { + name: toolName, + arguments: JSON.stringify(args), + }, + }, + ], + }, + }, + ], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + } + params.onSupervisorCall?.(); + return new Response( + JSON.stringify({ + choices: [{ message: { content: params.supervisorContent } }], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + }; + globalThis.fetch = Object.assign(mockFetch, { + preconnect: origFetch.preconnect.bind(origFetch), + }) as typeof fetch; + return () => { + globalThis.fetch = origFetch; + }; +} + const demoWorkflow = createWorkflow( { roles: { @@ -623,4 +715,102 @@ describe("executeThread", () => { await rm(root, { recursive: true, force: true }); } }); + + test("supervisor stops thread when interval elapses and model returns stop", async () => { + restoreFetch = installMockExtractThenSupervisor({ + extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }], + supervisorContent: "stop", + }); + + const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-stop-")); + try { + const threadId = "01KQXKW18CT8G75T53R8F4G7YG"; + const hash = "C9NMV6V2TQT81"; + const dataPath = join(root, "logs", hash, `${threadId}.data.jsonl`); + const infoPath = join(root, "logs", hash, `${threadId}.info.jsonl`); + await mkdir(join(root, "logs", hash), { recursive: true }); + await writeRegistryYaml(root, SUPERVISOR_INTERVAL_REGISTRY_YAML); + const cas = createCasStore(join(root, "cas")); + + const logger = createLogger({ sink: { kind: "file", path: infoPath } }); + const ac = new AbortController(); + + const result = await executeThread( + demoWorkflow, + "demo-flow", + { prompt: "supervisor-stop-case", steps: [] }, + { + maxRounds: 20, + depth: 0, + signal: ac.signal, + awaitAfterEachYield: async () => {}, + forkSourceThreadId: null, + prefilledDiskSteps: null, + storageRoot: root, + }, + { threadId, hash, dataJsonlPath: dataPath, infoJsonlPath: infoPath, cas }, + logger, + ); + + expect(result.returnCode).toBe(0); + expect(result.summary).toBe("completed: supervisor stopped thread"); + + const dataText = await readFile(dataPath, "utf8"); + const lines = dataText + .trim() + .split("\n") + .filter((l) => l !== ""); + expect(lines.length).toBe(3); + } finally { + await rm(root, { recursive: true, force: true }); + } + }); + + test("supervisor is not invoked before supervisorInterval rounds", async () => { + let supervisorCalls = 0; + restoreFetch = installMockExtractThenSupervisor({ + extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }], + supervisorContent: "stop", + onSupervisorCall: () => { + supervisorCalls += 1; + }, + }); + + const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-skip-")); + try { + const threadId = "01KQXKW18CT8G75T53R8F4G7YG"; + const hash = "C9NMV6V2TQT81"; + const dataPath = join(root, "logs", hash, `${threadId}.data.jsonl`); + const infoPath = join(root, "logs", hash, `${threadId}.info.jsonl`); + await mkdir(join(root, "logs", hash), { recursive: true }); + await writeRegistryYaml(root, SUPERVISOR_LONG_INTERVAL_REGISTRY_YAML); + const cas = createCasStore(join(root, "cas")); + + const logger = createLogger({ sink: { kind: "file", path: infoPath } }); + const ac = new AbortController(); + + const result = await executeThread( + demoWorkflow, + "demo-flow", + { prompt: "no-supervisor-yet", steps: [] }, + { + maxRounds: 20, + depth: 0, + signal: ac.signal, + awaitAfterEachYield: async () => {}, + forkSourceThreadId: null, + prefilledDiskSteps: null, + storageRoot: root, + }, + { threadId, hash, dataJsonlPath: dataPath, infoJsonlPath: infoPath, cas }, + logger, + ); + + expect(supervisorCalls).toBe(0); + expect(result.returnCode).toBe(0); + expect(result.summary).toBe("completed: moderator returned END"); + } finally { + await rm(root, { recursive: true, force: true }); + } + }); }); diff --git a/packages/workflow/__tests__/registry.test.ts b/packages/workflow/__tests__/registry.test.ts index d9b568b..20a7642 100644 --- a/packages/workflow/__tests__/registry.test.ts +++ b/packages/workflow/__tests__/registry.test.ts @@ -132,6 +132,65 @@ workflows: expect(r.value.config.providers.dashscope?.apiKey).toBe("secret-key"); expect(r.value.config.models.extract).toBe("dashscope/qwen-plus"); expect(r.value.config.models.default).toBe("dashscope/qwen-turbo"); + expect(r.value.config.supervisorInterval).toBe(3); + }); + + test("defaults supervisorInterval to 3 when omitted", () => { + const yaml = ` +config: + maxDepth: 0 + providers: + p: + baseUrl: https://example.com + apiKey: k + models: + default: p/m +workflows: {} +`; + const r = parseWorkflowRegistryYaml(yaml); + expect(r.ok).toBe(true); + if (!r.ok || r.value.config === null) { + return; + } + expect(r.value.config.supervisorInterval).toBe(3); + }); + + test("parses explicit supervisorInterval", () => { + const yaml = ` +config: + maxDepth: 0 + supervisorInterval: 7 + providers: + p: + baseUrl: https://example.com + apiKey: k + models: + default: p/m +workflows: {} +`; + const r = parseWorkflowRegistryYaml(yaml); + expect(r.ok).toBe(true); + if (!r.ok || r.value.config === null) { + return; + } + expect(r.value.config.supervisorInterval).toBe(7); + }); + + test("parse errors when supervisorInterval is negative", () => { + const yaml = ` +config: + maxDepth: 0 + supervisorInterval: -1 + providers: + p: + baseUrl: https://example.com + apiKey: k + models: + default: p/m +workflows: {} +`; + const r = parseWorkflowRegistryYaml(yaml); + expect(r.ok).toBe(false); }); test("parses config apiKey env: prefix from process.env", () => { diff --git a/packages/workflow/__tests__/resolve-model.test.ts b/packages/workflow/__tests__/resolve-model.test.ts index f369e49..7612cc0 100644 --- a/packages/workflow/__tests__/resolve-model.test.ts +++ b/packages/workflow/__tests__/resolve-model.test.ts @@ -6,6 +6,7 @@ import type { WorkflowConfig } from "../src/registry/index.js"; function sampleConfig(): WorkflowConfig { return { maxDepth: 3, + supervisorInterval: 3, providers: { dashscope: { baseUrl: "https://dashscope.aliyuncs.com/compatible-mode/v1", @@ -50,6 +51,7 @@ describe("resolveModel", () => { test("errs when scene missing and no default", () => { const config: WorkflowConfig = { maxDepth: 1, + supervisorInterval: 3, providers: { p: { baseUrl: "https://x", apiKey: "k" }, }, @@ -69,6 +71,7 @@ describe("resolveModel", () => { test("errs when provider is unknown", () => { const config: WorkflowConfig = { maxDepth: 1, + supervisorInterval: 3, providers: { p: { baseUrl: "https://x", apiKey: "k" }, }, @@ -87,6 +90,7 @@ describe("resolveModel", () => { test("errs on invalid model reference shape", () => { const config: WorkflowConfig = { maxDepth: 1, + supervisorInterval: 3, providers: { p: { baseUrl: "https://x", apiKey: "k" }, }, diff --git a/packages/workflow/__tests__/supervisor.test.ts b/packages/workflow/__tests__/supervisor.test.ts new file mode 100644 index 0000000..e2a9186 --- /dev/null +++ b/packages/workflow/__tests__/supervisor.test.ts @@ -0,0 +1,136 @@ +import { afterEach, describe, expect, test } from "bun:test"; + +import { parseSupervisorDecisionText, runSupervisor } from "../src/engine/supervisor.js"; +import type { WorkflowConfig } from "../src/registry/index.js"; +import type { LogFn } from "../src/util/index.js"; + +const noopLogger: LogFn = () => {}; + +function supervisorOnlyConfig(): WorkflowConfig { + return { + maxDepth: 3, + supervisorInterval: 3, + providers: { + stub: { baseUrl: "http://127.0.0.1:9/v1", apiKey: "k" }, + }, + models: { + extract: "stub/extract-model", + supervisor: "stub/supervisor-model", + }, + }; +} + +describe("parseSupervisorDecisionText", () => { + test("reads continue and stop case-insensitively", () => { + expect(parseSupervisorDecisionText("continue")).toBe("continue"); + expect(parseSupervisorDecisionText("CONTINUE")).toBe("continue"); + expect(parseSupervisorDecisionText("stop")).toBe("stop"); + expect(parseSupervisorDecisionText("STOP.")).toBe("stop"); + }); + + test("finds token inside a sentence", () => { + expect(parseSupervisorDecisionText("Answer: continue")).toBe("continue"); + expect(parseSupervisorDecisionText("I recommend stop now")).toBe("stop"); + }); + + test("when both appear, earlier token wins", () => { + expect(parseSupervisorDecisionText("continue then stop")).toBe("continue"); + expect(parseSupervisorDecisionText("stop then continue")).toBe("stop"); + }); + + test("defaults to continue when unclear", () => { + expect(parseSupervisorDecisionText("maybe later")).toBe("continue"); + }); +}); + +describe("runSupervisor", () => { + let restoreFetch: (() => void) | null = null; + + afterEach(() => { + restoreFetch?.(); + restoreFetch = null; + }); + + test("returns continue when supervisor model cannot be resolved (no fetch)", async () => { + const origFetch = globalThis.fetch; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + globalThis.fetch = Object.assign( + async () => { + throw new Error("fetch should not run when supervisor is not configured"); + }, + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + + const config: WorkflowConfig = { + maxDepth: 1, + supervisorInterval: 3, + providers: { + stub: { baseUrl: "http://127.0.0.1:9/v1", apiKey: "k" }, + }, + models: { + extract: "stub/m", + }, + }; + + const r = await runSupervisor({ + config, + prompt: "task", + recentSteps: [{ role: "planner", summary: "{}" }], + logger: noopLogger, + }); + expect(r.ok).toBe(true); + if (!r.ok) { + return; + } + expect(r.value).toBe("continue"); + }); + + test("returns stop from chat/completions assistant content", async () => { + const origFetch = globalThis.fetch; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + globalThis.fetch = Object.assign( + async () => + new Response( + JSON.stringify({ + choices: [{ message: { content: "stop" } }], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ), + { preconnect: origFetch.preconnect.bind(origFetch) }, + ) as typeof fetch; + + const r = await runSupervisor({ + config: supervisorOnlyConfig(), + prompt: "do X", + recentSteps: [{ role: "a", summary: "{}" }], + logger: noopLogger, + }); + expect(r.ok).toBe(true); + if (!r.ok) { + return; + } + expect(r.value).toBe("stop"); + }); + + test("returns err on invalid JSON body", async () => { + const origFetch = globalThis.fetch; + restoreFetch = () => { + globalThis.fetch = origFetch; + }; + globalThis.fetch = Object.assign(async () => new Response("not-json", { status: 200 }), { + preconnect: origFetch.preconnect.bind(origFetch), + }) as typeof fetch; + + const r = await runSupervisor({ + config: supervisorOnlyConfig(), + prompt: "p", + recentSteps: [], + logger: noopLogger, + }); + expect(r.ok).toBe(false); + }); +}); diff --git a/packages/workflow/__tests__/workflow-as-agent.test.ts b/packages/workflow/__tests__/workflow-as-agent.test.ts index 7e3d1d0..8237d73 100644 --- a/packages/workflow/__tests__/workflow-as-agent.test.ts +++ b/packages/workflow/__tests__/workflow-as-agent.test.ts @@ -155,6 +155,7 @@ workflows: {} ...reg.value, config: { maxDepth: 2, + supervisorInterval: 3, providers: { local: { baseUrl: "http://127.0.0.1:9", diff --git a/packages/workflow/src/engine/engine.ts b/packages/workflow/src/engine/engine.ts index f603fb0..03cb5b7 100644 --- a/packages/workflow/src/engine/engine.ts +++ b/packages/workflow/src/engine/engine.ts @@ -9,7 +9,7 @@ import { } from "../cas/index.js"; import { resolveModel } from "../config/index.js"; import { createExtract } from "../extract/index.js"; -import { readWorkflowRegistry } from "../registry/index.js"; +import { readWorkflowRegistry, type WorkflowConfig } from "../registry/index.js"; import type { LlmProvider, ThreadInput, @@ -20,12 +20,18 @@ import type { } from "../types.js"; import { err, type LogFn, normalizeRefsField, ok, type Result } from "../util/index.js"; +import { runSupervisor } from "./supervisor.js"; import type { ExecuteThreadIo, ExecuteThreadOptions } from "./types.js"; -async function resolveExtractRuntime( - storageRoot: string, -): Promise< - Result<{ extract: ReturnType; llmProvider: LlmProvider }, string> +async function resolveEngineRegistryRuntime(storageRoot: string): Promise< + Result< + { + extract: ReturnType; + llmProvider: LlmProvider; + workflowConfig: WorkflowConfig; + }, + string + > > { const reg = await readWorkflowRegistry(storageRoot); if (!reg.ok) { @@ -45,7 +51,7 @@ async function resolveExtractRuntime( apiKey: ex.apiKey, model: ex.model, }; - return ok({ extract: createExtract(llmProvider), llmProvider }); + return ok({ extract: createExtract(llmProvider), llmProvider, workflowConfig: cfg }); } async function appendDataLine(path: string, record: unknown): Promise { @@ -79,9 +85,66 @@ async function finalizeThreadResult(params: { }; } +async function finalizeAbortedThread(params: { + cas: CasStore; + workflowName: string; + threadId: string; + stepMerkleHashes: string[]; + logger: LogFn; + abortLogTag: string; +}): Promise { + params.logger(params.abortLogTag, `thread ${params.threadId} aborted`); + return finalizeThreadResult({ + cas: params.cas, + workflowName: params.workflowName, + threadId: params.threadId, + stepMerkleHashes: params.stepMerkleHashes, + completion: { returnCode: 130, summary: "thread aborted" }, + }); +} + +async function maybeSupervisorHaltsThread(params: { + workflowConfig: WorkflowConfig; + input: ThreadInput; + written: number; + recentSupervisorSteps: readonly { role: string; summary: string }[]; + logger: LogFn; + threadId: string; + cas: CasStore; + workflowName: string; + stepMerkleHashes: string[]; +}): Promise { + const interval = params.workflowConfig.supervisorInterval; + if (interval <= 0 || params.written % interval !== 0) { + return null; + } + const sup = await runSupervisor({ + config: params.workflowConfig, + prompt: params.input.prompt, + recentSteps: params.recentSupervisorSteps, + logger: params.logger, + }); + if (!sup.ok) { + params.logger("K6PW9NYT", `supervisor skipped: ${sup.error}`); + return null; + } + if (sup.value !== "stop") { + return null; + } + params.logger("M4QX8VHN", `thread ${params.threadId} stopped by supervisor`); + return finalizeThreadResult({ + cas: params.cas, + workflowName: params.workflowName, + threadId: params.threadId, + stepMerkleHashes: params.stepMerkleHashes, + completion: { returnCode: 0, summary: "completed: supervisor stopped thread" }, + }); +} + async function driveWorkflowGenerator(params: { fn: WorkflowFn; workflowName: string; + workflowConfig: WorkflowConfig; input: ThreadInput; bundleOptions: WorkflowFnOptions; executeOptions: ExecuteThreadOptions; @@ -94,6 +157,7 @@ async function driveWorkflowGenerator(params: { const { fn, workflowName, + workflowConfig, input, bundleOptions, executeOptions, @@ -105,16 +169,20 @@ async function driveWorkflowGenerator(params: { } = params; const gen = fn(input, bundleOptions); let written = 0; + const recentSupervisorSteps: { role: string; summary: string }[] = input.steps.map((s) => ({ + role: s.role, + summary: JSON.stringify(s.meta), + })); while (true) { if (executeOptions.signal.aborted) { - logger("V8JX4NP2", `thread ${threadId} aborted`); - return await finalizeThreadResult({ + return await finalizeAbortedThread({ cas, workflowName, threadId, stepMerkleHashes, - completion: { returnCode: 130, summary: "thread aborted" }, + logger, + abortLogTag: "V8JX4NP2", }); } @@ -172,6 +240,11 @@ async function driveWorkflowGenerator(params: { logger("N7BW4YHQ", `thread ${threadId} wrote role ${step.role}`); + recentSupervisorSteps.push({ + role: step.role, + summary: JSON.stringify(step.meta), + }); + await Promise.race([ executeOptions.awaitAfterEachYield(), new Promise((resolve) => { @@ -184,15 +257,30 @@ async function driveWorkflowGenerator(params: { ]); if (executeOptions.signal.aborted) { - logger("V8JX4NP4", `thread ${threadId} aborted`); - return await finalizeThreadResult({ + return await finalizeAbortedThread({ cas, workflowName, threadId, stepMerkleHashes, - completion: { returnCode: 130, summary: "thread aborted" }, + logger, + abortLogTag: "V8JX4NP4", }); } + + const supervised = await maybeSupervisorHaltsThread({ + workflowConfig, + input, + written, + recentSupervisorSteps, + logger, + threadId, + cas, + workflowName, + stepMerkleHashes, + }); + if (supervised !== null) { + return supervised; + } } } @@ -280,9 +368,9 @@ export async function executeThread( }); } - const extractRuntime = await resolveExtractRuntime(options.storageRoot); - if (!extractRuntime.ok) { - throw new Error(extractRuntime.error); + const registryRuntime = await resolveEngineRegistryRuntime(options.storageRoot); + if (!registryRuntime.ok) { + throw new Error(registryRuntime.error); } const bundleOptions: WorkflowFnOptions = { @@ -290,13 +378,14 @@ export async function executeThread( maxRounds: options.maxRounds, depth: options.depth, cas: io.cas, - extract: extractRuntime.value.extract, - llmProvider: extractRuntime.value.llmProvider, + extract: registryRuntime.value.extract, + llmProvider: registryRuntime.value.llmProvider, }; return await driveWorkflowGenerator({ fn, workflowName, + workflowConfig: registryRuntime.value.workflowConfig, input, bundleOptions, executeOptions: options, diff --git a/packages/workflow/src/engine/index.ts b/packages/workflow/src/engine/index.ts index 3ad776b..f7b6d74 100644 --- a/packages/workflow/src/engine/index.ts +++ b/packages/workflow/src/engine/index.ts @@ -17,6 +17,7 @@ export type { GcResult, ParsedThreadStartRecord, PrefilledDiskStep, + SupervisorDecision, ThreadPauseGate, } from "./types.js"; export { getWorkerHostScriptPath } from "./worker-entry-path.js"; diff --git a/packages/workflow/src/engine/supervisor.ts b/packages/workflow/src/engine/supervisor.ts new file mode 100644 index 0000000..9b151fa --- /dev/null +++ b/packages/workflow/src/engine/supervisor.ts @@ -0,0 +1,140 @@ +import { resolveModel } from "../config/index.js"; +import type { WorkflowConfig } from "../registry/index.js"; +import { err, type LogFn, ok, type Result } from "../util/index.js"; + +import type { SupervisorDecision } from "./types.js"; + +const SUPERVISOR_RECENT_STEP_LIMIT = 12; + +function chatCompletionsUrl(baseUrl: string): string { + const trimmed = baseUrl.replace(/\/+$/, ""); + return `${trimmed}/chat/completions`; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function readAssistantContent(parsed: unknown): string | null { + if (!isRecord(parsed)) { + return null; + } + const choices = parsed.choices; + if (!Array.isArray(choices) || choices.length === 0) { + return null; + } + const first = choices[0]; + if (!isRecord(first)) { + return null; + } + const messageObj = first.message; + if (!isRecord(messageObj)) { + return null; + } + const content = messageObj.content; + if (typeof content !== "string") { + return null; + } + return content; +} + +/** Lenient: accepts STOP/stop/stop. as prose; prefers {@link SupervisorDecision.stop} when both tokens appear. */ +export function parseSupervisorDecisionText(text: string): SupervisorDecision { + const lower = text.toLowerCase(); + const stopWord = /\bstop\b/.test(lower); + const continueWord = /\bcontinue\b/.test(lower); + if (stopWord && continueWord) { + const si = lower.search(/\bstop\b/); + const ci = lower.search(/\bcontinue\b/); + return si <= ci ? "stop" : "continue"; + } + if (stopWord) { + return "stop"; + } + if (continueWord) { + return "continue"; + } + if (lower.includes("stop")) { + return "stop"; + } + if (lower.includes("continue")) { + return "continue"; + } + return "continue"; +} + +type RunSupervisorArgs = { + config: WorkflowConfig; + prompt: string; + recentSteps: readonly { role: string; summary: string }[]; + logger: LogFn; +}; + +/** Calls the `supervisor` scene LLM; opt-out when {@link resolveModel} fails (returns ok(`continue`)). */ +export async function runSupervisor( + args: RunSupervisorArgs, +): Promise> { + const resolved = resolveModel(args.config, "supervisor"); + if (!resolved.ok) { + return ok("continue"); + } + const provider = resolved.value; + const recent = args.recentSteps.slice(-SUPERVISOR_RECENT_STEP_LIMIT); + const stepsBlock = recent.map((s, index) => `${index + 1}. [${s.role}] ${s.summary}`).join("\n"); + + const body = { + model: provider.model, + messages: [ + { + role: "system" as const, + content: + 'You supervise a multi-step workflow. Decide if the thread should keep running or halt.\n\nReply with exactly one token: either "continue" (progress toward the goal, not obviously stuck) or "stop" (done, looping, or no progress). Do not add explanation.', + }, + { + role: "user" as const, + content: `Original task:\n${args.prompt}\n\nRecent steps (oldest first):\n${stepsBlock === "" ? "(none)" : stepsBlock}`, + }, + ], + }; + + let response: Response; + try { + response = await fetch(chatCompletionsUrl(provider.baseUrl), { + method: "POST", + headers: { + Authorization: `Bearer ${provider.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }); + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + args.logger("R9CW4PLM", `supervisor request failed: ${message}`); + return err(`supervisor network error: ${message}`); + } + + const responseText = await response.text(); + if (!response.ok) { + args.logger("T3HN8VKQ", `supervisor HTTP ${response.status}: ${responseText.slice(0, 200)}`); + return err(`supervisor HTTP ${response.status}: ${responseText.slice(0, 500)}`); + } + + let parsed: unknown; + try { + parsed = JSON.parse(responseText) as unknown; + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + args.logger("W7BQ2NXM", `supervisor response is not JSON: ${message}`); + return err(`supervisor invalid JSON: ${message}`); + } + + const content = readAssistantContent(parsed); + if (content === null || content.trim() === "") { + args.logger("Y4JX9PKW", "supervisor returned empty assistant content"); + return err("supervisor empty assistant content"); + } + + const decision = parseSupervisorDecisionText(content); + args.logger("Z8KM5QWT", `supervisor says ${decision}`); + return ok(decision); +} diff --git a/packages/workflow/src/engine/types.ts b/packages/workflow/src/engine/types.ts index d0c1391..8c2d645 100644 --- a/packages/workflow/src/engine/types.ts +++ b/packages/workflow/src/engine/types.ts @@ -2,6 +2,8 @@ import type { CasStore } from "../cas/index.js"; import type { RoleOutput } from "../types.js"; import type { Result } from "../util/index.js"; +export type SupervisorDecision = "continue" | "stop"; + export type ExecuteThreadIo = { threadId: string; hash: string; diff --git a/packages/workflow/src/index.ts b/packages/workflow/src/index.ts index 617ba9e..b274d7d 100644 --- a/packages/workflow/src/index.ts +++ b/packages/workflow/src/index.ts @@ -48,6 +48,7 @@ export { type ParsedThreadStartRecord, type PrefilledDiskStep, parseThreadDataJsonl, + type SupervisorDecision, selectForkHistoricalSteps, type ThreadPauseGate, tryParseRoleStepRecord, diff --git a/packages/workflow/src/registry/registry-normalize.ts b/packages/workflow/src/registry/registry-normalize.ts index 3927941..5f493fc 100644 --- a/packages/workflow/src/registry/registry-normalize.ts +++ b/packages/workflow/src/registry/registry-normalize.ts @@ -110,11 +110,23 @@ function normalizeWorkflowConfig(raw: unknown): Result { } const c = raw as Record; const maxDepth = c.maxDepth; + const supervisorIntervalRaw = c.supervisorInterval; const providersRaw = c.providers; const modelsRaw = c.models; if (typeof maxDepth !== "number" || !Number.isInteger(maxDepth) || maxDepth < 0) { return err(new Error("config.maxDepth must be a non-negative integer")); } + let supervisorInterval = 3; + if (supervisorIntervalRaw !== undefined) { + if ( + typeof supervisorIntervalRaw !== "number" || + !Number.isInteger(supervisorIntervalRaw) || + supervisorIntervalRaw < 0 + ) { + return err(new Error("config.supervisorInterval must be a non-negative integer")); + } + supervisorInterval = supervisorIntervalRaw; + } const providersResult = normalizeProviders(providersRaw); if (!providersResult.ok) { return providersResult; @@ -125,6 +137,7 @@ function normalizeWorkflowConfig(raw: unknown): Result { } return ok({ maxDepth, + supervisorInterval, providers: providersResult.value, models: modelsResult.value, }); diff --git a/packages/workflow/src/registry/types.ts b/packages/workflow/src/registry/types.ts index 606ed15..8cba10f 100644 --- a/packages/workflow/src/registry/types.ts +++ b/packages/workflow/src/registry/types.ts @@ -13,6 +13,8 @@ export type WorkflowRegistryEntry = { export type WorkflowConfig = { maxDepth: number; + /** Run supervisor LLM every N completed role rounds (0 = disabled). Default from YAML: 3. */ + supervisorInterval: number; providers: Record; models: Record; };