about summary refs log tree commit homepage
diff options
context:
space:
mode:
authorEric Wong <e@yhbt.net>2010-12-25 22:44:53 +0000
committerEric Wong <e@yhbt.net>2010-12-25 22:46:55 +0000
commitb859c4a12905cbd71d19cde2aaa9f88ec0374cc5 (patch)
tree9c6ae8e68de32b255c628f9ea2cfaec5b18392af
parentef069ece624906b3946248421620d8458bcef605 (diff)
downloadkgio-b859c4a12905cbd71d19cde2aaa9f88ec0374cc5.tar.gz
This is preferred as we no longer have to rely on a global
constant.
-rw-r--r--ext/kgio/accept.c72
-rw-r--r--test/test_accept_class.rb10
2 files changed, 68 insertions, 14 deletions
diff --git a/ext/kgio/accept.c b/ext/kgio/accept.c
index 00a6563..4762676 100644
--- a/ext/kgio/accept.c
+++ b/ext/kgio/accept.c
@@ -20,6 +20,11 @@ struct accept_args {
         socklen_t *addrlen;
 };
 
+/*
+ * Sets the default class for newly accepted sockets.  This is
+ * legacy behavior, kgio_accept and kgio_tryaccept now take optional
+ * class arguments to override this value.
+ */
 static VALUE set_accepted(VALUE klass, VALUE aclass)
 {
         VALUE tmp;
@@ -39,6 +44,10 @@ static VALUE set_accepted(VALUE klass, VALUE aclass)
         return aclass;
 }
 
+/*
+ * Returns the default class for newly accepted sockets when kgio_accept
+ * or kgio_tryaccept are not passed arguments
+ */
 static VALUE get_accepted(VALUE klass)
 {
         return cClientSocket;
@@ -107,8 +116,19 @@ static int thread_accept(struct accept_args *a, int force_nonblock)
 #define set_blocking_or_block(fd) (void)rb_io_wait_readable(fd)
 #endif /* ! HAVE_RB_THREAD_BLOCKING_REGION */
 
+static VALUE acceptor(int argc, const VALUE *argv)
+{
+        if (argc == 0)
+                return cClientSocket; /* default, legacy behavior */
+        else if (argc == 1)
+                return argv[0];
+
+        rb_raise(rb_eArgError, "wrong number of arguments (%d for 1)", argc);
+}
+
 static VALUE
-my_accept(VALUE io, struct sockaddr *addr, socklen_t *addrlen, int nonblock)
+my_accept(VALUE io, VALUE klass,
+          struct sockaddr *addr, socklen_t *addrlen, int nonblock)
 {
         int client;
         struct accept_args a;
@@ -148,7 +168,7 @@ retry:
                         rb_sys_fail("accept");
                 }
         }
-        return sock_for_fd(cClientSocket, client);
+        return sock_for_fd(klass, client);
 }
 
 static void in_addr_set(VALUE io, struct sockaddr_in *addr)
@@ -175,12 +195,18 @@ static void in_addr_set(VALUE io, struct sockaddr_in *addr)
  * connected client on success.
  *
  * Returns nil on EAGAIN, and raises on other errors.
+ *
+ * An optional class argument may be specified to override the
+ * Kgio::Socket-class return value:
+ *
+ *      server.kgio_accept(MySocket) -> MySocket
  */
-static VALUE tcp_tryaccept(VALUE io)
+static VALUE tcp_tryaccept(int argc, VALUE *argv, VALUE io)
 {
         struct sockaddr_in addr;
         socklen_t addrlen = sizeof(struct sockaddr_in);
-        VALUE rv = my_accept(io, (struct sockaddr *)&addr, &addrlen, 1);
+        VALUE klass = acceptor(argc, argv);
+        VALUE rv = my_accept(io, klass, (struct sockaddr *)&addr, &addrlen, 1);
 
         if (!NIL_P(rv))
                 in_addr_set(rv, &addr);
@@ -199,12 +225,18 @@ static VALUE tcp_tryaccept(VALUE io)
  *
  * On Ruby implementations using native threads, this can use a blocking
  * accept(2) (or accept4(2)) system call to avoid thundering herds.
+ *
+ * An optional class argument may be specified to override the
+ * Kgio::Socket-class return value:
+ *
+ *      server.kgio_accept(MySocket) -> MySocket
  */
-static VALUE tcp_accept(VALUE io)
+static VALUE tcp_accept(int argc, VALUE *argv, VALUE io)
 {
         struct sockaddr_in addr;
         socklen_t addrlen = sizeof(struct sockaddr_in);
-        VALUE rv = my_accept(io, (struct sockaddr *)&addr, &addrlen, 0);
+        VALUE klass = acceptor(argc, argv);
+        VALUE rv = my_accept(io, klass, (struct sockaddr *)&addr, &addrlen, 0);
 
         in_addr_set(rv, &addr);
         return rv;
@@ -221,10 +253,16 @@ static VALUE tcp_accept(VALUE io)
  * Kgio::LOCALHOST) on success.
  *
  * Returns nil on EAGAIN, and raises on other errors.
+ *
+ * An optional class argument may be specified to override the
+ * Kgio::Socket-class return value:
+ *
+ *      server.kgio_tryaccept(MySocket) -> MySocket
  */
-static VALUE unix_tryaccept(VALUE io)
+static VALUE unix_tryaccept(int argc, VALUE *argv, VALUE io)
 {
-        VALUE rv = my_accept(io, NULL, NULL, 1);
+        VALUE klass = acceptor(argc, argv);
+        VALUE rv = my_accept(io, klass, NULL, NULL, 1);
 
         if (!NIL_P(rv))
                 rb_ivar_set(rv, iv_kgio_addr, localhost);
@@ -243,10 +281,16 @@ static VALUE unix_tryaccept(VALUE io)
  *
  * On Ruby implementations using native threads, this can use a blocking
  * accept(2) (or accept4(2)) system call to avoid thundering herds.
+ *
+ * An optional class argument may be specified to override the
+ * Kgio::Socket-class return value:
+ *
+ *      server.kgio_accept(MySocket) -> MySocket
  */
-static VALUE unix_accept(VALUE io)
+static VALUE unix_accept(int argc, VALUE *argv, VALUE io)
 {
-        VALUE rv = my_accept(io, NULL, NULL, 0);
+        VALUE klass = acceptor(argc, argv);
+        VALUE rv = my_accept(io, klass, NULL, NULL, 0);
 
         rb_ivar_set(rv, iv_kgio_addr, localhost);
         return rv;
@@ -360,13 +404,13 @@ void init_kgio_accept(void)
 
         cUNIXServer = rb_const_get(rb_cObject, rb_intern("UNIXServer"));
         cUNIXServer = rb_define_class_under(mKgio, "UNIXServer", cUNIXServer);
-        rb_define_method(cUNIXServer, "kgio_tryaccept", unix_tryaccept, 0);
-        rb_define_method(cUNIXServer, "kgio_accept", unix_accept, 0);
+        rb_define_method(cUNIXServer, "kgio_tryaccept", unix_tryaccept, -1);
+        rb_define_method(cUNIXServer, "kgio_accept", unix_accept, -1);
 
         cTCPServer = rb_const_get(rb_cObject, rb_intern("TCPServer"));
         cTCPServer = rb_define_class_under(mKgio, "TCPServer", cTCPServer);
-        rb_define_method(cTCPServer, "kgio_tryaccept", tcp_tryaccept, 0);
-        rb_define_method(cTCPServer, "kgio_accept", tcp_accept, 0);
+        rb_define_method(cTCPServer, "kgio_tryaccept", tcp_tryaccept, -1);
+        rb_define_method(cTCPServer, "kgio_accept", tcp_accept, -1);
         init_sock_for_fd();
         iv_kgio_addr = rb_intern("@kgio_addr");
 }
diff --git a/test/test_accept_class.rb b/test/test_accept_class.rb
index 3b5d343..cf59a2f 100644
--- a/test/test_accept_class.rb
+++ b/test/test_accept_class.rb
@@ -4,6 +4,9 @@ $-w = true
 require 'kgio'
 
 class TestAcceptClass < Test::Unit::TestCase
+  class FooSocket < Kgio::Socket
+  end
+
   def setup
     assert_equal Kgio::Socket, Kgio.accept_class
   end
@@ -48,5 +51,12 @@ class TestAcceptClass < Test::Unit::TestCase
     client = TCPSocket.new(@host, @port)
     IO.select([@srv])
     assert_instance_of Kgio::UNIXSocket, @srv.kgio_tryaccept
+
+    client = TCPSocket.new(@host, @port)
+    assert_instance_of FooSocket, @srv.kgio_accept(FooSocket)
+
+    client = TCPSocket.new(@host, @port)
+    IO.select([@srv])
+    assert_instance_of FooSocket, @srv.kgio_tryaccept(FooSocket)
   end
 end