summary refs log tree commit
path: root/lib/rack/mock.rb
diff options
context:
space:
mode:
Diffstat (limited to 'lib/rack/mock.rb')
-rw-r--r--lib/rack/mock.rb97
1 files changed, 74 insertions, 23 deletions
diff --git a/lib/rack/mock.rb b/lib/rack/mock.rb
index 4ebc4df1..f2b94832 100644
--- a/lib/rack/mock.rb
+++ b/lib/rack/mock.rb
@@ -1,9 +1,12 @@
+# frozen_string_literal: true
+
 require 'uri'
 require 'stringio'
 require 'rack'
 require 'rack/lint'
 require 'rack/utils'
 require 'rack/response'
+require 'cgi/cookie'
 
 module Rack
   # Rack::MockRequest helps testing your Rack application without
@@ -53,16 +56,16 @@ module Rack
       @app = app
     end
 
-    def get(uri, opts={})     request(GET, uri, opts)     end
-    def post(uri, opts={})    request(POST, uri, opts)    end
-    def put(uri, opts={})     request(PUT, uri, opts)     end
-    def patch(uri, opts={})   request(PATCH, uri, opts)   end
-    def delete(uri, opts={})  request(DELETE, uri, opts)  end
-    def head(uri, opts={})    request(HEAD, uri, opts)    end
-    def options(uri, opts={}) request(OPTIONS, uri, opts) end
+    def get(uri, opts = {})     request(GET, uri, opts)     end
+    def post(uri, opts = {})    request(POST, uri, opts)    end
+    def put(uri, opts = {})     request(PUT, uri, opts)     end
+    def patch(uri, opts = {})   request(PATCH, uri, opts)   end
+    def delete(uri, opts = {})  request(DELETE, uri, opts)  end
+    def head(uri, opts = {})    request(HEAD, uri, opts)    end
+    def options(uri, opts = {}) request(OPTIONS, uri, opts) end
 
-    def request(method=GET, uri="", opts={})
-      env = self.class.env_for(uri, opts.merge(:method => method))
+    def request(method = GET, uri = "", opts = {})
+      env = self.class.env_for(uri, opts.merge(method: method))
 
       if opts[:lint]
         app = Rack::Lint.new(@app)
@@ -71,7 +74,7 @@ module Rack
       end
 
       errors = env[RACK_ERRORS]
-      status, headers, body  = app.call(env)
+      status, headers, body = app.call(env)
       MockResponse.new(status, headers, body, errors)
     ensure
       body.close if body.respond_to?(:close)
@@ -85,19 +88,19 @@ module Rack
     end
 
     # Return the Rack environment used for a request to +uri+.
-    def self.env_for(uri="", opts={})
+    def self.env_for(uri = "", opts = {})
       uri = parse_uri_rfc2396(uri)
       uri.path = "/#{uri.path}" unless uri.path[0] == ?/
 
       env = DEFAULT_ENV.dup
 
-      env[REQUEST_METHOD]  = opts[:method] ? opts[:method].to_s.upcase : GET
-      env[SERVER_NAME]     = uri.host || "example.org"
-      env[SERVER_PORT]     = uri.port ? uri.port.to_s : "80"
-      env[QUERY_STRING]    = uri.query.to_s
-      env[PATH_INFO]       = (!uri.path || uri.path.empty?) ? "/" : uri.path
-      env[RACK_URL_SCHEME] = uri.scheme || "http"
-      env[HTTPS]           = env[RACK_URL_SCHEME] == "https" ? "on" : "off"
+      env[REQUEST_METHOD]  = (opts[:method] ? opts[:method].to_s.upcase : GET).b
+      env[SERVER_NAME]     = (uri.host || "example.org").b
+      env[SERVER_PORT]     = (uri.port ? uri.port.to_s : "80").b
+      env[QUERY_STRING]    = (uri.query.to_s).b
+      env[PATH_INFO]       = ((!uri.path || uri.path.empty?) ? "/" : uri.path).b
+      env[RACK_URL_SCHEME] = (uri.scheme || "http").b
+      env[HTTPS]           = (env[RACK_URL_SCHEME] == "https" ? "on" : "off").b
 
       env[SCRIPT_NAME] = opts[:script_name] || ""
 
@@ -128,7 +131,7 @@ module Rack
         end
       end
 
-      empty_str = String.new.force_encoding(Encoding::ASCII_8BIT)
+      empty_str = String.new
       opts[:input] ||= empty_str
       if String === opts[:input]
         rack_input = StringIO.new(opts[:input])
@@ -139,7 +142,7 @@ module Rack
       rack_input.set_encoding(Encoding::BINARY)
       env[RACK_INPUT] = rack_input
 
-      env["CONTENT_LENGTH"] ||= env[RACK_INPUT].length.to_s
+      env["CONTENT_LENGTH"] ||= env[RACK_INPUT].size.to_s if env[RACK_INPUT].respond_to?(:size)
 
       opts.each { |field, value|
         env[field] = value  if String === field
@@ -155,14 +158,15 @@ module Rack
 
   class MockResponse < Rack::Response
     # Headers
-    attr_reader :original_headers
+    attr_reader :original_headers, :cookies
 
     # Errors
     attr_accessor :errors
 
-    def initialize(status, headers, body, errors=StringIO.new(""))
+    def initialize(status, headers, body, errors = StringIO.new(""))
       @original_headers = headers
       @errors           = errors.string if errors.respond_to?(:string)
+      @cookies = parse_cookies_from_header
 
       super(body, status, headers)
     end
@@ -190,7 +194,54 @@ module Rack
     end
 
     def empty?
-      [201, 204, 205, 304].include? status
+      [201, 204, 304].include? status
+    end
+
+    def cookie(name)
+      cookies.fetch(name, nil)
+    end
+
+    private
+
+    def parse_cookies_from_header
+      cookies = Hash.new
+      if original_headers.has_key? 'Set-Cookie'
+        set_cookie_header = original_headers.fetch('Set-Cookie')
+        set_cookie_header.split("\n").each do |cookie|
+          cookie_name, cookie_filling = cookie.split('=', 2)
+          cookie_attributes = identify_cookie_attributes cookie_filling
+          parsed_cookie = CGI::Cookie.new(
+            'name' => cookie_name.strip,
+            'value' => cookie_attributes.fetch('value'),
+            'path' => cookie_attributes.fetch('path', nil),
+            'domain' => cookie_attributes.fetch('domain', nil),
+            'expires' => cookie_attributes.fetch('expires', nil),
+            'secure' => cookie_attributes.fetch('secure', false)
+          )
+          cookies.store(cookie_name, parsed_cookie)
+        end
+      end
+      cookies
     end
+
+    def identify_cookie_attributes(cookie_filling)
+      cookie_bits = cookie_filling.split(';')
+      cookie_attributes = Hash.new
+      cookie_attributes.store('value', cookie_bits[0].strip)
+      cookie_bits.each do |bit|
+        if bit.include? '='
+          cookie_attribute, attribute_value = bit.split('=')
+          cookie_attributes.store(cookie_attribute.strip, attribute_value.strip)
+          if cookie_attribute.include? 'max-age'
+            cookie_attributes.store('expires', Time.now + attribute_value.strip.to_i)
+          end
+        end
+        if bit.include? 'secure'
+          cookie_attributes.store('secure', true)
+        end
+      end
+      cookie_attributes
+    end
+
   end
 end