about summary refs log tree commit homepage
diff options
context:
space:
mode:
authorEric Wong <e@yhbt.net>2010-09-26 07:51:12 +0000
committerEric Wong <e@yhbt.net>2010-09-26 08:00:09 +0000
commitfdfecc6d815bab8dfc1d8ad6758a66d44ab51e31 (patch)
tree7874f5c8a7698d69edbbd8e91bacaedc0ad625cb
parentd8ee79e1e5c6e6908009213324db25cf41c583ce (diff)
downloadkgio-fdfecc6d815bab8dfc1d8ad6758a66d44ab51e31.tar.gz
Avoid altering behavior based on globals that
Kgio.wait_{read,writ}able stored in, since that's too confusing.
The non-try variants are closer to the normal IO read/write
methods, except they can be more easily plugged into alternate
reactors and event frameworks.
-rw-r--r--ext/kgio/kgio_ext.c196
-rw-r--r--test/lib_read_write.rb94
-rw-r--r--test/test_pipe_popen.rb2
3 files changed, 198 insertions, 94 deletions
diff --git a/ext/kgio/kgio_ext.c b/ext/kgio/kgio_ext.c
index f2d30ba..e301971 100644
--- a/ext/kgio/kgio_ext.c
+++ b/ext/kgio/kgio_ext.c
@@ -27,15 +27,15 @@
  */
 #  define USE_MSG_DONTWAIT
 static int accept4_flags = SOCK_CLOEXEC;
-#else
+#else /* ! linux */
 static int accept4_flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
-#endif
+#endif /* ! linux */
 
 static VALUE cSocket;
 static VALUE localhost;
 static VALUE mKgio_WaitReadable, mKgio_WaitWritable;
 static ID io_wait_rd, io_wait_wr;
-static ID iv_kgio_addr, id_ruby;
+static ID iv_kgio_addr;
 
 struct io_args {
         VALUE io;
@@ -45,38 +45,28 @@ struct io_args {
         int fd;
 };
 
-static int maybe_wait_readable(VALUE io)
+static void wait_readable(VALUE io)
 {
         if (io_wait_rd) {
-                if (io_wait_rd == id_ruby) {
-                        if (! rb_io_wait_readable(my_fileno(io)))
-                                rb_sys_fail("wait readable");
-                        errno = 0;
-                } else {
-                        errno = 0;
-                        (void)rb_funcall(io, io_wait_rd, 0, 0);
-                }
-                return 1;
+                (void)rb_funcall(io, io_wait_rd, 0, 0);
+        } else {
+                int fd = my_fileno(io);
+
+                if (!rb_io_wait_readable(fd))
+                        rb_sys_fail("wait readable");
         }
-        errno = 0;
-        return 0;
 }
 
-static int maybe_wait_writable(VALUE io)
+static void wait_writable(VALUE io)
 {
         if (io_wait_wr) {
-                if (io_wait_wr == id_ruby) {
-                        if (! rb_io_wait_writable(my_fileno(io)))
-                                rb_sys_fail("wait writable");
-                        errno = 0;
-                } else {
-                        errno = 0;
-                        (void)rb_funcall(io, io_wait_wr, 0, 0);
-                }
-                return 1;
+                (void)rb_funcall(io, io_wait_wr, 0, 0);
+        } else {
+                int fd = my_fileno(io);
+
+                if (!rb_io_wait_writable(fd))
+                        rb_sys_fail("wait writable");
         }
-        errno = 0;
-        return 0;
 }
 
 static void prepare_read(struct io_args *a, int argc, VALUE *argv, VALUE io)
@@ -96,14 +86,15 @@ static void prepare_read(struct io_args *a, int argc, VALUE *argv, VALUE io)
         a->ptr = RSTRING_PTR(a->buf);
 }
 
-static int read_check(struct io_args *a, long n, const char *msg)
+static int read_check(struct io_args *a, long n, const char *msg, int io_wait)
 {
         if (n == -1) {
                 if (errno == EINTR)
                         return -1;
                 rb_str_set_len(a->buf, 0);
                 if (errno == EAGAIN) {
-                        if (maybe_wait_readable(a->io)) {
+                        if (io_wait) {
+                                wait_readable(a->io);
                                 return -1;
                         } else {
                                 a->buf = mKgio_WaitReadable;
@@ -118,64 +109,75 @@ static int read_check(struct io_args *a, long n, const char *msg)
         return 0;
 }
 
-#ifdef USE_MSG_DONTWAIT
-
 /*
- * Document-method: Kgio::SocketMethods#kgio_read
+ * Document-method: Kgio::PipeMethods#kgio_read
  *
  * call-seq:
  *
- *        socket.kgio_read(maxlen) => buffer or Kgio::WaitReadable
- *        socket.kgio_read(maxlen, buffer) => buffer or Kgio::WaitReadable
+ *        socket.kgio_read(maxlen)  ->  buffer
+ *        socket.kgio_read(maxlen, buffer)  ->  buffer
  *
  * Reads at most maxlen bytes from the stream socket.  Returns with a
  * newly allocated buffer, or may reuse an existing buffer.  This
- * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
- * which case it will call the method referred to by Kgio.wait_readable.
+ * calls the method identified by Kgio.wait_readable, or uses
+ * the normal, thread-safe Ruby function to wait for readability.
+ * This returns nil on EOF.
+ *
+ * This behaves like read(2) and IO#readpartial, NOT fread(3) or
+ * IO#read which possess read-in-full behavior.
  */
-static VALUE kgio_recv(int argc, VALUE *argv, VALUE io)
+static VALUE my_read(int io_wait, int argc, VALUE *argv, VALUE io)
 {
         struct io_args a;
         long n;
 
         prepare_read(&a, argc, argv, io);
+        set_nonblocking(a.fd);
 retry:
-        n = (long)recv(a.fd, a.ptr, a.len, MSG_DONTWAIT);
-        if (read_check(&a, n, "recv") != 0)
+        n = (long)read(a.fd, a.ptr, a.len);
+        if (read_check(&a, n, "read", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
-#else /* ! USE_MSG_DONTWAIT */
-#  define kgio_recv kgio_read
-#endif /* USE_MSG_DONTWAIT */
 
-/*
- * Document-method: Kgio::PipeMethods#kgio_read
- *
- * call-seq:
- *
- *        socket.kgio_read(maxlen)  ->  buffer or Kgio::WaitReadable
- *        socket.kgio_read(maxlen, buffer)  ->  buffer or Kgio::WaitReadable
- *
- * Reads at most maxlen bytes from the stream socket.  Returns with a
- * newly allocated buffer, or may reuse an existing buffer.  This
- * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
- * which case it will call the method referred to by Kgio.wait_readable.
- */
 static VALUE kgio_read(int argc, VALUE *argv, VALUE io)
 {
+        return my_read(1, argc, argv, io);
+}
+
+static VALUE kgio_tryread(int argc, VALUE *argv, VALUE io)
+{
+        return my_read(0, argc, argv, io);
+}
+
+#ifdef USE_MSG_DONTWAIT
+static VALUE my_recv(int io_wait, int argc, VALUE *argv, VALUE io)
+{
         struct io_args a;
         long n;
 
         prepare_read(&a, argc, argv, io);
-        set_nonblocking(a.fd);
 retry:
-        n = (long)read(a.fd, a.ptr, a.len);
-        if (read_check(&a, n, "read") != 0)
+        n = (long)recv(a.fd, a.ptr, a.len, MSG_DONTWAIT);
+        if (read_check(&a, n, "recv", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
 
+static VALUE kgio_recv(int argc, VALUE *argv, VALUE io)
+{
+        return my_recv(1, argc, argv, io);
+}
+
+static VALUE kgio_tryrecv(int argc, VALUE *argv, VALUE io)
+{
+        return my_recv(0, argc, argv, io);
+}
+#else /* ! USE_MSG_DONTWAIT */
+#  define kgio_recv kgio_read
+#  define kgio_tryrecv kgio_tryread
+#endif /* USE_MSG_DONTWAIT */
+
 static void prepare_write(struct io_args *a, VALUE io, VALUE str)
 {
         a->buf = (TYPE(str) == T_STRING) ? str : rb_obj_as_string(str);
@@ -185,7 +187,7 @@ static void prepare_write(struct io_args *a, VALUE io, VALUE str)
         a->fd = my_fileno(io);
 }
 
-static int write_check(struct io_args *a, long n, const char *msg)
+static int write_check(struct io_args *a, long n, const char *msg, int io_wait)
 {
         if (a->len == n) {
                 a->buf = Qnil;
@@ -193,29 +195,28 @@ static int write_check(struct io_args *a, long n, const char *msg)
                 if (errno == EINTR)
                         return -1;
                 if (errno == EAGAIN) {
-                        if (maybe_wait_writable(a->io))
+                        if (io_wait) {
+                                wait_writable(a->io);
                                 return -1;
-                        a->buf = mKgio_WaitWritable;
-                        return 0;
+                        } else {
+                                a->buf = mKgio_WaitWritable;
+                                return 0;
+                        }
                 }
                 rb_sys_fail(msg);
         } else {
                 assert(n >= 0 && n < a->len && "write/send syscall broken?");
+                if (io_wait) {
+                        a->ptr += n;
+                        a->len -= n;
+                        return -1;
+                }
                 a->buf = rb_str_new(a->ptr + n, a->len - n);
         }
         return 0;
 }
 
-/*
- * Returns a String containing the unwritten portion if there was a
- * partial write.
- *
- * Returns true if the write was completed.
- *
- * Returns Kgio::WaitWritable if the write would block and
- * Kgio.wait_writable is not set
- */
-static VALUE kgio_write(VALUE io, VALUE str)
+static VALUE my_write(VALUE io, VALUE str, int io_wait)
 {
         struct io_args a;
         long n;
@@ -224,18 +225,40 @@ static VALUE kgio_write(VALUE io, VALUE str)
         set_nonblocking(a.fd);
 retry:
         n = (long)write(a.fd, a.ptr, a.len);
-        if (write_check(&a, n, "write") != 0)
+        if (write_check(&a, n, "write", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
 
+/*
+ * Returns true if the write was completed.
+ *
+ * Calls the method Kgio.wait_writable is not set
+ */
+static VALUE kgio_write(VALUE io, VALUE str)
+{
+        return my_write(io, str, 1);
+}
+
+/*
+ * Returns a String containing the unwritten portion if there was a
+ * partial write.  Will return Kgio::WaitReadable if EAGAIN is
+ * encountered.
+ *
+ * Returns true if the write completed in full.
+ */
+static VALUE kgio_trywrite(VALUE io, VALUE str)
+{
+        return my_write(io, str, 0);
+}
+
 #ifdef USE_MSG_DONTWAIT
 /*
  * This method behaves like Kgio::PipeMethods#kgio_write, except
  * it will use send(2) with the MSG_DONTWAIT flag on sockets to
  * avoid unnecessary calls to fcntl(2).
  */
-static VALUE kgio_send(VALUE io, VALUE str)
+static VALUE my_send(VALUE io, VALUE str, int io_wait)
 {
         struct io_args a;
         long n;
@@ -243,25 +266,36 @@ static VALUE kgio_send(VALUE io, VALUE str)
         prepare_write(&a, io, str);
 retry:
         n = (long)send(a.fd, a.ptr, a.len, MSG_DONTWAIT);
-        if (write_check(&a, n, "send") != 0)
+        if (write_check(&a, n, "send", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
+
+static VALUE kgio_send(VALUE io, VALUE str)
+{
+        return my_send(io, str, 1);
+}
+
+static VALUE kgio_trysend(VALUE io, VALUE str)
+{
+        return my_send(io, str, 0);
+}
 #else /* ! USE_MSG_DONTWAIT */
 #  define kgio_send kgio_write
+#  define kgio_trysend kgio_trywrite
 #endif /* ! USE_MSG_DONTWAIT */
 
 /*
  * call-seq:
  *
- *         Kgio.wait_readable = :method_name
+ *        Kgio.wait_readable = :method_name
  *
  * Sets a method for kgio_read to call when a read would block.
  * This is useful for non-blocking frameworks that use Fibers,
  * as the method referred to this may cause the current Fiber
  * to yield execution.
  *
- * A special value of ":ruby" will cause Ruby to wait using the
+ * A special value of nil will cause Ruby to wait using the
  * rb_io_wait_readable() function, giving kgio_read similar semantics to
  * IO#readpartial.
  */
@@ -439,7 +473,8 @@ my_connect(VALUE klass, int domain, void *addr, socklen_t addrlen)
                 if (errno == EINPROGRESS) {
                         VALUE io = sock_for_fd(klass, fd);
 
-                        (void)maybe_wait_writable(io);
+                        errno = EAGAIN;
+                        wait_writable(io);
                         return io;
                 }
                 rb_sys_fail("connect");
@@ -593,10 +628,14 @@ void Init_kgio_ext(void)
         mPipeMethods = rb_define_module_under(mKgio, "PipeMethods");
         rb_define_method(mPipeMethods, "kgio_read", kgio_read, -1);
         rb_define_method(mPipeMethods, "kgio_write", kgio_write, 1);
+        rb_define_method(mPipeMethods, "kgio_tryread", kgio_tryread, -1);
+        rb_define_method(mPipeMethods, "kgio_trywrite", kgio_trywrite, 1);
 
         mSocketMethods = rb_define_module_under(mKgio, "SocketMethods");
         rb_define_method(mSocketMethods, "kgio_read", kgio_recv, -1);
         rb_define_method(mSocketMethods, "kgio_write", kgio_send, 1);
+        rb_define_method(mSocketMethods, "kgio_tryread", kgio_tryrecv, -1);
+        rb_define_method(mSocketMethods, "kgio_trywrite", kgio_trysend, 1);
 
         rb_define_attr(mSocketMethods, "kgio_addr", 1, 1);
         rb_include_module(cSocket, mSocketMethods);
@@ -621,6 +660,5 @@ void Init_kgio_ext(void)
         rb_define_singleton_method(cUNIXSocket, "new", kgio_unix_connect, 1);
 
         iv_kgio_addr = rb_intern("@kgio_addr");
-        id_ruby = rb_intern("ruby");
         init_sock_for_fd();
 }
diff --git a/test/lib_read_write.rb b/test/lib_read_write.rb
index 18f13f0..b8483a0 100644
--- a/test/lib_read_write.rb
+++ b/test/lib_read_write.rb
@@ -19,6 +19,11 @@ module LibReadWriteTest
     assert_nil @rd.kgio_read(5)
   end
 
+  def test_tryread_eof
+    @wr.close
+    assert_nil @rd.kgio_tryread(5)
+  end
+
   def test_write_closed
     @rd.close
     assert_raises(Errno::EPIPE, Errno::ECONNRESET) {
@@ -26,13 +31,25 @@ module LibReadWriteTest
     }
   end
 
+  def test_trywrite_closed
+    @rd.close
+    assert_raises(Errno::EPIPE, Errno::ECONNRESET) {
+      loop { @wr.kgio_trywrite "HI" }
+    }
+  end
+
   def test_write_conv
     assert_equal nil, @wr.kgio_write(10)
     assert_equal "10", @rd.kgio_read(2)
   end
 
-  def test_read_empty
-    assert_equal Kgio::WaitReadable, @rd.kgio_read(1)
+  def test_trywrite_conv
+    assert_equal nil, @wr.kgio_trywrite(10)
+    assert_equal "10", @rd.kgio_tryread(2)
+  end
+
+  def test_tryread_empty
+    assert_equal Kgio::WaitReadable, @rd.kgio_tryread(1)
   end
 
   def test_read_too_much
@@ -40,12 +57,23 @@ module LibReadWriteTest
     assert_equal "hi", @rd.kgio_read(4)
   end
 
+  def test_tryread_too_much
+    assert_equal nil, @wr.kgio_trywrite("hi")
+    assert_equal "hi", @rd.kgio_tryread(4)
+  end
+
   def test_read_short
     assert_equal nil, @wr.kgio_write("hi")
     assert_equal "h", @rd.kgio_read(1)
     assert_equal "i", @rd.kgio_read(1)
   end
 
+  def test_tryread_short
+    assert_equal nil, @wr.kgio_trywrite("hi")
+    assert_equal "h", @rd.kgio_tryread(1)
+    assert_equal "i", @rd.kgio_tryread(1)
+  end
+
   def test_read_extra_buf
     tmp = ""
     tmp_object_id = tmp.object_id
@@ -56,9 +84,9 @@ module LibReadWriteTest
     assert_equal tmp_object_id, rv.object_id
   end
 
-  def test_write_return_wait_writable
+  def test_trywrite_return_wait_writable
     tmp = []
-    tmp << @wr.kgio_write("HI") until tmp[-1] == Kgio::WaitWritable
+    tmp << @wr.kgio_trywrite("HI") until tmp[-1] == Kgio::WaitWritable
     assert_equal Kgio::WaitWritable, tmp.pop
     assert tmp.size > 0
     penultimate = tmp.pop
@@ -67,19 +95,41 @@ module LibReadWriteTest
     tmp.each { |count| assert_equal nil, count }
   end
 
-  def test_monster_write
+  def test_monster_trywrite
     buf = "." * 1024 * 1024 * 10
-    rv = @wr.kgio_write(buf)
+    rv = @wr.kgio_trywrite(buf)
     assert_kind_of String, rv
     assert rv.size < buf.size
     assert_equal(buf, (rv + @rd.read(buf.size - rv.size)))
   end
 
-  def test_wait_readable_ruby_default
-    def @rd.ruby
-      raise RuntimeError, "Hello"
+  def test_monster_write
+    buf = "." * 1024 * 1024 * 10
+    thr = Thread.new { @wr.kgio_write(buf) }
+    readed = @rd.read(buf.size)
+    thr.join
+    assert_nil thr.value
+    assert_equal buf, readed
+  end
+
+  def test_monster_write_wait_writable
+    @wr.instance_variable_set :@nr, 0
+    def @wr.wait_writable
+      @nr += 1
+      IO.select(nil, [self])
     end
-    assert_nothing_raised { Kgio.wait_readable = :ruby }
+    Kgio.wait_writable = :wait_writable
+    buf = "." * 1024 * 1024 * 10
+    thr = Thread.new { @wr.kgio_write(buf) }
+    readed = @rd.read(buf.size)
+    thr.join
+    assert_nil thr.value
+    assert_equal buf, readed
+    assert @wr.instance_variable_get(:@nr) > 0
+  end
+
+  def test_wait_readable_ruby_default
+    assert_nothing_raised { Kgio.wait_readable = nil }
     elapsed = 0
     foo = nil
     t0 = Time.now
@@ -95,9 +145,6 @@ module LibReadWriteTest
   end
 
   def test_wait_writable_ruby_default
-    def @wr.ruby
-      raise RuntimeError, "Hello"
-    end
     buf = "." * 512
     nr = 0
     begin
@@ -105,7 +152,7 @@ module LibReadWriteTest
     rescue Errno::EAGAIN
       break
     end while true
-    assert_nothing_raised { Kgio.wait_writable = :ruby }
+    assert_nothing_raised { Kgio.wait_writable = nil }
     elapsed = 0
     foo = nil
     t0 = Time.now
@@ -138,6 +185,25 @@ module LibReadWriteTest
     assert_nil foo
   end
 
+  def test_tryread_wait_readable_method
+    def @rd.moo
+      raise "Hello"
+    end
+    assert_nothing_raised { Kgio.wait_readable = :moo }
+    assert_equal Kgio::WaitReadable, @rd.kgio_tryread(5)
+  end
+
+  def test_trywrite_wait_readable_method
+    def @wr.moo
+      raise "Hello"
+    end
+    assert_nothing_raised { Kgio.wait_writable = :moo }
+    tmp = []
+    buf = "." * 1024
+    10000.times { tmp << @wr.kgio_trywrite(buf) }
+    assert_equal Kgio::WaitWritable, tmp.pop
+  end
+
   def test_wait_writable_method
     def @wr.moo
       defined?(@z) ? raise(RuntimeError, "Hello") : @z = "HI"
diff --git a/test/test_pipe_popen.rb b/test/test_pipe_popen.rb
index 1f56979..8d1e414 100644
--- a/test/test_pipe_popen.rb
+++ b/test/test_pipe_popen.rb
@@ -6,7 +6,7 @@ require 'kgio'
 class TestPipePopen < Test::Unit::TestCase
   def test_popen
     io = Kgio::Pipe.popen("sleep 1 && echo HI")
-    assert_equal Kgio::WaitReadable, io.kgio_read(2)
+    assert_equal Kgio::WaitReadable, io.kgio_tryread(2)
     sleep 1.5
     assert_equal "HI\n", io.kgio_read(3)
     assert_nil io.kgio_read(5)