Merge pull request 'feat(engine): supervisor scene — opt-in LLM thread stop (Phase 3)' (#116) from feat/110-phase3-supervisor into main
This commit is contained in:
@@ -96,6 +96,98 @@ async function writeExtractRegistryConfig(storageRoot: string): Promise<void> {
|
||||
await writeFile(join(storageRoot, "workflow.yaml"), EXTRACT_REGISTRY_YAML, "utf8");
|
||||
}
|
||||
|
||||
const SUPERVISOR_INTERVAL_REGISTRY_YAML = `config:
|
||||
maxDepth: 3
|
||||
supervisorInterval: 2
|
||||
providers:
|
||||
stub:
|
||||
baseUrl: http://127.0.0.1:9
|
||||
apiKey: test
|
||||
models:
|
||||
extract: stub/model
|
||||
supervisor: stub/supervisor-cheap
|
||||
workflows: {}
|
||||
`;
|
||||
|
||||
const SUPERVISOR_LONG_INTERVAL_REGISTRY_YAML = `config:
|
||||
maxDepth: 3
|
||||
supervisorInterval: 10
|
||||
providers:
|
||||
stub:
|
||||
baseUrl: http://127.0.0.1:9
|
||||
apiKey: test
|
||||
models:
|
||||
extract: stub/model
|
||||
supervisor: stub/supervisor-cheap
|
||||
workflows: {}
|
||||
`;
|
||||
|
||||
async function writeRegistryYaml(storageRoot: string, yaml: string): Promise<void> {
|
||||
await writeFile(join(storageRoot, "workflow.yaml"), yaml, "utf8");
|
||||
}
|
||||
|
||||
/** Extract rounds use tool_calls; supervisor uses plain `content` (no tools). */
|
||||
function installMockExtractThenSupervisor(params: {
|
||||
extractArgs: ReadonlyArray<Record<string, unknown>>;
|
||||
supervisorContent: string;
|
||||
onSupervisorCall?: () => void;
|
||||
}): () => void {
|
||||
const origFetch = globalThis.fetch;
|
||||
let extractI = 0;
|
||||
const mockFetch = async (
|
||||
_input: Parameters<typeof fetch>[0],
|
||||
init?: RequestInit,
|
||||
): Promise<Response> => {
|
||||
const body = init?.body ? (JSON.parse(String(init.body)) as Record<string, unknown>) : {};
|
||||
const tools = body.tools;
|
||||
const hasTools = Array.isArray(tools) && tools.length > 0;
|
||||
if (hasTools) {
|
||||
const args =
|
||||
params.extractArgs[extractI] ?? params.extractArgs[params.extractArgs.length - 1];
|
||||
if (args === undefined) {
|
||||
throw new Error("installMockExtractThenSupervisor: empty extractArgs");
|
||||
}
|
||||
extractI += 1;
|
||||
const firstTool = tools[0] as Record<string, unknown>;
|
||||
const fn = firstTool.function as Record<string, unknown> | undefined;
|
||||
const toolName = typeof fn?.name === "string" ? fn.name : "extract";
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
tool_calls: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: JSON.stringify(args),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
}
|
||||
params.onSupervisorCall?.();
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
choices: [{ message: { content: params.supervisorContent } }],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
};
|
||||
globalThis.fetch = Object.assign(mockFetch, {
|
||||
preconnect: origFetch.preconnect.bind(origFetch),
|
||||
}) as typeof fetch;
|
||||
return () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
}
|
||||
|
||||
const demoWorkflow = createWorkflow<DemoMeta>(
|
||||
{
|
||||
roles: {
|
||||
@@ -623,4 +715,102 @@ describe("executeThread", () => {
|
||||
await rm(root, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
test("supervisor stops thread when interval elapses and model returns stop", async () => {
|
||||
restoreFetch = installMockExtractThenSupervisor({
|
||||
extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }],
|
||||
supervisorContent: "stop",
|
||||
});
|
||||
|
||||
const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-stop-"));
|
||||
try {
|
||||
const threadId = "01KQXKW18CT8G75T53R8F4G7YG";
|
||||
const hash = "C9NMV6V2TQT81";
|
||||
const dataPath = join(root, "logs", hash, `${threadId}.data.jsonl`);
|
||||
const infoPath = join(root, "logs", hash, `${threadId}.info.jsonl`);
|
||||
await mkdir(join(root, "logs", hash), { recursive: true });
|
||||
await writeRegistryYaml(root, SUPERVISOR_INTERVAL_REGISTRY_YAML);
|
||||
const cas = createCasStore(join(root, "cas"));
|
||||
|
||||
const logger = createLogger({ sink: { kind: "file", path: infoPath } });
|
||||
const ac = new AbortController();
|
||||
|
||||
const result = await executeThread(
|
||||
demoWorkflow,
|
||||
"demo-flow",
|
||||
{ prompt: "supervisor-stop-case", steps: [] },
|
||||
{
|
||||
maxRounds: 20,
|
||||
depth: 0,
|
||||
signal: ac.signal,
|
||||
awaitAfterEachYield: async () => {},
|
||||
forkSourceThreadId: null,
|
||||
prefilledDiskSteps: null,
|
||||
storageRoot: root,
|
||||
},
|
||||
{ threadId, hash, dataJsonlPath: dataPath, infoJsonlPath: infoPath, cas },
|
||||
logger,
|
||||
);
|
||||
|
||||
expect(result.returnCode).toBe(0);
|
||||
expect(result.summary).toBe("completed: supervisor stopped thread");
|
||||
|
||||
const dataText = await readFile(dataPath, "utf8");
|
||||
const lines = dataText
|
||||
.trim()
|
||||
.split("\n")
|
||||
.filter((l) => l !== "");
|
||||
expect(lines.length).toBe(3);
|
||||
} finally {
|
||||
await rm(root, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
test("supervisor is not invoked before supervisorInterval rounds", async () => {
|
||||
let supervisorCalls = 0;
|
||||
restoreFetch = installMockExtractThenSupervisor({
|
||||
extractArgs: [{ plan: "do-it", files: ["a.ts"] }, { diff: "+ok" }],
|
||||
supervisorContent: "stop",
|
||||
onSupervisorCall: () => {
|
||||
supervisorCalls += 1;
|
||||
},
|
||||
});
|
||||
|
||||
const root = await mkdtemp(join(tmpdir(), "wf-engine-sup-skip-"));
|
||||
try {
|
||||
const threadId = "01KQXKW18CT8G75T53R8F4G7YG";
|
||||
const hash = "C9NMV6V2TQT81";
|
||||
const dataPath = join(root, "logs", hash, `${threadId}.data.jsonl`);
|
||||
const infoPath = join(root, "logs", hash, `${threadId}.info.jsonl`);
|
||||
await mkdir(join(root, "logs", hash), { recursive: true });
|
||||
await writeRegistryYaml(root, SUPERVISOR_LONG_INTERVAL_REGISTRY_YAML);
|
||||
const cas = createCasStore(join(root, "cas"));
|
||||
|
||||
const logger = createLogger({ sink: { kind: "file", path: infoPath } });
|
||||
const ac = new AbortController();
|
||||
|
||||
const result = await executeThread(
|
||||
demoWorkflow,
|
||||
"demo-flow",
|
||||
{ prompt: "no-supervisor-yet", steps: [] },
|
||||
{
|
||||
maxRounds: 20,
|
||||
depth: 0,
|
||||
signal: ac.signal,
|
||||
awaitAfterEachYield: async () => {},
|
||||
forkSourceThreadId: null,
|
||||
prefilledDiskSteps: null,
|
||||
storageRoot: root,
|
||||
},
|
||||
{ threadId, hash, dataJsonlPath: dataPath, infoJsonlPath: infoPath, cas },
|
||||
logger,
|
||||
);
|
||||
|
||||
expect(supervisorCalls).toBe(0);
|
||||
expect(result.returnCode).toBe(0);
|
||||
expect(result.summary).toBe("completed: moderator returned END");
|
||||
} finally {
|
||||
await rm(root, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -132,6 +132,65 @@ workflows:
|
||||
expect(r.value.config.providers.dashscope?.apiKey).toBe("secret-key");
|
||||
expect(r.value.config.models.extract).toBe("dashscope/qwen-plus");
|
||||
expect(r.value.config.models.default).toBe("dashscope/qwen-turbo");
|
||||
expect(r.value.config.supervisorInterval).toBe(3);
|
||||
});
|
||||
|
||||
test("defaults supervisorInterval to 3 when omitted", () => {
|
||||
const yaml = `
|
||||
config:
|
||||
maxDepth: 0
|
||||
providers:
|
||||
p:
|
||||
baseUrl: https://example.com
|
||||
apiKey: k
|
||||
models:
|
||||
default: p/m
|
||||
workflows: {}
|
||||
`;
|
||||
const r = parseWorkflowRegistryYaml(yaml);
|
||||
expect(r.ok).toBe(true);
|
||||
if (!r.ok || r.value.config === null) {
|
||||
return;
|
||||
}
|
||||
expect(r.value.config.supervisorInterval).toBe(3);
|
||||
});
|
||||
|
||||
test("parses explicit supervisorInterval", () => {
|
||||
const yaml = `
|
||||
config:
|
||||
maxDepth: 0
|
||||
supervisorInterval: 7
|
||||
providers:
|
||||
p:
|
||||
baseUrl: https://example.com
|
||||
apiKey: k
|
||||
models:
|
||||
default: p/m
|
||||
workflows: {}
|
||||
`;
|
||||
const r = parseWorkflowRegistryYaml(yaml);
|
||||
expect(r.ok).toBe(true);
|
||||
if (!r.ok || r.value.config === null) {
|
||||
return;
|
||||
}
|
||||
expect(r.value.config.supervisorInterval).toBe(7);
|
||||
});
|
||||
|
||||
test("parse errors when supervisorInterval is negative", () => {
|
||||
const yaml = `
|
||||
config:
|
||||
maxDepth: 0
|
||||
supervisorInterval: -1
|
||||
providers:
|
||||
p:
|
||||
baseUrl: https://example.com
|
||||
apiKey: k
|
||||
models:
|
||||
default: p/m
|
||||
workflows: {}
|
||||
`;
|
||||
const r = parseWorkflowRegistryYaml(yaml);
|
||||
expect(r.ok).toBe(false);
|
||||
});
|
||||
|
||||
test("parses config apiKey env: prefix from process.env", () => {
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { WorkflowConfig } from "../src/registry/index.js";
|
||||
function sampleConfig(): WorkflowConfig {
|
||||
return {
|
||||
maxDepth: 3,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
dashscope: {
|
||||
baseUrl: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
@@ -50,6 +51,7 @@ describe("resolveModel", () => {
|
||||
test("errs when scene missing and no default", () => {
|
||||
const config: WorkflowConfig = {
|
||||
maxDepth: 1,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
p: { baseUrl: "https://x", apiKey: "k" },
|
||||
},
|
||||
@@ -69,6 +71,7 @@ describe("resolveModel", () => {
|
||||
test("errs when provider is unknown", () => {
|
||||
const config: WorkflowConfig = {
|
||||
maxDepth: 1,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
p: { baseUrl: "https://x", apiKey: "k" },
|
||||
},
|
||||
@@ -87,6 +90,7 @@ describe("resolveModel", () => {
|
||||
test("errs on invalid model reference shape", () => {
|
||||
const config: WorkflowConfig = {
|
||||
maxDepth: 1,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
p: { baseUrl: "https://x", apiKey: "k" },
|
||||
},
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
import { afterEach, describe, expect, test } from "bun:test";
|
||||
|
||||
import { parseSupervisorDecisionText, runSupervisor } from "../src/engine/supervisor.js";
|
||||
import type { WorkflowConfig } from "../src/registry/index.js";
|
||||
import type { LogFn } from "../src/util/index.js";
|
||||
|
||||
const noopLogger: LogFn = () => {};
|
||||
|
||||
function supervisorOnlyConfig(): WorkflowConfig {
|
||||
return {
|
||||
maxDepth: 3,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
stub: { baseUrl: "http://127.0.0.1:9/v1", apiKey: "k" },
|
||||
},
|
||||
models: {
|
||||
extract: "stub/extract-model",
|
||||
supervisor: "stub/supervisor-model",
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("parseSupervisorDecisionText", () => {
|
||||
test("reads continue and stop case-insensitively", () => {
|
||||
expect(parseSupervisorDecisionText("continue")).toBe("continue");
|
||||
expect(parseSupervisorDecisionText("CONTINUE")).toBe("continue");
|
||||
expect(parseSupervisorDecisionText("stop")).toBe("stop");
|
||||
expect(parseSupervisorDecisionText("STOP.")).toBe("stop");
|
||||
});
|
||||
|
||||
test("finds token inside a sentence", () => {
|
||||
expect(parseSupervisorDecisionText("Answer: continue")).toBe("continue");
|
||||
expect(parseSupervisorDecisionText("I recommend stop now")).toBe("stop");
|
||||
});
|
||||
|
||||
test("when both appear, earlier token wins", () => {
|
||||
expect(parseSupervisorDecisionText("continue then stop")).toBe("continue");
|
||||
expect(parseSupervisorDecisionText("stop then continue")).toBe("stop");
|
||||
});
|
||||
|
||||
test("defaults to continue when unclear", () => {
|
||||
expect(parseSupervisorDecisionText("maybe later")).toBe("continue");
|
||||
});
|
||||
});
|
||||
|
||||
describe("runSupervisor", () => {
|
||||
let restoreFetch: (() => void) | null = null;
|
||||
|
||||
afterEach(() => {
|
||||
restoreFetch?.();
|
||||
restoreFetch = null;
|
||||
});
|
||||
|
||||
test("returns continue when supervisor model cannot be resolved (no fetch)", async () => {
|
||||
const origFetch = globalThis.fetch;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
globalThis.fetch = Object.assign(
|
||||
async () => {
|
||||
throw new Error("fetch should not run when supervisor is not configured");
|
||||
},
|
||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||
) as typeof fetch;
|
||||
|
||||
const config: WorkflowConfig = {
|
||||
maxDepth: 1,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
stub: { baseUrl: "http://127.0.0.1:9/v1", apiKey: "k" },
|
||||
},
|
||||
models: {
|
||||
extract: "stub/m",
|
||||
},
|
||||
};
|
||||
|
||||
const r = await runSupervisor({
|
||||
config,
|
||||
prompt: "task",
|
||||
recentSteps: [{ role: "planner", summary: "{}" }],
|
||||
logger: noopLogger,
|
||||
});
|
||||
expect(r.ok).toBe(true);
|
||||
if (!r.ok) {
|
||||
return;
|
||||
}
|
||||
expect(r.value).toBe("continue");
|
||||
});
|
||||
|
||||
test("returns stop from chat/completions assistant content", async () => {
|
||||
const origFetch = globalThis.fetch;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
globalThis.fetch = Object.assign(
|
||||
async () =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
choices: [{ message: { content: "stop" } }],
|
||||
}),
|
||||
{ status: 200, headers: { "Content-Type": "application/json" } },
|
||||
),
|
||||
{ preconnect: origFetch.preconnect.bind(origFetch) },
|
||||
) as typeof fetch;
|
||||
|
||||
const r = await runSupervisor({
|
||||
config: supervisorOnlyConfig(),
|
||||
prompt: "do X",
|
||||
recentSteps: [{ role: "a", summary: "{}" }],
|
||||
logger: noopLogger,
|
||||
});
|
||||
expect(r.ok).toBe(true);
|
||||
if (!r.ok) {
|
||||
return;
|
||||
}
|
||||
expect(r.value).toBe("stop");
|
||||
});
|
||||
|
||||
test("returns err on invalid JSON body", async () => {
|
||||
const origFetch = globalThis.fetch;
|
||||
restoreFetch = () => {
|
||||
globalThis.fetch = origFetch;
|
||||
};
|
||||
globalThis.fetch = Object.assign(async () => new Response("not-json", { status: 200 }), {
|
||||
preconnect: origFetch.preconnect.bind(origFetch),
|
||||
}) as typeof fetch;
|
||||
|
||||
const r = await runSupervisor({
|
||||
config: supervisorOnlyConfig(),
|
||||
prompt: "p",
|
||||
recentSteps: [],
|
||||
logger: noopLogger,
|
||||
});
|
||||
expect(r.ok).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -155,6 +155,7 @@ workflows: {}
|
||||
...reg.value,
|
||||
config: {
|
||||
maxDepth: 2,
|
||||
supervisorInterval: 3,
|
||||
providers: {
|
||||
local: {
|
||||
baseUrl: "http://127.0.0.1:9",
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
} from "../cas/index.js";
|
||||
import { resolveModel } from "../config/index.js";
|
||||
import { createExtract } from "../extract/index.js";
|
||||
import { readWorkflowRegistry } from "../registry/index.js";
|
||||
import { readWorkflowRegistry, type WorkflowConfig } from "../registry/index.js";
|
||||
import type {
|
||||
LlmProvider,
|
||||
ThreadInput,
|
||||
@@ -20,12 +20,18 @@ import type {
|
||||
} from "../types.js";
|
||||
import { err, type LogFn, normalizeRefsField, ok, type Result } from "../util/index.js";
|
||||
|
||||
import { runSupervisor } from "./supervisor.js";
|
||||
import type { ExecuteThreadIo, ExecuteThreadOptions } from "./types.js";
|
||||
|
||||
async function resolveExtractRuntime(
|
||||
storageRoot: string,
|
||||
): Promise<
|
||||
Result<{ extract: ReturnType<typeof createExtract>; llmProvider: LlmProvider }, string>
|
||||
async function resolveEngineRegistryRuntime(storageRoot: string): Promise<
|
||||
Result<
|
||||
{
|
||||
extract: ReturnType<typeof createExtract>;
|
||||
llmProvider: LlmProvider;
|
||||
workflowConfig: WorkflowConfig;
|
||||
},
|
||||
string
|
||||
>
|
||||
> {
|
||||
const reg = await readWorkflowRegistry(storageRoot);
|
||||
if (!reg.ok) {
|
||||
@@ -45,7 +51,7 @@ async function resolveExtractRuntime(
|
||||
apiKey: ex.apiKey,
|
||||
model: ex.model,
|
||||
};
|
||||
return ok({ extract: createExtract(llmProvider), llmProvider });
|
||||
return ok({ extract: createExtract(llmProvider), llmProvider, workflowConfig: cfg });
|
||||
}
|
||||
|
||||
async function appendDataLine(path: string, record: unknown): Promise<void> {
|
||||
@@ -79,9 +85,66 @@ async function finalizeThreadResult(params: {
|
||||
};
|
||||
}
|
||||
|
||||
async function finalizeAbortedThread(params: {
|
||||
cas: CasStore;
|
||||
workflowName: string;
|
||||
threadId: string;
|
||||
stepMerkleHashes: string[];
|
||||
logger: LogFn;
|
||||
abortLogTag: string;
|
||||
}): Promise<WorkflowResult> {
|
||||
params.logger(params.abortLogTag, `thread ${params.threadId} aborted`);
|
||||
return finalizeThreadResult({
|
||||
cas: params.cas,
|
||||
workflowName: params.workflowName,
|
||||
threadId: params.threadId,
|
||||
stepMerkleHashes: params.stepMerkleHashes,
|
||||
completion: { returnCode: 130, summary: "thread aborted" },
|
||||
});
|
||||
}
|
||||
|
||||
async function maybeSupervisorHaltsThread(params: {
|
||||
workflowConfig: WorkflowConfig;
|
||||
input: ThreadInput;
|
||||
written: number;
|
||||
recentSupervisorSteps: readonly { role: string; summary: string }[];
|
||||
logger: LogFn;
|
||||
threadId: string;
|
||||
cas: CasStore;
|
||||
workflowName: string;
|
||||
stepMerkleHashes: string[];
|
||||
}): Promise<WorkflowResult | null> {
|
||||
const interval = params.workflowConfig.supervisorInterval;
|
||||
if (interval <= 0 || params.written % interval !== 0) {
|
||||
return null;
|
||||
}
|
||||
const sup = await runSupervisor({
|
||||
config: params.workflowConfig,
|
||||
prompt: params.input.prompt,
|
||||
recentSteps: params.recentSupervisorSteps,
|
||||
logger: params.logger,
|
||||
});
|
||||
if (!sup.ok) {
|
||||
params.logger("K6PW9NYT", `supervisor skipped: ${sup.error}`);
|
||||
return null;
|
||||
}
|
||||
if (sup.value !== "stop") {
|
||||
return null;
|
||||
}
|
||||
params.logger("M4QX8VHN", `thread ${params.threadId} stopped by supervisor`);
|
||||
return finalizeThreadResult({
|
||||
cas: params.cas,
|
||||
workflowName: params.workflowName,
|
||||
threadId: params.threadId,
|
||||
stepMerkleHashes: params.stepMerkleHashes,
|
||||
completion: { returnCode: 0, summary: "completed: supervisor stopped thread" },
|
||||
});
|
||||
}
|
||||
|
||||
async function driveWorkflowGenerator(params: {
|
||||
fn: WorkflowFn;
|
||||
workflowName: string;
|
||||
workflowConfig: WorkflowConfig;
|
||||
input: ThreadInput;
|
||||
bundleOptions: WorkflowFnOptions;
|
||||
executeOptions: ExecuteThreadOptions;
|
||||
@@ -94,6 +157,7 @@ async function driveWorkflowGenerator(params: {
|
||||
const {
|
||||
fn,
|
||||
workflowName,
|
||||
workflowConfig,
|
||||
input,
|
||||
bundleOptions,
|
||||
executeOptions,
|
||||
@@ -105,16 +169,20 @@ async function driveWorkflowGenerator(params: {
|
||||
} = params;
|
||||
const gen = fn(input, bundleOptions);
|
||||
let written = 0;
|
||||
const recentSupervisorSteps: { role: string; summary: string }[] = input.steps.map((s) => ({
|
||||
role: s.role,
|
||||
summary: JSON.stringify(s.meta),
|
||||
}));
|
||||
|
||||
while (true) {
|
||||
if (executeOptions.signal.aborted) {
|
||||
logger("V8JX4NP2", `thread ${threadId} aborted`);
|
||||
return await finalizeThreadResult({
|
||||
return await finalizeAbortedThread({
|
||||
cas,
|
||||
workflowName,
|
||||
threadId,
|
||||
stepMerkleHashes,
|
||||
completion: { returnCode: 130, summary: "thread aborted" },
|
||||
logger,
|
||||
abortLogTag: "V8JX4NP2",
|
||||
});
|
||||
}
|
||||
|
||||
@@ -172,6 +240,11 @@ async function driveWorkflowGenerator(params: {
|
||||
|
||||
logger("N7BW4YHQ", `thread ${threadId} wrote role ${step.role}`);
|
||||
|
||||
recentSupervisorSteps.push({
|
||||
role: step.role,
|
||||
summary: JSON.stringify(step.meta),
|
||||
});
|
||||
|
||||
await Promise.race([
|
||||
executeOptions.awaitAfterEachYield(),
|
||||
new Promise<void>((resolve) => {
|
||||
@@ -184,15 +257,30 @@ async function driveWorkflowGenerator(params: {
|
||||
]);
|
||||
|
||||
if (executeOptions.signal.aborted) {
|
||||
logger("V8JX4NP4", `thread ${threadId} aborted`);
|
||||
return await finalizeThreadResult({
|
||||
return await finalizeAbortedThread({
|
||||
cas,
|
||||
workflowName,
|
||||
threadId,
|
||||
stepMerkleHashes,
|
||||
completion: { returnCode: 130, summary: "thread aborted" },
|
||||
logger,
|
||||
abortLogTag: "V8JX4NP4",
|
||||
});
|
||||
}
|
||||
|
||||
const supervised = await maybeSupervisorHaltsThread({
|
||||
workflowConfig,
|
||||
input,
|
||||
written,
|
||||
recentSupervisorSteps,
|
||||
logger,
|
||||
threadId,
|
||||
cas,
|
||||
workflowName,
|
||||
stepMerkleHashes,
|
||||
});
|
||||
if (supervised !== null) {
|
||||
return supervised;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,9 +368,9 @@ export async function executeThread(
|
||||
});
|
||||
}
|
||||
|
||||
const extractRuntime = await resolveExtractRuntime(options.storageRoot);
|
||||
if (!extractRuntime.ok) {
|
||||
throw new Error(extractRuntime.error);
|
||||
const registryRuntime = await resolveEngineRegistryRuntime(options.storageRoot);
|
||||
if (!registryRuntime.ok) {
|
||||
throw new Error(registryRuntime.error);
|
||||
}
|
||||
|
||||
const bundleOptions: WorkflowFnOptions = {
|
||||
@@ -290,13 +378,14 @@ export async function executeThread(
|
||||
maxRounds: options.maxRounds,
|
||||
depth: options.depth,
|
||||
cas: io.cas,
|
||||
extract: extractRuntime.value.extract,
|
||||
llmProvider: extractRuntime.value.llmProvider,
|
||||
extract: registryRuntime.value.extract,
|
||||
llmProvider: registryRuntime.value.llmProvider,
|
||||
};
|
||||
|
||||
return await driveWorkflowGenerator({
|
||||
fn,
|
||||
workflowName,
|
||||
workflowConfig: registryRuntime.value.workflowConfig,
|
||||
input,
|
||||
bundleOptions,
|
||||
executeOptions: options,
|
||||
|
||||
@@ -17,6 +17,7 @@ export type {
|
||||
GcResult,
|
||||
ParsedThreadStartRecord,
|
||||
PrefilledDiskStep,
|
||||
SupervisorDecision,
|
||||
ThreadPauseGate,
|
||||
} from "./types.js";
|
||||
export { getWorkerHostScriptPath } from "./worker-entry-path.js";
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import { resolveModel } from "../config/index.js";
|
||||
import type { WorkflowConfig } from "../registry/index.js";
|
||||
import { err, type LogFn, ok, type Result } from "../util/index.js";
|
||||
|
||||
import type { SupervisorDecision } from "./types.js";
|
||||
|
||||
const SUPERVISOR_RECENT_STEP_LIMIT = 12;
|
||||
|
||||
function chatCompletionsUrl(baseUrl: string): string {
|
||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
||||
return `${trimmed}/chat/completions`;
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function readAssistantContent(parsed: unknown): string | null {
|
||||
if (!isRecord(parsed)) {
|
||||
return null;
|
||||
}
|
||||
const choices = parsed.choices;
|
||||
if (!Array.isArray(choices) || choices.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const first = choices[0];
|
||||
if (!isRecord(first)) {
|
||||
return null;
|
||||
}
|
||||
const messageObj = first.message;
|
||||
if (!isRecord(messageObj)) {
|
||||
return null;
|
||||
}
|
||||
const content = messageObj.content;
|
||||
if (typeof content !== "string") {
|
||||
return null;
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
/** Lenient: accepts STOP/stop/stop. as prose; prefers {@link SupervisorDecision.stop} when both tokens appear. */
|
||||
export function parseSupervisorDecisionText(text: string): SupervisorDecision {
|
||||
const lower = text.toLowerCase();
|
||||
const stopWord = /\bstop\b/.test(lower);
|
||||
const continueWord = /\bcontinue\b/.test(lower);
|
||||
if (stopWord && continueWord) {
|
||||
const si = lower.search(/\bstop\b/);
|
||||
const ci = lower.search(/\bcontinue\b/);
|
||||
return si <= ci ? "stop" : "continue";
|
||||
}
|
||||
if (stopWord) {
|
||||
return "stop";
|
||||
}
|
||||
if (continueWord) {
|
||||
return "continue";
|
||||
}
|
||||
if (lower.includes("stop")) {
|
||||
return "stop";
|
||||
}
|
||||
if (lower.includes("continue")) {
|
||||
return "continue";
|
||||
}
|
||||
return "continue";
|
||||
}
|
||||
|
||||
type RunSupervisorArgs = {
|
||||
config: WorkflowConfig;
|
||||
prompt: string;
|
||||
recentSteps: readonly { role: string; summary: string }[];
|
||||
logger: LogFn;
|
||||
};
|
||||
|
||||
/** Calls the `supervisor` scene LLM; opt-out when {@link resolveModel} fails (returns ok(`continue`)). */
|
||||
export async function runSupervisor(
|
||||
args: RunSupervisorArgs,
|
||||
): Promise<Result<SupervisorDecision, string>> {
|
||||
const resolved = resolveModel(args.config, "supervisor");
|
||||
if (!resolved.ok) {
|
||||
return ok("continue");
|
||||
}
|
||||
const provider = resolved.value;
|
||||
const recent = args.recentSteps.slice(-SUPERVISOR_RECENT_STEP_LIMIT);
|
||||
const stepsBlock = recent.map((s, index) => `${index + 1}. [${s.role}] ${s.summary}`).join("\n");
|
||||
|
||||
const body = {
|
||||
model: provider.model,
|
||||
messages: [
|
||||
{
|
||||
role: "system" as const,
|
||||
content:
|
||||
'You supervise a multi-step workflow. Decide if the thread should keep running or halt.\n\nReply with exactly one token: either "continue" (progress toward the goal, not obviously stuck) or "stop" (done, looping, or no progress). Do not add explanation.',
|
||||
},
|
||||
{
|
||||
role: "user" as const,
|
||||
content: `Original task:\n${args.prompt}\n\nRecent steps (oldest first):\n${stepsBlock === "" ? "(none)" : stepsBlock}`,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(chatCompletionsUrl(provider.baseUrl), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${provider.apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
} catch (cause) {
|
||||
const message = cause instanceof Error ? cause.message : String(cause);
|
||||
args.logger("R9CW4PLM", `supervisor request failed: ${message}`);
|
||||
return err(`supervisor network error: ${message}`);
|
||||
}
|
||||
|
||||
const responseText = await response.text();
|
||||
if (!response.ok) {
|
||||
args.logger("T3HN8VKQ", `supervisor HTTP ${response.status}: ${responseText.slice(0, 200)}`);
|
||||
return err(`supervisor HTTP ${response.status}: ${responseText.slice(0, 500)}`);
|
||||
}
|
||||
|
||||
let parsed: unknown;
|
||||
try {
|
||||
parsed = JSON.parse(responseText) as unknown;
|
||||
} catch (cause) {
|
||||
const message = cause instanceof Error ? cause.message : String(cause);
|
||||
args.logger("W7BQ2NXM", `supervisor response is not JSON: ${message}`);
|
||||
return err(`supervisor invalid JSON: ${message}`);
|
||||
}
|
||||
|
||||
const content = readAssistantContent(parsed);
|
||||
if (content === null || content.trim() === "") {
|
||||
args.logger("Y4JX9PKW", "supervisor returned empty assistant content");
|
||||
return err("supervisor empty assistant content");
|
||||
}
|
||||
|
||||
const decision = parseSupervisorDecisionText(content);
|
||||
args.logger("Z8KM5QWT", `supervisor says ${decision}`);
|
||||
return ok(decision);
|
||||
}
|
||||
@@ -2,6 +2,8 @@ import type { CasStore } from "../cas/index.js";
|
||||
import type { RoleOutput } from "../types.js";
|
||||
import type { Result } from "../util/index.js";
|
||||
|
||||
export type SupervisorDecision = "continue" | "stop";
|
||||
|
||||
export type ExecuteThreadIo = {
|
||||
threadId: string;
|
||||
hash: string;
|
||||
|
||||
@@ -47,6 +47,7 @@ export {
|
||||
type ParsedThreadStartRecord,
|
||||
type PrefilledDiskStep,
|
||||
parseThreadDataJsonl,
|
||||
type SupervisorDecision,
|
||||
selectForkHistoricalSteps,
|
||||
type ThreadPauseGate,
|
||||
tryParseRoleStepRecord,
|
||||
|
||||
@@ -110,11 +110,23 @@ function normalizeWorkflowConfig(raw: unknown): Result<WorkflowConfig, Error> {
|
||||
}
|
||||
const c = raw as Record<string, unknown>;
|
||||
const maxDepth = c.maxDepth;
|
||||
const supervisorIntervalRaw = c.supervisorInterval;
|
||||
const providersRaw = c.providers;
|
||||
const modelsRaw = c.models;
|
||||
if (typeof maxDepth !== "number" || !Number.isInteger(maxDepth) || maxDepth < 0) {
|
||||
return err(new Error("config.maxDepth must be a non-negative integer"));
|
||||
}
|
||||
let supervisorInterval = 3;
|
||||
if (supervisorIntervalRaw !== undefined) {
|
||||
if (
|
||||
typeof supervisorIntervalRaw !== "number" ||
|
||||
!Number.isInteger(supervisorIntervalRaw) ||
|
||||
supervisorIntervalRaw < 0
|
||||
) {
|
||||
return err(new Error("config.supervisorInterval must be a non-negative integer"));
|
||||
}
|
||||
supervisorInterval = supervisorIntervalRaw;
|
||||
}
|
||||
const providersResult = normalizeProviders(providersRaw);
|
||||
if (!providersResult.ok) {
|
||||
return providersResult;
|
||||
@@ -125,6 +137,7 @@ function normalizeWorkflowConfig(raw: unknown): Result<WorkflowConfig, Error> {
|
||||
}
|
||||
return ok({
|
||||
maxDepth,
|
||||
supervisorInterval,
|
||||
providers: providersResult.value,
|
||||
models: modelsResult.value,
|
||||
});
|
||||
|
||||
@@ -13,6 +13,8 @@ export type WorkflowRegistryEntry = {
|
||||
|
||||
export type WorkflowConfig = {
|
||||
maxDepth: number;
|
||||
/** Run supervisor LLM every N completed role rounds (0 = disabled). Default from YAML: 3. */
|
||||
supervisorInterval: number;
|
||||
providers: Record<string, ProviderConfig>;
|
||||
models: Record<string, string>;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user