From ec3c97b2005fc7b853004218bedb397a3a4fb22f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E6=9C=88?= Date: Tue, 12 May 2026 14:42:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20WS=20request=20proxy=20=E2=80=94=20Gate?= =?UTF-8?q?way=20proxies=20HTTP=20via=20WebSocket=20(#210=20Phase=202)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ws-protocol.ts with WsRequest/WsResponse types + parsers - AgentSocket DO: proxy POST handler, pending request map, 30s timeout - /api/agents/:agent/* routes through DO WS when connected, falls back to HTTP - ws-client handles incoming WsRequest, fetches local serve, returns WsResponse - startGatewayWsClient accepts localPort for request handling Testing: #213 --- packages/cli-workflow/package.json | 1 + .../cli-workflow/src/commands/serve/serve.ts | 1 + .../src/commands/serve/ws-client.ts | 69 ++++++++- packages/workflow-gateway/package.json | 4 + packages/workflow-gateway/src/agent-socket.ts | 138 +++++++++++++++-- packages/workflow-gateway/src/index.ts | 145 ++++++++++++++++-- packages/workflow-gateway/src/ws-protocol.ts | 93 +++++++++++ 7 files changed, 413 insertions(+), 38 deletions(-) create mode 100644 packages/workflow-gateway/src/ws-protocol.ts diff --git a/packages/cli-workflow/package.json b/packages/cli-workflow/package.json index 1da2774..9da6fd9 100644 --- a/packages/cli-workflow/package.json +++ b/packages/cli-workflow/package.json @@ -6,6 +6,7 @@ "uncaged-workflow": "src/cli.ts" }, "dependencies": { + "@uncaged/workflow-gateway": "workspace:*", "@uncaged/workflow-protocol": "workspace:*", "@uncaged/workflow-util": "workspace:*", "@uncaged/workflow-cas": "workspace:*", diff --git a/packages/cli-workflow/src/commands/serve/serve.ts b/packages/cli-workflow/src/commands/serve/serve.ts index 9456ee5..b41f3e7 100644 --- a/packages/cli-workflow/src/commands/serve/serve.ts +++ b/packages/cli-workflow/src/commands/serve/serve.ts @@ -128,6 +128,7 @@ export async function dispatchServe(storageRoot: string, argv: string[]): Promis gatewayUrl: options.gatewayUrl, name: options.name, secret: options.gatewaySecret, + localPort: options.port, log, }); printCliLine("gateway WebSocket reverse connection (no cloudflared)"); diff --git a/packages/cli-workflow/src/commands/serve/ws-client.ts b/packages/cli-workflow/src/commands/serve/ws-client.ts index b6d2e94..b58e49c 100644 --- a/packages/cli-workflow/src/commands/serve/ws-client.ts +++ b/packages/cli-workflow/src/commands/serve/ws-client.ts @@ -1,9 +1,11 @@ +import { parseWsRequestJson, type WsResponse } from "@uncaged/workflow-gateway/ws-protocol"; import type { LogFn } from "@uncaged/workflow-util"; export type GatewayWsClientParams = { gatewayUrl: string; name: string; secret: string; + localPort: number; log: LogFn; }; @@ -24,6 +26,58 @@ export function buildGatewayWsConnectUrl(gatewayUrl: string, name: string, secre return u.href; } +function headersToRecord(h: Headers): Record { + const out: Record = {}; + for (const [k, v] of h) { + out[k] = v; + } + return out; +} + +async function handleGatewayMessage( + ws: WebSocket, + raw: string, + params: GatewayWsClientParams, +): Promise { + const req = parseWsRequestJson(raw); + if (req === null) { + params.log("ZM8K2PQ1", "gateway WebSocket dropped non-request message"); + return; + } + const localUrl = `http://127.0.0.1:${String(params.localPort)}${req.path}`; + const initHeaders = new Headers(); + for (const [k, v] of Object.entries(req.headers)) { + initHeaders.set(k, v); + } + let resp: Response; + try { + resp = await fetch(localUrl, { + method: req.method, + headers: initHeaders, + body: req.body === null ? undefined : req.body, + }); + } catch (e) { + params.log("R4N7BQ3C", `local proxy fetch failed: ${String(e)}`); + const errBody: WsResponse = { + id: req.id, + status: 502, + headers: { "content-type": "application/json" }, + body: JSON.stringify({ error: "local fetch failed", detail: String(e) }), + }; + ws.send(JSON.stringify(errBody)); + return; + } + const bodyText = await resp.text(); + const headerRecord = headersToRecord(resp.headers); + const out: WsResponse = { + id: req.id, + status: resp.status, + headers: headerRecord, + body: bodyText, + }; + ws.send(JSON.stringify(out)); +} + /** Maintains a reverse WebSocket to the workflow gateway; reconnects with exponential backoff. */ export function startGatewayWsClient(params: GatewayWsClientParams): () => void { const wsUrl = buildGatewayWsConnectUrl(params.gatewayUrl, params.name, params.secret); @@ -87,15 +141,14 @@ export function startGatewayWsClient(params: GatewayWsClientParams): () => void }); ws.addEventListener("message", (ev) => { - let preview: string; - if (typeof ev.data === "string") { - preview = ev.data; - } else if (ev.data instanceof ArrayBuffer) { - preview = `[binary ${String(ev.data.byteLength)} bytes]`; - } else { - preview = "[non-text message]"; + const data = ev.data; + if (typeof data !== "string") { + params.log("T9W2KL5H", "gateway WebSocket non-text frame ignored"); + return; } - params.log("3FHK5NDJ", `gateway → agent (phase 2 stub): ${preview.slice(0, 500)}`); + void handleGatewayMessage(ws, data, params).catch((e: unknown) => { + params.log("V7KX2M9P", `gateway WebSocket handler error: ${String(e)}`); + }); }); }; diff --git a/packages/workflow-gateway/package.json b/packages/workflow-gateway/package.json index a5defe6..3b99b1d 100644 --- a/packages/workflow-gateway/package.json +++ b/packages/workflow-gateway/package.json @@ -3,6 +3,10 @@ "version": "0.1.0", "private": true, "type": "module", + "exports": { + ".": "./src/index.ts", + "./ws-protocol": "./src/ws-protocol.ts" + }, "scripts": { "dev": "wrangler dev", "deploy": "wrangler deploy" diff --git a/packages/workflow-gateway/src/agent-socket.ts b/packages/workflow-gateway/src/agent-socket.ts index a684801..a872529 100644 --- a/packages/workflow-gateway/src/agent-socket.ts +++ b/packages/workflow-gateway/src/agent-socket.ts @@ -1,29 +1,111 @@ /** One Durable Object instance per agent name; holds the reverse WebSocket from the agent CLI. */ import { DurableObject } from "cloudflare:workers"; +import { parseWsRequestJson, parseWsResponseJson, type WsResponse } from "./ws-protocol.js"; + type AgentSocketEnv = { GATEWAY_SECRET: string; }; export const AGENT_SOCKET_INTERNAL_STATUS_PATH = "/internal/agent-socket/status"; +export const AGENT_SOCKET_INTERNAL_PROXY_PATH = "/internal/agent-socket/proxy"; + +const PROXY_TIMEOUT_MS = 30_000; + +type PendingEntry = { + resolve: (r: Response) => void; + timer: ReturnType; +}; + +function jsonResponse(status: number, body: unknown): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +function wsResponseToHttp(wr: WsResponse): Response { + const headers = new Headers(); + for (const [k, v] of Object.entries(wr.headers)) { + headers.set(k, v); + } + return new Response(wr.body, { status: wr.status, headers }); +} export class AgentSocket extends DurableObject { + private readonly pending = new Map(); + + private requireAuth(request: Request): Response | null { + const auth = request.headers.get("Authorization"); + if (auth !== `Bearer ${this.env.GATEWAY_SECRET}`) { + return jsonResponse(401, { error: "unauthorized" }); + } + return null; + } + + private handleStatusGet(request: Request): Response { + const denied = this.requireAuth(request); + if (denied !== null) { + return denied; + } + const sockets = this.ctx.getWebSockets(); + const connected = sockets.length > 0; + return new Response(JSON.stringify({ connected, connectedCount: sockets.length }), { + headers: { "Content-Type": "application/json" }, + }); + } + + private async handleProxyPost(request: Request): Promise { + const denied = this.requireAuth(request); + if (denied !== null) { + return denied; + } + const raw = await request.text(); + const wsRequest = parseWsRequestJson(raw); + if (wsRequest === null) { + return jsonResponse(400, { error: "invalid proxy body" }); + } + + const sockets = this.ctx.getWebSockets(); + const ws = sockets[0]; + if (ws === undefined) { + return jsonResponse(503, { error: "no active websocket" }); + } + + return await new Promise((resolve) => { + const timer = setTimeout(() => { + this.pending.delete(wsRequest.id); + resolve(jsonResponse(504, { error: "gateway timeout" })); + }, PROXY_TIMEOUT_MS); + + this.pending.set(wsRequest.id, { + resolve: (r: Response) => { + clearTimeout(timer); + this.pending.delete(wsRequest.id); + resolve(r); + }, + timer, + }); + + try { + ws.send(JSON.stringify(wsRequest)); + } catch { + clearTimeout(timer); + this.pending.delete(wsRequest.id); + resolve(jsonResponse(502, { error: "websocket send failed" })); + } + }); + } + async fetch(request: Request): Promise { const url = new URL(request.url); if (url.pathname === AGENT_SOCKET_INTERNAL_STATUS_PATH && request.method === "GET") { - const auth = request.headers.get("Authorization"); - if (auth !== `Bearer ${this.env.GATEWAY_SECRET}`) { - return new Response(JSON.stringify({ error: "unauthorized" }), { - status: 401, - headers: { "Content-Type": "application/json" }, - }); - } - const sockets = this.ctx.getWebSockets(); - const connected = sockets.length > 0; - return new Response(JSON.stringify({ connected, connectedCount: sockets.length }), { - headers: { "Content-Type": "application/json" }, - }); + return this.handleStatusGet(request); + } + + if (url.pathname === AGENT_SOCKET_INTERNAL_PROXY_PATH && request.method === "POST") { + return this.handleProxyPost(request); } if (request.headers.get("Upgrade") !== "websocket") { @@ -41,14 +123,40 @@ export class AgentSocket extends DurableObject { return new Response(null, { status: 101, webSocket: client }); } - async webSocketMessage(_ws: WebSocket, _message: string | ArrayBuffer): Promise {} + async webSocketMessage(_ws: WebSocket, message: string | ArrayBuffer): Promise { + const text = typeof message === "string" ? message : new TextDecoder().decode(message); + const wr = parseWsResponseJson(text); + if (wr === null) { + return; + } + const entry = this.pending.get(wr.id); + if (entry === undefined) { + return; + } + clearTimeout(entry.timer); + this.pending.delete(wr.id); + entry.resolve(wsResponseToHttp(wr)); + } async webSocketClose( _ws: WebSocket, _code: number, _reason: string, _wasClean: boolean, - ): Promise {} + ): Promise { + this.rejectAllPending("agent websocket closed"); + } - async webSocketError(_ws: WebSocket, _error: unknown): Promise {} + async webSocketError(_ws: WebSocket, _error: unknown): Promise { + this.rejectAllPending("agent websocket error"); + } + + private rejectAllPending(message: string): void { + const entries = [...this.pending.values()]; + this.pending.clear(); + for (const entry of entries) { + clearTimeout(entry.timer); + entry.resolve(jsonResponse(502, { error: message })); + } + } } diff --git a/packages/workflow-gateway/src/index.ts b/packages/workflow-gateway/src/index.ts index ce2f61a..a8d21ca 100644 --- a/packages/workflow-gateway/src/index.ts +++ b/packages/workflow-gateway/src/index.ts @@ -1,7 +1,12 @@ import { Hono } from "hono"; import { cors } from "hono/cors"; -import { AGENT_SOCKET_INTERNAL_STATUS_PATH, AgentSocket } from "./agent-socket.js"; +import { + AGENT_SOCKET_INTERNAL_PROXY_PATH, + AGENT_SOCKET_INTERNAL_STATUS_PATH, + AgentSocket, +} from "./agent-socket.js"; +import type { WsRequest } from "./ws-protocol.js"; export { AgentSocket }; @@ -47,6 +52,97 @@ function isLocalAgentUrl(url: string): boolean { } } +function buildForwardHeaders(raw: Headers, agentToken: string): Record { + const out: Record = {}; + for (const [key, value] of raw) { + const lower = key.toLowerCase(); + if (lower === "host" || lower === "authorization") { + continue; + } + if ( + lower === "connection" || + lower === "keep-alive" || + lower === "proxy-connection" || + lower === "transfer-encoding" || + lower === "upgrade" + ) { + continue; + } + out[key] = value; + } + if (agentToken !== "") { + out["X-Agent-Token"] = agentToken; + } + return out; +} + +function buildDashboardProxyHeaders(raw: Headers, token: string): Headers { + const headers = new Headers(raw); + headers.delete("host"); + headers.delete("Authorization"); + if (token !== "") { + headers.set("X-Agent-Token", token); + } + return headers; +} + +async function readBodyForWsProxy(method: string, req: Request): Promise { + if (method === "GET" || method === "HEAD") { + return null; + } + const buf = await req.arrayBuffer(); + return buf.byteLength === 0 ? null : new TextDecoder().decode(buf); +} + +async function fetchThroughAgentSocket( + bindings: Env["Bindings"], + agent: string, + gateSecret: string, + wsRequest: WsRequest, +): Promise { + const stub = bindings.AGENT_SOCKET.get(bindings.AGENT_SOCKET.idFromName(agent)); + return stub.fetch( + new Request(`https://do.internal${AGENT_SOCKET_INTERNAL_PROXY_PATH}`, { + method: "POST", + headers: { + Authorization: `Bearer ${gateSecret}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(wsRequest), + }), + ); +} + +async function fetchAgentWithRecordHeaders( + targetUrl: string, + method: string, + forwardRecord: Record, + bodyStr: string | null, +): Promise { + const headers = new Headers(); + for (const [k, v] of Object.entries(forwardRecord)) { + headers.set(k, v); + } + return fetch(targetUrl, { + method, + headers, + body: method !== "GET" && method !== "HEAD" ? (bodyStr ?? undefined) : undefined, + }); +} + +async function fetchAgentWithDashboardHeaders( + targetUrl: string, + method: string, + headers: Headers, + rawBody: BodyInit | null | undefined, +): Promise { + return fetch(targetUrl, { + method, + headers, + body: method !== "GET" && method !== "HEAD" ? rawBody : undefined, + }); +} + async function fetchAgentSocketStatus( env: Env["Bindings"], name: string, @@ -181,7 +277,7 @@ gateway.get("/endpoints", async (c) => { app.route("/api/gateway", gateway); -// ── API proxy: /api/agents/:agent/* → agent's tunnel URL (dashboard auth) ── +// ── API proxy: /api/agents/:agent/* → WebSocket (preferred) or agent tunnel URL (dashboard auth) ── app.all("/api/agents/:agent/*", async (c) => { if (!checkDashboardAuth(c)) return c.json({ error: "unauthorized" }, 401); const agent = c.req.param("agent"); @@ -191,26 +287,45 @@ app.all("/api/agents/:agent/*", async (c) => { return c.json({ error: "agent not found" }, 404); } - // Build target URL: strip /api/:agent prefix, forward the rest const url = new URL(c.req.url); const pathAfterAgent = url.pathname.replace(`/api/agents/${agent}`, ""); const targetUrl = `${record.url}/api${pathAfterAgent}${url.search}`; + const proxyPath = `/api${pathAfterAgent}${url.search}`; + const method = c.req.method; + const token = record.agentToken ?? ""; + const forwardRecord = buildForwardHeaders(c.req.raw.headers, token); - const headers = new Headers(c.req.raw.headers); - headers.delete("host"); - headers.delete("Authorization"); // don't forward dashboard key to agent - if (record.agentToken) { - headers.set("X-Agent-Token", record.agentToken); + const doStatus = await fetchAgentSocketStatus(c.env, agent); + if (doStatus.ok && doStatus.connected) { + const bodyStr = await readBodyForWsProxy(method, c.req.raw); + const wsRequest: WsRequest = { + id: crypto.randomUUID(), + method, + path: proxyPath, + headers: forwardRecord, + body: bodyStr, + }; + const proxyResp = await fetchThroughAgentSocket(c.env, agent, c.env.GATEWAY_SECRET, wsRequest); + if (proxyResp.status !== 503) { + return new Response(proxyResp.body, { + status: proxyResp.status, + headers: proxyResp.headers, + }); + } + try { + const resp = await fetchAgentWithRecordHeaders(targetUrl, method, forwardRecord, bodyStr); + return new Response(resp.body, { + status: resp.status, + headers: resp.headers, + }); + } catch (err) { + return c.json({ error: "agent unreachable", detail: String(err) }, 502); + } } + const headers = buildDashboardProxyHeaders(c.req.raw.headers, token); try { - const resp = await fetch(targetUrl, { - method: c.req.method, - headers, - body: c.req.method !== "GET" && c.req.method !== "HEAD" ? c.req.raw.body : undefined, - }); - - // Stream response back + const resp = await fetchAgentWithDashboardHeaders(targetUrl, method, headers, c.req.raw.body); return new Response(resp.body, { status: resp.status, headers: resp.headers, diff --git a/packages/workflow-gateway/src/ws-protocol.ts b/packages/workflow-gateway/src/ws-protocol.ts new file mode 100644 index 0000000..643663d --- /dev/null +++ b/packages/workflow-gateway/src/ws-protocol.ts @@ -0,0 +1,93 @@ +/** Wire format for HTTP-over-WebSocket proxy between gateway Durable Object and local serve. */ + +export type WsRequest = { + id: string; + method: string; + path: string; + headers: Record; + body: string | null; +}; + +export type WsResponse = { + id: string; + status: number; + headers: Record; + body: string; +}; + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +function isNonEmptyString(value: unknown): value is string { + return typeof value === "string" && value.length > 0; +} + +/** Parse and validate a JSON payload as {@link WsRequest}. */ +export function parseWsRequestJson(raw: string): WsRequest | null { + let parsed: unknown; + try { + parsed = JSON.parse(raw) as unknown; + } catch { + return null; + } + if (!isRecord(parsed)) { + return null; + } + const id = parsed.id; + const method = parsed.method; + const path = parsed.path; + const headers = parsed.headers; + const body = parsed.body; + if (!isNonEmptyString(id) || !isNonEmptyString(method) || !isNonEmptyString(path)) { + return null; + } + if (!isRecord(headers)) { + return null; + } + const headerRecord: Record = {}; + for (const [k, v] of Object.entries(headers)) { + if (typeof v !== "string") { + return null; + } + headerRecord[k] = v; + } + if (body !== null && typeof body !== "string") { + return null; + } + return { id, method, path, headers: headerRecord, body: body === null ? null : body }; +} + +/** Parse and validate a JSON payload as {@link WsResponse}. */ +export function parseWsResponseJson(raw: string): WsResponse | null { + let parsed: unknown; + try { + parsed = JSON.parse(raw) as unknown; + } catch { + return null; + } + if (!isRecord(parsed)) { + return null; + } + const id = parsed.id; + const status = parsed.status; + const headers = parsed.headers; + const respBody = parsed.body; + if (!isNonEmptyString(id) || typeof status !== "number" || !Number.isFinite(status)) { + return null; + } + if (!isRecord(headers)) { + return null; + } + const headerRecord: Record = {}; + for (const [k, v] of Object.entries(headers)) { + if (typeof v !== "string") { + return null; + } + headerRecord[k] = v; + } + if (typeof respBody !== "string") { + return null; + } + return { id, status: Math.trunc(status), headers: headerRecord, body: respBody }; +}