about summary refs log tree commit homepage
diff options
context:
space:
mode:
-rw-r--r--lib/yahns/http_client.rb14
-rw-r--r--lib/yahns/http_response.rb5
-rw-r--r--test/helper.rb3
-rw-r--r--test/test_rack_hijack.rb5
-rw-r--r--test/test_ssl.rb5
5 files changed, 24 insertions, 8 deletions
diff --git a/lib/yahns/http_client.rb b/lib/yahns/http_client.rb
index 7351171..0c656e8 100644
--- a/lib/yahns/http_client.rb
+++ b/lib/yahns/http_client.rb
@@ -193,9 +193,9 @@ class Yahns::HttpClient < Kgio::Socket # :nodoc:
       mkinput_preread # keep looping (@state == :body)
       true
     else # :lazy, false
-      r = k.app.call(env = @hs.env)
-      return :ignore if env.include?(RACK_HIJACK_IO)
-      http_response_write(*r)
+      status, headers, body = k.app.call(env = @hs.env)
+      return :ignore if app_hijacked?(env, body)
+      http_response_write(status, headers, body)
     end
   end
 
@@ -217,7 +217,7 @@ class Yahns::HttpClient < Kgio::Socket # :nodoc:
 
     # run the rack app
     status, headers, body = k.app.call(env.merge!(k.app_defaults))
-    return :ignore if env.include?(RACK_HIJACK_IO)
+    return :ignore if app_hijacked?(env, body)
     if status.to_i == 100
       rv = http_100_response(env) and return rv
       status, headers, body = k.app.call(env)
@@ -298,4 +298,10 @@ class Yahns::HttpClient < Kgio::Socket # :nodoc:
     shutdown rescue nil
     return # always drop the connection on uncaught errors
   end
+
+  def app_hijacked?(env, body)
+    return false unless env.include?(RACK_HIJACK_IO)
+    body.close if body.respond_to?(:close)
+    true
+  end
 end
diff --git a/lib/yahns/http_response.rb b/lib/yahns/http_response.rb
index 1be28bc..fabd4b7 100644
--- a/lib/yahns/http_response.rb
+++ b/lib/yahns/http_response.rb
@@ -69,7 +69,9 @@ module Yahns::HttpResponse # :nodoc:
     end
     wbuf = Yahns::Wbuf.new(body, alive, self.class.output_buffer_tmpdir, ret)
     rv = wbuf.wbuf_write(self, header)
-    body.each { |chunk| rv = wbuf.wbuf_write(self, chunk) } if body
+    if body && ! alive.respond_to?(:call) # skip body.each if hijacked
+      body.each { |chunk| rv = wbuf.wbuf_write(self, chunk) }
+    end
     wbuf_maybe(wbuf, rv)
   end
 
@@ -155,7 +157,6 @@ module Yahns::HttpResponse # :nodoc:
           buf << kv_str(key, value)
         when "rack.hijack"
           hijack = value
-          body = nil # ensure we do not close body
         else
           buf << kv_str(key, value)
         end
diff --git a/test/helper.rb b/test/helper.rb
index 3e9f535..7b8c1aa 100644
--- a/test/helper.rb
+++ b/test/helper.rb
@@ -134,12 +134,13 @@ def require_exec(cmd)
 end
 
 class DieIfUsed
+  @@n = 0
   def each
     abort "body.each called after response hijack\n"
   end
 
   def close
-    abort "body.close called after response hijack\n"
+    warn "INFO #$$ closed DieIfUsed #{@@n += 1}"
   end
 end
 
diff --git a/test/test_rack_hijack.rb b/test/test_rack_hijack.rb
index 3e382eb..5bfc31f 100644
--- a/test/test_rack_hijack.rb
+++ b/test/test_rack_hijack.rb
@@ -48,6 +48,7 @@ class TestRackHijack < Testcase
     cfg.instance_eval do
       GTL.synchronize { app(:rack, HIJACK_APP) { listen "#{host}:#{port}" } }
       logger(Logger.new(err.path))
+      stderr_path err.path
     end
     pid = mkserver(cfg)
     res = Net::HTTP.start(host, port) { |h| h.get("/hijack_req") }
@@ -61,6 +62,10 @@ class TestRackHijack < Testcase
     assert_equal "zzz", res["X-Test"]
     assert_equal "1.1", res.http_version
 
+    errs = File.readlines(err.path).grep(/DieIfUsed/)
+    assert_equal([ "INFO #{pid} closed DieIfUsed 1\n",
+                   "INFO #{pid} closed DieIfUsed 2\n" ], errs)
+
     res = Net::HTTP.start(host, port) do |h|
       hdr = { "Content-Type" => 'application/octet-stream' }
       h.put("/hijack_input", "BLAH", hdr)
diff --git a/test/test_ssl.rb b/test/test_ssl.rb
index 8f01ef7..c54cc3c 100644
--- a/test/test_ssl.rb
+++ b/test/test_ssl.rb
@@ -124,10 +124,11 @@ AQjjxMXhwULlmuR/K+WwlaZPiLIBYalLAZQ7ZbOPeVkJ8ePao0eLAgEC
               p [ :ERR, req ]
             end until s.closed?
           end
-          [ 200, DieIfUsed, DieIfUsed ]
+          [ 200, DieIfUsed.new, DieIfUsed.new ]
         end
         app(:rack, ru) { listen "#{host}:#{port}", ssl_ctx: ctx }
         logger(Logger.new(err.path))
+        stderr_path err.path
       end
     end
     client = ssl_client(host, port)
@@ -147,6 +148,8 @@ AQjjxMXhwULlmuR/K+WwlaZPiLIBYalLAZQ7ZbOPeVkJ8ePao0eLAgEC
       %w(a b c d).each { |x| client.puts(x) }
       assert_equal "abcd", client.gets.strip
     end
+    errs = File.readlines(err.path).grep(/DieIfUsed/)
+    assert_equal([ "INFO #{pid} closed DieIfUsed 1\n" ], errs)
   ensure
     client.close if client
     quit_wait(pid)