diff --git a/http/server.ts b/http/server.ts index 723a345232..400171fc5b 100644 --- a/http/server.ts +++ b/http/server.ts @@ -1,5 +1,5 @@ // Copyright 2018-2019 the Deno authors. All rights reserved. MIT license. -import { listen, Conn, toAsyncIterator, Reader, copy } from "deno"; +import { listen, Conn, toAsyncIterator, Reader, Writer, copy } from "deno"; import { BufReader, BufState, BufWriter } from "../io/bufio.ts"; import { TextProtoReader } from "../textproto/mod.ts"; import { STATUS_TEXT } from "./http_status.ts"; @@ -40,6 +40,7 @@ interface ServeEnv { function serveConn(env: ServeEnv, conn: Conn, bufr?: BufReader) { readRequest(conn, bufr).then(maybeHandleReq.bind(null, env, conn)); } + function maybeHandleReq(env: ServeEnv, conn: Conn, maybeReq: any) { const [req, _err] = maybeReq; if (_err) { @@ -210,70 +211,77 @@ export class ServerRequest { return readAllIterator(this.bodyStream()); } - private async _streamBody(body: Reader, bodyLength: number) { - const n = await copy(this.w, body); - assert(n == bodyLength); - } - - private async _streamChunkedBody(body: Reader) { - const encoder = new TextEncoder(); - - for await (const chunk of toAsyncIterator(body)) { - const start = encoder.encode(`${chunk.byteLength.toString(16)}\r\n`); - const end = encoder.encode("\r\n"); - await this.w.write(start); - await this.w.write(chunk); - await this.w.write(end); - } - - const endChunk = encoder.encode("0\r\n\r\n"); - await this.w.write(endChunk); - } - async respond(r: Response): Promise { - const protoMajor = 1; - const protoMinor = 1; - const statusCode = r.status || 200; - const statusText = STATUS_TEXT.get(statusCode); - if (!statusText) { - throw Error("bad status code"); - } - - let out = `HTTP/${protoMajor}.${protoMinor} ${statusCode} ${statusText}\r\n`; - - setContentLength(r); - - if (r.headers) { - for (const [key, value] of r.headers) { - out += `${key}: ${value}\r\n`; - } - } - out += "\r\n"; - - const header = new TextEncoder().encode(out); - let n = await this.w.write(header); - assert(header.byteLength == n); - - if (r.body) { - if (r.body instanceof Uint8Array) { - n = await this.w.write(r.body); - assert(r.body.byteLength == n); - } else { - if (r.headers.has("content-length")) { - await this._streamBody( - r.body, - parseInt(r.headers.get("content-length")) - ); - } else { - await this._streamChunkedBody(r.body); - } - } - } - - await this.w.flush(); + return writeResponse(this.w, r); } } +function bufWriter(w: Writer): BufWriter { + if (w instanceof BufWriter) { + return w; + } else { + return new BufWriter(w); + } +} + +export async function writeResponse(w: Writer, r: Response): Promise { + const protoMajor = 1; + const protoMinor = 1; + const statusCode = r.status || 200; + const statusText = STATUS_TEXT.get(statusCode); + const writer = bufWriter(w); + if (!statusText) { + throw Error("bad status code"); + } + + let out = `HTTP/${protoMajor}.${protoMinor} ${statusCode} ${statusText}\r\n`; + + setContentLength(r); + + if (r.headers) { + for (const [key, value] of r.headers) { + out += `${key}: ${value}\r\n`; + } + } + out += "\r\n"; + + const header = new TextEncoder().encode(out); + let n = await writer.write(header); + assert(header.byteLength == n); + + if (r.body) { + if (r.body instanceof Uint8Array) { + n = await writer.write(r.body); + assert(r.body.byteLength == n); + } else { + if (r.headers.has("content-length")) { + const bodyLength = parseInt(r.headers.get("content-length")); + const n = await copy(writer, r.body); + assert(n == bodyLength); + } else { + await writeChunkedBody(writer, r.body); + } + } + } + await writer.flush(); +} + +async function writeChunkedBody(w: Writer, r: Reader) { + const writer = bufWriter(w); + const encoder = new TextEncoder(); + + for await (const chunk of toAsyncIterator(r)) { + const start = encoder.encode(`${chunk.byteLength.toString(16)}\r\n`); + const end = encoder.encode("\r\n"); + await writer.write(start); + await writer.write(chunk); + await writer.write(end); + } + + const endChunk = encoder.encode("0\r\n\r\n"); + await writer.write(endChunk); +} + async function readRequest( c: Conn, bufr?: BufReader diff --git a/http/server_test.ts b/http/server_test.ts index 5fdb63ceba..099547d0c7 100644 --- a/http/server_test.ts +++ b/http/server_test.ts @@ -6,14 +6,9 @@ // https://github.com/golang/go/blob/master/src/net/http/responsewrite_test.go import { Buffer } from "deno"; -import { test, assert, assertEqual } from "../testing/mod.ts"; -import { - listenAndServe, - ServerRequest, - setContentLength, - Response -} from "./server.ts"; -import { BufWriter, BufReader } from "../io/bufio.ts"; +import { assertEqual, test } from "../testing/mod.ts"; +import { Response, ServerRequest } from "./server.ts"; +import { BufReader, BufWriter } from "../io/bufio.ts"; interface ResponseTest { response: Response; diff --git a/ws/mod.ts b/ws/mod.ts index ca47bf5b89..6433a75d14 100644 --- a/ws/mod.ts +++ b/ws/mod.ts @@ -1,9 +1,9 @@ // Copyright 2018-2019 the Deno authors. All rights reserved. MIT license. import { Buffer, Writer, Conn } from "deno"; -import { ServerRequest } from "../http/server.ts"; import { BufReader, BufWriter } from "../io/bufio.ts"; import { readLong, readShort, sliceLongToBytes } from "../io/ioutil.ts"; import { Sha1 } from "./sha1.ts"; +import { writeResponse } from "../http/server.ts"; export enum OpCode { Continue = 0x0, @@ -71,6 +71,7 @@ export type WebSocket = { class WebSocketImpl implements WebSocket { encoder = new TextEncoder(); + constructor(private conn: Conn, private mask?: Uint8Array) {} async *receive(): AsyncIterableIterator { @@ -278,19 +279,24 @@ export function unmask(payload: Uint8Array, mask?: Uint8Array) { } } -export function acceptable(req: ServerRequest): boolean { +export function acceptable(req: { headers: Headers }): boolean { return ( req.headers.get("upgrade") === "websocket" && - req.headers.has("sec-websocket-key") + req.headers.has("sec-websocket-key") && + req.headers.get("sec-websocket-key").length > 0 ); } -export async function acceptWebSocket(req: ServerRequest): Promise { +export async function acceptWebSocket(req: { + conn: Conn; + headers: Headers; +}): Promise { + const { conn, headers } = req; if (acceptable(req)) { - const sock = new WebSocketImpl(req.conn); - const secKey = req.headers.get("sec-websocket-key"); + const sock = new WebSocketImpl(conn); + const secKey = headers.get("sec-websocket-key"); const secAccept = createSecAccept(secKey); - await req.respond({ + await writeResponse(conn, { status: 101, headers: new Headers({ Upgrade: "websocket", diff --git a/ws/test.ts b/ws/test.ts index 6a78d9fe07..684c43002b 100644 --- a/ws/test.ts +++ b/ws/test.ts @@ -3,9 +3,14 @@ import "./sha1_test.ts"; import { Buffer } from "deno"; import { BufReader } from "../io/bufio.ts"; -import { test, assert, assertEqual } from "../testing/mod.ts"; -import { createSecAccept, OpCode, readFrame, unmask } from "./mod.ts"; -import { serve } from "../http/server.ts"; +import { assert, assertEqual, test } from "../testing/mod.ts"; +import { + acceptable, + createSecAccept, + OpCode, + readFrame, + unmask +} from "./mod.ts"; test(async function testReadUnmaskedTextFrame() { // unmasked single text frame with payload "Hello" @@ -129,3 +134,29 @@ test(async function testCreateSecAccept() { const d = createSecAccept(nonce); assertEqual(d, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); }); + +test(function testAcceptable() { + const ret = acceptable({ + headers: new Headers({ + upgrade: "websocket", + "sec-websocket-key": "aaa" + }) + }); + assertEqual(ret, true); +}); + +const invalidHeaders = [ + { "sec-websocket-key": "aaa" }, + { upgrade: "websocket" }, + { upgrade: "invalid", "sec-websocket-key": "aaa" }, + { upgrade: "websocket", "sec-websocket-ky": "" } +]; + +test(function testAcceptableInvalid() { + for (const pat of invalidHeaders) { + const ret = acceptable({ + headers: new Headers(pat) + }); + assertEqual(ret, false); + } +});