about summary refs log tree commit homepage
diff options
context:
space:
mode:
authorEric Wong <normalperson@yhbt.net>2011-03-20 16:27:55 -0700
committerEric Wong <normalperson@yhbt.net>2011-03-20 16:27:55 -0700
commit59782a15d0be87130934cbecb34ed639be68b44a (patch)
treed8f8a0b5cccf979bd1ea5a4853a1937e95ec807e
parent0ef079617b7d71cc26574247918c4a3e18454b21 (diff)
downloadkgio-59782a15d0be87130934cbecb34ed639be68b44a.tar.gz
This allows applications to not rely on global accept4_flags
-rw-r--r--ext/kgio/accept.c162
-rw-r--r--test/lib_server_accept.rb19
2 files changed, 111 insertions, 70 deletions
diff --git a/ext/kgio/accept.c b/ext/kgio/accept.c
index b8efba5..c16ba36 100644
--- a/ext/kgio/accept.c
+++ b/ext/kgio/accept.c
@@ -19,6 +19,8 @@ struct accept_args {
         int flags;
         struct sockaddr *addr;
         socklen_t *addrlen;
+        VALUE accept_io;
+        VALUE accepted_class;
 };
 
 /*
@@ -54,6 +56,10 @@ static VALUE get_accepted(VALUE klass)
         return cClientSocket;
 }
 
+/*
+ * accept() wrapper that'll fall back on accept() if we were built on
+ * a system with accept4() but run on a system without accept4()
+ */
 static VALUE xaccept(void *ptr)
 {
         struct accept_args *a = ptr;
@@ -125,16 +131,56 @@ static int thread_accept(struct accept_args *a, int force_nonblock)
 #define set_blocking_or_block(fd) (void)rb_io_wait_readable(fd)
 #endif /* ! HAVE_RB_THREAD_BLOCKING_REGION */
 
-static VALUE acceptor(int argc, const VALUE *argv)
+static void
+prepare_accept(struct accept_args *a, VALUE self, int argc, const VALUE *argv)
 {
-        if (argc == 0)
-                return cClientSocket; /* default, legacy behavior */
-        else if (argc == 1)
-                return argv[0];
+        a->fd = my_fileno(self);
+        a->accept_io = self;
+
+        switch (argc) {
+        case 2:
+                a->flags = NUM2INT(argv[1]);
+                a->accepted_class = NIL_P(argv[0]) ? cClientSocket : argv[0];
+                return;
+        case 0: /* default, legacy behavior */
+                a->flags = accept4_flags;
+                a->accepted_class = cClientSocket;
+                return;
+        case 1:
+                a->flags = accept4_flags;
+                a->accepted_class = NIL_P(argv[0]) ? cClientSocket : argv[0];
+                return;
+        }
 
         rb_raise(rb_eArgError, "wrong number of arguments (%d for 1)", argc);
 }
 
+static VALUE in_addr_set(VALUE io, struct sockaddr_storage *addr, socklen_t len)
+{
+        VALUE host;
+        int host_len, rc;
+        char *host_ptr;
+
+        switch (addr->ss_family) {
+        case AF_INET:
+                host_len = (long)INET_ADDRSTRLEN;
+                break;
+        case AF_INET6:
+                host_len = (long)INET6_ADDRSTRLEN;
+                break;
+        default:
+                rb_raise(rb_eRuntimeError, "unsupported address family");
+        }
+        host = rb_str_new(NULL, host_len);
+        host_ptr = RSTRING_PTR(host);
+        rc = getnameinfo((struct sockaddr *)addr, len,
+                         host_ptr, host_len, NULL, 0, NI_NUMERICHOST);
+        if (rc != 0)
+                rb_raise(rb_eRuntimeError, "getnameinfo: %s", gai_strerror(rc));
+        rb_str_set_len(host, strlen(host_ptr));
+        return rb_ivar_set(io, iv_kgio_addr, host);
+}
+
 #if defined(__linux__)
 #  define post_accept kgio_autopush_accept
 #else
@@ -142,25 +188,19 @@ static VALUE acceptor(int argc, const VALUE *argv)
 #endif
 
 static VALUE
-my_accept(VALUE accept_io, VALUE klass,
-          struct sockaddr *addr, socklen_t *addrlen, int nonblock)
+my_accept(struct accept_args *a, int force_nonblock)
 {
-        int client;
+        int client_fd;
         VALUE client_io;
-        struct accept_args a;
 
-        a.fd = my_fileno(accept_io);
-        a.addr = addr;
-        a.addrlen = addrlen;
-        a.flags = accept4_flags;
 retry:
-        client = thread_accept(&a, nonblock);
-        if (client == -1) {
+        client_fd = thread_accept(a, force_nonblock);
+        if (client_fd == -1) {
                 switch (errno) {
                 case EAGAIN:
-                        if (nonblock)
+                        if (force_nonblock)
                                 return Qnil;
-                        set_blocking_or_block(a.fd);
+                        set_blocking_or_block(a->fd);
 #ifdef ECONNABORTED
                 case ECONNABORTED:
 #endif /* ECONNABORTED */
@@ -177,45 +217,25 @@ retry:
 #endif /* ENOBUFS */
                         errno = 0;
                         rb_gc();
-                        client = thread_accept(&a, nonblock);
+                        client_fd = thread_accept(a, force_nonblock);
                 }
-                if (client == -1) {
+                if (client_fd == -1) {
                         if (errno == EINTR)
                                 goto retry;
                         rb_sys_fail("accept");
                 }
         }
-        client_io = sock_for_fd(klass, client);
-        post_accept(accept_io, client_io);
+        client_io = sock_for_fd(a->accepted_class, client_fd);
+        post_accept(a->accept_io, client_io);
+
+        if (a->addr)
+                in_addr_set(client_io,
+                            (struct sockaddr_storage *)a->addr, *a->addrlen);
+        else
+                rb_ivar_set(client_io, iv_kgio_addr, localhost);
         return client_io;
 }
 
-static VALUE in_addr_set(VALUE io, struct sockaddr_storage *addr, socklen_t len)
-{
-        VALUE host;
-        int host_len, rc;
-        char *host_ptr;
-
-        switch (addr->ss_family) {
-        case AF_INET:
-                host_len = (long)INET_ADDRSTRLEN;
-                break;
-        case AF_INET6:
-                host_len = (long)INET6_ADDRSTRLEN;
-                break;
-        default:
-                rb_raise(rb_eRuntimeError, "unsupported address family");
-        }
-        host = rb_str_new(NULL, host_len);
-        host_ptr = RSTRING_PTR(host);
-        rc = getnameinfo((struct sockaddr *)addr, len,
-                         host_ptr, host_len, NULL, 0, NI_NUMERICHOST);
-        if (rc != 0)
-                rb_raise(rb_eRuntimeError, "getnameinfo: %s", gai_strerror(rc));
-        rb_str_set_len(host, strlen(host_ptr));
-        return rb_ivar_set(io, iv_kgio_addr, host);
-}
-
 /*
  * call-seq:
  *
@@ -253,16 +273,16 @@ static VALUE addr_bang(VALUE io)
  *
  *      server.kgio_tryaccept(MySocket) -> MySocket
  */
-static VALUE tcp_tryaccept(int argc, VALUE *argv, VALUE io)
+static VALUE tcp_tryaccept(int argc, VALUE *argv, VALUE self)
 {
         struct sockaddr_storage addr;
         socklen_t addrlen = sizeof(struct sockaddr_storage);
-        VALUE klass = acceptor(argc, argv);
-        VALUE rv = my_accept(io, klass, (struct sockaddr *)&addr, &addrlen, 1);
+        struct accept_args a;
 
-        if (!NIL_P(rv))
-                in_addr_set(rv, &addr, addrlen);
-        return rv;
+        a.addr = (struct sockaddr *)&addr;
+        a.addrlen = &addrlen;
+        prepare_accept(&a, self, argc, argv);
+        return my_accept(&a, 1);
 }
 
 /*
@@ -283,15 +303,16 @@ static VALUE tcp_tryaccept(int argc, VALUE *argv, VALUE io)
  *
  *      server.kgio_accept(MySocket) -> MySocket
  */
-static VALUE tcp_accept(int argc, VALUE *argv, VALUE io)
+static VALUE tcp_accept(int argc, VALUE *argv, VALUE self)
 {
         struct sockaddr_storage addr;
         socklen_t addrlen = sizeof(struct sockaddr_storage);
-        VALUE klass = acceptor(argc, argv);
-        VALUE rv = my_accept(io, klass, (struct sockaddr *)&addr, &addrlen, 0);
+        struct accept_args a;
 
-        in_addr_set(rv, &addr, addrlen);
-        return rv;
+        a.addr = (struct sockaddr *)&addr;
+        a.addrlen = &addrlen;
+        prepare_accept(&a, self, argc, argv);
+        return my_accept(&a, 0);
 }
 
 /*
@@ -311,14 +332,14 @@ static VALUE tcp_accept(int argc, VALUE *argv, VALUE io)
  *
  *      server.kgio_tryaccept(MySocket) -> MySocket
  */
-static VALUE unix_tryaccept(int argc, VALUE *argv, VALUE io)
+static VALUE unix_tryaccept(int argc, VALUE *argv, VALUE self)
 {
-        VALUE klass = acceptor(argc, argv);
-        VALUE rv = my_accept(io, klass, NULL, NULL, 1);
+        struct accept_args a;
 
-        if (!NIL_P(rv))
-                rb_ivar_set(rv, iv_kgio_addr, localhost);
-        return rv;
+        a.addr = NULL;
+        a.addrlen = NULL;
+        prepare_accept(&a, self, argc, argv);
+        return my_accept(&a, 1);
 }
 
 /*
@@ -339,13 +360,14 @@ static VALUE unix_tryaccept(int argc, VALUE *argv, VALUE io)
  *
  *      server.kgio_accept(MySocket) -> MySocket
  */
-static VALUE unix_accept(int argc, VALUE *argv, VALUE io)
+static VALUE unix_accept(int argc, VALUE *argv, VALUE self)
 {
-        VALUE klass = acceptor(argc, argv);
-        VALUE rv = my_accept(io, klass, NULL, NULL, 0);
+        struct accept_args a;
 
-        rb_ivar_set(rv, iv_kgio_addr, localhost);
-        return rv;
+        a.addr = NULL;
+        a.addrlen = NULL;
+        prepare_accept(&a, self, argc, argv);
+        return my_accept(&a, 0);
 }
 
 /*
@@ -386,7 +408,7 @@ static VALUE get_nonblock(VALUE mod)
  * TCPServer#kgio_tryaccept,
  * UNIXServer#kgio_accept,
  * and UNIXServer#kgio_tryaccept
- * are created with the FD_CLOEXEC file descriptor flag.
+ * default to being created with the FD_CLOEXEC file descriptor flag.
  *
  * This is on by default, as there is little reason to deal to enable
  * it for client sockets on a socket server.
diff --git a/test/lib_server_accept.rb b/test/lib_server_accept.rb
index 1e6bf24..6ea461b 100644
--- a/test/lib_server_accept.rb
+++ b/test/lib_server_accept.rb
@@ -1,4 +1,5 @@
 require 'test/unit'
+require 'fcntl'
 require 'io/nonblock'
 $-w = true
 require 'kgio'
@@ -19,6 +20,24 @@ module LibServerAccept
     assert_equal @host, b.kgio_addr
   end
 
+  def test_tryaccept_flags
+    a = client_connect
+    IO.select([@srv])
+    b = @srv.kgio_tryaccept nil, 0
+    assert_kind_of Kgio::Socket, b
+    assert_equal false, b.nonblock?
+    assert_equal 0, b.fcntl(Fcntl::F_GETFD)
+  end
+
+  def test_blocking_accept_flags
+    a = client_connect
+    IO.select([@srv])
+    b = @srv.kgio_accept nil, 0
+    assert_kind_of Kgio::Socket, b
+    assert_equal false, b.nonblock?
+    assert_equal 0, b.fcntl(Fcntl::F_GETFD)
+  end
+
   def test_tryaccept_fail
     assert_equal nil, @srv.kgio_tryaccept
   end