about summary refs log tree commit homepage
diff options
context:
space:
mode:
authorEric Wong <e@yhbt.net>2011-02-02 21:33:28 +0000
committerEric Wong <normalperson@yhbt.net>2011-02-02 13:42:33 -0800
commit17abe6ce8f01810022b948c71de0026b4ac89597 (patch)
treeca4cee3b39ecba1fcf0a51336f2399717c96750b
parent879f2f0ee9133f34ec3e24141bdb4936e3408d3a (diff)
downloadkgio-17abe6ce8f01810022b948c71de0026b4ac89597.tar.gz
No extra #ifdefs, we just won't support old systems without
getaddrinfo() and friends anymore.  I doubt anybody still has
them...
-rw-r--r--ext/kgio/accept.c43
-rw-r--r--ext/kgio/connect.c43
-rw-r--r--ext/kgio/extconf.rb7
-rw-r--r--ext/kgio/kgio.h1
-rw-r--r--test/lib_read_write.rb4
-rw-r--r--test/test_no_dns_on_tcp_connect.rb13
-rw-r--r--test/test_tcp6_client_read_server_write.rb23
7 files changed, 103 insertions, 31 deletions
diff --git a/ext/kgio/accept.c b/ext/kgio/accept.c
index afb44a2..f61e820 100644
--- a/ext/kgio/accept.c
+++ b/ext/kgio/accept.c
@@ -187,16 +187,29 @@ retry:
         return client_io;
 }
 
-static void in_addr_set(VALUE io, struct sockaddr_in *addr)
+static void in_addr_set(VALUE io, struct sockaddr_storage *addr, socklen_t len)
 {
-        VALUE host = rb_str_new(0, INET_ADDRSTRLEN);
-        socklen_t addrlen = (socklen_t)INET_ADDRSTRLEN;
-        const char *name;
-
-        name = inet_ntop(AF_INET, &addr->sin_addr, RSTRING_PTR(host), addrlen);
-        if (name == NULL)
-                rb_sys_fail("inet_ntop");
-        rb_str_set_len(host, strlen(name));
+        VALUE host;
+        int host_len, rc;
+        char *host_ptr;
+
+        switch (addr->ss_family) {
+        case AF_INET:
+                host_len = (long)INET_ADDRSTRLEN;
+                break;
+        case AF_INET6:
+                host_len = (long)INET6_ADDRSTRLEN;
+                break;
+        default:
+                rb_raise(rb_eRuntimeError, "unsupported address family");
+        }
+        host = rb_str_new(NULL, host_len);
+        host_ptr = RSTRING_PTR(host);
+        rc = getnameinfo((struct sockaddr *)addr, len,
+                         host_ptr, host_len, NULL, 0, NI_NUMERICHOST);
+        if (rc != 0)
+                rb_raise(rb_eRuntimeError, "getnameinfo: %s", gai_strerror(rc));
+        rb_str_set_len(host, strlen(host_ptr));
         rb_ivar_set(io, iv_kgio_addr, host);
 }
 
@@ -219,13 +232,13 @@ static void in_addr_set(VALUE io, struct sockaddr_in *addr)
  */
 static VALUE tcp_tryaccept(int argc, VALUE *argv, VALUE io)
 {
-        struct sockaddr_in addr;
-        socklen_t addrlen = sizeof(struct sockaddr_in);
+        struct sockaddr_storage addr;
+        socklen_t addrlen = sizeof(struct sockaddr_storage);
         VALUE klass = acceptor(argc, argv);
         VALUE rv = my_accept(io, klass, (struct sockaddr *)&addr, &addrlen, 1);
 
         if (!NIL_P(rv))
-                in_addr_set(rv, &addr);
+                in_addr_set(rv, &addr, addrlen);
         return rv;
 }
 
@@ -249,12 +262,12 @@ static VALUE tcp_tryaccept(int argc, VALUE *argv, VALUE io)
  */
 static VALUE tcp_accept(int argc, VALUE *argv, VALUE io)
 {
-        struct sockaddr_in addr;
-        socklen_t addrlen = sizeof(struct sockaddr_in);
+        struct sockaddr_storage addr;
+        socklen_t addrlen = sizeof(struct sockaddr_storage);
         VALUE klass = acceptor(argc, argv);
         VALUE rv = my_accept(io, klass, (struct sockaddr *)&addr, &addrlen, 0);
 
-        in_addr_set(rv, &addr);
+        in_addr_set(rv, &addr, addrlen);
         return rv;
 }
 
diff --git a/ext/kgio/connect.c b/ext/kgio/connect.c
index 9642429..0c8a9b2 100644
--- a/ext/kgio/connect.c
+++ b/ext/kgio/connect.c
@@ -57,20 +57,37 @@ my_connect(VALUE klass, int io_wait, int domain, void *addr, socklen_t addrlen)
 
 static VALUE tcp_connect(VALUE klass, VALUE ip, VALUE port, int io_wait)
 {
-        struct sockaddr_in addr = { 0 };
+        struct addrinfo hints;
+        struct sockaddr_storage addr;
+        int rc;
+        struct addrinfo *res;
+        VALUE sock;
+        const char *ipname = StringValuePtr(ip);
+        char ipport[6];
+        unsigned uport = FIX2UINT(port);
 
-        addr.sin_family = AF_INET;
-        addr.sin_port = htons((unsigned short)NUM2INT(port));
+        rc = snprintf(ipport, sizeof(ipport), "%u", uport);
+        if (rc >= (int)sizeof(ipport) || rc <= 0)
+                rb_raise(rb_eArgError, "invalid TCP port: %u", uport);
+        hints.ai_family = AF_UNSPEC;
+        hints.ai_socktype = SOCK_STREAM;
+        hints.ai_protocol = IPPROTO_TCP;
+        /* disallow non-deterministic DNS lookups */
+        hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV;
 
-        switch (inet_pton(AF_INET, StringValuePtr(ip), &addr.sin_addr)) {
-        case 1:
-                return my_connect(klass, io_wait, PF_INET, &addr, sizeof(addr));
-        case -1:
-                rb_sys_fail("inet_pton");
-        }
-        rb_raise(rb_eArgError, "invalid address: %s", StringValuePtr(ip));
+        rc = getaddrinfo(ipname, ipport, &hints, &res);
+        if (rc != 0)
+                rb_raise(rb_eArgError, "getaddrinfo(%s:%s): %s",
+                         ipname, ipport, gai_strerror(rc));
+
+        /* copy needed data and free ASAP to avoid needing rb_ensure */
+        hints.ai_family = res->ai_family;
+        hints.ai_addrlen = res->ai_addrlen;
+        memcpy(&addr, res->ai_addr, res->ai_addrlen);
+        freeaddrinfo(res);
 
-        return Qnil;
+        return my_connect(klass, io_wait, hints.ai_family,
+                          &addr, hints.ai_addrlen);
 }
 
 /*
@@ -173,12 +190,10 @@ static VALUE stream_connect(VALUE klass, VALUE addr, int io_wait)
         } else {
                 rb_raise(rb_eTypeError, "invalid address");
         }
-        switch (((struct sockaddr_in *)(sockaddr))->sin_family) {
+        switch (((struct sockaddr_storage *)(sockaddr))->ss_family) {
         case AF_UNIX: domain = PF_UNIX; break;
         case AF_INET: domain = PF_INET; break;
-#ifdef AF_INET6 /* IPv6 support incomplete */
         case AF_INET6: domain = PF_INET6; break;
-#endif /* AF_INET6 */
         default:
                 rb_raise(rb_eArgError, "invalid address family");
         }
diff --git a/ext/kgio/extconf.rb b/ext/kgio/extconf.rb
index e7220a4..3758e92 100644
--- a/ext/kgio/extconf.rb
+++ b/ext/kgio/extconf.rb
@@ -1,6 +1,13 @@
 require 'mkmf'
 $CPPFLAGS << ' -D_GNU_SOURCE'
+$CPPFLAGS << ' -DPOSIX_C_SOURCE=1'
 
+have_func("getaddrinfo", %w(sys/types.h sys/socket.h netdb.h)) or
+  abort "getaddrinfo required"
+have_func("getnameinfo", %w(sys/socket.h netdb.h)) or
+  abort "getnameinfo required"
+have_type("struct sockaddr_storage", %w(sys/types.h sys/socket.h)) or
+  abort "struct sockaddr_storage required"
 have_func('accept4', %w(sys/socket.h))
 if have_header('ruby/io.h')
   rubyio = %w(ruby.h ruby/io.h)
diff --git a/ext/kgio/kgio.h b/ext/kgio/kgio.h
index 3711061..244bae5 100644
--- a/ext/kgio/kgio.h
+++ b/ext/kgio/kgio.h
@@ -16,6 +16,7 @@
 #include <unistd.h>
 #include <arpa/inet.h>
 #include <assert.h>
+#include <netdb.h>
 
 #include "ancient_ruby.h"
 #include "nonblock.h"
diff --git a/test/lib_read_write.rb b/test/lib_read_write.rb
index b3c6f17..593a9e9 100644
--- a/test/lib_read_write.rb
+++ b/test/lib_read_write.rb
@@ -10,8 +10,8 @@ module LibReadWriteTest
 
   def teardown
     assert_nothing_raised do
-      @rd.close unless @rd.closed?
-      @wr.close unless @wr.closed?
+      @rd.close if defined?(@rd) && ! @rd.closed?
+      @wr.close if defined?(@wr) && ! @wr.closed?
     end
   end
 
diff --git a/test/test_no_dns_on_tcp_connect.rb b/test/test_no_dns_on_tcp_connect.rb
new file mode 100644
index 0000000..d296826
--- /dev/null
+++ b/test/test_no_dns_on_tcp_connect.rb
@@ -0,0 +1,13 @@
+require 'test/unit'
+$-w = true
+require 'kgio'
+
+class TestNoDnsOnTcpConnect < Test::Unit::TestCase
+  def test_connect_remote
+    assert_raises(ArgumentError) { Kgio::TCPSocket.new("example.com", 666) }
+  end
+
+  def test_connect_localhost
+    assert_raises(ArgumentError) { Kgio::TCPSocket.new("localhost", 666) }
+  end
+end
diff --git a/test/test_tcp6_client_read_server_write.rb b/test/test_tcp6_client_read_server_write.rb
new file mode 100644
index 0000000..9438fcc
--- /dev/null
+++ b/test/test_tcp6_client_read_server_write.rb
@@ -0,0 +1,23 @@
+require './test/lib_read_write'
+
+begin
+  tmp = TCPServer.new(ENV["TEST_HOST6"] || '::1')
+  ipv6_enabled = true
+rescue => e
+  warn "skipping IPv6 tests, host does not seem to be IPv6 enabled:"
+  warn "  #{e}"
+  ipv6_enabled = false
+end
+
+class TestTcp6ClientReadServerWrite < Test::Unit::TestCase
+  def setup
+    @host = ENV["TEST_HOST6"] || '::1'
+    @srv = Kgio::TCPServer.new(@host, 0)
+    @port = @srv.addr[1]
+    @wr = Kgio::TCPSocket.new(@host, @port)
+    @rd = @srv.kgio_accept
+    assert_equal Socket.unpack_sockaddr_in(@rd.getpeername)[-1], @rd.kgio_addr
+  end
+
+  include LibReadWriteTest
+end if ipv6_enabled