about summary refs log tree commit homepage
path: root/ext/raindrops/linux_inet_diag.c
diff options
context:
space:
mode:
Diffstat (limited to 'ext/raindrops/linux_inet_diag.c')
-rw-r--r--ext/raindrops/linux_inet_diag.c169
1 files changed, 90 insertions, 79 deletions
diff --git a/ext/raindrops/linux_inet_diag.c b/ext/raindrops/linux_inet_diag.c
index cabd427..79f24bb 100644
--- a/ext/raindrops/linux_inet_diag.c
+++ b/ext/raindrops/linux_inet_diag.c
@@ -1,6 +1,5 @@
 #include <ruby.h>
 #include <stdarg.h>
-#include <ruby/st.h>
 #include "my_fileno.h"
 #ifdef __linux__
 
@@ -54,12 +53,23 @@ struct listen_stats {
         uint32_t listener_p;
 };
 
+/* override khashl.h defaults, these run w/o GVL */
+#define kcalloc(N,Z) xcalloc(N,Z)
+#define kmalloc(Z) xmalloc(Z)
+#define krealloc(P,Z) abort() /* never called, we use ruby_xrealloc2 */
+#define kfree(P) xfree(P)
+
+#include "khashl.h"
+KHASHL_CMAP_INIT(KH_LOCAL, addr2stats /* type */, a2s /* pfx */,
+                char * /* key */, struct listen_stats * /* val */,
+                kh_hash_str, kh_eq_str)
+
 #define OPLEN (sizeof(struct inet_diag_bc_op) + \
                sizeof(struct inet_diag_hostcond) + \
                sizeof(struct sockaddr_storage))
 
 struct nogvl_args {
-        st_table *table;
+        addr2stats *a2s;
         struct iovec iov[3]; /* last iov holds inet_diag bytecode */
         struct listen_stats stats;
         int fd;
@@ -106,14 +116,6 @@ static VALUE rb_listen_stats(struct listen_stats *stats)
         return rb_struct_new(cListenStats, active, queued);
 }
 
-static int st_free_data(st_data_t key, st_data_t value, st_data_t ignored)
-{
-        xfree((void *)key);
-        xfree((void *)value);
-
-        return ST_DELETE;
-}
-
 /*
  * call-seq:
  *      remove_scope_id(ip_address)
@@ -151,36 +153,6 @@ static VALUE remove_scope_id(const char *addr)
         return rv;
 }
 
-static int st_to_hash(st_data_t key, st_data_t value, VALUE hash)
-{
-        struct listen_stats *stats = (struct listen_stats *)value;
-
-        if (stats->listener_p) {
-                VALUE k = remove_scope_id((const char *)key);
-                VALUE v = rb_listen_stats(stats);
-
-                OBJ_FREEZE(k);
-                rb_hash_aset(hash, k, v);
-        }
-        return st_free_data(key, value, 0);
-}
-
-static int st_AND_hash(st_data_t key, st_data_t value, VALUE hash)
-{
-        struct listen_stats *stats = (struct listen_stats *)value;
-
-        if (stats->listener_p) {
-                VALUE k = remove_scope_id((const char *)key);
-
-                if (rb_hash_lookup(hash, k) == Qtrue) {
-                        VALUE v = rb_listen_stats(stats);
-                        OBJ_FREEZE(k);
-                        rb_hash_aset(hash, k, v);
-                }
-        }
-        return st_free_data(key, value, 0);
-}
-
 static const char *addr_any(sa_family_t family)
 {
         static const char ipv4[] = "0.0.0.0";
@@ -209,32 +181,36 @@ static void bug_warn_nogvl(const char *fmt, ...)
         fflush(stderr);
 }
 
-static struct listen_stats *stats_for(st_table *table, struct inet_diag_msg *r)
+static struct listen_stats *stats_for(addr2stats *a2s, struct inet_diag_msg *r)
 {
         char *host, *key, *port, *old_key;
-        size_t alloca_len;
         struct listen_stats *stats;
         socklen_t hostlen;
         socklen_t portlen = (socklen_t)sizeof("65535");
-        int n;
+        int n, absent;
         const void *src = r->id.idiag_src;
+        char buf[INET6_ADDRSTRLEN];
+        size_t buf_len;
+        khint_t ki;
 
         switch (r->idiag_family) {
         case AF_INET: {
                 hostlen = INET_ADDRSTRLEN;
-                alloca_len = hostlen + portlen;
-                host = key = alloca(alloca_len);
+                buf_len = hostlen + portlen;
+                host = key = buf;
                 break;
                 }
         case AF_INET6: {
                 hostlen = INET6_ADDRSTRLEN;
-                alloca_len = 1 + hostlen + 1 + portlen;
-                key = alloca(alloca_len);
+                buf_len = 1 + hostlen + 1 + portlen;
+                key = buf;
                 host = key + 1;
                 break;
                 }
         default:
-                assert(0 && "unsupported address family, could that be IPv7?!");
+                fprintf(stderr, "unsupported .idiag_family: %u\n",
+                        (unsigned)r->idiag_family);
+                return NULL; /* can't raise w/o GVL */
         }
         if (!inet_ntop(r->idiag_family, src, host, hostlen)) {
                 bug_warn_nogvl("BUG: inet_ntop: %s\n", strerror(errno));
@@ -254,7 +230,8 @@ static struct listen_stats *stats_for(st_table *table, struct inet_diag_msg *r)
                 port = host + hostlen + 2;
                 break;
         default:
-                assert(0 && "unsupported address family, could that be IPv7?!");
+                assert(0 && "should never get here (returned above)");
+                abort();
         }
 
         n = snprintf(port, portlen, "%u", ntohs(r->id.idiag_sport));
@@ -263,21 +240,24 @@ static struct listen_stats *stats_for(st_table *table, struct inet_diag_msg *r)
                 *key = '\0';
         }
 
-        if (st_lookup(table, (st_data_t)key, (st_data_t *)&stats))
-                return stats;
+        ki = a2s_get(a2s, key);
+        if (ki < kh_end(a2s))
+                return kh_val(a2s, ki);
 
         old_key = key;
 
         if (r->idiag_state == TCP_ESTABLISHED) {
-                n = snprintf(key, alloca_len, "%s:%u",
+                n = snprintf(key, buf_len, "%s:%u",
                                  addr_any(r->idiag_family),
                                  ntohs(r->id.idiag_sport));
                 if (n <= 0) {
                         bug_warn_nogvl("BUG: snprintf: %d\n", n);
                         *key = '\0';
                 }
-                if (st_lookup(table, (st_data_t)key, (st_data_t *)&stats))
-                        return stats;
+
+                ki = a2s_get(a2s, key);
+                if (ki < kh_end(a2s))
+                        return kh_val(a2s, ki);
                 if (n <= 0) {
                         key = xmalloc(1);
                         *key = '\0';
@@ -292,21 +272,25 @@ static struct listen_stats *stats_for(st_table *table, struct inet_diag_msg *r)
                 memcpy(key, old_key, old_len);
         }
         stats = xcalloc(1, sizeof(struct listen_stats));
-        st_insert(table, (st_data_t)key, (st_data_t)stats);
+        ki = a2s_put(a2s, key, &absent); /* fails on OOM due to xrealloc */
+        assert(absent > 0 && "redundant put");
+        kh_val(a2s, ki) = stats;
         return stats;
 }
 
-static void table_incr_active(st_table *table, struct inet_diag_msg *r)
+static void table_incr_active(addr2stats *a2s, struct inet_diag_msg *r)
 {
-        struct listen_stats *stats = stats_for(table, r);
+        struct listen_stats *stats = stats_for(a2s, r);
+        if (!stats) return;
         ++stats->active;
 }
 
-static void table_set_queued(st_table *table, struct inet_diag_msg *r)
+static void table_set_queued(addr2stats *a2s, struct inet_diag_msg *r)
 {
-        struct listen_stats *stats = stats_for(table, r);
+        struct listen_stats *stats = stats_for(a2s, r);
+        if (!stats) return;
         stats->listener_p = 1;
-        stats->queued = r->idiag_rqueue;
+        stats->queued += r->idiag_rqueue;
 }
 
 /* inner loop of inet_diag, called for every socket returned by netlink */
@@ -320,15 +304,15 @@ static inline void r_acc(struct nogvl_args *args, struct inet_diag_msg *r)
         if (r->idiag_inode == 0)
                 return;
         if (r->idiag_state == TCP_ESTABLISHED) {
-                if (args->table)
-                        table_incr_active(args->table, r);
+                if (args->a2s)
+                        table_incr_active(args->a2s, r);
                 else
                         args->stats.active++;
         } else { /* if (r->idiag_state == TCP_LISTEN) */
-                if (args->table)
-                        table_set_queued(args->table, r);
+                if (args->a2s)
+                        table_set_queued(args->a2s, r);
                 else
-                        args->stats.queued = r->idiag_rqueue;
+                        args->stats.queued += r->idiag_rqueue;
         }
         /*
          * we wont get anything else because of the idiag_states filter
@@ -444,11 +428,18 @@ static VALUE diag(void *ptr)
         }
 out:
         /* prepare to raise, free memory before reacquiring GVL */
-        if (err && args->table) {
+        if (err && args->a2s) {
                 int save_errno = errno;
+                khint_t ki;
+
+                /* no kh_foreach* in khashl.h (unlike original khash.h) */
+                for (ki = 0; ki < kh_end(args->a2s); ki++) {
+                        if (!kh_exist(args->a2s, ki)) continue;
 
-                st_foreach(args->table, st_free_data, 0);
-                st_free_table(args->table);
+                        xfree(kh_key(args->a2s, ki));
+                        xfree(kh_val(args->a2s, ki));
+                }
+                a2s_destroy(args->a2s);
                 errno = save_errno;
         }
         return (VALUE)err;
@@ -564,7 +555,7 @@ static void gen_bytecode(struct iovec *iov, union any_addr *inet)
 
 /*
  * n.b. we may safely raise here because an error will cause diag()
- * to free args->table
+ * to free args->a2s
  */
 static void nl_errcheck(VALUE r)
 {
@@ -591,6 +582,7 @@ static VALUE tcp_stats(struct nogvl_args *args, VALUE addr)
         return rb_listen_stats(&args->stats);
 }
 
+/* part of the Ruby rb_hash_* API still relies on st_data_t... */
 static int drop_placeholders(st_data_t k, st_data_t v, st_data_t ign)
 {
         if ((VALUE)v == Qtrue)
@@ -615,7 +607,10 @@ static VALUE tcp_listener_stats(int argc, VALUE *argv, VALUE self)
 {
         VALUE rv = rb_hash_new();
         struct nogvl_args args;
-        VALUE addrs, sock;
+        VALUE addrs, sock, buf;
+        khint_t ki;
+        struct listen_stats *stats;
+        char *key;
 
         rb_scan_args(argc, argv, "02", &addrs, &sock);
 
@@ -624,17 +619,18 @@ static VALUE tcp_listener_stats(int argc, VALUE *argv, VALUE self)
          * buffer for recvmsg() later, we already checked for
          * OPLEN <= page_size at initialization
          */
+        buf = rb_str_buf_new(page_size);
         args.iov[2].iov_len = OPLEN;
-        args.iov[2].iov_base = alloca(page_size);
-        args.table = NULL;
-        if (NIL_P(sock))
-                sock = rb_funcall(cIDSock, id_new, 0);
+        args.iov[2].iov_base = RSTRING_PTR(buf);
+        args.a2s = NULL;
+        sock = NIL_P(sock) ? rb_funcall(cIDSock, id_new, 0)
+                        : rb_io_get_io(sock);
         args.fd = my_fileno(sock);
 
         switch (TYPE(addrs)) {
         case T_STRING:
                 rb_hash_aset(rv, addrs, tcp_stats(&args, addrs));
-                return rv;
+                goto out;
         case T_ARRAY: {
                 long i;
                 long len = RARRAY_LEN(addrs);
@@ -643,7 +639,7 @@ static VALUE tcp_listener_stats(int argc, VALUE *argv, VALUE self)
                         VALUE cur = rb_ary_entry(addrs, 0);
 
                         rb_hash_aset(rv, cur, tcp_stats(&args, cur));
-                        return rv;
+                        goto out;
                 }
                 for (i = 0; i < len; i++) {
                         union any_addr check;
@@ -655,23 +651,38 @@ static VALUE tcp_listener_stats(int argc, VALUE *argv, VALUE self)
                 /* fall through */
         }
         case T_NIL:
-                args.table = st_init_strtable();
+                args.a2s = a2s_init();
                 gen_bytecode_all(&args.iov[2]);
                 break;
         default:
+                if (argc < 2) rb_io_close(sock);
                 rb_raise(rb_eArgError,
                          "addr must be an array of strings, a string, or nil");
         }
 
         nl_errcheck(rd_fd_region(diag, &args, args.fd));
 
-        st_foreach(args.table, NIL_P(addrs) ? st_to_hash : st_AND_hash, rv);
-        st_free_table(args.table);
+        /* no kh_foreach* in khashl.h (unlike original khash.h) */
+        for (ki = 0; ki < kh_end(args.a2s); ki++) {
+                if (!kh_exist(args.a2s, ki)) continue;
+                key = kh_key(args.a2s, ki);
+                stats = kh_val(args.a2s, ki);
+                if (stats->listener_p) {
+                        VALUE k = remove_scope_id(key);
+                        if (NIL_P(addrs) || rb_hash_lookup(rv, k) == Qtrue)
+                                rb_hash_aset(rv, k, rb_listen_stats(stats));
+                }
+                xfree(key);
+                xfree(stats);
+        }
+        a2s_destroy(args.a2s);
 
         if (RHASH_SIZE(rv) > 1)
                 rb_hash_foreach(rv, drop_placeholders, Qfalse);
 
+out:
         /* let GC deal with corner cases */
+        rb_str_resize(buf, 0);
         if (argc < 2) rb_io_close(sock);
         return rv;
 }