about summary refs log tree commit homepage
diff options
context:
space:
mode:
authorEric Wong <e@yhbt.net>2010-09-28 18:04:51 -0700
committerEric Wong <e@yhbt.net>2010-09-28 18:12:54 -0700
commit911f6ab306aff1e24c9c570eeae33923fa1b99d9 (patch)
tree2b7732d29982e5641d192cd74d219583fadb5534
parent526b4bd48a20a34ef5959fdc4aa580d5f9199652 (diff)
downloadkgio-911f6ab306aff1e24c9c570eeae33923fa1b99d9.tar.gz
These can be useful for avoiding wrapper objects and
also allows users to more easily try different things
without stepping on others' toe^H^H^Hclasses.
-rw-r--r--ext/kgio/kgio_ext.c33
-rw-r--r--test/test_accept_class.rb52
2 files changed, 83 insertions, 2 deletions
diff --git a/ext/kgio/kgio_ext.c b/ext/kgio/kgio_ext.c
index 3b20064..c630ab5 100644
--- a/ext/kgio/kgio_ext.c
+++ b/ext/kgio/kgio_ext.c
@@ -32,6 +32,8 @@ static int accept4_flags = A4_SOCK_CLOEXEC;
 static int accept4_flags = A4_SOCK_CLOEXEC | A4_SOCK_NONBLOCK;
 #endif /* ! linux */
 
+static VALUE cClientSocket;
+static VALUE mSocketMethods;
 static VALUE cSocket;
 static VALUE localhost;
 static VALUE mKgio_WaitReadable, mKgio_WaitWritable;
@@ -533,7 +535,7 @@ retry:
                         rb_sys_fail("accept");
                 }
         }
-        return sock_for_fd(cSocket, client);
+        return sock_for_fd(cClientSocket, client);
 }
 
 static void in_addr_set(VALUE io, struct sockaddr_in *addr)
@@ -950,10 +952,34 @@ static VALUE kgio_start(VALUE klass, VALUE addr)
         return stream_connect(klass, addr, 0);
 }
 
+static VALUE set_accepted(VALUE klass, VALUE aclass)
+{
+        VALUE tmp;
+
+        if (NIL_P(aclass))
+                aclass = cSocket;
+
+        tmp = rb_funcall(aclass, rb_intern("included_modules"), 0, 0);
+        tmp = rb_funcall(tmp, rb_intern("include?"), 1, mSocketMethods);
+
+        if (tmp != Qtrue)
+                rb_raise(rb_eTypeError,
+                         "class must include Kgio::SocketMethods");
+
+        cClientSocket = aclass;
+
+        return aclass;
+}
+
+static VALUE get_accepted(VALUE klass)
+{
+        return cClientSocket;
+}
+
 void Init_kgio_ext(void)
 {
         VALUE mKgio = rb_define_module("Kgio");
-        VALUE mPipeMethods, mSocketMethods;
+        VALUE mPipeMethods;
         VALUE cUNIXServer, cTCPServer, cUNIXSocket, cTCPSocket;
 
         rb_require("socket");
@@ -967,6 +993,7 @@ void Init_kgio_ext(void)
          */
         cSocket = rb_const_get(rb_cObject, rb_intern("Socket"));
         cSocket = rb_define_class_under(mKgio, "Socket", cSocket);
+        cClientSocket = cSocket;
 
         localhost = rb_str_new2("127.0.0.1");
 
@@ -1001,6 +1028,8 @@ void Init_kgio_ext(void)
         rb_define_singleton_method(mKgio, "accept_cloexec=", set_cloexec, 1);
         rb_define_singleton_method(mKgio, "accept_nonblock?", get_nonblock, 0);
         rb_define_singleton_method(mKgio, "accept_nonblock=", set_nonblock, 1);
+        rb_define_singleton_method(mKgio, "accept_class=", set_accepted, 1);
+        rb_define_singleton_method(mKgio, "accept_class", get_accepted, 0);
 
         /*
          * Document-module: Kgio::PipeMethods
diff --git a/test/test_accept_class.rb b/test/test_accept_class.rb
new file mode 100644
index 0000000..3b5d343
--- /dev/null
+++ b/test/test_accept_class.rb
@@ -0,0 +1,52 @@
+require 'test/unit'
+require 'io/nonblock'
+$-w = true
+require 'kgio'
+
+class TestAcceptClass < Test::Unit::TestCase
+  def setup
+    assert_equal Kgio::Socket, Kgio.accept_class
+  end
+
+  def teardown
+    assert_nothing_raised { Kgio.accept_class = nil }
+    assert_equal Kgio::Socket, Kgio.accept_class
+  end
+
+  def test_tcp_socket
+    assert_nothing_raised { Kgio.accept_class = Kgio::TCPSocket }
+    assert_equal Kgio::TCPSocket, Kgio.accept_class
+  end
+
+  def test_invalid
+    assert_raises(TypeError) { Kgio.accept_class = TCPSocket }
+    assert_equal Kgio::Socket, Kgio.accept_class
+  end
+
+  def test_accepted_class
+    @host = ENV["TEST_HOST"] || '127.0.0.1'
+    @srv = Kgio::TCPServer.new(@host, 0)
+    @port = @srv.addr[1]
+
+    assert_nothing_raised { Kgio.accept_class = Kgio::TCPSocket }
+    client = TCPSocket.new(@host, @port)
+    assert_instance_of Kgio::TCPSocket, @srv.kgio_accept
+    client = TCPSocket.new(@host, @port)
+    IO.select([@srv])
+    assert_instance_of Kgio::TCPSocket, @srv.kgio_tryaccept
+
+    assert_nothing_raised { Kgio.accept_class = nil }
+    client = TCPSocket.new(@host, @port)
+    assert_instance_of Kgio::Socket, @srv.kgio_accept
+    client = TCPSocket.new(@host, @port)
+    IO.select([@srv])
+    assert_instance_of Kgio::Socket, @srv.kgio_tryaccept
+
+    assert_nothing_raised { Kgio.accept_class = Kgio::UNIXSocket }
+    client = TCPSocket.new(@host, @port)
+    assert_instance_of Kgio::UNIXSocket, @srv.kgio_accept
+    client = TCPSocket.new(@host, @port)
+    IO.select([@srv])
+    assert_instance_of Kgio::UNIXSocket, @srv.kgio_tryaccept
+  end
+end