From 18e3dc76033a8502ba05d2cc3428e327dda15aa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E6=9C=88?= Date: Tue, 12 May 2026 14:21:49 +0800 Subject: [PATCH 1/2] feat: WebSocket reverse-connection gateway Phase 1 (#210) - Add AgentSocket Durable Object (holds one WS per agent name) - Add /ws/connect route with GATEWAY_SECRET auth - Add ws-client.ts with auto-reconnect (exponential backoff 1s-30s) - serve defaults to WS mode (no cloudflared needed) - Keep --tunnel-url and --no-tunnel as fallback options - Endpoints list merges KV heartbeat + DO WebSocket status Testing: #211 --- .../cli-workflow/src/commands/serve/serve.ts | 66 +++++++---- .../cli-workflow/src/commands/serve/types.ts | 1 + .../src/commands/serve/ws-client.ts | 112 ++++++++++++++++++ packages/workflow-gateway/src/agent-socket.ts | 54 +++++++++ packages/workflow-gateway/src/index.ts | 76 +++++++++++- packages/workflow-gateway/wrangler.toml | 7 ++ 6 files changed, 289 insertions(+), 27 deletions(-) create mode 100644 packages/cli-workflow/src/commands/serve/ws-client.ts create mode 100644 packages/workflow-gateway/src/agent-socket.ts diff --git a/packages/cli-workflow/src/commands/serve/serve.ts b/packages/cli-workflow/src/commands/serve/serve.ts index 684f340..9456ee5 100644 --- a/packages/cli-workflow/src/commands/serve/serve.ts +++ b/packages/cli-workflow/src/commands/serve/serve.ts @@ -1,17 +1,14 @@ import { randomUUID } from "node:crypto"; import { hostname as osHostname } from "node:os"; import { err, ok, type Result } from "@uncaged/workflow-protocol"; +import { createLogger } from "@uncaged/workflow-util"; import { serve } from "bun"; import { printCliLine } from "../../cli-output.js"; import { createApp } from "./app.js"; -import { - registerWithGateway, - startHeartbeat, - startTunnel, - unregisterFromGateway, -} from "./tunnel.js"; +import { registerWithGateway, startHeartbeat, unregisterFromGateway } from "./tunnel.js"; import type { ServeOptions } from "./types.js"; +import { startGatewayWsClient } from "./ws-client.js"; const DEFAULT_GATEWAY_URL = "https://workflow-gateway.shazhou.workers.dev"; const HEARTBEAT_INTERVAL_MS = 60_000; @@ -56,6 +53,7 @@ function parseServeArgv(argv: string[]): Result { let hostname = "127.0.0.1"; let name = osHostname().split(".")[0].toLowerCase(); let noTunnel = false; + let tunnelUrl: string | null = null; let gatewayUrl = DEFAULT_GATEWAY_URL; const gatewaySecret = process.env.WORKFLOW_GATEWAY_SECRET ?? ""; const stringFlags: Record void> = { @@ -68,6 +66,9 @@ function parseServeArgv(argv: string[]): Result { "--gateway": (v) => { gatewayUrl = v; }, + "--tunnel-url": (v) => { + tunnelUrl = v; + }, }; for (let i = 0; i < argv.length; i++) { @@ -87,7 +88,7 @@ function parseServeArgv(argv: string[]): Result { } } - return ok({ port, hostname, name, noTunnel, gatewayUrl, gatewaySecret }); + return ok({ port, hostname, name, noTunnel, tunnelUrl, gatewayUrl, gatewaySecret }); } export async function dispatchServe(storageRoot: string, argv: string[]): Promise { @@ -107,47 +108,63 @@ export async function dispatchServe(storageRoot: string, argv: string[]): Promis return 0; } - // Start cloudflared quick tunnel - printCliLine("starting cloudflared quick tunnel..."); - const tunnel = await startTunnel(options.port); + let resolvedTunnelUrl: string; + let stopWsClient: (() => void) | null = null; - if (!tunnel) { - printCliLine("failed to create tunnel — continuing without gateway registration"); - await new Promise(() => {}); - return 0; + if (options.tunnelUrl !== null) { + resolvedTunnelUrl = options.tunnelUrl; + printCliLine(`using tunnel URL: ${resolvedTunnelUrl}`); + } else { + if (options.gatewaySecret === "") { + printCliLine( + "WORKFLOW_GATEWAY_SECRET not set — cannot use WebSocket gateway connection (set env or pass --tunnel-url)", + ); + await new Promise(() => {}); + return 0; + } + resolvedTunnelUrl = `http://127.0.0.1:${options.port}`; + const log = createLogger({ sink: { kind: "stderr" } }); + stopWsClient = startGatewayWsClient({ + gatewayUrl: options.gatewayUrl, + name: options.name, + secret: options.gatewaySecret, + log, + }); + printCliLine("gateway WebSocket reverse connection (no cloudflared)"); } - printCliLine(`tunnel: ${tunnel.url}`); - - // Register with gateway if (options.gatewaySecret) { + if (agentToken === null) { + printCliLine("internal error: agent token missing"); + await new Promise(() => {}); + return 1; + } + const token = agentToken; const registered = await registerWithGateway( options.gatewayUrl, options.name, - tunnel.url, + resolvedTunnelUrl, options.gatewaySecret, - agentToken!, + token, ); if (registered) { printCliLine(`registered with gateway as "${options.name}"`); } - // Start heartbeat const heartbeatTimer = startHeartbeat( options.gatewayUrl, options.name, - tunnel.url, + resolvedTunnelUrl, options.gatewaySecret, - agentToken!, + token, HEARTBEAT_INTERVAL_MS, ); - // Cleanup on exit const cleanup = async () => { clearInterval(heartbeatTimer); + stopWsClient?.(); printCliLine("unregistering from gateway..."); await unregisterFromGateway(options.gatewayUrl, options.name, options.gatewaySecret); - tunnel.process.kill(); process.exit(0); }; @@ -157,7 +174,6 @@ export async function dispatchServe(storageRoot: string, argv: string[]): Promis printCliLine("WORKFLOW_GATEWAY_SECRET not set — skipping gateway registration"); } - // Keep process alive await new Promise(() => {}); return 0; } diff --git a/packages/cli-workflow/src/commands/serve/types.ts b/packages/cli-workflow/src/commands/serve/types.ts index 541269c..8c19cd7 100644 --- a/packages/cli-workflow/src/commands/serve/types.ts +++ b/packages/cli-workflow/src/commands/serve/types.ts @@ -3,6 +3,7 @@ export type ServeOptions = { hostname: string; name: string; noTunnel: boolean; + tunnelUrl: string | null; gatewayUrl: string; gatewaySecret: string; }; diff --git a/packages/cli-workflow/src/commands/serve/ws-client.ts b/packages/cli-workflow/src/commands/serve/ws-client.ts new file mode 100644 index 0000000..b6d2e94 --- /dev/null +++ b/packages/cli-workflow/src/commands/serve/ws-client.ts @@ -0,0 +1,112 @@ +import type { LogFn } from "@uncaged/workflow-util"; + +export type GatewayWsClientParams = { + gatewayUrl: string; + name: string; + secret: string; + log: LogFn; +}; + +const INITIAL_BACKOFF_MS = 1000; +const MAX_BACKOFF_MS = 30_000; + +export function buildGatewayWsConnectUrl(gatewayUrl: string, name: string, secret: string): string { + const u = new URL(gatewayUrl); + if (u.protocol === "https:") { + u.protocol = "wss:"; + } else if (u.protocol === "http:") { + u.protocol = "ws:"; + } + u.pathname = "/ws/connect"; + u.search = ""; + u.searchParams.set("name", name); + u.searchParams.set("secret", secret); + return u.href; +} + +/** Maintains a reverse WebSocket to the workflow gateway; reconnects with exponential backoff. */ +export function startGatewayWsClient(params: GatewayWsClientParams): () => void { + const wsUrl = buildGatewayWsConnectUrl(params.gatewayUrl, params.name, params.secret); + let socket: WebSocket | null = null; + let reconnectTimer: ReturnType | null = null; + let stopped = false; + let attempt = 0; + + const clearReconnectTimer = (): void => { + if (reconnectTimer !== null) { + clearTimeout(reconnectTimer); + reconnectTimer = null; + } + }; + + const scheduleReconnect = (): void => { + if (stopped) { + return; + } + clearReconnectTimer(); + const delayMs = Math.min(INITIAL_BACKOFF_MS * 2 ** attempt, MAX_BACKOFF_MS); + attempt++; + params.log("6CJX2RLP", `gateway WebSocket reconnect in ${delayMs}ms (attempt ${attempt})`); + reconnectTimer = setTimeout(connect, delayMs); + }; + + const connect = (): void => { + if (stopped) { + return; + } + clearReconnectTimer(); + params.log("2XK7HM9Q", "gateway WebSocket connecting..."); + try { + socket = new WebSocket(wsUrl); + } catch (e) { + params.log("7NQW4HBT", `gateway WebSocket create failed: ${String(e)}`); + scheduleReconnect(); + return; + } + + const ws = socket; + + ws.addEventListener("open", () => { + attempt = 0; + params.log("4PWN3V82", "gateway WebSocket connected"); + }); + + ws.addEventListener("close", (ev) => { + socket = null; + params.log( + "8QTR6ZKC", + `gateway WebSocket closed code=${String(ev.code)} reason=${ev.reason} wasClean=${String(ev.wasClean)}`, + ); + if (!stopped) { + scheduleReconnect(); + } + }); + + ws.addEventListener("error", () => { + params.log("9BWS1M7F", "gateway WebSocket error"); + }); + + ws.addEventListener("message", (ev) => { + let preview: string; + if (typeof ev.data === "string") { + preview = ev.data; + } else if (ev.data instanceof ArrayBuffer) { + preview = `[binary ${String(ev.data.byteLength)} bytes]`; + } else { + preview = "[non-text message]"; + } + params.log("3FHK5NDJ", `gateway → agent (phase 2 stub): ${preview.slice(0, 500)}`); + }); + }; + + connect(); + + return (): void => { + stopped = true; + clearReconnectTimer(); + if (socket !== null && socket.readyState === WebSocket.OPEN) { + socket.close(1000, "shutdown"); + } + socket = null; + }; +} diff --git a/packages/workflow-gateway/src/agent-socket.ts b/packages/workflow-gateway/src/agent-socket.ts new file mode 100644 index 0000000..a684801 --- /dev/null +++ b/packages/workflow-gateway/src/agent-socket.ts @@ -0,0 +1,54 @@ +/** One Durable Object instance per agent name; holds the reverse WebSocket from the agent CLI. */ +import { DurableObject } from "cloudflare:workers"; + +type AgentSocketEnv = { + GATEWAY_SECRET: string; +}; + +export const AGENT_SOCKET_INTERNAL_STATUS_PATH = "/internal/agent-socket/status"; + +export class AgentSocket extends DurableObject { + async fetch(request: Request): Promise { + const url = new URL(request.url); + + if (url.pathname === AGENT_SOCKET_INTERNAL_STATUS_PATH && request.method === "GET") { + const auth = request.headers.get("Authorization"); + if (auth !== `Bearer ${this.env.GATEWAY_SECRET}`) { + return new Response(JSON.stringify({ error: "unauthorized" }), { + status: 401, + headers: { "Content-Type": "application/json" }, + }); + } + const sockets = this.ctx.getWebSockets(); + const connected = sockets.length > 0; + return new Response(JSON.stringify({ connected, connectedCount: sockets.length }), { + headers: { "Content-Type": "application/json" }, + }); + } + + if (request.headers.get("Upgrade") !== "websocket") { + return new Response("expected WebSocket upgrade", { status: 426 }); + } + + for (const ws of this.ctx.getWebSockets()) { + ws.close(1000, "replaced by new connection"); + } + + const pair = new WebSocketPair(); + const client = pair[0]; + const server = pair[1]; + this.ctx.acceptWebSocket(server); + return new Response(null, { status: 101, webSocket: client }); + } + + async webSocketMessage(_ws: WebSocket, _message: string | ArrayBuffer): Promise {} + + async webSocketClose( + _ws: WebSocket, + _code: number, + _reason: string, + _wasClean: boolean, + ): Promise {} + + async webSocketError(_ws: WebSocket, _error: unknown): Promise {} +} diff --git a/packages/workflow-gateway/src/index.ts b/packages/workflow-gateway/src/index.ts index e303790..ce2f61a 100644 --- a/packages/workflow-gateway/src/index.ts +++ b/packages/workflow-gateway/src/index.ts @@ -1,11 +1,16 @@ import { Hono } from "hono"; import { cors } from "hono/cors"; +import { AGENT_SOCKET_INTERNAL_STATUS_PATH, AgentSocket } from "./agent-socket.js"; + +export { AgentSocket }; + type Env = { Bindings: { ENDPOINTS: KVNamespace; GATEWAY_SECRET: string; DASHBOARD_API_KEY: string; + AGENT_SOCKET: DurableObjectNamespace; }; }; @@ -33,9 +38,74 @@ function checkDashboardAuth(c: { return key === c.env.DASHBOARD_API_KEY; } +function isLocalAgentUrl(url: string): boolean { + try { + const u = new URL(url); + return u.hostname === "localhost" || u.hostname === "127.0.0.1"; + } catch { + return false; + } +} + +async function fetchAgentSocketStatus( + env: Env["Bindings"], + name: string, +): Promise<{ ok: true; connected: boolean } | { ok: false }> { + try { + const id = env.AGENT_SOCKET.idFromName(name); + const stub = env.AGENT_SOCKET.get(id); + const resp = await stub.fetch( + new Request(`https://do${AGENT_SOCKET_INTERNAL_STATUS_PATH}`, { + method: "GET", + headers: { Authorization: `Bearer ${env.GATEWAY_SECRET}` }, + }), + ); + if (!resp.ok) { + return { ok: false }; + } + const body = (await resp.json()) as { connected: boolean }; + return { ok: true, connected: body.connected }; + } catch { + return { ok: false }; + } +} + +function endpointStatusFromKvAndDo(record: EndpointRecord, doConnected: boolean | null): string { + if (doConnected === true) { + return "online"; + } + if (doConnected === false) { + if (isLocalAgentUrl(record.url)) { + return "offline"; + } + const age = Date.now() - record.lastHeartbeat; + return age < TTL_SECONDS * 1000 ? "online" : "offline"; + } + const age = Date.now() - record.lastHeartbeat; + return age < TTL_SECONDS * 1000 ? "online" : "offline"; +} + // ── Health ────────────────────────────────────────────────────────── app.get("/healthz", (c) => c.json({ ok: true })); +// ── Agent reverse WebSocket (GATEWAY_SECRET query param) ──────────── +app.get("/ws/connect", async (c) => { + const secret = c.req.query("secret"); + const name = c.req.query("name"); + if (name === undefined || name === "") { + return c.json({ error: "name required" }, 400); + } + if (secret !== c.env.GATEWAY_SECRET) { + return c.json({ error: "unauthorized" }, 401); + } + if (c.req.header("Upgrade") !== "websocket") { + return c.text("expected WebSocket upgrade", 426); + } + const id = c.env.AGENT_SOCKET.idFromName(name); + const stub = c.env.AGENT_SOCKET.get(id); + return stub.fetch(c.req.raw); +}); + // ── Gateway management (GATEWAY_SECRET auth) ──────────────────────── const gateway = new Hono(); @@ -95,11 +165,12 @@ gateway.get("/endpoints", async (c) => { for (const key of list.keys) { const record = await c.env.ENDPOINTS.get(key.name, "json"); if (record) { - const age = Date.now() - record.lastHeartbeat; + const doStatus = await fetchAgentSocketStatus(c.env, record.name); + const doConnected = doStatus.ok ? doStatus.connected : null; endpoints.push({ name: record.name, url: record.url, - status: age < TTL_SECONDS * 1000 ? "online" : "offline", + status: endpointStatusFromKvAndDo(record, doConnected), lastHeartbeat: record.lastHeartbeat, }); } @@ -149,4 +220,5 @@ app.all("/api/agents/:agent/*", async (c) => { } }); +// biome-ignore lint/style/noDefaultExport: Cloudflare Workers entry expects default export export default app; diff --git a/packages/workflow-gateway/wrangler.toml b/packages/workflow-gateway/wrangler.toml index 9688a1d..e8d6a6d 100644 --- a/packages/workflow-gateway/wrangler.toml +++ b/packages/workflow-gateway/wrangler.toml @@ -6,4 +6,11 @@ compatibility_date = "2025-04-01" binding = "ENDPOINTS" id = "88b118d1cfab4c049f9c1684848811a3" +[durable_objects] +bindings = [{ name = "AGENT_SOCKET", class_name = "AgentSocket" }] + +[[migrations]] +tag = "add-agent-socket" +new_sqlite_classes = ["AgentSocket"] + # GATEWAY_SECRET is set via `wrangler secret put` -- 2.43.0 From ec3c97b2005fc7b853004218bedb397a3a4fb22f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E6=9C=88?= Date: Tue, 12 May 2026 14:42:19 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat:=20WS=20request=20proxy=20=E2=80=94=20?= =?UTF-8?q?Gateway=20proxies=20HTTP=20via=20WebSocket=20(#210=20Phase=202)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ws-protocol.ts with WsRequest/WsResponse types + parsers - AgentSocket DO: proxy POST handler, pending request map, 30s timeout - /api/agents/:agent/* routes through DO WS when connected, falls back to HTTP - ws-client handles incoming WsRequest, fetches local serve, returns WsResponse - startGatewayWsClient accepts localPort for request handling Testing: #213 --- packages/cli-workflow/package.json | 1 + .../cli-workflow/src/commands/serve/serve.ts | 1 + .../src/commands/serve/ws-client.ts | 69 ++++++++- packages/workflow-gateway/package.json | 4 + packages/workflow-gateway/src/agent-socket.ts | 138 +++++++++++++++-- packages/workflow-gateway/src/index.ts | 145 ++++++++++++++++-- packages/workflow-gateway/src/ws-protocol.ts | 93 +++++++++++ 7 files changed, 413 insertions(+), 38 deletions(-) create mode 100644 packages/workflow-gateway/src/ws-protocol.ts diff --git a/packages/cli-workflow/package.json b/packages/cli-workflow/package.json index 1da2774..9da6fd9 100644 --- a/packages/cli-workflow/package.json +++ b/packages/cli-workflow/package.json @@ -6,6 +6,7 @@ "uncaged-workflow": "src/cli.ts" }, "dependencies": { + "@uncaged/workflow-gateway": "workspace:*", "@uncaged/workflow-protocol": "workspace:*", "@uncaged/workflow-util": "workspace:*", "@uncaged/workflow-cas": "workspace:*", diff --git a/packages/cli-workflow/src/commands/serve/serve.ts b/packages/cli-workflow/src/commands/serve/serve.ts index 9456ee5..b41f3e7 100644 --- a/packages/cli-workflow/src/commands/serve/serve.ts +++ b/packages/cli-workflow/src/commands/serve/serve.ts @@ -128,6 +128,7 @@ export async function dispatchServe(storageRoot: string, argv: string[]): Promis gatewayUrl: options.gatewayUrl, name: options.name, secret: options.gatewaySecret, + localPort: options.port, log, }); printCliLine("gateway WebSocket reverse connection (no cloudflared)"); diff --git a/packages/cli-workflow/src/commands/serve/ws-client.ts b/packages/cli-workflow/src/commands/serve/ws-client.ts index b6d2e94..b58e49c 100644 --- a/packages/cli-workflow/src/commands/serve/ws-client.ts +++ b/packages/cli-workflow/src/commands/serve/ws-client.ts @@ -1,9 +1,11 @@ +import { parseWsRequestJson, type WsResponse } from "@uncaged/workflow-gateway/ws-protocol"; import type { LogFn } from "@uncaged/workflow-util"; export type GatewayWsClientParams = { gatewayUrl: string; name: string; secret: string; + localPort: number; log: LogFn; }; @@ -24,6 +26,58 @@ export function buildGatewayWsConnectUrl(gatewayUrl: string, name: string, secre return u.href; } +function headersToRecord(h: Headers): Record { + const out: Record = {}; + for (const [k, v] of h) { + out[k] = v; + } + return out; +} + +async function handleGatewayMessage( + ws: WebSocket, + raw: string, + params: GatewayWsClientParams, +): Promise { + const req = parseWsRequestJson(raw); + if (req === null) { + params.log("ZM8K2PQ1", "gateway WebSocket dropped non-request message"); + return; + } + const localUrl = `http://127.0.0.1:${String(params.localPort)}${req.path}`; + const initHeaders = new Headers(); + for (const [k, v] of Object.entries(req.headers)) { + initHeaders.set(k, v); + } + let resp: Response; + try { + resp = await fetch(localUrl, { + method: req.method, + headers: initHeaders, + body: req.body === null ? undefined : req.body, + }); + } catch (e) { + params.log("R4N7BQ3C", `local proxy fetch failed: ${String(e)}`); + const errBody: WsResponse = { + id: req.id, + status: 502, + headers: { "content-type": "application/json" }, + body: JSON.stringify({ error: "local fetch failed", detail: String(e) }), + }; + ws.send(JSON.stringify(errBody)); + return; + } + const bodyText = await resp.text(); + const headerRecord = headersToRecord(resp.headers); + const out: WsResponse = { + id: req.id, + status: resp.status, + headers: headerRecord, + body: bodyText, + }; + ws.send(JSON.stringify(out)); +} + /** Maintains a reverse WebSocket to the workflow gateway; reconnects with exponential backoff. */ export function startGatewayWsClient(params: GatewayWsClientParams): () => void { const wsUrl = buildGatewayWsConnectUrl(params.gatewayUrl, params.name, params.secret); @@ -87,15 +141,14 @@ export function startGatewayWsClient(params: GatewayWsClientParams): () => void }); ws.addEventListener("message", (ev) => { - let preview: string; - if (typeof ev.data === "string") { - preview = ev.data; - } else if (ev.data instanceof ArrayBuffer) { - preview = `[binary ${String(ev.data.byteLength)} bytes]`; - } else { - preview = "[non-text message]"; + const data = ev.data; + if (typeof data !== "string") { + params.log("T9W2KL5H", "gateway WebSocket non-text frame ignored"); + return; } - params.log("3FHK5NDJ", `gateway → agent (phase 2 stub): ${preview.slice(0, 500)}`); + void handleGatewayMessage(ws, data, params).catch((e: unknown) => { + params.log("V7KX2M9P", `gateway WebSocket handler error: ${String(e)}`); + }); }); }; diff --git a/packages/workflow-gateway/package.json b/packages/workflow-gateway/package.json index a5defe6..3b99b1d 100644 --- a/packages/workflow-gateway/package.json +++ b/packages/workflow-gateway/package.json @@ -3,6 +3,10 @@ "version": "0.1.0", "private": true, "type": "module", + "exports": { + ".": "./src/index.ts", + "./ws-protocol": "./src/ws-protocol.ts" + }, "scripts": { "dev": "wrangler dev", "deploy": "wrangler deploy" diff --git a/packages/workflow-gateway/src/agent-socket.ts b/packages/workflow-gateway/src/agent-socket.ts index a684801..a872529 100644 --- a/packages/workflow-gateway/src/agent-socket.ts +++ b/packages/workflow-gateway/src/agent-socket.ts @@ -1,29 +1,111 @@ /** One Durable Object instance per agent name; holds the reverse WebSocket from the agent CLI. */ import { DurableObject } from "cloudflare:workers"; +import { parseWsRequestJson, parseWsResponseJson, type WsResponse } from "./ws-protocol.js"; + type AgentSocketEnv = { GATEWAY_SECRET: string; }; export const AGENT_SOCKET_INTERNAL_STATUS_PATH = "/internal/agent-socket/status"; +export const AGENT_SOCKET_INTERNAL_PROXY_PATH = "/internal/agent-socket/proxy"; + +const PROXY_TIMEOUT_MS = 30_000; + +type PendingEntry = { + resolve: (r: Response) => void; + timer: ReturnType; +}; + +function jsonResponse(status: number, body: unknown): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +function wsResponseToHttp(wr: WsResponse): Response { + const headers = new Headers(); + for (const [k, v] of Object.entries(wr.headers)) { + headers.set(k, v); + } + return new Response(wr.body, { status: wr.status, headers }); +} export class AgentSocket extends DurableObject { + private readonly pending = new Map(); + + private requireAuth(request: Request): Response | null { + const auth = request.headers.get("Authorization"); + if (auth !== `Bearer ${this.env.GATEWAY_SECRET}`) { + return jsonResponse(401, { error: "unauthorized" }); + } + return null; + } + + private handleStatusGet(request: Request): Response { + const denied = this.requireAuth(request); + if (denied !== null) { + return denied; + } + const sockets = this.ctx.getWebSockets(); + const connected = sockets.length > 0; + return new Response(JSON.stringify({ connected, connectedCount: sockets.length }), { + headers: { "Content-Type": "application/json" }, + }); + } + + private async handleProxyPost(request: Request): Promise { + const denied = this.requireAuth(request); + if (denied !== null) { + return denied; + } + const raw = await request.text(); + const wsRequest = parseWsRequestJson(raw); + if (wsRequest === null) { + return jsonResponse(400, { error: "invalid proxy body" }); + } + + const sockets = this.ctx.getWebSockets(); + const ws = sockets[0]; + if (ws === undefined) { + return jsonResponse(503, { error: "no active websocket" }); + } + + return await new Promise((resolve) => { + const timer = setTimeout(() => { + this.pending.delete(wsRequest.id); + resolve(jsonResponse(504, { error: "gateway timeout" })); + }, PROXY_TIMEOUT_MS); + + this.pending.set(wsRequest.id, { + resolve: (r: Response) => { + clearTimeout(timer); + this.pending.delete(wsRequest.id); + resolve(r); + }, + timer, + }); + + try { + ws.send(JSON.stringify(wsRequest)); + } catch { + clearTimeout(timer); + this.pending.delete(wsRequest.id); + resolve(jsonResponse(502, { error: "websocket send failed" })); + } + }); + } + async fetch(request: Request): Promise { const url = new URL(request.url); if (url.pathname === AGENT_SOCKET_INTERNAL_STATUS_PATH && request.method === "GET") { - const auth = request.headers.get("Authorization"); - if (auth !== `Bearer ${this.env.GATEWAY_SECRET}`) { - return new Response(JSON.stringify({ error: "unauthorized" }), { - status: 401, - headers: { "Content-Type": "application/json" }, - }); - } - const sockets = this.ctx.getWebSockets(); - const connected = sockets.length > 0; - return new Response(JSON.stringify({ connected, connectedCount: sockets.length }), { - headers: { "Content-Type": "application/json" }, - }); + return this.handleStatusGet(request); + } + + if (url.pathname === AGENT_SOCKET_INTERNAL_PROXY_PATH && request.method === "POST") { + return this.handleProxyPost(request); } if (request.headers.get("Upgrade") !== "websocket") { @@ -41,14 +123,40 @@ export class AgentSocket extends DurableObject { return new Response(null, { status: 101, webSocket: client }); } - async webSocketMessage(_ws: WebSocket, _message: string | ArrayBuffer): Promise {} + async webSocketMessage(_ws: WebSocket, message: string | ArrayBuffer): Promise { + const text = typeof message === "string" ? message : new TextDecoder().decode(message); + const wr = parseWsResponseJson(text); + if (wr === null) { + return; + } + const entry = this.pending.get(wr.id); + if (entry === undefined) { + return; + } + clearTimeout(entry.timer); + this.pending.delete(wr.id); + entry.resolve(wsResponseToHttp(wr)); + } async webSocketClose( _ws: WebSocket, _code: number, _reason: string, _wasClean: boolean, - ): Promise {} + ): Promise { + this.rejectAllPending("agent websocket closed"); + } - async webSocketError(_ws: WebSocket, _error: unknown): Promise {} + async webSocketError(_ws: WebSocket, _error: unknown): Promise { + this.rejectAllPending("agent websocket error"); + } + + private rejectAllPending(message: string): void { + const entries = [...this.pending.values()]; + this.pending.clear(); + for (const entry of entries) { + clearTimeout(entry.timer); + entry.resolve(jsonResponse(502, { error: message })); + } + } } diff --git a/packages/workflow-gateway/src/index.ts b/packages/workflow-gateway/src/index.ts index ce2f61a..a8d21ca 100644 --- a/packages/workflow-gateway/src/index.ts +++ b/packages/workflow-gateway/src/index.ts @@ -1,7 +1,12 @@ import { Hono } from "hono"; import { cors } from "hono/cors"; -import { AGENT_SOCKET_INTERNAL_STATUS_PATH, AgentSocket } from "./agent-socket.js"; +import { + AGENT_SOCKET_INTERNAL_PROXY_PATH, + AGENT_SOCKET_INTERNAL_STATUS_PATH, + AgentSocket, +} from "./agent-socket.js"; +import type { WsRequest } from "./ws-protocol.js"; export { AgentSocket }; @@ -47,6 +52,97 @@ function isLocalAgentUrl(url: string): boolean { } } +function buildForwardHeaders(raw: Headers, agentToken: string): Record { + const out: Record = {}; + for (const [key, value] of raw) { + const lower = key.toLowerCase(); + if (lower === "host" || lower === "authorization") { + continue; + } + if ( + lower === "connection" || + lower === "keep-alive" || + lower === "proxy-connection" || + lower === "transfer-encoding" || + lower === "upgrade" + ) { + continue; + } + out[key] = value; + } + if (agentToken !== "") { + out["X-Agent-Token"] = agentToken; + } + return out; +} + +function buildDashboardProxyHeaders(raw: Headers, token: string): Headers { + const headers = new Headers(raw); + headers.delete("host"); + headers.delete("Authorization"); + if (token !== "") { + headers.set("X-Agent-Token", token); + } + return headers; +} + +async function readBodyForWsProxy(method: string, req: Request): Promise { + if (method === "GET" || method === "HEAD") { + return null; + } + const buf = await req.arrayBuffer(); + return buf.byteLength === 0 ? null : new TextDecoder().decode(buf); +} + +async function fetchThroughAgentSocket( + bindings: Env["Bindings"], + agent: string, + gateSecret: string, + wsRequest: WsRequest, +): Promise { + const stub = bindings.AGENT_SOCKET.get(bindings.AGENT_SOCKET.idFromName(agent)); + return stub.fetch( + new Request(`https://do.internal${AGENT_SOCKET_INTERNAL_PROXY_PATH}`, { + method: "POST", + headers: { + Authorization: `Bearer ${gateSecret}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(wsRequest), + }), + ); +} + +async function fetchAgentWithRecordHeaders( + targetUrl: string, + method: string, + forwardRecord: Record, + bodyStr: string | null, +): Promise { + const headers = new Headers(); + for (const [k, v] of Object.entries(forwardRecord)) { + headers.set(k, v); + } + return fetch(targetUrl, { + method, + headers, + body: method !== "GET" && method !== "HEAD" ? (bodyStr ?? undefined) : undefined, + }); +} + +async function fetchAgentWithDashboardHeaders( + targetUrl: string, + method: string, + headers: Headers, + rawBody: BodyInit | null | undefined, +): Promise { + return fetch(targetUrl, { + method, + headers, + body: method !== "GET" && method !== "HEAD" ? rawBody : undefined, + }); +} + async function fetchAgentSocketStatus( env: Env["Bindings"], name: string, @@ -181,7 +277,7 @@ gateway.get("/endpoints", async (c) => { app.route("/api/gateway", gateway); -// ── API proxy: /api/agents/:agent/* → agent's tunnel URL (dashboard auth) ── +// ── API proxy: /api/agents/:agent/* → WebSocket (preferred) or agent tunnel URL (dashboard auth) ── app.all("/api/agents/:agent/*", async (c) => { if (!checkDashboardAuth(c)) return c.json({ error: "unauthorized" }, 401); const agent = c.req.param("agent"); @@ -191,26 +287,45 @@ app.all("/api/agents/:agent/*", async (c) => { return c.json({ error: "agent not found" }, 404); } - // Build target URL: strip /api/:agent prefix, forward the rest const url = new URL(c.req.url); const pathAfterAgent = url.pathname.replace(`/api/agents/${agent}`, ""); const targetUrl = `${record.url}/api${pathAfterAgent}${url.search}`; + const proxyPath = `/api${pathAfterAgent}${url.search}`; + const method = c.req.method; + const token = record.agentToken ?? ""; + const forwardRecord = buildForwardHeaders(c.req.raw.headers, token); - const headers = new Headers(c.req.raw.headers); - headers.delete("host"); - headers.delete("Authorization"); // don't forward dashboard key to agent - if (record.agentToken) { - headers.set("X-Agent-Token", record.agentToken); + const doStatus = await fetchAgentSocketStatus(c.env, agent); + if (doStatus.ok && doStatus.connected) { + const bodyStr = await readBodyForWsProxy(method, c.req.raw); + const wsRequest: WsRequest = { + id: crypto.randomUUID(), + method, + path: proxyPath, + headers: forwardRecord, + body: bodyStr, + }; + const proxyResp = await fetchThroughAgentSocket(c.env, agent, c.env.GATEWAY_SECRET, wsRequest); + if (proxyResp.status !== 503) { + return new Response(proxyResp.body, { + status: proxyResp.status, + headers: proxyResp.headers, + }); + } + try { + const resp = await fetchAgentWithRecordHeaders(targetUrl, method, forwardRecord, bodyStr); + return new Response(resp.body, { + status: resp.status, + headers: resp.headers, + }); + } catch (err) { + return c.json({ error: "agent unreachable", detail: String(err) }, 502); + } } + const headers = buildDashboardProxyHeaders(c.req.raw.headers, token); try { - const resp = await fetch(targetUrl, { - method: c.req.method, - headers, - body: c.req.method !== "GET" && c.req.method !== "HEAD" ? c.req.raw.body : undefined, - }); - - // Stream response back + const resp = await fetchAgentWithDashboardHeaders(targetUrl, method, headers, c.req.raw.body); return new Response(resp.body, { status: resp.status, headers: resp.headers, diff --git a/packages/workflow-gateway/src/ws-protocol.ts b/packages/workflow-gateway/src/ws-protocol.ts new file mode 100644 index 0000000..643663d --- /dev/null +++ b/packages/workflow-gateway/src/ws-protocol.ts @@ -0,0 +1,93 @@ +/** Wire format for HTTP-over-WebSocket proxy between gateway Durable Object and local serve. */ + +export type WsRequest = { + id: string; + method: string; + path: string; + headers: Record; + body: string | null; +}; + +export type WsResponse = { + id: string; + status: number; + headers: Record; + body: string; +}; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function isNonEmptyString(value: unknown): value is string { + return typeof value === "string" && value.length > 0; +} + +/** Parse and validate a JSON payload as {@link WsRequest}. */ +export function parseWsRequestJson(raw: string): WsRequest | null { + let parsed: unknown; + try { + parsed = JSON.parse(raw) as unknown; + } catch { + return null; + } + if (!isRecord(parsed)) { + return null; + } + const id = parsed.id; + const method = parsed.method; + const path = parsed.path; + const headers = parsed.headers; + const body = parsed.body; + if (!isNonEmptyString(id) || !isNonEmptyString(method) || !isNonEmptyString(path)) { + return null; + } + if (!isRecord(headers)) { + return null; + } + const headerRecord: Record = {}; + for (const [k, v] of Object.entries(headers)) { + if (typeof v !== "string") { + return null; + } + headerRecord[k] = v; + } + if (body !== null && typeof body !== "string") { + return null; + } + return { id, method, path, headers: headerRecord, body: body === null ? null : body }; +} + +/** Parse and validate a JSON payload as {@link WsResponse}. */ +export function parseWsResponseJson(raw: string): WsResponse | null { + let parsed: unknown; + try { + parsed = JSON.parse(raw) as unknown; + } catch { + return null; + } + if (!isRecord(parsed)) { + return null; + } + const id = parsed.id; + const status = parsed.status; + const headers = parsed.headers; + const respBody = parsed.body; + if (!isNonEmptyString(id) || typeof status !== "number" || !Number.isFinite(status)) { + return null; + } + if (!isRecord(headers)) { + return null; + } + const headerRecord: Record = {}; + for (const [k, v] of Object.entries(headers)) { + if (typeof v !== "string") { + return null; + } + headerRecord[k] = v; + } + if (typeof respBody !== "string") { + return null; + } + return { id, status: Math.trunc(status), headers: headerRecord, body: respBody }; +} -- 2.43.0