about summary refs log tree commit homepage
diff options
context:
space:
mode:
-rw-r--r--ext/unicorn_http/global_variables.h9
-rw-r--r--ext/unicorn_http/unicorn_http.rl55
-rw-r--r--test/unit/test_http_parser.rb54
-rw-r--r--test/unit/test_http_parser_ng.rb18
4 files changed, 128 insertions, 8 deletions
diff --git a/ext/unicorn_http/global_variables.h b/ext/unicorn_http/global_variables.h
index 1383ed4..b8b6221 100644
--- a/ext/unicorn_http/global_variables.h
+++ b/ext/unicorn_http/global_variables.h
@@ -15,16 +15,19 @@ static VALUE g_path_info;
 static VALUE g_server_name;
 static VALUE g_server_port;
 static VALUE g_server_protocol;
-static VALUE g_server_protocol_value;
 static VALUE g_http_host;
 static VALUE g_http_x_forwarded_proto;
 static VALUE g_http_transfer_encoding;
 static VALUE g_content_length;
 static VALUE g_http_trailer;
+static VALUE g_http_connection;
 static VALUE g_port_80;
 static VALUE g_port_443;
 static VALUE g_localhost;
 static VALUE g_http;
+static VALUE g_http_11;
+static VALUE g_GET;
+static VALUE g_HEAD;
 
 /** Defines common length and error messages for input length validation. */
 #define DEF_MAX_LENGTH(N, length) \
@@ -71,12 +74,14 @@ void init_globals(void)
   DEF_GLOBAL(server_name, "SERVER_NAME");
   DEF_GLOBAL(server_port, "SERVER_PORT");
   DEF_GLOBAL(server_protocol, "SERVER_PROTOCOL");
-  DEF_GLOBAL(server_protocol_value, "HTTP/1.1");
   DEF_GLOBAL(http_x_forwarded_proto, "HTTP_X_FORWARDED_PROTO");
   DEF_GLOBAL(port_80, "80");
   DEF_GLOBAL(port_443, "443");
   DEF_GLOBAL(localhost, "localhost");
   DEF_GLOBAL(http, "http");
+  DEF_GLOBAL(http_11, "HTTP/1.1");
+  DEF_GLOBAL(GET, "GET");
+  DEF_GLOBAL(HEAD, "HEAD");
 
   eHttpParserError =
          rb_define_class_under(mUnicorn, "HttpParserError", rb_eIOError);
diff --git a/ext/unicorn_http/unicorn_http.rl b/ext/unicorn_http/unicorn_http.rl
index ac74990..220069b 100644
--- a/ext/unicorn_http/unicorn_http.rl
+++ b/ext/unicorn_http/unicorn_http.rl
@@ -18,6 +18,10 @@
 #define UH_FL_HASTRAILER 0x8
 #define UH_FL_INTRAILER 0x10
 #define UH_FL_INCHUNK  0x20
+#define UH_FL_KAMETHOD 0x40
+#define UH_FL_KAVERSION 0x80
+
+#define UH_FL_KEEPALIVE (UH_FL_KAMETHOD | UH_FL_KAVERSION)
 
 struct http_parser {
   int cs; /* Ragel internal state */
@@ -46,6 +50,37 @@ static void finalize_header(VALUE req);
 #define PTR_TO(F) (buffer + hp->F)
 #define STR_NEW(M,FPC) rb_str_new(PTR_TO(M), LEN(M, FPC))
 
+static void
+request_method(struct http_parser *hp, VALUE req, const char *ptr, size_t len)
+{
+  VALUE v;
+
+  if (CONST_MEM_EQ("GET", ptr, len)) {
+    hp->flags |= UH_FL_KAMETHOD;
+    v = g_GET;
+  } else if (CONST_MEM_EQ("HEAD", ptr, len)) {
+    hp->flags |= UH_FL_KAMETHOD;
+    v = g_HEAD;
+  } else {
+    v = rb_str_new(ptr, len);
+  }
+  rb_hash_aset(req, g_request_method, v);
+}
+
+static void
+http_version(struct http_parser *hp, VALUE req, const char *ptr, size_t len)
+{
+  VALUE v;
+
+  if (CONST_MEM_EQ("HTTP/1.1", ptr, len)) {
+    hp->flags |= UH_FL_KAVERSION;
+    v = g_http_11;
+  } else {
+    v = rb_str_new(ptr, len);
+  }
+  rb_hash_aset(req, g_http_version, v);
+}
+
 static void invalid_if_trailer(int flags)
 {
   if (flags & UH_FL_INTRAILER)
@@ -64,6 +99,9 @@ static void write_value(VALUE req, struct http_parser *hp,
   if (f == Qnil) {
     VALIDATE_MAX_LENGTH(hp->s.field_len, FIELD_NAME);
     f = uncommon_field(PTR_TO(start.field), hp->s.field_len);
+  } else if (f == g_http_connection) {
+    if (STR_CSTR_CASE_EQ(v, "close"))
+      hp->flags &= ~UH_FL_KEEPALIVE;
   } else if (f == g_content_length) {
     hp->len.content = parse_length(RSTRING_PTR(v), RSTRING_LEN(v));
     if (hp->len.content < 0)
@@ -103,7 +141,7 @@ static void write_value(VALUE req, struct http_parser *hp,
   action start_value { MARK(mark, fpc); }
   action write_value { write_value(req, hp, buffer, fpc); }
   action request_method {
-    rb_hash_aset(req, g_request_method, STR_NEW(mark, fpc));
+    request_method(hp, req, PTR_TO(mark), LEN(mark, fpc));
   }
   action scheme {
     rb_hash_aset(req, g_rack_url_scheme, STR_NEW(mark, fpc));
@@ -136,9 +174,7 @@ static void write_value(VALUE req, struct http_parser *hp,
     VALIDATE_MAX_LENGTH(LEN(start.query, fpc), QUERY_STRING);
     rb_hash_aset(req, g_query_string, STR_NEW(start.query, fpc));
   }
-  action http_version {
-    rb_hash_aset(req, g_http_version, STR_NEW(mark, fpc));
-  }
+  action http_version { http_version(hp, req, PTR_TO(mark), LEN(mark, fpc)); }
   action request_path {
     VALUE val;
     size_t len = LEN(mark, fpc);
@@ -295,7 +331,7 @@ static void finalize_header(VALUE req)
   }
   rb_hash_aset(req, g_server_name, server_name);
   rb_hash_aset(req, g_server_port, server_port);
-  rb_hash_aset(req, g_server_protocol, g_server_protocol_value);
+  rb_hash_aset(req, g_server_protocol, g_http_11);
 
   /* rack requires QUERY_STRING */
   if (rb_hash_aref(req, g_query_string) == Qnil)
@@ -407,6 +443,13 @@ static VALUE HttpParser_body_eof(VALUE self)
   return hp->len.content == 0 ? Qtrue : Qfalse;
 }
 
+static VALUE HttpParser_keepalive(VALUE self)
+{
+  struct http_parser *hp = data_get(self);
+
+  return (hp->flags & UH_FL_KEEPALIVE) == UH_FL_KEEPALIVE ? Qtrue : Qfalse;
+}
+
 /**
  * call-seq:
  *    parser.filter_body(buf, data) -> nil/data
@@ -485,6 +528,7 @@ void Init_unicorn_http(void)
   rb_define_method(cHttpParser, "trailers", HttpParser_headers, 2);
   rb_define_method(cHttpParser, "content_length", HttpParser_content_length, 0);
   rb_define_method(cHttpParser, "body_eof?", HttpParser_body_eof, 0);
+  rb_define_method(cHttpParser, "keepalive?", HttpParser_keepalive, 0);
 
   /*
    * The maximum size a single chunk when using chunked transfer encoding.
@@ -506,5 +550,6 @@ void Init_unicorn_http(void)
   SET_GLOBAL(g_http_trailer, "TRAILER");
   SET_GLOBAL(g_http_transfer_encoding, "TRANSFER_ENCODING");
   SET_GLOBAL(g_content_length, "CONTENT_LENGTH");
+  SET_GLOBAL(g_http_connection, "CONNECTION");
 }
 #undef SET_GLOBAL
diff --git a/test/unit/test_http_parser.rb b/test/unit/test_http_parser.rb
index 7072571..6cef678 100644
--- a/test/unit/test_http_parser.rb
+++ b/test/unit/test_http_parser.rb
@@ -9,7 +9,7 @@ require 'test/test_helper'
 include Unicorn
 
 class HttpParserTest < Test::Unit::TestCase
-    
+
   def test_parse_simple
     parser = HttpParser.new
     req = {}
@@ -25,6 +25,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_nil req['FRAGMENT']
     assert_equal '', req['QUERY_STRING']
 
+    assert parser.keepalive?
     parser.reset
     req.clear
 
@@ -45,6 +46,40 @@ class HttpParserTest < Test::Unit::TestCase
     assert_nil req['FRAGMENT']
     assert_equal '', req['QUERY_STRING']
     assert_equal '', http
+    assert parser.keepalive?
+  end
+
+  def test_connection_close_no_ka
+    parser = HttpParser.new
+    req = {}
+    tmp = "GET / HTTP/1.1\r\nConnection: close\r\n\r\n"
+    assert_equal req.object_id, parser.headers(req, tmp).object_id
+    assert_equal "GET", req['REQUEST_METHOD']
+    assert ! parser.keepalive?
+  end
+
+  def test_connection_keep_alive_ka
+    parser = HttpParser.new
+    req = {}
+    tmp = "HEAD / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n"
+    assert_equal req.object_id, parser.headers(req, tmp).object_id
+    assert parser.keepalive?
+  end
+
+  def test_connection_keep_alive_ka_bad_method
+    parser = HttpParser.new
+    req = {}
+    tmp = "POST / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n"
+    assert_equal req.object_id, parser.headers(req, tmp).object_id
+    assert ! parser.keepalive?
+  end
+
+  def test_connection_keep_alive_ka_bad_version
+    parser = HttpParser.new
+    req = {}
+    tmp = "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n"
+    assert_equal req.object_id, parser.headers(req, tmp).object_id
+    assert ! parser.keepalive?
   end
 
   def test_parse_server_host_default_port
@@ -55,6 +90,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'foo', req['SERVER_NAME']
     assert_equal '80', req['SERVER_PORT']
     assert_equal '', tmp
+    assert parser.keepalive?
   end
 
   def test_parse_server_host_alt_port
@@ -65,6 +101,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'foo', req['SERVER_NAME']
     assert_equal '999', req['SERVER_PORT']
     assert_equal '', tmp
+    assert parser.keepalive?
   end
 
   def test_parse_server_host_empty_port
@@ -75,6 +112,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'foo', req['SERVER_NAME']
     assert_equal '80', req['SERVER_PORT']
     assert_equal '', tmp
+    assert parser.keepalive?
   end
 
   def test_parse_server_host_xfp_https
@@ -86,6 +124,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'foo', req['SERVER_NAME']
     assert_equal '443', req['SERVER_PORT']
     assert_equal '', tmp
+    assert parser.keepalive?
   end
 
   def test_parse_strange_headers
@@ -94,6 +133,7 @@ class HttpParserTest < Test::Unit::TestCase
     should_be_good = "GET / HTTP/1.1\r\naaaaaaaaaaaaa:++++++++++\r\n\r\n"
     assert_equal req, parser.headers(req, should_be_good)
     assert_equal '', should_be_good
+    assert parser.keepalive?
 
     # ref: http://thread.gmane.org/gmane.comp.lang.ruby.mongrel.devel/37/focus=45
     # (note we got 'pen' mixed up with 'pound' in that thread,
@@ -119,6 +159,7 @@ class HttpParserTest < Test::Unit::TestCase
       assert_equal req, parser.headers(req, sorta_safe)
       assert_equal path, req['REQUEST_URI']
       assert_equal '', sorta_safe
+      assert parser.keepalive?
     end
   end
   
@@ -133,6 +174,7 @@ class HttpParserTest < Test::Unit::TestCase
     parser.reset
     req.clear
     assert_equal req, parser.headers(req, "GET / HTTP/1.0\r\n\r\n")
+    assert ! parser.keepalive?
   end
 
   def test_piecemeal
@@ -153,6 +195,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_nil req['FRAGMENT']
     assert_equal '', req['QUERY_STRING']
     assert_equal "", http
+    assert ! parser.keepalive?
   end
 
   # not common, but underscores do appear in practice
@@ -170,6 +213,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'under_score.example.com', req['SERVER_NAME']
     assert_equal '80', req['SERVER_PORT']
     assert_equal "", http
+    assert ! parser.keepalive?
   end
 
   def test_absolute_uri
@@ -186,6 +230,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'example.com', req['SERVER_NAME']
     assert_equal '80', req['SERVER_PORT']
     assert_equal "", http
+    assert ! parser.keepalive?
   end
 
   # X-Forwarded-Proto is not in rfc2616, absolute URIs are, however...
@@ -204,6 +249,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'example.com', req['SERVER_NAME']
     assert_equal '443', req['SERVER_PORT']
     assert_equal "", http
+    assert parser.keepalive?
   end
 
   # Host: header should be ignored for absolute URIs
@@ -222,6 +268,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'example.com', req['SERVER_NAME']
     assert_equal '8080', req['SERVER_PORT']
     assert_equal "", http
+    assert ! parser.keepalive? # TODO: read HTTP/1.2 when it's final
   end
 
   def test_absolute_uri_with_empty_port
@@ -239,6 +286,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'example.com', req['SERVER_NAME']
     assert_equal '443', req['SERVER_PORT']
     assert_equal "", http
+    assert parser.keepalive? # TODO: read HTTP/1.2 when it's final
   end
 
   def test_put_body_oneshot
@@ -252,6 +300,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'HTTP/1.0', req['HTTP_VERSION']
     assert_equal 'HTTP/1.1', req['SERVER_PROTOCOL']
     assert_equal "abcde", http
+    assert ! parser.keepalive? # TODO: read HTTP/1.2 when it's final
   end
 
   def test_put_body_later
@@ -265,6 +314,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'HTTP/1.0', req['HTTP_VERSION']
     assert_equal 'HTTP/1.1', req['SERVER_PROTOCOL']
     assert_equal "", http
+    assert ! parser.keepalive? # TODO: read HTTP/1.2 when it's final
   end
 
   def test_unknown_methods
@@ -282,6 +332,7 @@ class HttpParserTest < Test::Unit::TestCase
       assert_equal 'page=1', req['QUERY_STRING']
       assert_equal "", s
       assert_equal m, req['REQUEST_METHOD']
+      assert ! parser.keepalive? # TODO: read HTTP/1.2 when it's final
     }
   end
 
@@ -298,6 +349,7 @@ class HttpParserTest < Test::Unit::TestCase
     assert_equal 'posts-17408', req['FRAGMENT']
     assert_equal 'page=1', req['QUERY_STRING']
     assert_equal '', get
+    assert parser.keepalive?
   end
 
   # lame random garbage maker
diff --git a/test/unit/test_http_parser_ng.rb b/test/unit/test_http_parser_ng.rb
index 8aed270..bacf2cf 100644
--- a/test/unit/test_http_parser_ng.rb
+++ b/test/unit/test_http_parser_ng.rb
@@ -20,6 +20,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal req.object_id, @parser.headers(req, str).object_id
     assert_equal '123', req['CONTENT_LENGTH']
     assert_equal 0, str.size
+    assert ! @parser.keepalive?
   end
 
   def test_identity_oneshot_header
@@ -28,6 +29,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal req.object_id, @parser.headers(req, str).object_id
     assert_equal '123', req['CONTENT_LENGTH']
     assert_equal 0, str.size
+    assert ! @parser.keepalive?
   end
 
   def test_identity_oneshot_header_with_body
@@ -45,6 +47,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal 0, str.size
     assert_equal tmp, body
     assert_equal "", @parser.filter_body(tmp, str)
+    assert ! @parser.keepalive?
   end
 
   def test_identity_oneshot_header_with_body_partial
@@ -62,6 +65,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_nil rv
     assert_equal "", str
     assert_equal str.object_id, @parser.filter_body(tmp, str).object_id
+    assert ! @parser.keepalive?
   end
 
   def test_identity_oneshot_header_with_body_slop
@@ -75,6 +79,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal "G", @parser.filter_body(tmp, str)
     assert_equal 1, tmp.size
     assert_equal "a", tmp
+    assert ! @parser.keepalive?
   end
 
   def test_chunked
@@ -96,6 +101,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     rv = "PUT"
     assert_equal rv.object_id, @parser.filter_body(tmp, rv).object_id
     assert_equal "PUT", rv
+    assert ! @parser.keepalive?
   end
 
   def test_two_chunks
@@ -127,6 +133,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     rv = @parser.filter_body(tmp, buf = "\nGET")
     assert_equal "GET", rv
     assert_equal buf.object_id, rv.object_id
+    assert ! @parser.keepalive?
   end
 
   def test_big_chunk
@@ -146,6 +153,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert ! @parser.body_eof?
     assert_equal "", @parser.filter_body(tmp, "\r\n0\r\n")
     assert @parser.body_eof?
+    assert ! @parser.keepalive?
   end
 
   def test_two_chunks_oneshot
@@ -158,6 +166,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal 'a..', tmp
     rv = @parser.filter_body(tmp, str)
     assert_equal rv.object_id, str.object_id
+    assert ! @parser.keepalive?
   end
 
   def test_trailers
@@ -184,6 +193,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_nil @parser.trailers(req, str << "\r")
     assert_equal req, @parser.trailers(req, str << "\nGET / ")
     assert_equal "GET / ", str
+    assert ! @parser.keepalive?
   end
 
   def test_max_chunk
@@ -194,6 +204,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal req, @parser.headers(req, str)
     assert_nil @parser.content_length
     assert_nothing_raised { @parser.filter_body('', str) }
+    assert ! @parser.keepalive?
   end
 
   def test_max_body
@@ -202,6 +213,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     req = {}
     assert_nothing_raised { @parser.headers(req, str) }
     assert_equal n, req['CONTENT_LENGTH'].to_i
+    assert ! @parser.keepalive?
   end
 
   def test_overflow_chunk
@@ -213,12 +225,14 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal req, @parser.headers(req, str)
     assert_nil @parser.content_length
     assert_raise(HttpParserError) { @parser.filter_body('', str) }
+    assert ! @parser.keepalive?
   end
 
   def test_overflow_content_length
     n = HttpParser::LENGTH_MAX + 1
     str = "PUT / HTTP/1.1\r\nContent-Length: #{n}\r\n\r\n"
     assert_raise(HttpParserError) { @parser.headers({}, str) }
+    assert ! @parser.keepalive?
   end
 
   def test_bad_chunk
@@ -229,11 +243,13 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal req, @parser.headers(req, str)
     assert_nil @parser.content_length
     assert_raise(HttpParserError) { @parser.filter_body('', str) }
+    assert ! @parser.keepalive?
   end
 
   def test_bad_content_length
     str = "PUT / HTTP/1.1\r\nContent-Length: 7ff\r\n\r\n"
     assert_raise(HttpParserError) { @parser.headers({}, str) }
+    assert ! @parser.keepalive?
   end
 
   def test_bad_trailers
@@ -250,6 +266,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     assert_equal '', str
     str << "Transfer-Encoding: identity\r\n\r\n"
     assert_raise(HttpParserError) { @parser.trailers(req, str) }
+    assert ! @parser.keepalive?
   end
 
   def test_repeat_headers
@@ -261,6 +278,7 @@ class HttpParserNgTest < Test::Unit::TestCase
     req = {}
     assert_equal req, @parser.headers(req, str)
     assert_equal 'Content-MD5,Content-SHA1', req['HTTP_TRAILER']
+    assert ! @parser.keepalive?
   end
 
 end