about summary refs log tree commit homepage
path: root/ext/kgio/kgio_ext.c
diff options
context:
space:
mode:
Diffstat (limited to 'ext/kgio/kgio_ext.c')
-rw-r--r--ext/kgio/kgio_ext.c196
1 files changed, 117 insertions, 79 deletions
diff --git a/ext/kgio/kgio_ext.c b/ext/kgio/kgio_ext.c
index f2d30ba..e301971 100644
--- a/ext/kgio/kgio_ext.c
+++ b/ext/kgio/kgio_ext.c
@@ -27,15 +27,15 @@
  */
 #  define USE_MSG_DONTWAIT
 static int accept4_flags = SOCK_CLOEXEC;
-#else
+#else /* ! linux */
 static int accept4_flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
-#endif
+#endif /* ! linux */
 
 static VALUE cSocket;
 static VALUE localhost;
 static VALUE mKgio_WaitReadable, mKgio_WaitWritable;
 static ID io_wait_rd, io_wait_wr;
-static ID iv_kgio_addr, id_ruby;
+static ID iv_kgio_addr;
 
 struct io_args {
         VALUE io;
@@ -45,38 +45,28 @@ struct io_args {
         int fd;
 };
 
-static int maybe_wait_readable(VALUE io)
+static void wait_readable(VALUE io)
 {
         if (io_wait_rd) {
-                if (io_wait_rd == id_ruby) {
-                        if (! rb_io_wait_readable(my_fileno(io)))
-                                rb_sys_fail("wait readable");
-                        errno = 0;
-                } else {
-                        errno = 0;
-                        (void)rb_funcall(io, io_wait_rd, 0, 0);
-                }
-                return 1;
+                (void)rb_funcall(io, io_wait_rd, 0, 0);
+        } else {
+                int fd = my_fileno(io);
+
+                if (!rb_io_wait_readable(fd))
+                        rb_sys_fail("wait readable");
         }
-        errno = 0;
-        return 0;
 }
 
-static int maybe_wait_writable(VALUE io)
+static void wait_writable(VALUE io)
 {
         if (io_wait_wr) {
-                if (io_wait_wr == id_ruby) {
-                        if (! rb_io_wait_writable(my_fileno(io)))
-                                rb_sys_fail("wait writable");
-                        errno = 0;
-                } else {
-                        errno = 0;
-                        (void)rb_funcall(io, io_wait_wr, 0, 0);
-                }
-                return 1;
+                (void)rb_funcall(io, io_wait_wr, 0, 0);
+        } else {
+                int fd = my_fileno(io);
+
+                if (!rb_io_wait_writable(fd))
+                        rb_sys_fail("wait writable");
         }
-        errno = 0;
-        return 0;
 }
 
 static void prepare_read(struct io_args *a, int argc, VALUE *argv, VALUE io)
@@ -96,14 +86,15 @@ static void prepare_read(struct io_args *a, int argc, VALUE *argv, VALUE io)
         a->ptr = RSTRING_PTR(a->buf);
 }
 
-static int read_check(struct io_args *a, long n, const char *msg)
+static int read_check(struct io_args *a, long n, const char *msg, int io_wait)
 {
         if (n == -1) {
                 if (errno == EINTR)
                         return -1;
                 rb_str_set_len(a->buf, 0);
                 if (errno == EAGAIN) {
-                        if (maybe_wait_readable(a->io)) {
+                        if (io_wait) {
+                                wait_readable(a->io);
                                 return -1;
                         } else {
                                 a->buf = mKgio_WaitReadable;
@@ -118,64 +109,75 @@ static int read_check(struct io_args *a, long n, const char *msg)
         return 0;
 }
 
-#ifdef USE_MSG_DONTWAIT
-
 /*
- * Document-method: Kgio::SocketMethods#kgio_read
+ * Document-method: Kgio::PipeMethods#kgio_read
  *
  * call-seq:
  *
- *        socket.kgio_read(maxlen) => buffer or Kgio::WaitReadable
- *        socket.kgio_read(maxlen, buffer) => buffer or Kgio::WaitReadable
+ *        socket.kgio_read(maxlen)  ->  buffer
+ *        socket.kgio_read(maxlen, buffer)  ->  buffer
  *
  * Reads at most maxlen bytes from the stream socket.  Returns with a
  * newly allocated buffer, or may reuse an existing buffer.  This
- * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
- * which case it will call the method referred to by Kgio.wait_readable.
+ * calls the method identified by Kgio.wait_readable, or uses
+ * the normal, thread-safe Ruby function to wait for readability.
+ * This returns nil on EOF.
+ *
+ * This behaves like read(2) and IO#readpartial, NOT fread(3) or
+ * IO#read which possess read-in-full behavior.
  */
-static VALUE kgio_recv(int argc, VALUE *argv, VALUE io)
+static VALUE my_read(int io_wait, int argc, VALUE *argv, VALUE io)
 {
         struct io_args a;
         long n;
 
         prepare_read(&a, argc, argv, io);
+        set_nonblocking(a.fd);
 retry:
-        n = (long)recv(a.fd, a.ptr, a.len, MSG_DONTWAIT);
-        if (read_check(&a, n, "recv") != 0)
+        n = (long)read(a.fd, a.ptr, a.len);
+        if (read_check(&a, n, "read", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
-#else /* ! USE_MSG_DONTWAIT */
-#  define kgio_recv kgio_read
-#endif /* USE_MSG_DONTWAIT */
 
-/*
- * Document-method: Kgio::PipeMethods#kgio_read
- *
- * call-seq:
- *
- *        socket.kgio_read(maxlen)  ->  buffer or Kgio::WaitReadable
- *        socket.kgio_read(maxlen, buffer)  ->  buffer or Kgio::WaitReadable
- *
- * Reads at most maxlen bytes from the stream socket.  Returns with a
- * newly allocated buffer, or may reuse an existing buffer.  This
- * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
- * which case it will call the method referred to by Kgio.wait_readable.
- */
 static VALUE kgio_read(int argc, VALUE *argv, VALUE io)
 {
+        return my_read(1, argc, argv, io);
+}
+
+static VALUE kgio_tryread(int argc, VALUE *argv, VALUE io)
+{
+        return my_read(0, argc, argv, io);
+}
+
+#ifdef USE_MSG_DONTWAIT
+static VALUE my_recv(int io_wait, int argc, VALUE *argv, VALUE io)
+{
         struct io_args a;
         long n;
 
         prepare_read(&a, argc, argv, io);
-        set_nonblocking(a.fd);
 retry:
-        n = (long)read(a.fd, a.ptr, a.len);
-        if (read_check(&a, n, "read") != 0)
+        n = (long)recv(a.fd, a.ptr, a.len, MSG_DONTWAIT);
+        if (read_check(&a, n, "recv", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
 
+static VALUE kgio_recv(int argc, VALUE *argv, VALUE io)
+{
+        return my_recv(1, argc, argv, io);
+}
+
+static VALUE kgio_tryrecv(int argc, VALUE *argv, VALUE io)
+{
+        return my_recv(0, argc, argv, io);
+}
+#else /* ! USE_MSG_DONTWAIT */
+#  define kgio_recv kgio_read
+#  define kgio_tryrecv kgio_tryread
+#endif /* USE_MSG_DONTWAIT */
+
 static void prepare_write(struct io_args *a, VALUE io, VALUE str)
 {
         a->buf = (TYPE(str) == T_STRING) ? str : rb_obj_as_string(str);
@@ -185,7 +187,7 @@ static void prepare_write(struct io_args *a, VALUE io, VALUE str)
         a->fd = my_fileno(io);
 }
 
-static int write_check(struct io_args *a, long n, const char *msg)
+static int write_check(struct io_args *a, long n, const char *msg, int io_wait)
 {
         if (a->len == n) {
                 a->buf = Qnil;
@@ -193,29 +195,28 @@ static int write_check(struct io_args *a, long n, const char *msg)
                 if (errno == EINTR)
                         return -1;
                 if (errno == EAGAIN) {
-                        if (maybe_wait_writable(a->io))
+                        if (io_wait) {
+                                wait_writable(a->io);
                                 return -1;
-                        a->buf = mKgio_WaitWritable;
-                        return 0;
+                        } else {
+                                a->buf = mKgio_WaitWritable;
+                                return 0;
+                        }
                 }
                 rb_sys_fail(msg);
         } else {
                 assert(n >= 0 && n < a->len && "write/send syscall broken?");
+                if (io_wait) {
+                        a->ptr += n;
+                        a->len -= n;
+                        return -1;
+                }
                 a->buf = rb_str_new(a->ptr + n, a->len - n);
         }
         return 0;
 }
 
-/*
- * Returns a String containing the unwritten portion if there was a
- * partial write.
- *
- * Returns true if the write was completed.
- *
- * Returns Kgio::WaitWritable if the write would block and
- * Kgio.wait_writable is not set
- */
-static VALUE kgio_write(VALUE io, VALUE str)
+static VALUE my_write(VALUE io, VALUE str, int io_wait)
 {
         struct io_args a;
         long n;
@@ -224,18 +225,40 @@ static VALUE kgio_write(VALUE io, VALUE str)
         set_nonblocking(a.fd);
 retry:
         n = (long)write(a.fd, a.ptr, a.len);
-        if (write_check(&a, n, "write") != 0)
+        if (write_check(&a, n, "write", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
 
+/*
+ * Returns true if the write was completed.
+ *
+ * Calls the method Kgio.wait_writable is not set
+ */
+static VALUE kgio_write(VALUE io, VALUE str)
+{
+        return my_write(io, str, 1);
+}
+
+/*
+ * Returns a String containing the unwritten portion if there was a
+ * partial write.  Will return Kgio::WaitReadable if EAGAIN is
+ * encountered.
+ *
+ * Returns true if the write completed in full.
+ */
+static VALUE kgio_trywrite(VALUE io, VALUE str)
+{
+        return my_write(io, str, 0);
+}
+
 #ifdef USE_MSG_DONTWAIT
 /*
  * This method behaves like Kgio::PipeMethods#kgio_write, except
  * it will use send(2) with the MSG_DONTWAIT flag on sockets to
  * avoid unnecessary calls to fcntl(2).
  */
-static VALUE kgio_send(VALUE io, VALUE str)
+static VALUE my_send(VALUE io, VALUE str, int io_wait)
 {
         struct io_args a;
         long n;
@@ -243,25 +266,36 @@ static VALUE kgio_send(VALUE io, VALUE str)
         prepare_write(&a, io, str);
 retry:
         n = (long)send(a.fd, a.ptr, a.len, MSG_DONTWAIT);
-        if (write_check(&a, n, "send") != 0)
+        if (write_check(&a, n, "send", io_wait) != 0)
                 goto retry;
         return a.buf;
 }
+
+static VALUE kgio_send(VALUE io, VALUE str)
+{
+        return my_send(io, str, 1);
+}
+
+static VALUE kgio_trysend(VALUE io, VALUE str)
+{
+        return my_send(io, str, 0);
+}
 #else /* ! USE_MSG_DONTWAIT */
 #  define kgio_send kgio_write
+#  define kgio_trysend kgio_trywrite
 #endif /* ! USE_MSG_DONTWAIT */
 
 /*
  * call-seq:
  *
- *         Kgio.wait_readable = :method_name
+ *        Kgio.wait_readable = :method_name
  *
  * Sets a method for kgio_read to call when a read would block.
  * This is useful for non-blocking frameworks that use Fibers,
  * as the method referred to this may cause the current Fiber
  * to yield execution.
  *
- * A special value of ":ruby" will cause Ruby to wait using the
+ * A special value of nil will cause Ruby to wait using the
  * rb_io_wait_readable() function, giving kgio_read similar semantics to
  * IO#readpartial.
  */
@@ -439,7 +473,8 @@ my_connect(VALUE klass, int domain, void *addr, socklen_t addrlen)
                 if (errno == EINPROGRESS) {
                         VALUE io = sock_for_fd(klass, fd);
 
-                        (void)maybe_wait_writable(io);
+                        errno = EAGAIN;
+                        wait_writable(io);
                         return io;
                 }
                 rb_sys_fail("connect");
@@ -593,10 +628,14 @@ void Init_kgio_ext(void)
         mPipeMethods = rb_define_module_under(mKgio, "PipeMethods");
         rb_define_method(mPipeMethods, "kgio_read", kgio_read, -1);
         rb_define_method(mPipeMethods, "kgio_write", kgio_write, 1);
+        rb_define_method(mPipeMethods, "kgio_tryread", kgio_tryread, -1);
+        rb_define_method(mPipeMethods, "kgio_trywrite", kgio_trywrite, 1);
 
         mSocketMethods = rb_define_module_under(mKgio, "SocketMethods");
         rb_define_method(mSocketMethods, "kgio_read", kgio_recv, -1);
         rb_define_method(mSocketMethods, "kgio_write", kgio_send, 1);
+        rb_define_method(mSocketMethods, "kgio_tryread", kgio_tryrecv, -1);
+        rb_define_method(mSocketMethods, "kgio_trywrite", kgio_trysend, 1);
 
         rb_define_attr(mSocketMethods, "kgio_addr", 1, 1);
         rb_include_module(cSocket, mSocketMethods);
@@ -621,6 +660,5 @@ void Init_kgio_ext(void)
         rb_define_singleton_method(cUNIXSocket, "new", kgio_unix_connect, 1);
 
         iv_kgio_addr = rb_intern("@kgio_addr");
-        id_ruby = rb_intern("ruby");
         init_sock_for_fd();
 }