[Tarantool-patches] [PATCH] Align the lua sockets API to documentation

Leonid lvasiliev at tarantool.org
Thu Nov 21 10:08:42 MSK 2019


https://github.com/tarantool/tarantool/issues/4087
https://github.com/tarantool/tarantool/tree/lvasiliev/gh-4087-fix-socket-stuff

---
 src/lua/socket.lua       | 182 +++++++++++++++++++++++++++++++++++----
 test/app/socket.result   |   5 +-
 test/app/socket.test.lua |   2 +-
 3 files changed, 170 insertions(+), 19 deletions(-)

diff --git a/src/lua/socket.lua b/src/lua/socket.lua
index a334ad45b..70a0f50f4 100644
--- a/src/lua/socket.lua
+++ b/src/lua/socket.lua
@@ -80,11 +80,13 @@ local function check_socket(socket)
         if fd >= 0 then
             return fd
         else
-            error("attempt to use closed socket")
+            return nil
         end
     else
         local msg = "Usage: socket:method()"
-        if socket ~= nil then msg = msg .. ", called with non-socket" end
+        if socket ~= nil then
+            msg = msg .. ", called with non-socket"
+        end
         error(msg)
     end
 end
@@ -102,6 +104,12 @@ local gc_socket_sentinel = ffi.new(gc_socket_t, { fd = -1 })
 
 local function socket_close(socket)
     local fd = check_socket(socket)
+
+    if fd == nil then
+        socket._errno = boxerrno.EBADF
+        return false
+    end
+
     socket._errno = nil
     local r = ffi.C.coio_close(fd)
     -- .fd is const to prevent tampering
@@ -110,6 +118,7 @@ local function socket_close(socket)
         socket._errno = boxerrno()
         return false
     end
+
     return true
 end
 
@@ -131,6 +140,12 @@ local soname_mt = {
 
 local function socket_name(self)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     local aka = internal.name(fd)
     if aka == nil then
         self._errno = boxerrno()
@@ -143,6 +158,12 @@ end
 
 local function socket_peer(self)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return false
+    end
+
     local peer = internal.peer(fd)
     if peer == nil then
         self._errno = boxerrno()
@@ -204,7 +225,7 @@ local function getprotobyname(name)
 end
 
 local function socket_errno(self)
-    check_socket(self)
+    check_socket(self) -- validate type
     if self['_errno'] == nil then
         return 0
     else
@@ -213,7 +234,7 @@ local function socket_errno(self)
 end
 
 local function socket_error(self)
-    check_socket(self)
+    check_socket(self) -- validate type
     if self['_errno'] == nil then
         return nil
     else
@@ -227,6 +248,12 @@ local addr = ffi.cast('struct sockaddr *', addrbuf)
 local addr_len = ffi.new('socklen_t[1]')
 local function socket_sysconnect(self, host, port)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return false
+    end
+
     self._errno = nil
 
     host = tostring(host)
@@ -246,6 +273,12 @@ end
 
 local function syswrite(self, charptr, size)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     self._errno = nil
     local done = ffi.C.write(fd, charptr, size)
     if done < 0 then
@@ -270,6 +303,11 @@ end
 local function sysread(self, charptr, size)
     local fd = check_socket(self)
 
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     self._errno = nil
     local res = ffi.C.read(fd, charptr, size)
     if res < 0 then
@@ -308,6 +346,12 @@ end
 
 local function socket_nonblock(self, nb)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     self._errno = nil
 
     local res
@@ -335,6 +379,11 @@ end
 local function do_wait(self, what, timeout)
     local fd = check_socket(self)
 
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return 0
+    end
+
     self._errno = nil
     timeout = timeout or TIMEOUT_INFINITY
 
@@ -352,19 +401,28 @@ local function do_wait(self, what, timeout)
 end
 
 local function socket_readable(self, timeout)
-    return do_wait(self, 1, timeout) ~= 0
+    local wait_result = do_wait(self, 'R', timeout)
+    return check_socket(self) and wait_result ~= 0
 end
 
 local function socket_writable(self, timeout)
-    return do_wait(self, 2, timeout) ~= 0
+    local wait_result = do_wait(self, 'W', timeout)
+    return check_socket(self) and wait_result ~= 0
 end
 
 local function socket_wait(self, timeout)
-    return do_wait(self, 'RW', timeout)
+    local wait_result = do_wait(self, 'RW', timeout)
+    return check_socket(self) and wait_result
 end
 
 local function socket_listen(self, backlog)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return false
+    end
+
     self._errno = nil
     if backlog == nil then
         backlog = 256
@@ -379,6 +437,12 @@ end
 
 local function socket_bind(self, host, port)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return false
+    end
+
     self._errno = nil
 
     host = tostring(host)
@@ -449,6 +513,11 @@ end
 local function socket_setsockopt(self, level, name, value)
     local fd = check_socket(self)
 
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return false
+    end
+
     level = getsol(level)
     if level == nil then
         self._errno = boxerrno()
@@ -505,6 +574,11 @@ end
 local function socket_getsockopt(self, level, name)
     local fd = check_socket(self)
 
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     level = getsol(level)
     if level == nil then
         self._errno = boxerrno()
@@ -557,6 +631,11 @@ end
 local function socket_linger(self, active, timeout)
     local fd = check_socket(self)
 
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     local level = internal.SOL_SOCKET
     local info = internal.SO_OPT[level].SO_LINGER
     self._errno = nil
@@ -602,6 +681,12 @@ end
 
 local function socket_accept(self)
     local server_fd = check_socket(self)
+
+    if server_fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     self._errno = nil
 
     local client_fd, from = internal.accept(server_fd)
@@ -727,7 +812,13 @@ local function read(self, limit, timeout, check, ...)
 end
 
 local function socket_read(self, opts, timeout)
-    check_socket(self)
+    local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     timeout = timeout or TIMEOUT_INFINITY
     if type(opts) == 'number' then
         return read(self, opts, timeout, check_limit)
@@ -748,7 +839,13 @@ local function socket_read(self, opts, timeout)
 end
 
 local function socket_write(self, octets, timeout)
-    check_socket(self)
+    local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     if timeout == nil then
         timeout = TIMEOUT_INFINITY
     end
@@ -781,6 +878,12 @@ end
 
 local function socket_send(self, octets, flags)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     local iflags = get_iflags(internal.SEND_FLAGS, flags)
 
     self._errno = nil
@@ -843,6 +946,12 @@ end
 
 local function socket_recv(self, size, flags)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     local iflags = get_iflags(internal.SEND_FLAGS, flags)
     if iflags == nil then
         self._errno = boxerrno.EINVAL
@@ -867,6 +976,12 @@ end
 
 local function socket_recvfrom(self, size, flags)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     local iflags = get_iflags(internal.SEND_FLAGS, flags)
     if iflags == nil then
         self._errno = boxerrno.EINVAL
@@ -889,6 +1004,12 @@ end
 
 local function socket_sendto(self, host, port, octets, flags)
     local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil
+    end
+
     local iflags = get_iflags(internal.SEND_FLAGS, flags)
 
     if iflags == nil then
@@ -1285,7 +1406,11 @@ local lsocket_tcp_client_mt
 
 local function lsocket_tcp_tostring(self)
     local fd = check_socket(self)
-    return string.format("tcp{master}: fd=%d", fd)
+    if fd == nil then
+        return string.format("tcp{master}: invalid socket")
+    else
+        return string.format("tcp{master}: fd=%d", fd)
+    end
 end
 
 local function lsocket_tcp_close(self)
@@ -1312,7 +1437,6 @@ local function lsocket_tcp_getpeername(self)
 end
 
 local function lsocket_tcp_settimeout(self, value, mode)
-    check_socket(self)
     self.timeout = value
     -- mode is effectively ignored
     return 1
@@ -1355,7 +1479,13 @@ local function lsocket_tcp_listen(self, backlog)
 end
 
 local function lsocket_tcp_connect(self, host, port)
-    check_socket(self)
+    local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil, socket_error(self)
+    end
+
     local deadline = fiber.clock() + (self.timeout or TIMEOUT_INFINITY)
     -- This function is broken by design
     local ga_opts = { family = 'AF_INET', type = 'SOCK_STREAM' }
@@ -1395,11 +1525,21 @@ lsocket_tcp_mt = {
 
 local function lsocket_tcp_server_tostring(self)
     local fd = check_socket(self)
-    return string.format("tcp{server}: fd=%d", fd)
+    if fd == nil then
+        return string.format("tcp{server}: invalid socket")
+    else
+        return string.format("tcp{server}: fd=%d", fd)
+    end
 end
 
 local function lsocket_tcp_accept(self)
-    check_socket(self)
+    local fd = check_socket(self)
+
+    if fd == nil then
+        self._errno = boxerrno.EBADF
+        return nil, socket_error(self)
+    end
+
     local deadline = fiber.clock() + (self.timeout or TIMEOUT_INFINITY)
     repeat
         local client = socket_accept(self)
@@ -1434,11 +1574,21 @@ lsocket_tcp_server_mt = {
 
 local function lsocket_tcp_client_tostring(self)
     local fd = check_socket(self)
-    return string.format("tcp{client}: fd=%d", fd)
+    if fd == nil then
+        return string.format("tcp{client}: invalid socket")
+    else
+        return string.format("tcp{client}: fd=%d", fd)
+    end
 end
 
 local function lsocket_tcp_receive(self, pattern, prefix)
-    check_socket(self)
+    local fd = check_socket(self)
+
+    if fd == nil then
+        socket._errno = boxerrno.EBADF
+        return nil, socket_error(self)
+    end
+
     prefix = prefix or ''
     local timeout = self.timeout or TIMEOUT_INFINITY
     local data
diff --git a/test/app/socket.result b/test/app/socket.result
index fd299424c..c61bf9cfa 100644
--- a/test/app/socket.result
+++ b/test/app/socket.result
@@ -1697,10 +1697,11 @@ s:close()
 ---
 - 1
 ...
--- Sic: incompatible with Lua Socket
+-- Second close return false and set errno = EBADF
 s:close()
 ---
-- error: 'builtin/socket.lua: attempt to use closed socket'
+- null
+- Bad file descriptor
 ...
 s = socket.tcp()
 ---
diff --git a/test/app/socket.test.lua b/test/app/socket.test.lua
index c72d41763..6e2ca3ef1 100644
--- a/test/app/socket.test.lua
+++ b/test/app/socket.test.lua
@@ -585,7 +585,7 @@ test_run:cmd("push filter 'fd=([0-9]+)' to 'fd=<FD>'")
 s = socket.tcp()
 s
 s:close()
--- Sic: incompatible with Lua Socket
+-- Second close return false and set errno = EBADF
 s:close()
 
 s = socket.tcp()
-- 
2.17.1



More information about the Tarantool-patches mailing list