From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from smtpng2.m.smailru.net (smtpng2.m.smailru.net [94.100.179.3]) (using TLSv1.2 with cipher ECDHE-RSA-AES256-GCM-SHA384 (256/256 bits)) (No client certificate requested) by dev.tarantool.org (Postfix) with ESMTPS id C1A8046970F for ; Thu, 21 Nov 2019 10:08:45 +0300 (MSK) From: Leonid Date: Thu, 21 Nov 2019 10:08:42 +0300 Message-Id: Subject: [Tarantool-patches] [PATCH] Align the lua sockets API to documentation List-Id: Tarantool development patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , To: alexander.turenko@tarantool.org Cc: tarantool-patches@dev.tarantool.org https://github.com/tarantool/tarantool/issues/4087 https://github.com/tarantool/tarantool/tree/lvasiliev/gh-4087-fix-socket-stuff --- src/lua/socket.lua | 182 +++++++++++++++++++++++++++++++++++---- test/app/socket.result | 5 +- test/app/socket.test.lua | 2 +- 3 files changed, 170 insertions(+), 19 deletions(-) diff --git a/src/lua/socket.lua b/src/lua/socket.lua index a334ad45b..70a0f50f4 100644 --- a/src/lua/socket.lua +++ b/src/lua/socket.lua @@ -80,11 +80,13 @@ local function check_socket(socket) if fd >= 0 then return fd else - error("attempt to use closed socket") + return nil end else local msg = "Usage: socket:method()" - if socket ~= nil then msg = msg .. ", called with non-socket" end + if socket ~= nil then + msg = msg .. ", called with non-socket" + end error(msg) end end @@ -102,6 +104,12 @@ local gc_socket_sentinel = ffi.new(gc_socket_t, { fd = -1 }) local function socket_close(socket) local fd = check_socket(socket) + + if fd == nil then + socket._errno = boxerrno.EBADF + return false + end + socket._errno = nil local r = ffi.C.coio_close(fd) -- .fd is const to prevent tampering @@ -110,6 +118,7 @@ local function socket_close(socket) socket._errno = boxerrno() return false end + return true end @@ -131,6 +140,12 @@ local soname_mt = { local function socket_name(self) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + local aka = internal.name(fd) if aka == nil then self._errno = boxerrno() @@ -143,6 +158,12 @@ end local function socket_peer(self) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return false + end + local peer = internal.peer(fd) if peer == nil then self._errno = boxerrno() @@ -204,7 +225,7 @@ local function getprotobyname(name) end local function socket_errno(self) - check_socket(self) + check_socket(self) -- validate type if self['_errno'] == nil then return 0 else @@ -213,7 +234,7 @@ local function socket_errno(self) end local function socket_error(self) - check_socket(self) + check_socket(self) -- validate type if self['_errno'] == nil then return nil else @@ -227,6 +248,12 @@ local addr = ffi.cast('struct sockaddr *', addrbuf) local addr_len = ffi.new('socklen_t[1]') local function socket_sysconnect(self, host, port) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return false + end + self._errno = nil host = tostring(host) @@ -246,6 +273,12 @@ end local function syswrite(self, charptr, size) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + self._errno = nil local done = ffi.C.write(fd, charptr, size) if done < 0 then @@ -270,6 +303,11 @@ end local function sysread(self, charptr, size) local fd = check_socket(self) + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + self._errno = nil local res = ffi.C.read(fd, charptr, size) if res < 0 then @@ -308,6 +346,12 @@ end local function socket_nonblock(self, nb) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + self._errno = nil local res @@ -335,6 +379,11 @@ end local function do_wait(self, what, timeout) local fd = check_socket(self) + if fd == nil then + self._errno = boxerrno.EBADF + return 0 + end + self._errno = nil timeout = timeout or TIMEOUT_INFINITY @@ -352,19 +401,28 @@ local function do_wait(self, what, timeout) end local function socket_readable(self, timeout) - return do_wait(self, 1, timeout) ~= 0 + local wait_result = do_wait(self, 'R', timeout) + return check_socket(self) and wait_result ~= 0 end local function socket_writable(self, timeout) - return do_wait(self, 2, timeout) ~= 0 + local wait_result = do_wait(self, 'W', timeout) + return check_socket(self) and wait_result ~= 0 end local function socket_wait(self, timeout) - return do_wait(self, 'RW', timeout) + local wait_result = do_wait(self, 'RW', timeout) + return check_socket(self) and wait_result end local function socket_listen(self, backlog) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return false + end + self._errno = nil if backlog == nil then backlog = 256 @@ -379,6 +437,12 @@ end local function socket_bind(self, host, port) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return false + end + self._errno = nil host = tostring(host) @@ -449,6 +513,11 @@ end local function socket_setsockopt(self, level, name, value) local fd = check_socket(self) + if fd == nil then + self._errno = boxerrno.EBADF + return false + end + level = getsol(level) if level == nil then self._errno = boxerrno() @@ -505,6 +574,11 @@ end local function socket_getsockopt(self, level, name) local fd = check_socket(self) + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + level = getsol(level) if level == nil then self._errno = boxerrno() @@ -557,6 +631,11 @@ end local function socket_linger(self, active, timeout) local fd = check_socket(self) + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + local level = internal.SOL_SOCKET local info = internal.SO_OPT[level].SO_LINGER self._errno = nil @@ -602,6 +681,12 @@ end local function socket_accept(self) local server_fd = check_socket(self) + + if server_fd == nil then + self._errno = boxerrno.EBADF + return nil + end + self._errno = nil local client_fd, from = internal.accept(server_fd) @@ -727,7 +812,13 @@ local function read(self, limit, timeout, check, ...) end local function socket_read(self, opts, timeout) - check_socket(self) + local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + timeout = timeout or TIMEOUT_INFINITY if type(opts) == 'number' then return read(self, opts, timeout, check_limit) @@ -748,7 +839,13 @@ local function socket_read(self, opts, timeout) end local function socket_write(self, octets, timeout) - check_socket(self) + local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + if timeout == nil then timeout = TIMEOUT_INFINITY end @@ -781,6 +878,12 @@ end local function socket_send(self, octets, flags) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + local iflags = get_iflags(internal.SEND_FLAGS, flags) self._errno = nil @@ -843,6 +946,12 @@ end local function socket_recv(self, size, flags) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + local iflags = get_iflags(internal.SEND_FLAGS, flags) if iflags == nil then self._errno = boxerrno.EINVAL @@ -867,6 +976,12 @@ end local function socket_recvfrom(self, size, flags) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + local iflags = get_iflags(internal.SEND_FLAGS, flags) if iflags == nil then self._errno = boxerrno.EINVAL @@ -889,6 +1004,12 @@ end local function socket_sendto(self, host, port, octets, flags) local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil + end + local iflags = get_iflags(internal.SEND_FLAGS, flags) if iflags == nil then @@ -1285,7 +1406,11 @@ local lsocket_tcp_client_mt local function lsocket_tcp_tostring(self) local fd = check_socket(self) - return string.format("tcp{master}: fd=%d", fd) + if fd == nil then + return string.format("tcp{master}: invalid socket") + else + return string.format("tcp{master}: fd=%d", fd) + end end local function lsocket_tcp_close(self) @@ -1312,7 +1437,6 @@ local function lsocket_tcp_getpeername(self) end local function lsocket_tcp_settimeout(self, value, mode) - check_socket(self) self.timeout = value -- mode is effectively ignored return 1 @@ -1355,7 +1479,13 @@ local function lsocket_tcp_listen(self, backlog) end local function lsocket_tcp_connect(self, host, port) - check_socket(self) + local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil, socket_error(self) + end + local deadline = fiber.clock() + (self.timeout or TIMEOUT_INFINITY) -- This function is broken by design local ga_opts = { family = 'AF_INET', type = 'SOCK_STREAM' } @@ -1395,11 +1525,21 @@ lsocket_tcp_mt = { local function lsocket_tcp_server_tostring(self) local fd = check_socket(self) - return string.format("tcp{server}: fd=%d", fd) + if fd == nil then + return string.format("tcp{server}: invalid socket") + else + return string.format("tcp{server}: fd=%d", fd) + end end local function lsocket_tcp_accept(self) - check_socket(self) + local fd = check_socket(self) + + if fd == nil then + self._errno = boxerrno.EBADF + return nil, socket_error(self) + end + local deadline = fiber.clock() + (self.timeout or TIMEOUT_INFINITY) repeat local client = socket_accept(self) @@ -1434,11 +1574,21 @@ lsocket_tcp_server_mt = { local function lsocket_tcp_client_tostring(self) local fd = check_socket(self) - return string.format("tcp{client}: fd=%d", fd) + if fd == nil then + return string.format("tcp{client}: invalid socket") + else + return string.format("tcp{client}: fd=%d", fd) + end end local function lsocket_tcp_receive(self, pattern, prefix) - check_socket(self) + local fd = check_socket(self) + + if fd == nil then + socket._errno = boxerrno.EBADF + return nil, socket_error(self) + end + prefix = prefix or '' local timeout = self.timeout or TIMEOUT_INFINITY local data diff --git a/test/app/socket.result b/test/app/socket.result index fd299424c..c61bf9cfa 100644 --- a/test/app/socket.result +++ b/test/app/socket.result @@ -1697,10 +1697,11 @@ s:close() --- - 1 ... --- Sic: incompatible with Lua Socket +-- Second close return false and set errno = EBADF s:close() --- -- error: 'builtin/socket.lua: attempt to use closed socket' +- null +- Bad file descriptor ... s = socket.tcp() --- diff --git a/test/app/socket.test.lua b/test/app/socket.test.lua index c72d41763..6e2ca3ef1 100644 --- a/test/app/socket.test.lua +++ b/test/app/socket.test.lua @@ -585,7 +585,7 @@ test_run:cmd("push filter 'fd=([0-9]+)' to 'fd='") s = socket.tcp() s s:close() --- Sic: incompatible with Lua Socket +-- Second close return false and set errno = EBADF s:close() s = socket.tcp() -- 2.17.1