diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0b8e040 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.eunit +*.beam +deps +ebin +log +TEST-*xml + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e5a790f --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +compile: + ./rebar compile + +eunit: + ./rebar skip_deps=true eunit + +run: + erl -pa ebin -pa deps/lager/ebin -s espdy_test_server + +.PHONY: eunit compile diff --git a/README.md b/README.md new file mode 100644 index 0000000..c22598b --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +#Status + +This project is still very alpha. Enough of the protocol has been implemented to view pages in google chrome. Check out espdy_test_server.erl + +#Running +/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --use-spdy=ssl +make compile +./test_server +http://localhost:8443 + +# {active, false} socket like API +This makes the code much more complicated :( and I'm wondering if it is worth having. If you use a callback module or {active, true} style message sending +then the code becomes much simpler. + +However, {active, false} gives us flow control (with spdy/3) and really simple integration with erlang web frameworks that use tcp like sockets and do recv to process POSTs. + +We also support multiple concurrent send() or recv() operations on the same stream which makes the code a little bit more complicated. I'm not sure how useful this behaviour is because you have to make sure you are send()'ing and recv()'ing atomic blocks for it to be useful otherwise you might interleave messages corrupting the stream. gen_tcp supports this behaviour. i'm not sure if the ssl module supports this behaviour. + +#Issues + +## NPN Support in Erlang + +There is no NPN support for SSL in Erlang. I'm currently working on a patch to fix this. Whether Erlang +would accept a patch for a draft extensions is another question. + +## Write Queueing + +We queue up writes in a write process. If the reader at the other end is not reading data fast enough then our write process will start consuming lots of memory. We need to kill off connections that are misbehaving. send() calls block until the write has returned so there is _some_ flow control on the send() side. once the os buffer has filled up send() will start blocking. however, we will send data in response to PING and protocol errors and it is possible a misbehaving client could trick us into allocating lots of memory for the write queue. + +## Read Buffering + +There is no flow control in spdy draft 2. When we receive data we add it to a buffer. We keep adding data to the buffer. If we receive data faster than we are recv()'ing it then we can start chewing up memory. The plan is to add a max_recv_buffer setting and to stop reading from the socket if this limit is met. + +## Goaway Implementation + +We don't correctly send {error, closed_not_processed} in a lot of places. This hasn't been implemented in send/recv. Also, we kill the gen_server too early and a client may receive {error, closed} because the gen_server is not around instead of {error, closed_not_processed}. + +## No settings + +Have not implemented sending or receiving of the settings frame. + +## Partial implementation of headers + +No support for receiving header frames. I'm not sure how this would look in the api :( + +## StreamId limit + +No support for stopping connections when the StreamId limit has been reached + +## Not checking version + +We are currently not checking the version field on control or data frames :( + +## No backlimit for accepts + +We continue to add connections to the accept buffer without bound. + +## Communicating errors back to the client + +If a client isn't send'ing or recv'ing on a stream then they won't receive error notifications about previous failed sends. I think this is how unix +tcp works so i don't think it is that terrible :). However, the spdy http protocol is the client writes to the server and then half-closes, then the +server writes to the client and then half-closes. There is no server waiting for an ack from the client. If the servers response to the client is not +received it will not be aware of this. I'm not sure how big of a problem this is. Generally http servers don't care if they fail to deliver a response +to the client. + +## No chunking of sends + +If you do a large send() then it will block all the other channels using the spdy socket. We should probably chunk sends() to a reasonable value. + +# Code duplication + +There is some code duplication between espdy_server and espdy_frame for control frame decoding + diff --git a/include/espdy_frame.hrl b/include/espdy_frame.hrl new file mode 100644 index 0000000..c4a6616 --- /dev/null +++ b/include/espdy_frame.hrl @@ -0,0 +1,53 @@ +-define(BYTE(X), (X):8/unsigned-big-integer). +-define(UINT16(X), (X):16/unsigned-big-integer). +-define(UINT24(X), (X):24/unsigned-big-integer). +-define(UINT32(X), (X):32/unsigned-big-integer). +-define(UINT64(X), (X):64/unsigned-big-integer). + +-define(NOOP, 5). +-define(PING, 6). +-define(SYN_STREAM, 1). +-define(SYN_REPLY, 2). +-define(RST_STREAM, 3). +-define(HEADERS, 8). +-define(GOAWAY, 7). + +-define(SPDY_VERSION, 2). +-define(CONTROL_FRAME, 1:1/unsigned-big-integer). +-define(DATA_FRAME, 0:1/unsigned-big-integer). +-define(TYPE(X), (X):16/unsigned-big-integer). +-define(VERSION(X), (X):15/unsigned-big-integer). +-define(LENGTH(X), ?UINT24(X)). +-define(FLAGS(X), ?BYTE(X)). +-define(RESERVED_BIT, _:1). +-define(MAKE_RESERVED_BIT, 0:1). +-define(STREAM_ID(X), (X):31/unsigned-big-integer). +-define(PRIORITY(X), (X):2/unsigned-big-integer). +-define(UNUSED(X), _:(X)). +-define(MAKE_UNUSED(X), 0:(X)). + +-define(UNUSED_SYN_REPLY, 16). +-define(UNUSED_SYN_STREAM, 14). +-define(UNUSED_HEADERS, 16). + +-define(FLAG_FIN, 1). +-define(FLAG_UNIDIRECTIONAL, 2). + +%% RST_STREAM STATUSES +-define(PROTOCOL_ERROR, 1). +-define(INVALID_STREAM, 2). +-define(REFUSED_STREAM, 3). +-define(UNSUPPORTED_VERSION, 4). +-define(CANCEL, 5). +-define(FLOW_CONTROL_ERROR, 6). +-define(STREAM_IN_USE, 7). +-define(STREAM_ALREADY_CLOSED, 8). + +-record(control_frame, {version, type, flags, data}). +-record(rst_stream, {stream_id, status}). +-record(headers, {headers, stream_id}). +-record(data_frame, {stream_id, flags, data}). +-record(syn_reply, {stream_id, flags, headers}). +-record(syn_stream, {stream_id, flags, headers, associated_stream_id}). +-record(goaway, {last_good_stream_id}). +-record(noop, {}). diff --git a/include/espdy_server.hrl b/include/espdy_server.hrl new file mode 100644 index 0000000..b7d7643 --- /dev/null +++ b/include/espdy_server.hrl @@ -0,0 +1 @@ +-record(espdy_socket, {pid, stream_id}). diff --git a/rebar b/rebar new file mode 100755 index 0000000..c3e01cc Binary files /dev/null and b/rebar differ diff --git a/rebar.config b/rebar.config new file mode 100644 index 0000000..5ca0a58 --- /dev/null +++ b/rebar.config @@ -0,0 +1,6 @@ +{deps, [ + {lager, "0.9.*", {git, "git://github.com/basho/lager", {branch, "master"}}} +]}. +{cover_enabled, true}. +{eunit_opts, [verbose, + {report, {eunit_surefire, [{dir, "."}]}}]}. diff --git a/server.crt b/server.crt new file mode 100644 index 0000000..ee391df --- /dev/null +++ b/server.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICKTCCAZICCQC/DgnBN+wo7jANBgkqhkiG9w0BAQUFADBZMQswCQYDVQQGEwJB +VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTExMDA2MjM1NjEzWhcN +MTIxMDA1MjM1NjEzWjBZMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0 +ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRIwEAYDVQQDDAls +b2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKZOG7/aDqeEVr+R +73/SZs9Xzw+XYTPkPNykPe5/7pL4LDuNc0seFnthMI8/qBAZ14pBnR7pIx1s6eqy +hSH+yxghc9T6yot+cM71EsVydPaN0IbjJONwR4RMStiydSLsVECAaDXt4bSwDuCf +qoGsx5q4bBwv7SR8gy0Wqx3+PH6/AgMBAAEwDQYJKoZIhvcNAQEFBQADgYEAYtrw +KlJT23M50BOfC6KEvl9so0OzkyaCoLVPcM1UcNi5AfjMuXcxf2foFn0SmorGnJQ6 +FkXXwu/nF6R/ngDDgrLBv7vVLQY/JfU1+eUjblkfdpeykwpK+bd1N5UzdYM6LX5Z +2IMU4eYaBkaISkeA7IAS09lRyDSs9MsxgvfGCjk= +-----END CERTIFICATE----- diff --git a/server.key b/server.key new file mode 100644 index 0000000..687211a --- /dev/null +++ b/server.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCmThu/2g6nhFa/ke9/0mbPV88Pl2Ez5DzcpD3uf+6S+Cw7jXNL +HhZ7YTCPP6gQGdeKQZ0e6SMdbOnqsoUh/ssYIXPU+sqLfnDO9RLFcnT2jdCG4yTj +cEeETErYsnUi7FRAgGg17eG0sA7gn6qBrMeauGwcL+0kfIMtFqsd/jx+vwIDAQAB +AoGAa+ceBhrbi0FIb7+mX48KedmFVZ5oyRx7iMVgEZEcIGu5d2JNvh1lhOQId8lb +qCa8PM5ZvaaSzBLQqyDtVKNW7eNt8z1cCUUCXy74dfUAqT0gk13cMs+InCqNWheU +DaFttpgerPUHW4puaNoKJr+UKDgkEQqgyla/ixBHFNfW6dECQQDRa+cDf00g9aB0 +Kq6Yurf/KIf3LsJJHYizYKucirQdWhIo87mD9tez/UK9iTCjC2SM+sIklGYiW2ns +ZfxVJtYTAkEAy0s71pJBsbV7b7PQf8p2P1ICX/QN7/wALYmd6iJqWC+599GzF8j6 +Fq3pNUED6wy61W715YcFziIdyRNcKUv6JQJABQUkJYZQsACTVxWK1+hp7rjnAXri +d2Q42avwkTEV/johg0/MW6h4JT1l8ystukrUnziHnN7dz+cHE/6h3NywdwJAdGE3 +elj0PtXUOlhITkALfahnL6M5r18mgus7eeQF2UJJRjPIQR+O/BjHXvM/WPpKoxEI +uEOZ8S3au1fX9NRH+QJBAJKPv47Nr3Old03Imdwu7JPjwwXtm6gJb6Id4MRDiP6k +bUPOw1GtM79o2xjc33+5e1ekaMAP5+d6+04vAlZ4Ijg= +-----END RSA PRIVATE KEY----- diff --git a/src/espdy.app.src b/src/espdy.app.src new file mode 100644 index 0000000..0e5b7ec --- /dev/null +++ b/src/espdy.app.src @@ -0,0 +1,8 @@ +{application, espdy, + [{description, "SPDY Server"}, + {vsn, "0.0.1"}, + {modules, []}, + {registered, []}, + {applications, [kernel, stdlib]}, + {build_dependencies, []}, + {env, []}]}. diff --git a/src/espdy_frame.erl b/src/espdy_frame.erl new file mode 100644 index 0000000..6eee44c --- /dev/null +++ b/src/espdy_frame.erl @@ -0,0 +1,308 @@ +-module(espdy_frame). +-include("espdy_frame.hrl"). +-compile([{parse_transform, lager_transform}]). + + +-export([decode_name_value_header_block/1, encode_name_value_header_block/1, + decode_name_value_header_block/2, encode_name_value_header_block/2, + initialize_zlib_for_deflate/0, initialize_zlib_for_inflate/0,dictionary/0, + read_frame/2, encode_ping/1, + encode_syn_stream/6, + encode_syn_stream/5, + encode_syn_stream_with_raw_uncompressed_name_value_pairs/5, + encode_data_frame/3, + encode_rst_stream/2, + encode_headers/3, + encode_headers_with_raw_uncompressed_name_value_pairs/3, + encode_headers_with_raw_uncompressed_name_value_pairs/4, + encode_headers/4, + encode_syn_reply/3, + encode_syn_reply/4, + encode_noop/0, + encode_goaway/1, + decode_exactly_one_frame/1, + decode_control_frame/1, + decode_control_frame/2, + encode_control_frame/3, + decode_frame/1]). + +read_frame(Buffer, <<>>) -> + decode_frame(Buffer); +read_frame(<<>>, Data) -> + decode_frame(Data); +read_frame(Buffer, Data) -> + decode_frame(<>). + +decode_frame(<>) -> + {#control_frame{version = Version, type = Type, flags = Flags, data = Data}, Rest}; +decode_frame(<>) -> + {#data_frame{stream_id = StreamID, flags = Flags, data = Data}, Rest}; +decode_frame(Data) -> + Data. + +decode_exactly_one_frame(Data) -> + {Frame, <<>>} = decode_frame(Data), + Frame. + + +decode_control_frame(Frame) -> + ZLib = initialize_zlib_for_inflate(), + ControlFrame = decode_control_frame(ZLib, Frame), + zlib:close(ZLib), + ControlFrame. + + +decode_control_frame(_ZLib, #control_frame{type = ?GOAWAY, data = + <>}) -> + #goaway{last_good_stream_id = LastGoodStreamId}; + +decode_control_frame(_ZLib, #control_frame{type = ?RST_STREAM, data = + <>}) -> + #rst_stream{stream_id = StreamId, status = StatusCode}; + + +decode_control_frame(_ZLib, #control_frame{type = ?NOOP, data = <<>>}) -> + #noop{}; + +decode_control_frame(ZLib, #control_frame{type = ?HEADERS, data = + <>}) -> + Headers = decode_name_value_header_block(ZLib, NameValueHeaderBlock), + #headers{stream_id = StreamId, headers = Headers}; + +decode_control_frame(ZLib, #control_frame{flags = Flags, type = ?SYN_REPLY, data = + <>}) -> + Headers = decode_name_value_header_block(ZLib, NameValueHeaderBlock), + #syn_reply{flags = Flags, stream_id = StreamId, headers = Headers}; + +decode_control_frame(ZLib, #control_frame{flags = Flags, type = ?SYN_STREAM, data = + <>}) -> + Headers = decode_name_value_header_block(ZLib, NameValueHeaderBlock), + #syn_stream{flags = Flags, stream_id = StreamId, headers = Headers, associated_stream_id = AssociatedToStreamId}. + +validate_valid_values(Pairs) -> + fold_error(fun validate_value/1, Pairs). + +validate_value({_Name, <<>>}) -> + {error, empty_value}; +validate_value({_Name, <<0, _Rest/binary>>}) -> + {error, value_starts_with_null}; +validate_value({_Name, Value}) -> + case binary:last(Value) =:= 0 of + true -> {error, value_ends_with_null}; + _ -> + case binary:match(Value, <<0,0>>) of + nomatch -> ok; + _ -> {error, empty_value} + end + end. + + +validate_no_dups(Pairs) -> + SortedPairs = lists:ukeysort(1, Pairs), + case length(SortedPairs) =:= length(Pairs) of + true -> ok; + false -> {error, duplicate_header_names} + end. + +expand_headers(Pairs) -> + lists:map(fun expand_header/1, Pairs). + +expand_header({Name, <<>>}) -> + {Name, []}; +expand_header({Name, Value}) -> + Values = binary:split(Value, <<0>>), + {Name, Values}. + +chain_validations([], _Data) -> + ok; +chain_validations([F|Rest], Data) -> + case F(Data) of + ok -> chain_validations(Rest, Data); + Error -> Error + end. + +fold_error(_Fun, []) -> + ok; +fold_error(Fun, [H|Rest]) -> + case Fun(H) of + ok -> fold_error(Fun, Rest); + Error -> Error + end. + +validate_headers_names_valid(Pairs) -> + fold_error(fun validate_headers_name_valid/1, Pairs). + +validate_headers_name_valid({<<>>, _Value}) -> + {error, empty_header_name}; +validate_headers_name_valid(_) -> + ok. + +make_headers(Pairs) -> + case chain_validations([fun validate_no_dups/1, fun validate_valid_values/1, fun validate_headers_names_valid/1], Pairs) of + ok -> + expand_headers(Pairs); + Error -> Error + end. + +decode_name_value_header_block(ZLib, Binary) -> + Iodata = + try + zlib:inflate(ZLib, Binary) + catch + error : {need_dictionary,3751956914} -> + zlib:inflateSetDictionary(ZLib, dictionary()), + zlib:inflate(ZLib, []) + end, + decode_name_value_header_block(list_to_binary(Iodata)). + +encode_name_value_header_block(ZLib, NameValuePairs) -> + Binary = encode_name_value_header_block(NameValuePairs), + zlib:deflate(ZLib, Binary, sync). + +encode_name_value_header_block(NameValuesPairs) -> + lists:foldl(fun (E, Acc) -> <> end, + <>, + NameValuesPairs). + +encode_name_values({Name, Values}) -> + EncodedValues = encode_values(Values), + <>. + +encode_values(Values) -> + lists:foldl( + fun (<<>>, _) -> + throw(empty_header_value); + (E, <<>>) -> + <>; + (E, Acc) -> + <> + end, + <<>>, Values). + +decode_name_value_header_block(<>) -> + make_headers(decode_name_value(Repeats, NumberOfNameValuePairs, [])). + +decode_name_value(_Binary, 0, Acc) -> + Acc; +decode_name_value(Binary0, N, Acc) -> + {Value, Binary} = decode_name_value(Binary0), + decode_name_value(Binary, N - 1, [Value | Acc]). + +decode_name_value(<>) -> + {{Name, Value}, Rest}. + +dictionary() -> + <<"optionsgetheadpostputdeletetraceacceptaccept-charsetaccept-encodingaccept-" + "languageauthorizationexpectfromhostif-modified-sinceif-matchif-none-matchi" + "f-rangeif-unmodifiedsincemax-forwardsproxy-authorizationrangerefererteuser" + "-agent10010120020120220320420520630030130230330430530630740040140240340440" + "5406407408409410411412413414415416417500501502503504505accept-rangesageeta" + "glocationproxy-authenticatepublicretry-afterservervarywarningwww-authentic" + "ateallowcontent-basecontent-encodingcache-controlconnectiondatetrailertran" + "sfer-encodingupgradeviawarningcontent-languagecontent-lengthcontent-locati" + "oncontent-md5content-rangecontent-typeetagexpireslast-modifiedset-cookieMo" + "ndayTuesdayWednesdayThursdayFridaySaturdaySundayJanFebMarAprMayJunJulAugSe" + "pOctNovDecchunkedtext/htmlimage/pngimage/jpgimage/gifapplication/xmlapplic" + "ation/xhtmltext/plainpublicmax-agecharset=iso-8859-1utf-8gzipdeflateHTTP/1" + ".1statusversionurl", 0>>. + + +initialize_zlib_for_deflate() -> + Z = zlib:open(), + zlib:deflateInit(Z, default, deflated, 15, 8, default), + zlib:deflateSetDictionary(Z, dictionary()), + Z. + +initialize_zlib_for_inflate() -> + Z = zlib:open(), + zlib:inflateInit(Z, 15), + Z. + +encode_data_frame(StreamId, Flags, Data) -> + <>. + +encode_control_frame(Type, Flags, Data) -> + <>. + +encode_noop() -> + encode_control_frame(?NOOP, 0, <<>>). + +encode_goaway(LastStreamId) -> + encode_control_frame(?GOAWAY, 0, <>). + +encode_ping(ID) -> + encode_control_frame(?PING, 0, <>). + +encode_rst_stream(StreamId, StatusCode) -> + encode_control_frame(?RST_STREAM, 0, + <<0:1, ?STREAM_ID(StreamId), + ?UINT32(StatusCode)>>). + +encode_syn_stream_with_raw_uncompressed_name_value_pairs(Flags, StreamId, AssociatedToStreamId, Priority, RawBlock) -> + ZLib = initialize_zlib_for_deflate(), + NameValueHeaderBlock = list_to_binary(zlib:deflate(ZLib, RawBlock, sync)), + encode_syn_stream_with_raw_name_value_pairs(Flags, StreamId, AssociatedToStreamId, Priority, NameValueHeaderBlock). + +encode_syn_stream_with_raw_name_value_pairs(Flags, StreamId, AssociatedToStreamId, Priority, NameValueHeaderBlock) -> + encode_control_frame(?SYN_STREAM, Flags, + <>). + +encode_syn_stream(Flags, StreamId, AssociatedToStreamId, Priority, NameValuePairs) -> + ZLib = initialize_zlib_for_deflate(), + Frame = encode_syn_stream(ZLib, Flags, StreamId, AssociatedToStreamId, Priority, NameValuePairs), + zlib:close(ZLib), + Frame. + +encode_syn_stream(ZLib, Flags, StreamId, AssociatedToStreamId, Priority, NameValuePairs) -> + NameValueHeaderBlock = list_to_binary(encode_name_value_header_block(ZLib, NameValuePairs)), + encode_syn_stream_with_raw_name_value_pairs(Flags, StreamId, AssociatedToStreamId, Priority, NameValueHeaderBlock). + +encode_syn_reply(StreamId, Flags, Headers) -> + ZLib = initialize_zlib_for_deflate(), + encode_syn_reply(StreamId, Flags, Headers), + zlib:close(ZLib). + +encode_syn_reply(ZLib, StreamId, Flags, Headers) -> + NameValueHeaderBlock = list_to_binary(encode_name_value_header_block(ZLib, Headers)), + Frame = << + ?MAKE_RESERVED_BIT, ?STREAM_ID(StreamId), + ?MAKE_UNUSED(?UNUSED_SYN_REPLY), + NameValueHeaderBlock/binary>>, + encode_control_frame(?SYN_REPLY, Flags, Frame). + +encode_headers(StreamId, Flags, Headers) -> + ZLib = initialize_zlib_for_deflate(), + Frame = encode_headers(ZLib, StreamId, Flags, Headers), + zlib:close(ZLib), + Frame. + +encode_headers_with_raw_uncompressed_name_value_pairs(ZLib, StreamId, Flags, Headers) -> + encode_headers_with_compressed_name_value_header_block(StreamId, Flags, list_to_binary(zlib:deflate(ZLib, Headers, sync))). + +encode_headers_with_raw_uncompressed_name_value_pairs(StreamId, Flags, Headers) -> + ZLib = initialize_zlib_for_deflate(), + Frame = encode_headers_with_raw_uncompressed_name_value_pairs(StreamId, Flags, Headers), + zlib:close(ZLib), + Frame. + +encode_headers(ZLib, StreamId, Flags, Headers) -> + NameValueHeaderBlock = list_to_binary(encode_name_value_header_block(ZLib, Headers)), + encode_headers_with_compressed_name_value_header_block(StreamId, Flags, NameValueHeaderBlock). + +encode_headers_with_compressed_name_value_header_block(StreamId, Flags, NameValueHeaderBlock) -> + Frame = << + ?MAKE_RESERVED_BIT, ?STREAM_ID(StreamId), + ?MAKE_UNUSED(?UNUSED_HEADERS), + NameValueHeaderBlock/binary>>, + encode_control_frame(?HEADERS, Flags, Frame). \ No newline at end of file diff --git a/src/espdy_mock_socket.erl b/src/espdy_mock_socket.erl new file mode 100644 index 0000000..58c6fd9 --- /dev/null +++ b/src/espdy_mock_socket.erl @@ -0,0 +1,30 @@ +-module(espdy_mock_socket, [PacketReceiver]). + +-export([send/1, close/0, shutdown/0, setopts/1, controlling_process/1, data_tag/0, error_tag/0, close_tag/0]). + +close_tag() -> + tcp_closed. + +error_tag() -> + tcp_error. + +data_tag() -> + tcp. + +send(Packet) -> + PacketReceiver ! {packet, self(), Packet}, + receive + {mock_packet_result, Result} -> Result + end. + +close() -> + ok. + +setopts(Options) -> + ok. + +shutdown() -> + ok. + +controlling_process(Pid) -> + ok. \ No newline at end of file diff --git a/src/espdy_server.erl b/src/espdy_server.erl new file mode 100644 index 0000000..17dc2d2 --- /dev/null +++ b/src/espdy_server.erl @@ -0,0 +1,799 @@ +-module(espdy_server). +-compile([{parse_transform, lager_transform}]). +-behaviour(gen_server). + +-include("espdy_frame.hrl"). +-include("espdy_server.hrl"). + +-export([start_link/2]). +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3, socket_writer/1, graceful_close/1]). + + +-export([accept/1, + connect/1, + connect/3, + connect_for_push/3, + send_headers/2, + send/2, + close/1, + send_and_close/2, + recv/3, recv/2, + statistics/1, + stream_id/1]). + +% TESTING API ONLY. NOT SAFE TO USE +-export([async_accept/1, async_send_headers/2, async_send/2, async_close/1, async_send_and_close/2, + async_recv/3, async_recv/2]). + +make_socket(Pid, StreamId) -> + #espdy_socket{pid = Pid, stream_id = StreamId}. + +accept(Pid) -> + case send_call(Pid, accept) of + {ok, StreamId, Headers} -> {ok, make_socket(Pid, StreamId), Headers}; + Error -> Error + end. + +connect(Pid) -> + connect(Pid, false, []). + +connect_for_push(#espdy_socket{pid = Pid, stream_id = StreamId}, SendFin, Headers) -> + connect(Pid, SendFin, Headers, StreamId). + +connect(Pid, SendFin, Headers) -> + connect(Pid, SendFin, Headers, 0). + +connect(Pid, SendFin, Headers, AssociatedStreamId) -> + case send_call(Pid, {connect, Headers, SendFin, AssociatedStreamId}) of + {ok, StreamId} -> {ok, make_socket(Pid, StreamId)}; + Error -> Error + end. + +graceful_close(Pid) -> + gen_server:cast(Pid, graceful_close). + +send_headers(#espdy_socket{pid = Pid, stream_id = StreamId}, Headers) -> + send_call(Pid, {send_headers, StreamId, Headers}). + +send(#espdy_socket{pid = Pid, stream_id = StreamId}, Data) -> + send_call(Pid, {send, StreamId, Data}). + +send_and_close(#espdy_socket{pid = Pid, stream_id = StreamId}, Data) -> + send_call(Pid, {send_and_close, StreamId, Data}). + +close(#espdy_socket{pid = Pid, stream_id = StreamId}) -> + send_call(Pid, {close, StreamId}). + +recv(#espdy_socket{pid = Pid, stream_id = StreamId}, Length, Timeout) -> + send_call(Pid, {recv, StreamId, Length, Timeout}). + + + +recv(Socket, Length) -> + recv(Socket, Length, infinity). + +statistics(Pid) -> + send_call(Pid, statistics). + +stream_id(#espdy_socket{stream_id = StreamId}) -> + StreamId. + +async_accept(Pid) -> + send_async(Pid, accept). + +async_send(#espdy_socket{pid = Pid, stream_id = StreamId}, Data) -> + send_async(Pid, {send, StreamId, Data}). + +async_recv(#espdy_socket{pid = Pid, stream_id = StreamId}, Length, Timeout) -> + send_async(Pid, {recv, StreamId, Length, Timeout}). + +async_recv(Socket, Length) -> + async_recv(Socket, Length, infinity). + +async_send_headers(#espdy_socket{pid = Pid, stream_id = StreamId}, Headers) -> + send_async(Pid, {send_headers, StreamId, Headers}). + +async_close(#espdy_socket{pid = Pid, stream_id = StreamId}) -> + send_async(Pid, {close, StreamId}). + +async_send_and_close(#espdy_socket{pid = Pid, stream_id = StreamId}, Data) -> + send_async(Pid, {send_and_close, StreamId, Data}). + +send_async(Pid, Msg) -> + Ref = make_ref(), + Pid ! {'$gen_call', {self(), Ref}, Msg}, + Ref. + +send_call(Pid, Request) -> + try + gen_server:call(Pid, Request, infinity) + catch + % not sure if this is too broad... + exit:_ -> + {error, closed} + end. + +-record(state, { + last_stream_accepted = 0, + next_syn_stream_id, + closing = false, + other_side_closing = false, + data_tag, + close_tag, + role, + socket, + socket_writer = undefined, + zlib_read_nv_context = undefined, + zlib_write_nv_context = undefined, + read_buffer = <<>>, + streams = dict:new(), + accept_buffer = [], + waiting_acceptors = []}). + +-record(stream, { + + read_open = true, + write_closing = false, + write_open = true, + read_buffer = <<>>, + recv_waiters = [], + send_waiters = [], + send_header_waiters = [], + syn_reply_sent = false}). + +stream_to_proplist(#stream{} = Rec) -> + lists:zip(record_info(fields, stream), tl(tuple_to_list(Rec))). + +debug_streams(#state{streams = Streams}) -> + lists:map(fun({N, S}) -> {N, stream_to_proplist(S) } end, dict:to_list(Streams)). + +start_link(Socket, Role) -> + case gen_server:start_link(espdy_server, [Socket, Role], []) of + {ok, Pid} = Result -> + ok = Socket:controlling_process(Pid), + ok = gen_server:call(Pid, start), + Result; + Error -> Error + end. + +initial_stream_id(server) -> + 2; +initial_stream_id(client) -> + 1. + +init([Socket, Role]) -> + {ok, #state{ + next_syn_stream_id = initial_stream_id(Role), + data_tag = Socket:data_tag(), + close_tag = Socket:close_tag(), + socket = Socket, + socket_writer = spawn_link(?MODULE, socket_writer, [Socket]), + role = Role, + zlib_read_nv_context = espdy_frame:initialize_zlib_for_inflate(), + zlib_write_nv_context = espdy_frame:initialize_zlib_for_deflate() + }}. + +socket_writer(Socket) -> + receive + {send, Data, Callback} -> + lager:debug("Sending Data ~p", [Data]), + Callback(Socket:send(Data)), + socket_writer(Socket); + _ -> + socket_writer(Socket) + end. + +socket_write(State, Data) -> + socket_write(State, Data, fun(_) -> ok end). + +socket_write(State, Data, CallbackFun) -> + (State#state.socket_writer) ! {send, Data, CallbackFun}. + +split_at_or_at_end(N, Binary) when byte_size(Binary) > N -> + split_at(N, Binary); +split_at_or_at_end(_N, Binary) -> + {Binary, <<>>}. + +split_at(N, Binary) -> + case Binary of + <> -> {Bytes, Rest} + end. + +handle_receive_need_more_data(_StreamId, #stream{read_open = false}, _Length, _From, State) -> + {reply, {error, closed}, State}; +handle_receive_need_more_data(StreamId, StreamRecord, Length, From, State) -> + RecvWaiters = StreamRecord#stream.recv_waiters, + NewStreams = dict:store(StreamId, StreamRecord#stream{recv_waiters = [{From, Length} | RecvWaiters]}, State#state.streams), + {noreply, State#state{streams = NewStreams}}. + +handle_receive(StreamId, StreamRecord, Length, From, State) -> + ReadBuffer = StreamRecord#stream.read_buffer, + case byte_size(ReadBuffer) >= Length of + true -> + {First, Rest} = split_at(Length, ReadBuffer), + NewStreams = dict:store(StreamId, StreamRecord#stream{read_buffer = Rest}, State#state.streams), + {reply, First, State#state{streams = NewStreams}}; + _ -> + handle_receive_need_more_data(StreamId, StreamRecord, Length, From, State) + end. + +write_headers(State, StreamId, Stream, Data, From) -> + Self = self(), + + WriteFun = fun(Result) -> + case Result of + ok -> + Self ! {header_finished, {ok, StreamId}}; + Reason -> + Self ! {write_error, Reason} + end + end, + + socket_write(State, Data, WriteFun), + Stream#stream{syn_reply_sent = true, send_header_waiters = Stream#stream.send_header_waiters ++ [From]}. + + + + +handle_send_headers(State, StreamId, #stream{syn_reply_sent = true} = StreamRecord, Headers, From) -> + Packet = espdy_frame:encode_headers(State#state.zlib_write_nv_context, StreamId, 0, Headers), + write_headers(State, StreamId, StreamRecord, Packet, From); + +handle_send_headers(State, StreamId, #stream{syn_reply_sent = false} = StreamRecord, Headers, From) -> + SynReply = espdy_frame:encode_syn_reply(State#state.zlib_write_nv_context, StreamId, 0, Headers), + write_headers(State, StreamId, StreamRecord, SynReply, From). + +ensure_syn_reply_sent(_State, _StreamId, #stream{syn_reply_sent = true} = Stream) -> + Stream; +ensure_syn_reply_sent(State, StreamId, Stream) -> + Data = espdy_frame:encode_syn_reply(State#state.zlib_write_nv_context, StreamId, 0, []), + WriteFun = fun(_Result) -> + ok % TODO: handle write error + end, + + socket_write(State, Data, WriteFun), + Stream#stream{syn_reply_sent = true}. + +write_flags(true) -> + ?FLAG_FIN; +write_flags(_) -> + 0. + +update_write_closing_flag(_Stream, true) -> + true; +update_write_closing_flag(#stream{write_closing = Closing}, _Fin) -> + Closing. + +perform_write(State, StreamId, Stream0, Buffer, Fin, From) -> + Stream = ensure_syn_reply_sent(State, StreamId, Stream0), + perform_write_prime(State, StreamId, Stream, Buffer, Fin, From). + +perform_write_prime(State, StreamId, Stream, BufferToSend, Fin, From) -> + Self = self(), + Data = espdy_frame:encode_data_frame(StreamId, write_flags(Fin), BufferToSend), + + WriteFun = fun(Result) -> + + case Result of + ok -> + Self ! {write_finished, {ok, StreamId, Fin, true}}; + Reason -> + Self ! {write_error, Reason} + end + end, + + socket_write(State, Data, WriteFun), + Stream#stream{write_closing = update_write_closing_flag(Stream, Fin), send_waiters = Stream#stream.send_waiters ++ [From]}. + + +handle_send(StreamId, StreamRecord, Packet, Fin, From, State) -> + StreamRecord1 = perform_write(State, StreamId, StreamRecord, Packet, Fin, From), + update_stream(StreamId, StreamRecord1, State). + +update_stream(StreamId, StreamRecord, State) -> + NewStreams = dict:store(StreamId, StreamRecord, State#state.streams), + State#state{streams = NewStreams}. + +handle_accept(_From, #state{closing = true} = State) -> + {reply, {error, closed}, State}; +handle_accept(_From, #state{accept_buffer = [{StreamId, Headers} | Rest]} = State) -> + {reply, {ok, StreamId, Headers}, State#state{accept_buffer = Rest}}; +handle_accept(From, State) -> + {noreply, State#state{waiting_acceptors = State#state.waiting_acceptors ++ [From]}}. + +send_fin(StreamId, Stream, State) -> + Data = espdy_frame:encode_data_frame(StreamId, ?FLAG_FIN, <<>>), + Self = self(), + WriteFun = fun(Result) -> + + case Result of + ok -> + Self ! {write_finished, {ok, StreamId, true, false}}; + Reason -> + Self ! {write_error, Reason} + end + end, + + socket_write(State, Data, WriteFun), + + Stream#stream{write_closing = true}. + +handle_close(StreamId, Stream, State) -> + send_fin(StreamId, Stream, State). + +syn_flags(WithFin, StreamId) -> + syn_flags_fin(WithFin) bor syn_flags_uni(StreamId). + +syn_flags_fin(true) -> + ?FLAG_FIN; +syn_flags_fin(_WithFin) -> + 0. + +syn_flags_uni(0) -> + 0; +syn_flags_uni(_) -> + ?FLAG_UNIDIRECTIONAL. + +handle_connect(Headers, WithFin, AssociatedStreamId, _From, State) -> + StreamId = State#state.next_syn_stream_id, + + Frame = espdy_frame:encode_syn_stream(State#state.zlib_write_nv_context, syn_flags(WithFin, AssociatedStreamId), StreamId, AssociatedStreamId, 0, Headers), + lager:debug("Sending connect frame StreamId: ~p AssociatedStreamId: ~p SynFlags: ~p Headers: ~p", [StreamId, AssociatedStreamId, syn_flags(WithFin, AssociatedStreamId), Headers]), + + socket_write(State, Frame), + % TODO: FIX fully closed streams here + State1 = update_stream(StreamId, #stream{syn_reply_sent = true, write_open = (not WithFin), read_open = (AssociatedStreamId =:= 0)}, State), + {reply, {ok, StreamId}, State1#state{next_syn_stream_id = StreamId + 2}}. + + +run_with_stream(StreamId, State, Function) -> + case dict:find(StreamId, State#state.streams) of + error -> + lager:debug("Stream not found ~p ~p", [StreamId, debug_streams(State)]), + {reply, {error, closed}, State}; + {ok, StreamRecord} -> + Function(StreamRecord) + end. + +is_write_open(#stream{write_closing = false, write_open = true}) -> + true; +is_write_open(_) -> + false. + +run_with_stream_write_open(StreamId, State, Function) -> + run_with_stream(StreamId, State, + fun(Stream) -> + case is_write_open(Stream) of + true -> Function(Stream); + _ -> {reply, {error, closed}, State} + end + end). + + + +is_closing(#state{closing = true}) -> + true; +is_closing(#state{other_side_closing = true}) -> + true; +is_closing(_) -> + false. + +handle_call({connect, Headers, WithFin, AssociatedStreamId}, From, State) -> + case is_closing(State) of + true -> + {reply, {error, closed_not_processed}, State}; + _ -> + handle_connect(Headers, WithFin, AssociatedStreamId, From, State) + end; +handle_call({close, StreamId}, _From, State) -> + run_with_stream(StreamId, State, fun(StreamRecord) -> + NewStream = handle_close(StreamId, StreamRecord, State), + {reply, ok, update_stream(StreamId, NewStream, State)} + end); + +handle_call(statistics, _From, State) -> + lager:debug("Stream status for statistics ~p", [debug_streams(State)]), + {reply, get_statistics(State), State}; + +handle_call(accept, From, State) -> + handle_accept(From, State); + +handle_call({send_headers, StreamId, Headers}, From, State) -> + run_with_stream_write_open(StreamId, State, fun(StreamRecord) -> + NewStream = handle_send_headers(State, StreamId, StreamRecord, Headers, From), + {noreply, update_stream(StreamId, NewStream, State)} + end); + +handle_call({send_and_close, StreamId, Binary}, From, State) -> + run_with_stream_write_open(StreamId, State, fun(StreamRecord) -> + {noreply, handle_send(StreamId, StreamRecord, Binary, true, From, State)} + end); +handle_call({send, StreamId, Binary}, From, State) -> + run_with_stream_write_open(StreamId, State, fun(StreamRecord) -> + {noreply, handle_send(StreamId, StreamRecord, Binary, false, From, State)} + end); + +handle_call({recv, StreamId, Length, _Timeout}, From, State) -> + run_with_stream(StreamId, State, + fun(StreamRecord) -> + handle_receive(StreamId, StreamRecord, Length, From, State) + end); + +handle_call(start, _From, State) -> + (State#state.socket):setopts([{active, once}, {packet, raw}, {mode, binary}]), + {reply, ok, State}; + +handle_call(Call, _From, State) -> + lager:debug("Received invalid call ~p", [Call]), + {reply, {error, unknown_call}, State}. + +handle_cast(graceful_close, #state{closing = true} = State) -> + {noreply, State}; +handle_cast(graceful_close, #state{closing = false} = State) -> + socket_write(State, espdy_frame:encode_goaway(State#state.last_stream_accepted)), + case dict:size(State#state.streams) of + 0 -> + {stop, normal, State}; + _ -> + {noreply, State#state{closing = true}} + end; + + +handle_cast(_Call, State) -> + {noreply, State}. + +is_ping_reply(ID, Role) -> + (ID rem 2 =:= 0) =:= (Role =:= server). + +handle_ping_reply(ID, State) -> + lager:debug("Received PING reply ~p", [ID]), + State. + +handle_ping(ID, State) -> + lager:debug("Received PING ~p", [ID]), + socket_write(State, espdy_frame:encode_ping(ID)), + State. + + +stream_error_for_open_stream(StreamId, Stream, Status, State) -> + socket_write(State, espdy_frame:encode_rst_stream(StreamId, Status)), + stream_abnormally_closed(StreamId, Stream, State). + +stream_error(StreamId, Status, State) -> + socket_write(State, espdy_frame:encode_rst_stream(StreamId, Status)), + State. + +is_flag_fin(Flags) -> + (Flags band ?FLAG_FIN) =:= ?FLAG_FIN. + +send_closed(From) -> + lager:debug("Sending {error, closed} to ~p ", [From]), + gen_server:reply(From, {error, closed}). + +stream_abnormally_closed(StreamId, Stream, State) -> + stream_read_closed(Stream), + stream_write_abnormally_closed(Stream), + NewStreams = dict:erase(StreamId, State#state.streams), + State#state{streams = NewStreams}. + +stream_write_abnormally_closed(Stream) -> + lists:foreach(fun send_closed/1, Stream#stream.send_waiters). + +stream_read_closed(Stream) -> + [send_closed(From) || {From, _Length} <- Stream#stream.recv_waiters]. + + + +add_stream_data(Stream, Data) -> + ReadBuffer = Stream#stream.read_buffer, + lager:debug("Reply To Receivers ~p ~p ~n", [Data, Stream#stream.recv_waiters]), + {NewBuffer, NewReceivers} = reply_to_receivers(<>, lists:reverse(Stream#stream.recv_waiters)), + Stream#stream{read_buffer = NewBuffer, recv_waiters = NewReceivers}. + +reply_to_receivers(Buffer, []) -> + {Buffer, []}; +reply_to_receivers(Binary, [{From, Length} | Rest] = Receivers) -> + case Binary of + <> -> + lager:debug("Reply To Receiver ~p ~p ~n", [From, Packet]), + gen_server:reply(From, Packet), + reply_to_receivers(RestOfBuffer, Rest); + _ -> + {Binary, Receivers} + end. + +update_open_stream(Stream, Flags) -> + case is_flag_fin(Flags) of + true -> + lager:debug("Closing Stream", []), + stream_read_closed(Stream), + Stream#stream{read_open = false, recv_waiters = []}; + _ -> + Stream + end. + + +handle_valid_data_frame(StreamId, #stream{read_open = false} = Stream, _Flags, _Data, State) -> + stream_error_for_open_stream(StreamId, Stream, ?INVALID_STREAM, State); + +handle_valid_data_frame(StreamId, Stream, Flags, Data, State) -> + Stream1 = add_stream_data(Stream, Data), + Stream2 = update_open_stream(Stream1, Flags), + check_if_stream_is_fully_closed_and_update_stream(StreamId, Stream2, State). + +new_stream(StreamId, Headers, #state{waiting_acceptors = [H|Rest]} = State) -> + gen_server:reply(H, {ok, StreamId, Headers}), + State#state{waiting_acceptors = Rest}; +new_stream(StreamId, Headers, State) -> + State#state{accept_buffer = State#state.accept_buffer ++ [{StreamId, Headers}]}. + +expected_stream_modulus(#state{role = server}) -> + 1; +expected_stream_modulus(#state{role = client}) -> + 0. + +can_accept_stream(State) -> + State#state.closing =:= false. + +is_valid_stream_id(State, StreamId) -> + (StreamId =/= 0) and (expected_stream_modulus(State) =:= (StreamId rem 2)) and (State#state.last_stream_accepted < StreamId). + +check_headers(StreamId, {error, Reason}, _Flags, State) -> + lager:debug("Receive invalid header ~p", [Reason]), + stream_error(StreamId, ?PROTOCOL_ERROR, State); + +check_headers(StreamId, Headers, Flags, State) -> + NewStreams = dict:store(StreamId, #stream{read_open = not is_flag_fin(Flags)}, State#state.streams), + new_stream(StreamId, Headers, State#state{streams = NewStreams, last_stream_accepted = StreamId}). + +accept_stream(State, StreamId, Flags, Headers) -> + + case is_valid_stream_id(State, StreamId) of + true -> + check_headers(StreamId, Headers, Flags, State); + false -> + stream_error(StreamId, ?PROTOCOL_ERROR, State), + State + end. + +refuse_stream(State, StreamId) -> + stream_error(StreamId, ?REFUSED_STREAM, State), + State. + +with_stream(StreamId, State, Fun) -> + case dict:find(StreamId, State#state.streams) of + error -> + lager:debug("Stream Error: INVALID STREAM", []), + stream_error(StreamId, ?INVALID_STREAM, State); + {ok, Stream} -> + Fun(Stream) + end. + +validate_header_frame(Stream, Headers) -> + case {Headers, Stream#stream.read_open} of + {{error, _Reason}, _} -> + {error, ?PROTOCOL_ERROR}; + {_, false } -> + {error, ?INVALID_STREAM}; + _ -> + ok + end. + +handle_frame(#control_frame{type = ?GOAWAY, data = <>}, State) -> + State#state{other_side_closing = true}; +handle_frame(#control_frame{type = ?HEADERS, data = + <>}, State) -> + %% must always decode headers + Headers = espdy_frame:decode_name_value_header_block(State#state.zlib_read_nv_context, NameValueHeaderBlock), + lager:debug("Receive headers ~p", [Headers]), + + with_stream(StreamId, State, fun(Stream) -> + case validate_header_frame(Stream, Headers) of + {error, Status} -> + stream_error(StreamId, Status, State); + _ -> + lager:debug("Received valid headers but we ignore headers..."), + State + end + end); + +handle_frame(#control_frame{type = ?NOOP}, State) -> + State; + +handle_frame(#control_frame{type = ?PING, data = <>}, State) -> + case is_ping_reply(ID, State#state.role) of + true -> handle_ping_reply(ID, State); + _ -> handle_ping(ID, State) + end; + +handle_frame(#control_frame{flags = Flags, type = ?SYN_STREAM, data = + <>}, State) -> + + lager:debug("Received SYN STREAM for ~p with NameValueHeaderBlock size ~p", [StreamId, byte_size(NameValueHeaderBlock)]), + % always need to decode headers to stop streams getting of sync + + Headers = espdy_frame:decode_name_value_header_block(State#state.zlib_read_nv_context, NameValueHeaderBlock), + + case can_accept_stream(State) of + true -> + accept_stream(State, StreamId, Flags, Headers); + false -> + refuse_stream(State, StreamId) + end; + +handle_frame(#control_frame{type = ?RST_STREAM, data = + <>}, State) -> + case dict:find(StreamId, State#state.streams) of + error -> + lager:debug("Received RST_STREAM for unknown stream: ~p (~p)", [StreamId, StatusCode]), + State; + {ok, Stream} -> + lager:debug("Received RST_STREAM for stream: ~p (~p)", [StreamId, StatusCode]), + stream_abnormally_closed(StreamId, Stream, State) + end; + +handle_frame(#data_frame{stream_id = StreamId, flags = Flags, data = Data}, State) -> + with_stream(StreamId, State, fun(Stream) -> + handle_valid_data_frame(StreamId, Stream, Flags, Data, State) + end); + +handle_frame(Frame, State) -> + lager:debug("Received Unknown Frame: ~p", [Frame]), + State. + +handle_data(Data, State0) -> + case espdy_frame:read_frame(State0#state.read_buffer, Data) of + {Frame, NewBuffer} -> + lager:debug("Received frame ~p", [Frame]), + State = handle_frame(Frame, State0), + handle_data(<<>>, State#state{read_buffer = NewBuffer}); + NewBuffer -> + lager:debug("Buffering... ~p", [NewBuffer]), + (State0#state.socket):setopts([{active, once}]), + State0#state{read_buffer = NewBuffer} + end. + +notify_write_finished(Stream, true) -> + case Stream#stream.send_waiters of + [Top | Rest] -> + gen_server:reply(Top, ok), + Stream#stream{send_waiters = Rest}; + _ -> + %WTF + Stream + end; + +notify_write_finished(Stream, false) -> + Stream. + +stream_finished(StreamId, _Stream, State) -> + lager:debug("Stream fully closed ~p", [StreamId]), + State#state{streams = dict:erase(StreamId, State#state.streams)}. + +check_if_stream_is_fully_closed_and_update_stream(StreamId, #stream{read_open = false, write_open = false} = Stream, State) -> + stream_finished(StreamId, Stream, State); +check_if_stream_is_fully_closed_and_update_stream(StreamId, Stream, State) -> + update_stream(StreamId, Stream, State). + +check_for_graceful_shutdown(State) -> + case (State#state.closing =:= true) and (dict:size(State#state.streams) =:= 0) of + true -> + lager:debug("scheduled shutdown", []), + {stop, normal, State}; + _ -> + {noreply, State} + end. + +handle_write_finished(StreamId, Stream0, Fin, Notify, State) -> + Stream1 = notify_write_finished(Stream0, Notify), + Stream2 = case Fin of + true -> + Stream1#stream{write_open = false}; + _ -> + Stream1 + end, + + check_if_stream_is_fully_closed_and_update_stream(StreamId, Stream2, State). + +handle_header_finished(StreamId, Stream, State) -> + NewStream = case Stream#stream.send_header_waiters of + [Top | Rest] -> + gen_server:reply(Top, ok), + Stream#stream{send_header_waiters = Rest}; + _ -> + %WTF + Stream + end, + + update_stream(StreamId, NewStream, State). + + +handle_info({header_finished, {ok, StreamId}}, State) -> + lager:debug("header_finished ~p", [StreamId]), + case dict:find(StreamId, State#state.streams) of + {ok, Stream} -> + {noreply, handle_header_finished(StreamId, Stream, State)}; + _ -> + lager:debug("Received header_finished notification for stream that no longer exists ~p", [StreamId]), + {noreply, State} + end; +handle_info({write_error, Reason}, State) -> + lager:debug("Received write_error notification ~p", [Reason]), + {stop, normal, State}; + +handle_info({write_finished, {ok, StreamId, Fin, Notify}}, State) -> + lager:debug("write_finished ~p", [StreamId]), + case dict:find(StreamId, State#state.streams) of + {ok, Stream} -> + State1 = handle_write_finished(StreamId, Stream, Fin, Notify, State), + lager:debug("Stream status ~p", [debug_streams(State1)]), + check_for_graceful_shutdown(State1); + _ -> + lager:debug("Received write_finished notification for stream that no longer exists ~p", [StreamId]), + {noreply, State} + end; + +handle_info({CloseTag, _Socket}, #state{close_tag = CloseTag} = State0) -> + lager:debug("Received closed ~p", [self()]), + %% mmmm... i think this is wrong. need to add tests for half closed tcp + + {stop, normal, State0}; + +handle_info({DataTag, _Socket, Data}, #state{data_tag = DataTag} = State0) -> + lager:debug("Received data ~p Buffer ~p ~n", [Data, State0#state.read_buffer]), + State = handle_data(Data, State0), + lager:debug("Stream status ~p", [debug_streams(State)]), + check_for_graceful_shutdown(State); + + +handle_info(Msg, State) -> + lager:debug("Received unknown message ~p", [Msg]), + {noreply, State}. + +clean_up_listeners(Stream) -> + stream_read_closed(Stream), + stream_write_abnormally_closed(Stream). + +dict_for_each_value(Fun, Dict) -> + dict:fold( + fun(_K, V, Acc) -> + Fun(V), + Acc + end, 0, Dict). + +clean_up_all_listeners(State) -> + dict_for_each_value(fun clean_up_listeners/1, State#state.streams). + + +get_statistics(State) -> + [RO, WO, BO, Z] = dict:fold( + fun (_K, #stream{read_open = true} = Stream, [ReadOpen, WriteOpen, BothOpen, Zombies]) -> + case is_write_open(Stream) of + true -> + [ReadOpen, WriteOpen, BothOpen + 1, Zombies]; + _ -> + [ReadOpen + 1, WriteOpen, BothOpen, Zombies] + end; + (_K, #stream{read_open = false} = Stream , [ReadOpen, WriteOpen, BothOpen, Zombies]) -> + case is_write_open(Stream) of + true -> + [ReadOpen, WriteOpen + 1, BothOpen, Zombies]; + _ -> + [ReadOpen, WriteOpen, BothOpen, Zombies + 1] + end + end, [0, 0, 0, 0], State#state.streams), + [{read_open, RO}, {write_open, WO}, {both_open, BO}, {zombies, Z}]. + + +terminate(_Reason, State) -> + zlib:close(State#state.zlib_read_nv_context), + zlib:close(State#state.zlib_write_nv_context), + clean_up_all_listeners(State), + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. \ No newline at end of file diff --git a/src/espdy_ssl_socket.erl b/src/espdy_ssl_socket.erl new file mode 100644 index 0000000..0209b0f --- /dev/null +++ b/src/espdy_ssl_socket.erl @@ -0,0 +1,27 @@ +-module(espdy_ssl_socket, [Socket]). + +-export([send/1, close/0, shutdown/0, setopts/1, controlling_process/1, data_tag/0, close_tag/0, error_tag/0]). + +data_tag() -> + ssl. + +close_tag() -> + ssl_closed. + +error_tag() -> + ssl_error. + +send(Packet) -> + ssl:send(Socket, Packet). + +close() -> + ssl:close(Socket). + +setopts(Options) -> + ssl:setopts(Socket, Options). + +shutdown() -> + ssl:shutdown(Socket). + +controlling_process(Pid) -> + ssl:controlling_process(Socket, Pid). \ No newline at end of file diff --git a/src/espdy_tcp_socket.erl b/src/espdy_tcp_socket.erl new file mode 100644 index 0000000..6cf9d08 --- /dev/null +++ b/src/espdy_tcp_socket.erl @@ -0,0 +1,27 @@ +-module(espdy_tcp_socket, [Socket]). + +-export([send/1, close/0, shutdown/0, setopts/1, controlling_process/1, data_tag/0, error_tag/0, close_tag/0]). + +close_tag() -> + tcp_closed. + +error_tag() -> + tcp_error. + +data_tag() -> + tcp. + +send(Packet) -> + gen_tcp:send(Socket, Packet). + +close() -> + gen_tcp:close(Socket). + +setopts(Options) -> + inet:setopts(Socket, Options). + +shutdown() -> + gen_tcp:shutdown(Socket). + +controlling_process(Pid) -> + gen_tcp:controlling_process(Socket, Pid). \ No newline at end of file diff --git a/src/espdy_test_server.erl b/src/espdy_test_server.erl new file mode 100644 index 0000000..3dbff65 --- /dev/null +++ b/src/espdy_test_server.erl @@ -0,0 +1,133 @@ +-module(espdy_test_server). +-compile([{parse_transform, lager_transform}]). + +-export([start/0]). + +start() -> + lager:start(), + lager:set_loglevel(lager_console_backend, debug), + crypto:start(), + ssl:start(), + {ok, Listen} = ssl:listen(8443, [ + {keyfile, "server.key"}, + {certfile, "server.crt"}, + {reuseaddr, true}, + {ssl_imp, new}, + binary, + {packet, raw}, + {active, false} + ]), + + lager:info("Listening on port 80", []), + + loop(Listen). + +loop(Listen) -> + case ssl:transport_accept(Listen) of + {ok, Socket} -> + case ssl:ssl_accept(Socket) of + ok -> + spawn(fun() -> loop(Listen) end), + accept(Socket); + {error, Reason} -> + lager:info("Error ssl_accept socket", [Reason]), + loop(Listen) + end; + {error, Reason} -> + lager:info("Error accepting socket", [Reason]), + loop(Listen) + end. + +accept(Socket) -> + lager:info("Accepted Socket ~p", [Socket]), + {ok, Pid} = espdy_server:start_link(espdy_ssl_socket:new(Socket), server), + espdy_loop(Pid). + +espdy_loop(Pid) -> + {ok, Socket, Headers} = espdy_server:accept(Pid), + spawn(fun() -> espdy_loop(Pid) end), + espdy_accept(Socket, Headers). + +get_header(Header, Headers) -> + + case lists:keyfind(Header, 1, Headers) of + {_, [Url]} -> {ok, Url}; + _ -> notfound + end. + +espdy_accept(Socket, Headers) -> + io:format(user, "Headers ~p", [Headers]), + ResponseHeaders = [ + {<<"status">>, [<<"200 OK">>]}, + {<<"version">>, [<<"HTTP/1.1">>]}, + {<<"Set-Cookie">>, [<<"cookie1=value1">>, <<"cookie2=value2">>]} + ], + + + case get_header(<<"url">>, Headers) of + {ok, <<"/style.css">>} -> + timer:sleep(1000), + espdy_server:send_headers(Socket, ResponseHeaders), + espdy_server:send(Socket, css_file1()), + espdy_server:close(Socket); + {ok, <<"/style2.css">>} -> + espdy_server:send_headers(Socket, ResponseHeaders), + timer:sleep(500), + espdy_server:send(Socket, css_file2()), + espdy_server:close(Socket); + {ok, <<"/">>} -> + espdy_server:send_headers(Socket, ResponseHeaders), + {ok, Scheme} = get_header(<<"scheme">>, Headers), + {ok, Host} = get_header(<<"host">>, Headers), + Url = list_to_binary([Scheme, "://", Host, "/style3.css"]), + + PushHeaders = [{<<"status">>, <<"200 OK">>}, {<<"version">>, <<"HTTP/1.1">>}, {<<"url">>, Url}, {<<"content-type">>, <<"text/css">>}], + {ok, PushSocket} = espdy_server:connect_for_push(Socket, false, simple_headers(PushHeaders)), + espdy_server:send_and_close(PushSocket, css_file3()), + espdy_server:send_and_close(Socket, html_page()); + _ -> + espdy_server:send_headers(Socket, simple_headers([{<<"status">>, <<"404 Not Found">>}, {<<"version">>, <<"HTTP/1.1">>}])), + espdy_server:close(Socket) + end. + +simple_header({Name, Value}) -> + {Name, [Value]}. + +simple_headers(Headers) -> + lists:map(fun simple_header/1, Headers). + +css_file1() -> + <<".css1 {}">>. +css_file2() -> + <<".css2 {}">>. +css_file3() -> + <<".css3 {}">>. + + +html_page() -> + <<"" + "" + " " + "" + "" + "hello from erlang" + "" + "">>. + diff --git a/test/espdy_frame_test.erl b/test/espdy_frame_test.erl new file mode 100644 index 0000000..877443b --- /dev/null +++ b/test/espdy_frame_test.erl @@ -0,0 +1,131 @@ +-module(espdy_frame_test). +-include_lib("eunit/include/eunit.hrl"). +-include("espdy_frame.hrl"). + +decode_name_value_header_block_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(2), + ?UINT16(5), "name1", ?UINT16(4), "valu", + ?UINT16(4), "nam2", ?UINT16(3), "val" + >>), + ?assertEqual(lists:sort([{<<"name1">>, [<<"valu">>]}, {<<"nam2">>, [<<"val">>]}]), lists:sort(Headers)). + +decode_name_value_header_block_with_empty_value_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(5), "name1", ?UINT16(0) + >>), + ?assertMatch({error, _Reason}, Headers). + +decode_name_value_header_block_with_multiple_values_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(5), "name1", ?UINT16(3), "h", 0, "i" + >>), + ?assertEqual([{<<"name1">>, [<<"h">>, <<"i">>]}], Headers). + +decode_name_value_when_starting_with_null_value_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(5), "name1", ?UINT16(2), 0, "h" + >>), + ?assertMatch({error, _Reason}, Headers). + +decode_name_value_when_ending_with_null_value_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(5), "name1", ?UINT16(2), "h", 0 + >>), + ?assertMatch({error, _Reason}, Headers). + +decode_name_value_when_ending_with_null_value2_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(5), "name1", ?UINT16(4), "foo", 0 + >>), + ?assertMatch({error, _Reason}, Headers). + +decode_name_value_with_empty_value_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(5), "name1", ?UINT16(4), "h", 0, 0, "i" + >>), + ?assertMatch({error, _Reason}, Headers). + +decode_name_value_with_empty_header_name_test() -> + Headers = espdy_frame:decode_name_value_header_block(<< + ?UINT16(1), + ?UINT16(0), ?UINT16(5), "hello" + >>), + ?assertMatch({error, _Reason}, Headers). + +encode_name_value_header_block_test() -> + Frame = espdy_frame:encode_name_value_header_block([{<<"hello">>, [<<"world">>]}, {<<"foo">>, [<<"one">>, <<"two">>]}]), + Expected = <>, + ?assertEqual(Expected, Frame). + +encode_name_value_header_block_with_empty_test() -> + Frame = espdy_frame:encode_name_value_header_block([{<<"hello">>, []}]), + Expected = <>, + ?assertEqual(Expected, Frame). + +encode_name_value_header_block_with_empty_value_test() -> + ?assertException(throw, empty_header_value, espdy_frame:encode_name_value_header_block([{<<"hello">>, [<<>>]}])). + + +encode_then_decode_compressed_name_value_header_block_test() -> + ZLibDeflate = espdy_frame:initialize_zlib_for_deflate(), + Frame = list_to_binary(espdy_frame:encode_name_value_header_block(ZLibDeflate, [{<<"host">>, [<<"hostname">>]}])), + ZLibInflate = espdy_frame:initialize_zlib_for_inflate(), + Headers = espdy_frame:decode_name_value_header_block(ZLibInflate, Frame), + ?assertEqual([{<<"host">>, [<<"hostname">>]}], Headers). + + +decode_ping_frame_test() -> + Frame = espdy_frame:encode_ping(1), + {DecodedFrame, Data} = espdy_frame:read_frame(<<>>, Frame), + ?assertEqual(<<>>, Data), + ?assertEqual(#control_frame{version = ?SPDY_VERSION, type = ?PING, flags = 0, data = <<0,0,0,1>>}, DecodedFrame). + +decode_split_ping_frame_test() -> + Frame = espdy_frame:encode_ping(1), + <> = Frame, + {DecodedFrame, Data} = espdy_frame:read_frame(Part1, Part2), + ?assertEqual(<<>>, Data), + ?assertEqual(#control_frame{version = ?SPDY_VERSION, type = ?PING, flags = 0, data = <<0,0,0,1>>}, DecodedFrame). + +chrome_frame_test() -> + Buffer = <<128,2,0,1,1,0,1,31,0,0,0,1,0,0,0,0,0,0,56,234,223,162,81,178,98,224,102,96,131,164,23,6,123,184,11,117,48,44,214,174,64,23,205,205,177,46,180,53,208,179,212,209,210,215,2,179,44,24,248,80,115,44,131,156,103,176,63,212,61,58,96,7,129,213,153,235,64,212,27,51,240,163,229,105,6,65,144,139,117,160,78,214,41,78,73,206,128,171,129,37,3,6,190,212,60,221,208,96,157,212,60,168,165,188,40,137,141,129,19,26,36,182,6,12,44,160,220,207,192,9,74,34,57,96,38,91,46,176,192,201,79,97,96,118,119,13,97,96,43,6,106,203,77,5,170,42,41,41,96,96,6,133,5,163,62,3,23,34,3,51,148,250,230,87,101,230,228,36,234,155,234,25,40,104,0,228,155,152,156,153,87,146,95,156,97,173,224,9,76,83,57,10,64,1,5,255,96,133,8,5,67,131,120,243,120,3,77,5,71,96,240,164,134,167,38,121,103,150,232,155,26,155,234,25,42,104,120,123,132,248,250,232,40,228,100,102,167,42,184,167,38,103,231,107,42,56,103,0,203,165,84,125,67,19,61,160,235,129,170,128,101,131,66,112,98,90,98,81,38,68,19,3,59,52,118,24,56,96,145,6,0,0,0,255,255>> , + {Frame, <<>>} = espdy_frame:decode_frame(Buffer), + DecodedFrame = espdy_frame:decode_control_frame(Frame). + +encode_two_syn_replies_test() -> + Headers = [ + {<<"status">>, [<<"200 OK">>]}, + {<<"version">>, [<<"HTTP/1.1">>]} + ], + + ZLib = espdy_frame:initialize_zlib_for_deflate(), + Syn1 = espdy_frame:encode_syn_reply(ZLib, 1, 0, Headers), + + Syn2 = espdy_frame:encode_syn_reply(ZLib, 3, 0, Headers), + + Inflate = espdy_frame:initialize_zlib_for_inflate(), + + DecodeFrame1 = espdy_frame:decode_exactly_one_frame(Syn1), + Syn1Decoded = espdy_frame:decode_control_frame(Inflate, DecodeFrame1), + + DecodeFrame2 = espdy_frame:decode_exactly_one_frame(Syn2), + Syn2Decoded = espdy_frame:decode_control_frame(Inflate, DecodeFrame2). + +decode_our_headers_test() -> + ZLib = espdy_frame:initialize_zlib_for_inflate(), + Buffer = <<128,2,0,2,0,0,0,39,0,0,0,1,0,0,120,187,223,162,81,178,98,96,98,96,131,8,50,176,1,19,171,130,191,55,3,59,84,154,129,3,166,11,0,0,0,255,255>>, + Frame = espdy_frame:decode_exactly_one_frame(Buffer), + DecodedFrame = espdy_frame:decode_control_frame(ZLib, Frame), + + Buffer2 = <<128,2,0,2,0,0,0,14,0,0,0,3,0,0,34,74,17,0,0,0,255,255>>, + Frame2= espdy_frame:decode_exactly_one_frame(Buffer2), + DecodedFrame2 = espdy_frame:decode_control_frame(ZLib, Frame2). + + diff --git a/test/espdy_server_test.erl b/test/espdy_server_test.erl new file mode 100644 index 0000000..fc7fb9f --- /dev/null +++ b/test/espdy_server_test.erl @@ -0,0 +1,706 @@ +-module(espdy_server_test). +-compile([{parse_transform, lager_transform}]). + +-include_lib("eunit/include/eunit.hrl"). +-include("espdy_frame.hrl"). +-include("espdy_server.hrl"). + +-define(TEST(X), {setup, local, fun setup/0, fun tear_down/1, {atom_to_list(X), fun X/0}}). + +% there has to be an easier way of doing this... right ? + +packet_sim_test_() -> + {spawn, + [ + ?TEST(close_after_send_tst), + ?TEST(syn_reply_tst), + ?TEST(send_tst), + ?TEST(rst_while_waiting_for_read_tst), + ?TEST(rst_read_tst), + ?TEST(flag_fin_read_ok_tst), + ?TEST(flag_fin_tst), + ?TEST(flag_fin_during_tst), + ?TEST(syn_stream_tst), + ?TEST(syn_stream_buffer_tst), + ?TEST(syn_stream_blocking_tst), + ?TEST(server_ping_tst), + ?TEST(client_ping_tst), + ?TEST(server_receive_ping_reply_tst), + ?TEST(client_receive_ping_reply_tst), + ?TEST(statistics_tst), + ?TEST(statistics_read_open_tst), + ?TEST(handle_accept_queue_tst), + ?TEST(handle_receive_write_after_close_from_other_side_tst), + ?TEST(handle_non_existant_data_frame_tst), + ?TEST(handle_write_finished_for_stream_that_no_longer_exists_tst), + ?TEST(handle_header_finished_for_stream_that_no_longer_exists_tst), + ?TEST(receive_closed_when_the_underlying_connection_is_closed_beneath_us_tst), + ?TEST(send_headers_tst), + ?TEST(handle_split_frame_tst), + ?TEST(noop_tst), + ?TEST(receive_protocol_error_for_invalid_syn_tst), + ?TEST(receive_protocol_error_for_invalid_syn_when_zero_tst), + ?TEST(receive_protocol_error_for_invalid_syn_when_going_backwards_tst), + ?TEST(graceful_close_closes_server_when_there_are_no_connections_tst), + ?TEST(graceful_close_closes_server_when_last_connection_is_closed_tst), + ?TEST(graceful_close_close_twice_tst), + ?TEST(unknown_cast_tst), + ?TEST(unknown_call_tst), + ?TEST(statistics_zombie_tst), + ?TEST(zero_length_name_receives_rst_tst), + ?TEST(zero_length_value_receives_rst_tst), + ?TEST(zero_length_multi_value_receives_rst_tst), + ?TEST(zero_length_name_receives_rst_when_receiving_headers_tst), + ?TEST(receive_rst_after_sending_on_closed_channel_tst), + ?TEST(refuse_new_connections_after_receiving_goaway_tst), + ?TEST(refuse_new_connects_after_sending_goaway_tst), + ?TEST(refuse_new_accepts_after_sending_goaway_tst), + ?TEST(connect_and_recv_tst), + ?TEST(connect_for_push_and_get_error_closed_when_recving_tst), + ?TEST(connect_and_send_syn_fin_tst) + + ] + } + . + +setup() -> + application:start(lager), + lager:set_loglevel(lager_console_backend, error). + +tear_down(_) -> + check_for_messages(). + +check_for_messages() -> + receive + Msg -> + lager:info("Received message ~p", [Msg]), + ?assert(false) + after 100 -> + ok + end. + +start_socket(Role) -> + Socket = espdy_mock_socket:new(self()), + {ok, Pid} = espdy_server:start_link(Socket, Role), + Pid. + +send_syn_stream(ZLib, Pid, StreamId, Headers) -> + Frame = espdy_frame:encode_syn_stream(ZLib, 0, StreamId, 0, 0, Headers), + send_frame(Pid, Frame). + +send_syn_stream(ZLib, Pid, StreamId) -> + send_syn_stream(ZLib, Pid, StreamId, []). + +send_syn_stream(Pid, StreamId) -> + Frame = espdy_frame:encode_syn_stream(0, StreamId, 0, 0, []), + send_frame(Pid, Frame). + +start_channel(ZLib, Pid, N) -> + send_syn_stream(ZLib, Pid, N), + {ok, Socket, _Headers} = espdy_server:accept(Pid), + Socket. + +start_channel(Pid, N) -> + send_syn_stream(Pid, N), + {ok, Socket, _Headers} = espdy_server:accept(Pid), + Socket. + +start_channel_one(Pid) -> + start_channel(Pid, 1). + +start_channel_one() -> + start_channel_one(start_socket(server)). + + +receive_packet() -> + receive + {packet, PacketPid, Packet} -> + PacketPid ! {mock_packet_result, ok}, + timer:sleep(50) % DODGY HACK + end, + Packet. + +receive_data_frame() -> + receive + {packet, PacketPid, Packet} -> PacketPid ! {mock_packet_result, ok} + end, + Frame = espdy_frame:decode_exactly_one_frame(Packet), + ?assertMatch(#data_frame{}, Frame), + Frame. + +receive_control_frame() -> + receive + {packet, PacketPid, Packet} -> PacketPid ! {mock_packet_result, ok} + end, + Frame = espdy_frame:decode_exactly_one_frame(Packet), + ?assertMatch(#control_frame{}, Frame), + Frame. + + + +send_frame(#espdy_socket{pid = Pid}, Packet) -> + Pid ! {tcp, undefined, Packet}; + +send_frame(Pid, Packet) -> + Pid ! {tcp, undefined, Packet}. + +assert_dont_receive_packet() -> + receive + {packet, _, _} -> ?assert(false) + after 10 -> + ok + end. + +receive_response(Ref) -> + receive + {Ref, Resp} -> ok + end, + Resp. + + + +statistics_zombie_tst() -> + Pid = start_socket(server), + Socket = start_channel_one(Pid), + + send_frame(Socket, espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<>>)), + + Ref1 = espdy_server:async_send_and_close(Socket, <<"hello">>), + ?assertMatch(#control_frame{type = ?SYN_REPLY}, receive_control_frame()), + + % ZOOOMBIIIESSS. this probably shouldn't be a zombie. but the idea is we need to keep + % streams around until they have written all their crap out. + ?assertEqual([{read_open, 0}, {write_open, 0}, {both_open, 0}, {zombies, 1}], espdy_server:statistics(Pid)), + ?assertMatch(#data_frame{flags = ?FLAG_FIN}, receive_data_frame()), + ok = receive_response(Ref1). + + +statistics_read_open_tst() -> + Pid = start_socket(server), + Socket = start_channel_one(Pid), + + ?assertEqual([{read_open, 0}, {write_open, 0}, {both_open, 1}, {zombies, 0}], espdy_server:statistics(Pid)), + + Ref1 = espdy_server:async_send_and_close(Socket, <<"hello">>), + ?assertMatch(#control_frame{type = ?SYN_REPLY}, receive_control_frame()), + ?assertMatch(#data_frame{flags = ?FLAG_FIN}, receive_data_frame()), + ok = receive_response(Ref1), + + ?assertEqual([{read_open, 1}, {write_open, 0}, {both_open, 0}, {zombies, 0}], espdy_server:statistics(Pid)). + + +statistics_tst() -> + Pid = start_socket(server), + Socket = start_channel_one(Pid), + + ?assertEqual([{read_open, 0}, {write_open, 0}, {both_open, 1}, {zombies, 0}], espdy_server:statistics(Pid)), + + send_frame(Socket, espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<>>)), + + ?assertEqual([{read_open, 0}, {write_open, 1}, {both_open, 0}, {zombies, 0}], espdy_server:statistics(Pid)), + + Ref1 = espdy_server:async_send_and_close(Socket, <<"hello">>), + + ?assertMatch(#control_frame{type = ?SYN_REPLY}, receive_control_frame()), + ?assertMatch(#data_frame{flags = ?FLAG_FIN}, receive_data_frame()), + + ok = receive_response(Ref1), + + ?assertEqual([{read_open, 0}, {write_open, 0}, {both_open, 0}, {zombies, 0}], espdy_server:statistics(Pid)) + + . + +close_after_send_tst() -> + + Socket = start_channel_one(), + Ref1 = espdy_server:async_send(Socket, <<"hello">>), + + %SYN_REPLY + receive_control_frame(), + + receive_data_frame(), + ok = receive_response(Ref1), + + Ref2 = espdy_server:async_close(Socket), + + DataFrame = receive_data_frame(), + ?assertEqual(?FLAG_FIN, DataFrame#data_frame.flags), + + ok = receive_response(Ref2). + + +syn_reply_tst() -> + Headers = [{<<"foo">>, [<<"bar">>]}], + + Socket = start_channel_one(), + Ref1 = espdy_server:async_send_headers(Socket, Headers), + ControlFrame = receive_control_frame(), + SynReply = espdy_frame:decode_control_frame(ControlFrame), + ?assertEqual(1, SynReply#syn_reply.stream_id), + ?assertEqual(Headers, SynReply#syn_reply.headers), + ?assertEqual(0, SynReply#syn_reply.flags), + + ok = receive_response(Ref1). + +send_tst() -> + + Socket = start_channel_one(), + Ref1 = espdy_server:async_send(Socket, <<"helloworld">>), + Ref2 = espdy_server:async_send(Socket, <<"lollercopter">>), + + ControlFrame = receive_control_frame(), + SynReply = espdy_frame:decode_control_frame(ControlFrame), + ?assertEqual(#syn_reply{stream_id = 1, headers = [], flags = 0}, SynReply), + + DataFrame = receive_data_frame(), + ?assertEqual(1, DataFrame#data_frame.stream_id), + ?assertEqual(<<"helloworld">>, DataFrame#data_frame.data), + + ?assertEqual(#data_frame{stream_id = 1, data = <<"lollercopter">>, flags = 0}, receive_data_frame()), + + ok = receive_response(Ref1), + ok = receive_response(Ref2). + +rst_while_waiting_for_read_tst() -> + + Socket = start_channel_one(), + DataFrame = espdy_frame:encode_data_frame(1, 0, <<"1234">>), + RstFrame = espdy_frame:encode_rst_stream(1, ?PROTOCOL_ERROR), + + send_frame(Socket, DataFrame), + + Ref1 = espdy_server:async_recv(Socket, 6), + + send_frame(Socket, RstFrame), + + {error, closed} = receive_response(Ref1). + +rst_read_tst() -> + + Socket = start_channel_one(), + + DataFrame = espdy_frame:encode_data_frame(1, 0, <<"12345">>), + + send_frame(Socket, DataFrame), + RstFrame = espdy_frame:encode_rst_stream(1, ?PROTOCOL_ERROR), + + send_frame(Socket, RstFrame), + + ?assertEqual({error, closed}, espdy_server:recv(Socket, 4)). + + +flag_fin_read_ok_tst() -> + + Socket = start_channel_one(), + DataFrame = espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<"1234">>), + send_frame(Socket, DataFrame), + + ?assertEqual(<<"1234">>, espdy_server:recv(Socket, 4)). + +flag_fin_tst() -> + + Socket = start_channel_one(), + DataFrame = espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<"1234">>), + send_frame(Socket, DataFrame), + + ?assertEqual({error, closed}, espdy_server:recv(Socket, 6)). + +flag_fin_during_tst() -> + + Socket = start_channel_one(), + DataFrame = espdy_frame:encode_data_frame(1, 0, <<"1234">>), + send_frame(Socket, DataFrame), + Ref1 = espdy_server:async_recv(Socket, 6), + + DataFrame2 = espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<"5">>), + send_frame(Socket, DataFrame2), + ?assertEqual({error, closed}, receive_response(Ref1)). + +syn_stream_tst() -> + + Socket = start_channel_one(), + DataFrame = espdy_frame:encode_data_frame(1, 0, <<"helloxworld">>), + send_frame(Socket, DataFrame), + + Binary = espdy_server:recv(Socket, 6), + ?assertEqual(<<"hellox">>, Binary), + Binary2 = espdy_server:recv(Socket, 5), + ?assertEqual(<<"world">>, Binary2). + +syn_stream_buffer_tst() -> + + Socket = start_channel_one(), + + DataFrame = espdy_frame:encode_data_frame(1, 0, <<"1234">>), + DataFrame2 = espdy_frame:encode_data_frame(1, 0, <<"5678">>), + send_frame(Socket, DataFrame), + send_frame(Socket, DataFrame2), + + Binary = espdy_server:recv(Socket, 6), + ?assertEqual(<<"123456">>, Binary), + Binary2 = espdy_server:recv(Socket, 2), + ?assertEqual(<<"78">>, Binary2). + +syn_stream_blocking_tst() -> + + Socket = start_channel_one(), + + Ref1 = espdy_server:async_recv(Socket, 6), + Ref2 = espdy_server:async_recv(Socket, 5), + + DataFrame = espdy_frame:encode_data_frame(1, 0, <<"helloxworld">>), + + send_frame(Socket, DataFrame), + + ?assertEqual(<<"hellox">>, receive_response(Ref1)), + ?assertEqual(<<"world">>, receive_response(Ref2)). + + +receive_accept(Pid, Ref) -> + {ok, StreamId, Headers} = receive_response(Ref), + {ok, #espdy_socket{pid = Pid, stream_id = StreamId}, Headers}. + +handle_accept_queue_tst() -> + Socket = start_socket(server), + + Ref1 = espdy_server:async_accept(Socket), + Ref2 = espdy_server:async_accept(Socket), + + ZLib = espdy_frame:initialize_zlib_for_deflate(), + send_syn_stream(ZLib, Socket, 1), + send_syn_stream(ZLib, Socket, 3), + zlib:close(ZLib), + + {ok, Socket1, []} = receive_accept(Socket, Ref1), + ?assertEqual(1, espdy_server:stream_id(Socket1)), + {ok, Socket2, []} = receive_accept(Socket, Ref2), + ?assertEqual(3, espdy_server:stream_id(Socket2)). + +receive_rst_stream(StreamId, Status) -> + ControlFrame = receive_control_frame(), + ?assertMatch(#control_frame{type = ?RST_STREAM}, ControlFrame), + ?assertMatch(#rst_stream{stream_id = StreamId, status = Status}, espdy_frame:decode_control_frame(ControlFrame)). + +handle_receive_write_after_close_from_other_side_tst() -> + Socket = start_channel_one(), + send_frame(Socket, espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<>>)), + send_frame(Socket, espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<>>)), + + receive_rst_stream(1, ?INVALID_STREAM), + + {error, closed} = espdy_server:send(Socket, <<"hello">>). + +handle_non_existant_data_frame_tst() -> + Socket = start_channel_one(), + send_frame(Socket, espdy_frame:encode_data_frame(2, ?FLAG_FIN, <<>>)), + receive_rst_stream(2, ?INVALID_STREAM). + + +handle_write_finished_for_stream_that_no_longer_exists_tst() -> + % not sure if this is actually possible.. but we have tests for it... + Socket = start_socket(server), + + Socket ! {write_finished, {ok, 1, false, true}}. + +handle_header_finished_for_stream_that_no_longer_exists_tst() -> + % not sure if this is actually possible.. but we have tests for it... + Socket = start_socket(server), + + Socket ! {header_finished, {ok, 1}}. + +tcp_close(#espdy_socket{pid = Pid}) -> + Pid ! {tcp_closed, undefined}. + +receive_closed_when_the_underlying_connection_is_closed_beneath_us_tst() -> + Socket = start_channel_one(), + Ref1 = espdy_server:async_recv(Socket, 5), + tcp_close(Socket), + {error, closed} = receive_response(Ref1). + +receive_syn_reply(ZLib) -> + ControlFrame = receive_control_frame(), + SynReply = espdy_frame:decode_control_frame(ZLib, ControlFrame), + ?assertMatch(#syn_reply{}, SynReply), + SynReply. + +receive_headers(ZLib) -> + ControlFrame = receive_control_frame(), + Headers = espdy_frame:decode_control_frame(ZLib, ControlFrame), + ?assertMatch(#headers{}, Headers), + Headers. + +send_headers_tst() -> + Socket = start_channel_one(), + FirstHeaders = [{<<"foo">>, [<<"bar">>]}], + Ref1 = espdy_server:async_send_headers(Socket, FirstHeaders), + ZLib = espdy_frame:initialize_zlib_for_inflate(), + ?assertMatch(#syn_reply{headers = FirstHeaders}, receive_syn_reply(ZLib)), + + SecondHeaders = [{<<"bar">>, [<<"foo">>]}], + Ref2 = espdy_server:async_send_headers(Socket, SecondHeaders), + ?assertMatch(#headers{headers = SecondHeaders}, receive_headers(ZLib)), + + zlib:close(ZLib), + + ok = receive_response(Ref1), + ok = receive_response(Ref2). + +handle_split_frame_tst() -> + Socket = start_socket(server), + Frame = espdy_frame:encode_syn_stream(0, 1, 0, 0, []), + <> = Frame, + send_frame(Socket, Part1), + send_frame(Socket, Part2), + {ok, _StreamSock, []} = espdy_server:accept(Socket). + +noop_tst() -> + Socket = start_socket(server), + Noop = espdy_frame:encode_noop(), + send_frame(Socket, Noop). + +receive_protocol_error_for_invalid_syn_tst() -> + Socket = start_socket(server), + ZLib = espdy_frame:initialize_zlib_for_deflate(), + Headers = [{<<"hello">>, [<<"world">>]}], + + send_syn_stream(ZLib, Socket, 2, Headers), + + + + ControlFrame = receive_control_frame(), + RstStream = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#rst_stream{stream_id = 2, status = ?PROTOCOL_ERROR}, RstStream), + + send_syn_stream(ZLib, Socket, 3, Headers), + + {ok, Socket1, Headers} = espdy_server:accept(Socket). + +receive_protocol_error_for_invalid_syn_when_zero_tst() -> + Socket = start_socket(client), + Frame = espdy_frame:encode_syn_stream(0, 0, 0, 0, []), + send_frame(Socket, Frame), + + ControlFrame = receive_control_frame(), + RstStream = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#rst_stream{stream_id = 0, status = ?PROTOCOL_ERROR}, RstStream). + + +receive_protocol_error_for_invalid_syn_when_going_backwards_tst() -> + Socket = start_socket(server), + ZLib = espdy_frame:initialize_zlib_for_deflate(), + Socket3 = start_channel(ZLib, Socket, 3), + send_syn_stream(ZLib, Socket, 1, []), + + ControlFrame = receive_control_frame(), + RstStream = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#rst_stream{stream_id = 1, status = ?PROTOCOL_ERROR}, RstStream). + +zero_length_name_receives_rst_tst() -> + Socket = start_socket(server), + Frame = espdy_frame:encode_syn_stream_with_raw_uncompressed_name_value_pairs(0, 1, 0, 0, <>), + send_frame(Socket, Frame), + RstStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#rst_stream{stream_id = 1, status = ?PROTOCOL_ERROR}, RstStream). + +zero_length_value_receives_rst_tst() -> + Socket = start_socket(server), + Frame = espdy_frame:encode_syn_stream_with_raw_uncompressed_name_value_pairs(0, 1, 0, 0, <>), + send_frame(Socket, Frame), + RstStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#rst_stream{stream_id = 1, status = ?PROTOCOL_ERROR}, RstStream). + +zero_length_multi_value_receives_rst_tst() -> + Socket = start_socket(server), + Frame = espdy_frame:encode_syn_stream_with_raw_uncompressed_name_value_pairs(0, 1, 0, 0, <>), + send_frame(Socket, Frame), + RstStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#rst_stream{stream_id = 1, status = ?PROTOCOL_ERROR}, RstStream). + + +zero_length_name_receives_rst_when_receiving_headers_tst() -> + ZLib = espdy_frame:initialize_zlib_for_deflate(), + Socket = start_socket(server), + send_syn_stream(ZLib, Socket, 1), + + Frame = espdy_frame:encode_headers_with_raw_uncompressed_name_value_pairs(ZLib, 1, 0, <>), + send_frame(Socket, Frame), + RstStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#rst_stream{stream_id = 1, status = ?PROTOCOL_ERROR}, RstStream). + +% receive_headers_tst() -> +% ZLib = espdy_frame:initialize_zlib_for_deflate(), +% Socket = start_socket(server), +% Headers = [{<<"name">>, [<<"value">>]}], +% send_syn_stream(ZLib, Socket, 1), +% {ok, Socket1, _} = espdy_server:accept(Socket), +% send_frame(Socket, espdy_frame:encode_headers(ZLib, 1, 0, Headers)), +% +% ?assertMatch({header, Headers}, espdy_server:recv_header_or_data(Socket1)). + +graceful_close_closes_server_when_there_are_no_connections_tst() -> + Socket = start_socket(server), + espdy_server:graceful_close(Socket), + ControlFrame = receive_control_frame(), + GoAway = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#goaway{last_good_stream_id = 0}, GoAway), + ?assertEqual({error, closed}, espdy_server:statistics(Socket)). + +assert_cant_start_channel_with_refused(ZLib, Socket, N) -> + send_syn_stream(ZLib, Socket, 3), + ControlFrame = receive_control_frame(), + RstStream = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#rst_stream{stream_id = N, status = ?REFUSED_STREAM}, RstStream). + +graceful_close_close_twice_tst() -> + Socket = start_socket(server), + Socket1 = start_channel(Socket, 1), + espdy_server:graceful_close(Socket), + espdy_server:graceful_close(Socket), + ControlFrame = receive_control_frame(), + GoAway = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#goaway{last_good_stream_id = 1}, GoAway). + +graceful_close_closes_server_when_last_connection_is_closed_tst() -> + ZLib = espdy_frame:initialize_zlib_for_deflate(), + Msg1 = <<"omgdoesitwork?">>, + Socket = start_socket(server), + Socket1 = start_channel(ZLib, Socket, 1), + espdy_server:graceful_close(Socket), + ControlFrame = receive_control_frame(), + GoAway = espdy_frame:decode_control_frame(ControlFrame), + ?assertMatch(#goaway{last_good_stream_id = 1}, GoAway), + + assert_cant_start_channel_with_refused(ZLib, Socket, 3), + + send_frame(Socket1, espdy_frame:encode_data_frame(1, ?FLAG_FIN, Msg1)), + ?assertEqual(Msg1, espdy_server:recv(Socket1, byte_size(Msg1))), + + Ref1 = espdy_server:async_close(Socket1), + receive_packet(), + ok = receive_response(Ref1), + + ?assertEqual({error, closed}, espdy_server:statistics(Socket)). + +receive_rst_after_sending_on_closed_channel_tst() -> + ZLib = espdy_frame:initialize_zlib_for_deflate(), + Socket = start_socket(server), + Headers = [{<<"name">>, [<<"value">>]}], + send_syn_stream(ZLib, Socket, 1), + + {ok, Socket1, _} = espdy_server:accept(Socket), + + send_frame(Socket, espdy_frame:encode_data_frame(1, ?FLAG_FIN, <<>>)), + send_frame(Socket, espdy_frame:encode_headers(ZLib, 1, 0, Headers)), + + ?assertMatch(#rst_stream{stream_id = 1, status = ?INVALID_STREAM}, espdy_frame:decode_control_frame(receive_control_frame())), + + send_syn_stream(ZLib, Socket, 3, Headers), + + {ok, Socket2, Headers} = espdy_server:accept(Socket) + + + % this can happen naturally not just from broken clients because there is a race + % between cancelling/refusing a stream and the other party being aware of this + % + % ie: refuse a stream and then a headers frame pops down. + % we always need to make sure we decode headers with zlib no matter what or the zlib + % dictionaries will get out of sync which will break ALL subsequent SYN or HEADER + % frames. + + . + + +refuse_new_connects_after_sending_goaway_tst() -> + Socket = start_socket(server), + Channel = start_channel_one(Socket), + espdy_server:graceful_close(Socket), + + GoAway = receive_control_frame(), + + + ?assertMatch({error, closed_not_processed}, espdy_server:connect(Socket)). + + +refuse_new_accepts_after_sending_goaway_tst() -> + Socket = start_socket(server), + Channel = start_channel_one(Socket), + espdy_server:graceful_close(Socket), + + GoAway = receive_control_frame(), + + ?assertMatch({error, closed}, espdy_server:accept(Socket)). + +refuse_new_connections_after_receiving_goaway_tst() -> + Socket = start_socket(server), + Channel = start_channel_one(Socket), + send_frame(Socket, espdy_frame:encode_goaway(1)), + + + ?assertMatch({error, closed_not_processed}, espdy_server:connect(Socket)). + +connect_and_recv_tst() -> + Socket = start_socket(server), + {ok, Channel} = espdy_server:connect(Socket), + SynStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#syn_stream{associated_stream_id = 0, stream_id = 2, flags = 0, headers = []}, SynStream), + send_frame(Socket, espdy_frame:encode_data_frame(Channel#espdy_socket.stream_id, 0, <<"hello">>)), + ?assertEqual(<<"hello">>, espdy_server:recv(Channel, 5)). + +connect_and_send_syn_fin_tst() -> + Socket = start_socket(server), + Headers = [{<<"hello">>, [<<"world">>]}], + {ok, Channel} = espdy_server:connect(Socket, true, Headers), + SynStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#syn_stream{associated_stream_id = 0, stream_id = 2, flags = ?FLAG_FIN, headers = Headers}, SynStream), + ?assertEqual({error, closed}, espdy_server:send(Channel, <<"more_data">>)). + +connect_for_push_and_get_error_closed_when_recving_tst() -> + Socket = start_socket(server), + ChannelOne = start_channel_one(Socket), + Headers = [{<<"hello">>, [<<"world">>]}], + + {ok, Channel} = espdy_server:connect_for_push(ChannelOne, false, Headers), + SynStream = espdy_frame:decode_control_frame(receive_control_frame()), + ?assertMatch(#syn_stream{associated_stream_id = 1, stream_id = 2, flags = ?FLAG_UNIDIRECTIONAL, headers = Headers}, SynStream), + ?assertEqual({error, closed}, espdy_server:recv(Channel, <<"more_data">>)). + +run_ping(Role, ID) -> + Socket = start_socket(Role), + Frame = espdy_frame:encode_ping(ID), + send_frame(Socket, Frame), + + ?assertEqual(Frame, receive_packet()). + +run_receive_ping(Role, ID) -> + Socket = start_socket(Role), + Frame = espdy_frame:encode_ping(ID), + send_frame(Socket, Frame), + + assert_dont_receive_packet(). + + +server_ping_tst() -> + + run_ping(server, 1). + +client_ping_tst() -> + + run_ping(client, 2). + +server_receive_ping_reply_tst() -> + + run_receive_ping(server, 2). + +client_receive_ping_reply_tst() -> + + run_receive_ping(client, 1). + +unknown_cast_tst() -> + gen_server:cast(start_socket(server), lols). + +unknown_call_tst() -> + ?assertEqual({error, unknown_call}, gen_server:call(start_socket(server), lols, infinity)). + +code_change_test() -> + ?assertEqual({ok, state}, espdy_server:code_change(old_vsn, state, extra)). + \ No newline at end of file diff --git a/test_server b/test_server new file mode 100755 index 0000000..f6cd1ca --- /dev/null +++ b/test_server @@ -0,0 +1 @@ +erl -pa ebin -pa deps/lager/ebin -s espdy_test_server