diff --git a/packages/dashboard/src/components/thread-detail.tsx b/packages/dashboard/src/components/thread-detail.tsx index 27ddd0c..08d62e9 100644 --- a/packages/dashboard/src/components/thread-detail.tsx +++ b/packages/dashboard/src/components/thread-detail.tsx @@ -101,9 +101,9 @@ export function ThreadDetail({ threadId, onBack }: Props) { )} {(status === "ok" || liveActive || records.length > 0) && (
- {records.map((r, i) => ( + {records.map((r) => (
diff --git a/packages/workflow-template-develop/__tests__/develop-template.test.ts b/packages/workflow-template-develop/__tests__/develop-template.test.ts index 0b9ff36..f67adba 100644 --- a/packages/workflow-template-develop/__tests__/develop-template.test.ts +++ b/packages/workflow-template-develop/__tests__/develop-template.test.ts @@ -1,11 +1,6 @@ import { describe, expect, test } from "bun:test"; -import { - END, - type ModeratorContext, - type RoleStep, - START, - validateWorkflowDescriptor, -} from "@uncaged/workflow-runtime"; +import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; +import { validateWorkflowDescriptor } from "@uncaged/workflow"; import { buildDevelopDescriptor } from "../src/descriptor.js"; import { developModerator } from "../src/index.js"; import type { CommitterMeta, PlannerMeta } from "../src/roles/index.js"; diff --git a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts index b10ff6f..6fb6e7f 100644 --- a/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts +++ b/packages/workflow-template-solve-issue/__tests__/solve-issue-template.test.ts @@ -2,14 +2,8 @@ import { afterEach, describe, expect, test } from "bun:test"; import { mkdtemp, rm } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { createCasStore, createExtract, createWorkflow } from "@uncaged/workflow"; -import { - END, - type ModeratorContext, - type RoleStep, - START, - validateWorkflowDescriptor, -} from "@uncaged/workflow-runtime"; +import { createCasStore, createExtract, createWorkflow, validateWorkflowDescriptor } from "@uncaged/workflow"; +import { END, type ModeratorContext, type RoleStep, START } from "@uncaged/workflow-runtime"; import { buildSolveIssueDescriptor } from "../src/descriptor.js"; import type { DeveloperMeta } from "../src/developer.js"; import { solveIssueModerator, solveIssueWorkflowDefinition } from "../src/index.js"; diff --git a/packages/workflow/README.md b/packages/workflow/README.md index fe84528..db319e3 100644 --- a/packages/workflow/README.md +++ b/packages/workflow/README.md @@ -29,7 +29,7 @@ import { createWorkflow, readWorkflowRegistry, executeThread } from "@uncaged/wo | **Registry** | `readWorkflowRegistry`, `writeWorkflowRegistry`, `registerWorkflowVersion`, `workflowRegistryPath`, YAML helpers | | **CAS** | `createCasStore`, Merkle helpers (`putStepMerkleNode`, `getContentMerklePayload`, …), `hashWorkflowBundleBytes` | | **Engine** | `createWorkflow`, `executeThread`, `parseThreadDataJsonl`, fork helpers, `garbageCollectCas` | -| **Extract / LLM tools** | `llmExtract`, `reactExtract`, `createExtract`, `getExtractProvider` | +| **Extract / LLM tools** | `llmExtract`, `createExtract`, `createThreadReactor`, `createLlmFn`, `getExtractProvider` | | **Agent bridge** | `workflowAsAgent` — expose a registered workflow as an agent-backed role | | **Utilities** | `createLogger`, ULID / Crockford Base32 codecs, `getDefaultWorkflowStorageRoot`, paths | diff --git a/packages/workflow/__tests__/react-extract.test.ts b/packages/workflow/__tests__/thread-reactor.test.ts similarity index 70% rename from packages/workflow/__tests__/react-extract.test.ts rename to packages/workflow/__tests__/thread-reactor.test.ts index ed3faa1..c880df1 100644 --- a/packages/workflow/__tests__/react-extract.test.ts +++ b/packages/workflow/__tests__/thread-reactor.test.ts @@ -6,7 +6,8 @@ import type { LlmProvider } from "@uncaged/workflow-runtime"; import * as z from "zod/v4"; import { createCasStore } from "../src/cas/cas.js"; import { createContentMerkleNode, serializeMerkleNode } from "../src/cas/merkle.js"; -import { reactExtract } from "../src/extract/react-extract.js"; +import { extractFunctionToolFromZodSchema } from "../src/extract/llm-extract.js"; +import { createLlmFn, createThreadReactor } from "../src/reactor/index.js"; const metaSchema = z.object({ seen: z.string() }); @@ -16,7 +17,57 @@ const provider: LlmProvider = { model: "test", }; -describe("reactExtract", () => { +const CAS_GET_TOOL_DEFINITION = { + type: "function" as const, + function: { + name: "cas_get", + description: "Read CAS node", + parameters: { + type: "object", + properties: { + hash: { type: "string", description: "hash" }, + }, + required: ["hash"], + }, + }, +}; + +type ThreadCtx = { cas: ReturnType }; + +function createTestReactor() { + const llm = createLlmFn(provider); + return createThreadReactor({ + llm, + maxRounds: 10, + staticTools: [CAS_GET_TOOL_DEFINITION], + structuredToolFromSchema: (schema) => { + const t = extractFunctionToolFromZodSchema(schema); + return { + name: t.name, + tool: { + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }, + }; + }, + systemPromptForStructuredTool: (structuredToolName) => + `Extract metadata. Use cas_get when needed. Call ${structuredToolName} with JSON args matching the schema, or reply with plain JSON.`, + toolHandler: async (call, thread) => { + if (call.function.name !== "cas_get") { + return `unexpected tool ${call.function.name}`; + } + const ta = JSON.parse(call.function.arguments) as { hash: string }; + const blob = await thread.cas.get(ta.hash); + return blob === null ? "null" : blob; + }, + }); +} + +describe("createThreadReactor (extract-shaped)", () => { let restoreFetch: (() => void) | null = null; afterEach(() => { @@ -25,7 +76,7 @@ describe("reactExtract", () => { }); test("cas_get rounds then extract tool yields validated meta", async () => { - const casDir = await mkdtemp(join(tmpdir(), "react-extract-")); + const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-")); const cas = createCasStore(casDir); try { const blob = serializeMerkleNode(createContentMerkleNode("needle")); @@ -87,12 +138,12 @@ describe("reactExtract", () => { { preconnect: origFetch.preconnect.bind(origFetch) }, ) as typeof fetch; + const reactor = createTestReactor(); const text = `## Agent Output\n${h}\n## Extraction Instruction\nExtract seen from CAS.`; - const result = await reactExtract({ - text, + const result = await reactor({ + thread: { cas }, + input: text, schema: metaSchema, - provider, - cas, }); expect(result.ok).toBe(true); @@ -107,7 +158,7 @@ describe("reactExtract", () => { }); test("stops after max tool rounds when model keeps calling cas_get", async () => { - const casDir = await mkdtemp(join(tmpdir(), "react-extract-max-")); + const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-max-")); const cas = createCasStore(casDir); try { const blob = serializeMerkleNode(createContentMerkleNode("x")); @@ -146,11 +197,11 @@ describe("reactExtract", () => { { preconnect: origFetch.preconnect.bind(origFetch) }, ) as typeof fetch; - const result = await reactExtract({ - text: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.", + const reactor = createTestReactor(); + const result = await reactor({ + thread: { cas }, + input: "## Agent Output\nnoop\n## Extraction Instruction\nExtract seen.", schema: metaSchema, - provider, - cas, }); expect(result.ok).toBe(false); @@ -165,7 +216,7 @@ describe("reactExtract", () => { }); test("passthrough JSON assistant message without tool calls", async () => { - const casDir = await mkdtemp(join(tmpdir(), "react-extract-pass-")); + const casDir = await mkdtemp(join(tmpdir(), "thread-reactor-pass-")); const cas = createCasStore(casDir); try { const origFetch = globalThis.fetch; @@ -189,11 +240,11 @@ describe("reactExtract", () => { { preconnect: origFetch.preconnect.bind(origFetch) }, ) as typeof fetch; - const result = await reactExtract({ - text: "## Agent Output\nok\n## Extraction Instruction\nExtract.", + const reactor = createTestReactor(); + const result = await reactor({ + thread: { cas }, + input: "## Agent Output\nok\n## Extraction Instruction\nExtract.", schema: metaSchema, - provider, - cas, }); expect(result.ok).toBe(true); diff --git a/packages/workflow/src/extract/extract-fn.ts b/packages/workflow/src/extract/extract-fn.ts index 5f79380..b278562 100644 --- a/packages/workflow/src/extract/extract-fn.ts +++ b/packages/workflow/src/extract/extract-fn.ts @@ -1,12 +1,39 @@ import type { ExtractContext, ExtractFn, LlmProvider } from "@uncaged/workflow-runtime"; import type * as z from "zod/v4"; import { type CasStore, getContentMerklePayload } from "../cas/index.js"; -import { reactExtract } from "./react-extract.js"; +import { createLlmFn, createThreadReactor } from "../reactor/index.js"; +import { extractFunctionToolFromZodSchema } from "./llm-extract.js"; export type ExtractDeps = { cas: CasStore; }; +const MAX_REACT_ROUNDS = 10; + +const CAS_GET_TOOL_DEFINITION = { + type: "function" as const, + function: { + name: "cas_get", + description: + "Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.", + parameters: { + type: "object", + properties: { + hash: { type: "string", description: "The CAS hash to retrieve" }, + }, + required: ["hash"], + }, + }, +}; + +export type ExtractThreadContext = { + cas: CasStore; +}; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + /** Builds the user-side extraction prompt (thread + agent output + instruction). */ export async function buildExtractUserContent( ctx: ExtractContext, @@ -46,17 +73,61 @@ export async function buildExtractUserContent( * Create an ExtractFn backed by an LLM provider. * * Internally runs a multi-turn ReAct loop with two tools (`cas_get` for traversing the - * Merkle DAG and a schema-shaped `extract` tool); the loop also accepts a plain-JSON + * Merkle DAG and a schema-shaped extract tool); the loop also accepts a plain-JSON * assistant reply as a short-circuit, which covers the legacy "single" extraction path. */ export function createExtract(provider: LlmProvider, deps: ExtractDeps): ExtractFn { + const llm = createLlmFn(provider); + const reactor = createThreadReactor({ + llm, + maxRounds: MAX_REACT_ROUNDS, + staticTools: [CAS_GET_TOOL_DEFINITION], + structuredToolFromSchema: (schema) => { + const t = extractFunctionToolFromZodSchema(schema); + return { + name: t.name, + tool: { + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }, + }; + }, + systemPromptForStructuredTool: (structuredToolName) => + `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${structuredToolName} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`, + toolHandler: async (call, thread) => { + if (call.function.name !== "cas_get") { + return `Unexpected tool routed to handler: ${call.function.name}`; + } + let hash: string; + try { + const ta = JSON.parse(call.function.arguments) as unknown; + if (!isRecord(ta) || typeof ta.hash !== "string") { + return 'cas_get requires a JSON object with a string "hash" field.'; + } + hash = ta.hash; + } catch { + return 'cas_get arguments were not valid JSON. Provide {"hash": ""}.'; + } + const blob = await thread.cas.get(hash); + return blob === null ? "null" : blob; + }, + }); + return async >( schema: z.ZodType, prompt: string, ctx: ExtractContext, ): Promise => { const text = await buildExtractUserContent(ctx, prompt, deps); - const result = await reactExtract({ text, schema, provider, cas: deps.cas }); + const result = await reactor({ + thread: { cas: deps.cas }, + input: text, + schema, + }); if (!result.ok) { throw new Error(`extract failed: ${result.error}`); } diff --git a/packages/workflow/src/extract/index.ts b/packages/workflow/src/extract/index.ts index cf502d8..e1069af 100644 --- a/packages/workflow/src/extract/index.ts +++ b/packages/workflow/src/extract/index.ts @@ -1,16 +1,11 @@ export { buildExtractUserContent, createExtract, + type ExtractThreadContext, } from "./extract-fn.js"; export { extractFunctionToolFromZodSchema, llmErrorToCause, llmExtract, } from "./llm-extract.js"; -export { reactExtract } from "./react-extract.js"; -export type { - ExtractFn, - LlmError, - LlmExtractArgs, - ReactExtractArgs, -} from "./types.js"; +export type { ExtractFn, LlmError, LlmExtractArgs } from "./types.js"; diff --git a/packages/workflow/src/extract/react-extract.ts b/packages/workflow/src/extract/react-extract.ts deleted file mode 100644 index 5c0c456..0000000 --- a/packages/workflow/src/extract/react-extract.ts +++ /dev/null @@ -1,343 +0,0 @@ -import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime"; -import type * as z from "zod/v4"; -import { err, ok, type Result } from "../util/index.js"; - -import { extractFunctionToolFromZodSchema } from "./llm-extract.js"; -import type { ReactExtractArgs } from "./types.js"; - -const MAX_REACT_ROUNDS = 10; - -const CAS_GET_TOOL_DEFINITION = { - type: "function" as const, - function: { - name: "cas_get", - description: - "Read a Merkle DAG node from content-addressed storage by its hash. Returns YAML-formatted node with type, payload, and children fields.", - parameters: { - type: "object", - properties: { - hash: { type: "string", description: "The CAS hash to retrieve" }, - }, - required: ["hash"], - }, - }, -}; - -function chatCompletionsUrl(baseUrl: string): string { - const trimmed = baseUrl.replace(/\/+$/, ""); - return `${trimmed}/chat/completions`; -} - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null && !Array.isArray(value); -} - -function tryParseJsonContent(content: string): unknown | null { - const trimmed = content.trim(); - const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed); - const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed; - try { - return JSON.parse(payload) as unknown; - } catch { - return null; - } -} - -type ToolCall = { - id: string; - type: "function"; - function: { name: string; arguments: string }; -}; - -type ChatMessage = - | { role: "system"; content: string } - | { role: "user"; content: string } - | { - role: "assistant"; - content: string | null; - tool_calls: ToolCall[]; - } - | { role: "assistant"; content: string } - | { role: "tool"; tool_call_id: string; content: string }; - -type AssistantTurn = - | { kind: "plain_json"; value: T } - | { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null }; - -function firstAssistantMessage(responseText: string): Result, string> { - let parsed: unknown; - try { - parsed = JSON.parse(responseText) as unknown; - } catch (cause) { - const message = cause instanceof Error ? cause.message : String(cause); - return err(`invalid_response_json:${message}`); - } - if (!isRecord(parsed)) { - return err("invalid_response_top_level"); - } - const choices = parsed.choices; - if (!Array.isArray(choices) || choices.length === 0) { - return err("no_choices_in_response"); - } - const firstChoice = choices[0]; - if (!isRecord(firstChoice)) { - return err("invalid_choice"); - } - const messageObj = firstChoice.message; - if (!isRecord(messageObj)) { - return err("invalid_message"); - } - return ok(messageObj); -} - -function normalizeToolCalls(toolCallsRaw: unknown[]): Result { - const toolCalls: ToolCall[] = []; - for (const tc of toolCallsRaw) { - if (!isRecord(tc)) { - return err("invalid_tool_call"); - } - const id = tc.id; - const tcType = tc.type; - const fn = tc.function; - if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) { - return err("invalid_tool_call_shape"); - } - const name = fn.name; - const argumentsStr = fn.arguments; - if (typeof name !== "string" || typeof argumentsStr !== "string") { - return err("invalid_tool_call_function"); - } - toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } }); - } - return ok(toolCalls); -} - -type AssistantTurnOrCorrection> = - | AssistantTurn - | { kind: "plain_json_invalid"; rawContent: string; correction: string }; - -function classifyAssistantTurn>( - messageObj: Record, - schema: z.ZodType, -): Result, string> { - const toolCallsRaw = messageObj.tool_calls; - if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) { - const content = messageObj.content; - if (typeof content !== "string") { - return err("no_tool_calls_and_no_string_content"); - } - const jsonParsed = tryParseJsonContent(content); - if (jsonParsed === null) { - return ok({ - kind: "plain_json_invalid", - rawContent: content, - correction: - "Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the extract tool with the structured arguments.", - }); - } - const validated = schema.safeParse(jsonParsed); - if (!validated.success) { - return ok({ - kind: "plain_json_invalid", - rawContent: content, - correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the extract tool with the structured arguments.`, - }); - } - return ok({ kind: "plain_json", value: validated.data }); - } - const callsResult = normalizeToolCalls(toolCallsRaw); - if (!callsResult.ok) { - return err(callsResult.error); - } - const assistantContent = messageObj.content; - return ok({ - kind: "tool_calls", - calls: callsResult.value, - assistantContent: typeof assistantContent === "string" ? assistantContent : null, - }); -} - -async function appendCasGetToolResult( - tc: ToolCall, - cas: CasStore, - messages: ChatMessage[], -): Promise> { - let hash: string; - try { - const ta = JSON.parse(tc.function.arguments) as unknown; - if (!isRecord(ta) || typeof ta.hash !== "string") { - return err("cas_get_invalid_arguments"); - } - hash = ta.hash; - } catch { - return err("cas_get_arguments_not_json"); - } - const blob = await cas.get(hash); - const toolContent = blob === null ? "null" : blob; - messages.push({ - role: "tool", - tool_call_id: tc.id, - content: toolContent, - }); - return ok(null); -} - -async function appendExtractToolResult>( - tc: ToolCall, - schema: z.ZodType, - messages: ChatMessage[], -): Promise> { - let parsedArgs: unknown; - try { - parsedArgs = JSON.parse(tc.function.arguments) as unknown; - } catch { - return err("extract_tool_arguments_not_json"); - } - const validated = schema.safeParse(parsedArgs); - if (!validated.success) { - return err(`schema_validation_failed:${validated.error.message}`); - } - messages.push({ - role: "tool", - tool_call_id: tc.id, - content: '{"ok":true}', - }); - return ok(validated.data); -} - -async function appendToolResults>( - toolCalls: ToolCall[], - extractToolName: string, - schema: z.ZodType, - cas: CasStore, - messages: ChatMessage[], -): Promise> { - let extracted: T | null = null; - for (const tc of toolCalls) { - if (tc.function.name === "cas_get") { - const casRes = await appendCasGetToolResult(tc, cas, messages); - if (!casRes.ok) { - return casRes; - } - continue; - } - if (tc.function.name === extractToolName) { - const exRes = await appendExtractToolResult(tc, schema, messages); - if (!exRes.ok) { - return exRes; - } - extracted = exRes.value; - continue; - } - return err(`unknown_tool:${tc.function.name}`); - } - return ok(extracted); -} - -async function postChatCompletion( - provider: LlmProvider, - messages: ChatMessage[], - tools: readonly Record[], -): Promise> { - try { - const response = await fetch(chatCompletionsUrl(provider.baseUrl), { - method: "POST", - headers: { - Authorization: `Bearer ${provider.apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: provider.model, - messages, - tools, - tool_choice: "auto", - }), - }); - const responseText = await response.text(); - if (!response.ok) { - return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`); - } - return ok(responseText); - } catch (cause) { - const message = cause instanceof Error ? cause.message : String(cause); - return err(`network_error:${message}`); - } -} - -/** - * Multi-turn ReAct extraction with `cas_get` plus a schema-shaped extract tool (OpenAI-compatible). - * Final meta comes from a successful extract tool call or from plain JSON in the assistant message. - */ -export async function reactExtract>( - args: ReactExtractArgs, -): Promise> { - const extractTool = extractFunctionToolFromZodSchema(args.schema); - const tools = [ - CAS_GET_TOOL_DEFINITION, - { - type: "function" as const, - function: { - name: extractTool.name, - description: extractTool.description, - parameters: extractTool.parameters, - }, - }, - ]; - - const systemContent = `You extract structured metadata from the agent output below. Use cas_get to read Merkle DAG nodes from CAS (YAML: type, payload, children) when the agent output references hashes you must traverse. When you have the complete structured object, call the ${extractTool.name} tool with JSON arguments matching the schema. You may instead reply with only a JSON object (no prose) when no tools are needed.`; - - const messages: ChatMessage[] = [ - { role: "system", content: systemContent }, - { role: "user", content: args.text }, - ]; - - for (let round = 0; round < MAX_REACT_ROUNDS; round++) { - const bodyResult = await postChatCompletion(args.provider, messages, tools); - if (!bodyResult.ok) { - return bodyResult; - } - - const msgResult = firstAssistantMessage(bodyResult.value); - if (!msgResult.ok) { - return msgResult; - } - - const classified = classifyAssistantTurn(msgResult.value, args.schema); - if (!classified.ok) { - return classified; - } - - const turn = classified.value; - if (turn.kind === "plain_json") { - return ok(turn.value); - } - - if (turn.kind === "plain_json_invalid") { - messages.push({ role: "assistant", content: turn.rawContent }); - messages.push({ role: "user", content: turn.correction }); - continue; - } - - messages.push({ - role: "assistant", - content: turn.assistantContent, - tool_calls: turn.calls, - }); - - const toolsRound = await appendToolResults( - turn.calls, - extractTool.name, - args.schema, - args.cas, - messages, - ); - if (!toolsRound.ok) { - return toolsRound; - } - if (toolsRound.value !== null) { - return ok(toolsRound.value); - } - } - - return err("max_react_rounds_exceeded"); -} diff --git a/packages/workflow/src/extract/types.ts b/packages/workflow/src/extract/types.ts index 016ab60..c5bf283 100644 --- a/packages/workflow/src/extract/types.ts +++ b/packages/workflow/src/extract/types.ts @@ -1,15 +1,8 @@ -import type { CasStore, LlmProvider } from "@uncaged/workflow-runtime"; +import type { LlmProvider } from "@uncaged/workflow-runtime"; import type * as z from "zod/v4"; export type { ExtractFn } from "@uncaged/workflow-runtime"; -export type ReactExtractArgs> = { - text: string; - schema: z.ZodType; - provider: LlmProvider; - cas: CasStore; -}; - export type LlmExtractArgs = { text: string; schema: z.ZodType; diff --git a/packages/workflow/src/index.ts b/packages/workflow/src/index.ts index 70897a1..896f865 100644 --- a/packages/workflow/src/index.ts +++ b/packages/workflow/src/index.ts @@ -56,12 +56,23 @@ export { export { createExtract, type ExtractFn, + type ExtractThreadContext, type LlmError, llmErrorToCause, llmExtract, - type ReactExtractArgs, - reactExtract, } from "./extract/index.js"; +export { + type ChatMessage, + createLlmFn, + createThreadReactor, + type LlmFn, + type StructuredToolSpec, + type ThreadReactorConfig, + type ThreadReactorFn, + type ThreadReactorInvokeArgs, + type ToolCall, + type ToolDefinition, +} from "./reactor/index.js"; export { getRegisteredWorkflow, listRegisteredWorkflowNames, diff --git a/packages/workflow/src/reactor/index.ts b/packages/workflow/src/reactor/index.ts new file mode 100644 index 0000000..80f2ab2 --- /dev/null +++ b/packages/workflow/src/reactor/index.ts @@ -0,0 +1,12 @@ +export { createLlmFn } from "./llm-fn.js"; +export { createThreadReactor } from "./thread-reactor.js"; +export type { + ChatMessage, + LlmFn, + StructuredToolSpec, + ThreadReactorConfig, + ThreadReactorFn, + ThreadReactorInvokeArgs, + ToolCall, + ToolDefinition, +} from "./types.js"; diff --git a/packages/workflow/src/reactor/llm-fn.ts b/packages/workflow/src/reactor/llm-fn.ts new file mode 100644 index 0000000..fd5c911 --- /dev/null +++ b/packages/workflow/src/reactor/llm-fn.ts @@ -0,0 +1,48 @@ +import type { LlmProvider } from "@uncaged/workflow-runtime"; + +import { err, ok } from "../util/index.js"; + +import type { ChatMessage, LlmFn, ToolDefinition } from "./types.js"; + +function chatCompletionsUrl(baseUrl: string): string { + const trimmed = baseUrl.replace(/\/+$/, ""); + return `${trimmed}/chat/completions`; +} + +/** + * Wraps provider credentials into an {@link LlmFn}: single POST to chat/completions, + * returns raw JSON body text or a {@link Result} error. Callers parse assistant messages. + */ +export function createLlmFn(provider: LlmProvider): LlmFn { + return async ({ + messages, + tools, + }: { + messages: ChatMessage[]; + tools: readonly ToolDefinition[]; + }) => { + try { + const response = await fetch(chatCompletionsUrl(provider.baseUrl), { + method: "POST", + headers: { + Authorization: `Bearer ${provider.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: provider.model, + messages, + tools, + tool_choice: "auto", + }), + }); + const responseText = await response.text(); + if (!response.ok) { + return err(`http_error:${String(response.status)}:${responseText.slice(0, 4000)}`); + } + return ok(responseText); + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + return err(`network_error:${message}`); + } + }; +} diff --git a/packages/workflow/src/reactor/thread-reactor.ts b/packages/workflow/src/reactor/thread-reactor.ts new file mode 100644 index 0000000..4931bd4 --- /dev/null +++ b/packages/workflow/src/reactor/thread-reactor.ts @@ -0,0 +1,317 @@ +import type * as z from "zod/v4"; + +import { err, ok, type Result } from "../util/index.js"; + +import type { + ChatMessage, + StructuredToolSpec, + ThreadReactorConfig, + ThreadReactorFn, + ToolCall, + ToolDefinition, +} from "./types.js"; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function tryParseJsonContent(content: string): unknown | null { + const trimmed = content.trim(); + const fenceMatch = /^```(?:json)?\s*([\s\S]*?)```$/m.exec(trimmed); + const payload = fenceMatch !== null ? fenceMatch[1].trim() : trimmed; + try { + return JSON.parse(payload) as unknown; + } catch { + return null; + } +} + +function firstAssistantMessage(responseText: string): Result, string> { + let parsed: unknown; + try { + parsed = JSON.parse(responseText) as unknown; + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + return err(`invalid_response_json:${message}`); + } + if (!isRecord(parsed)) { + return err("invalid_response_top_level"); + } + const choices = parsed.choices; + if (!Array.isArray(choices) || choices.length === 0) { + return err("no_choices_in_response"); + } + const firstChoice = choices[0]; + if (!isRecord(firstChoice)) { + return err("invalid_choice"); + } + const messageObj = firstChoice.message; + if (!isRecord(messageObj)) { + return err("invalid_message"); + } + return ok(messageObj); +} + +function normalizeToolCalls(toolCallsRaw: unknown[]): Result { + const toolCalls: ToolCall[] = []; + for (const tc of toolCallsRaw) { + if (!isRecord(tc)) { + return err("invalid_tool_call"); + } + const id = tc.id; + const tcType = tc.type; + const fn = tc.function; + if (typeof id !== "string" || tcType !== "function" || !isRecord(fn)) { + return err("invalid_tool_call_shape"); + } + const name = fn.name; + const argumentsStr = fn.arguments; + if (typeof name !== "string" || typeof argumentsStr !== "string") { + return err("invalid_tool_call_function"); + } + toolCalls.push({ id, type: "function", function: { name, arguments: argumentsStr } }); + } + return ok(toolCalls); +} + +type AssistantTurn = + | { kind: "plain_json"; value: T } + | { kind: "tool_calls"; calls: ToolCall[]; assistantContent: string | null }; + +type AssistantTurnOrCorrection = + | AssistantTurn + | { kind: "plain_json_invalid"; rawContent: string; correction: string }; + +function classifyAssistantTurn( + messageObj: Record, + schema: z.ZodType, + structuredToolName: string, +): Result, string> { + const toolCallsRaw = messageObj.tool_calls; + if (!Array.isArray(toolCallsRaw) || toolCallsRaw.length === 0) { + const content = messageObj.content; + if (typeof content !== "string") { + return err("no_tool_calls_and_no_string_content"); + } + const jsonParsed = tryParseJsonContent(content); + if (jsonParsed === null) { + return ok({ + kind: "plain_json_invalid", + rawContent: content, + correction: `Your previous reply was not valid JSON and contained no tool calls. Reply with a single JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`, + }); + } + const validated = schema.safeParse(jsonParsed); + if (!validated.success) { + return ok({ + kind: "plain_json_invalid", + rawContent: content, + correction: `Your previous JSON reply did not satisfy the schema: ${validated.error.message}. Reply again with a JSON object that matches the schema, or call the ${structuredToolName} tool with the structured arguments.`, + }); + } + return ok({ kind: "plain_json", value: validated.data }); + } + const callsResult = normalizeToolCalls(toolCallsRaw); + if (!callsResult.ok) { + return err(callsResult.error); + } + const assistantContent = messageObj.content; + return ok({ + kind: "tool_calls", + calls: callsResult.value, + assistantContent: typeof assistantContent === "string" ? assistantContent : null, + }); +} + +function toolNamesFromDefinitions(tools: readonly { function: { name: string } }[]): Set { + return new Set(tools.map((t) => t.function.name)); +} + +function appendStructuredToolResult( + tc: ToolCall, + schema: z.ZodType, + messages: ChatMessage[], +): T | null { + let parsedArgs: unknown; + try { + parsedArgs = JSON.parse(tc.function.arguments) as unknown; + } catch { + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: + "Tool arguments were not valid JSON. Provide valid JSON object arguments matching the schema.", + }); + return null; + } + const validated = schema.safeParse(parsedArgs); + if (!validated.success) { + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: `Schema validation failed: ${validated.error.message}. Fix the arguments and call the tool again with a JSON object that matches the schema.`, + }); + return null; + } + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: '{"ok":true}', + }); + return validated.data; +} + +async function dispatchToolCall( + tc: ToolCall, + spec: StructuredToolSpec, + knownNames: Set, + schema: z.ZodType, + thread: TThread, + toolHandler: ThreadReactorConfig["toolHandler"], + messages: ChatMessage[], +): Promise { + if (!knownNames.has(tc.function.name)) { + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: `Unknown tool: ${tc.function.name}. Use one of the declared tools only.`, + }); + return null; + } + if (tc.function.name === spec.name) { + return appendStructuredToolResult(tc, schema, messages); + } + let toolContent: string; + try { + toolContent = await toolHandler(tc, thread); + } catch (cause) { + const message = cause instanceof Error ? cause.message : String(cause); + toolContent = `Tool execution failed: ${message}`; + } + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: toolContent, + }); + return null; +} + +async function resolveToolCallRound( + turn: Extract, { kind: "tool_calls" }>, + spec: StructuredToolSpec, + knownNames: Set, + schema: z.ZodType, + thread: TThread, + toolHandler: ThreadReactorConfig["toolHandler"], + messages: ChatMessage[], +): Promise | null> { + messages.push({ + role: "assistant", + content: turn.assistantContent, + tool_calls: turn.calls, + }); + let extractedRound: T | null = null; + for (const tc of turn.calls) { + const extracted = await dispatchToolCall( + tc, + spec, + knownNames, + schema, + thread, + toolHandler, + messages, + ); + if (extracted !== null) { + extractedRound = extracted; + } + } + if (extractedRound !== null) { + return ok(extractedRound); + } + return null; +} + +async function runOneReactRound( + config: ThreadReactorConfig, + args: { thread: TThread; schema: z.ZodType }, + tools: readonly ToolDefinition[], + knownNames: Set, + spec: StructuredToolSpec, + messages: ChatMessage[], +): Promise | null> { + const bodyResult = await config.llm({ messages, tools }); + if (!bodyResult.ok) { + return bodyResult; + } + + const msgResult = firstAssistantMessage(bodyResult.value); + if (!msgResult.ok) { + return msgResult; + } + + const classified = classifyAssistantTurn(msgResult.value, args.schema, spec.name); + if (!classified.ok) { + return classified; + } + + const turn = classified.value; + if (turn.kind === "plain_json") { + return ok(turn.value); + } + + if (turn.kind === "plain_json_invalid") { + messages.push({ role: "assistant", content: turn.rawContent }); + messages.push({ role: "user", content: turn.correction }); + return null; + } + + return resolveToolCallRound( + turn, + spec, + knownNames, + args.schema, + args.thread, + config.toolHandler, + messages, + ); +} + +/** + * Generic ReAct loop: LLM round-trips with tools until structured output validates, + * plain JSON matches schema, or {@link ThreadReactorConfig.maxRounds} is exceeded. + */ +export function createThreadReactor( + config: ThreadReactorConfig, +): ThreadReactorFn { + return async (args: { + thread: TThread; + input: string; + schema: z.ZodType; + }): Promise> => { + const spec = config.structuredToolFromSchema(args.schema); + const tools = [...config.staticTools, spec.tool]; + const knownNames = toolNamesFromDefinitions(tools); + const systemPrompt = config.systemPromptForStructuredTool(spec.name); + + const messages: ChatMessage[] = [ + { role: "system", content: systemPrompt }, + { role: "user", content: args.input }, + ]; + + for (let round = 0; round < config.maxRounds; round++) { + const step = await runOneReactRound( + config, + { thread: args.thread, schema: args.schema }, + tools, + knownNames, + spec, + messages, + ); + if (step !== null) { + return step; + } + } + + return err("max_react_rounds_exceeded"); + }; +} diff --git a/packages/workflow/src/reactor/types.ts b/packages/workflow/src/reactor/types.ts new file mode 100644 index 0000000..5d9d499 --- /dev/null +++ b/packages/workflow/src/reactor/types.ts @@ -0,0 +1,62 @@ +import type * as z from "zod/v4"; + +import type { Result } from "../util/index.js"; + +export type ToolCall = { + id: string; + type: "function"; + function: { name: string; arguments: string }; +}; + +export type ToolDefinition = { + type: "function"; + function: { + name: string; + description: string; + parameters: Record; + }; +}; + +export type ChatMessage = + | { role: "system"; content: string } + | { role: "user"; content: string } + | { + role: "assistant"; + content: string | null; + tool_calls: ToolCall[]; + } + | { role: "assistant"; content: string } + | { role: "tool"; tool_call_id: string; content: string }; + +export type LlmFn = (input: { + messages: ChatMessage[]; + tools: readonly ToolDefinition[]; +}) => Promise>; + +/** Structured tool derived from the per-invocation Zod schema (e.g. extract tool). */ +export type StructuredToolSpec = { + name: string; + tool: ToolDefinition; +}; + +export type ThreadReactorConfig = { + llm: LlmFn; + /** Static tools (e.g. cas_get); structured tool is appended per invocation. */ + staticTools: readonly ToolDefinition[]; + /** Builds the schema-shaped tool and its OpenAI name for this invocation. */ + structuredToolFromSchema: (schema: z.ZodType) => StructuredToolSpec; + /** System prompt for this run; include the structured tool name for cache stability per schema. */ + systemPromptForStructuredTool: (structuredToolName: string) => string; + toolHandler: (call: ToolCall, thread: TThread) => Promise; + maxRounds: number; +}; + +export type ThreadReactorInvokeArgs = { + thread: TThread; + input: string; + schema: z.ZodType; +}; + +export type ThreadReactorFn = ( + args: ThreadReactorInvokeArgs, +) => Promise>;