Tarantool development patches archive
 help / color / mirror / Atom feed
From: Mergen Imeev via Tarantool-patches <tarantool-patches@dev.tarantool.org>
To: kyukhin@tarantool.org
Cc: tarantool-patches@dev.tarantool.org
Subject: [Tarantool-patches] [PATCH v2 3/4] sql: introduce custom aggregate functions
Date: Thu, 10 Feb 2022 12:14:15 +0300	[thread overview]
Message-ID: <d23ca922ff5dc5c493dcf40a0e74df3747337b58.1644484227.git.imeevma@gmail.com> (raw)
In-Reply-To: <cover.1644484227.git.imeevma@gmail.com>

This patch introduces user-defined aggregate functions in SQL.

Part of #2579
---
 src/box/alter.cc                              |  13 ++
 src/box/lua/schema.lua                        |   2 +-
 src/box/sql/expr.c                            |  17 ++-
 src/box/sql/func.c                            |  18 ++-
 src/box/sql/select.c                          |  31 +++--
 test/sql-tap/CMakeLists.txt                   |   2 +
 test/sql-tap/gh-2579-custom-aggregate.c       |  28 +++++
 .../sql-tap/gh-2579-custom-aggregate.test.lua | 114 ++++++++++++++++++
 8 files changed, 205 insertions(+), 20 deletions(-)
 create mode 100644 test/sql-tap/gh-2579-custom-aggregate.c
 create mode 100755 test/sql-tap/gh-2579-custom-aggregate.test.lua

diff --git a/src/box/alter.cc b/src/box/alter.cc
index 65c1cb952..b85d279e3 100644
--- a/src/box/alter.cc
+++ b/src/box/alter.cc
@@ -3460,6 +3460,13 @@ func_def_new_from_tuple(struct tuple *tuple)
 				  def->name, "invalid aggregate value");
 			return NULL;
 		}
+		if (def->aggregate == FUNC_AGGREGATE_GROUP &&
+		    def->exports.lua) {
+			diag_set(ClientError, ER_CREATE_FUNCTION, def->name,
+				 "aggregate function can only be accessed in "
+				 "SQL");
+			return NULL;
+		}
 		const char *param_list = tuple_field_with_type(tuple,
 			BOX_FUNC_FIELD_PARAM_LIST, MP_ARRAY);
 		if (param_list == NULL)
@@ -3482,6 +3489,12 @@ func_def_new_from_tuple(struct tuple *tuple)
 				return NULL;
 			}
 		}
+		if (def->aggregate == FUNC_AGGREGATE_GROUP && argc == 0) {
+			diag_set(ClientError, ER_CREATE_FUNCTION, def->name,
+				 "aggregate function must have at least one "
+				 "argument");
+			return NULL;
+		}
 		def->param_count = argc;
 		const char *opts = tuple_field(tuple, BOX_FUNC_FIELD_OPTS);
 		if (opts_decode(&def->opts, func_opts_reg, &opts,
diff --git a/src/box/lua/schema.lua b/src/box/lua/schema.lua
index 2c02949c5..23d6d0f64 100644
--- a/src/box/lua/schema.lua
+++ b/src/box/lua/schema.lua
@@ -2603,7 +2603,7 @@ box.schema.func.create = function(name, opts)
                               language = 'string', body = 'string',
                               is_deterministic = 'boolean',
                               is_sandboxed = 'boolean',
-                              is_multikey = 'boolean',
+                              is_multikey = 'boolean', aggregate = 'string',
                               takes_raw_args = 'boolean',
                               comment = 'string',
                               param_list = 'table', returns = 'string',
diff --git a/src/box/sql/expr.c b/src/box/sql/expr.c
index eb169aeb8..920ad9d08 100644
--- a/src/box/sql/expr.c
+++ b/src/box/sql/expr.c
@@ -5469,6 +5469,17 @@ analyzeAggregate(Walker * pWalker, Expr * pExpr)
 						       (pExpr, EP_xIsSelect));
 						pItem = &pAggInfo->aFunc[i];
 						pItem->pExpr = pExpr;
+						int n = pExpr->x.pList == NULL ?
+							0 : pExpr->x.pList->nExpr;
+						/*
+						 * Allocate n MEMs for arguments
+						 * and one more MEM for
+						 * accumulator. This makes it
+						 * easier to pass these n + 1
+						 * MEMs to the user-defined
+						 * aggregate function.
+						 */
+						pParse->nMem += n;
 						pItem->iMem = ++pParse->nMem;
 						assert(!ExprHasProperty
 						       (pExpr, EP_IntValue));
@@ -5479,12 +5490,6 @@ analyzeAggregate(Walker * pWalker, Expr * pExpr)
 								true;
 							return WRC_Abort;
 						}
-						assert(pItem->func->def->
-						       language ==
-						       FUNC_LANGUAGE_SQL_BUILTIN &&
-						       pItem->func->def->
-						       aggregate ==
-						       FUNC_AGGREGATE_GROUP);
 						if (pExpr->flags & EP_Distinct) {
 							pItem->iDistinct =
 								pParse->nTab++;
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index b69bf7fd6..dc2fd771a 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -2060,9 +2060,16 @@ sql_func_find(struct Expr *expr)
 		return NULL;
 	}
 	int n = expr->x.pList != NULL ? expr->x.pList->nExpr : 0;
-	if (func->def->param_count != n) {
+	/*
+	 * Aggregate functions accept state for the last argument, so it accepts
+	 * one argument less than it is said in definition.
+	 */
+	int argc = func->def->aggregate == FUNC_AGGREGATE_GROUP ?
+		   func->def->param_count - 1 : func->def->param_count;
+	assert(argc >= 0);
+	if (argc != n) {
 		diag_set(ClientError, ER_FUNC_WRONG_ARG_COUNT, name,
-			 tt_sprintf("%d", func->def->param_count), n);
+			 tt_sprintf("%d", argc), n);
 		return NULL;
 	}
 	return func;
@@ -2072,9 +2079,12 @@ uint32_t
 sql_func_flags(const char *name)
 {
 	struct sql_func_dictionary *dict = built_in_func_get(name);
-	if (dict == NULL)
+	if (dict != NULL)
+		return dict->flags;
+	struct func *func = func_by_name(name, strlen(name));
+	if (func == NULL || func->def->aggregate != FUNC_AGGREGATE_GROUP)
 		return 0;
-	return dict->flags;
+	return SQL_FUNC_AGG;
 }
 
 static struct func_vtab func_sql_builtin_vtab;
diff --git a/src/box/sql/select.c b/src/box/sql/select.c
index 6159a9670..dcea48c6e 100644
--- a/src/box/sql/select.c
+++ b/src/box/sql/select.c
@@ -4648,8 +4648,6 @@ is_simple_count(struct Select *select, struct AggInfo *agg_info)
 		return NULL;
 	if (NEVER(agg_info->nFunc == 0))
 		return NULL;
-	assert(agg_info->aFunc->func->def->language ==
-	       FUNC_LANGUAGE_SQL_BUILTIN);
 	if (strcmp(agg_info->aFunc->func->def->name, "COUNT") != 0 ||
 	    (agg_info->aFunc->pExpr->x.pList != NULL &&
 	     agg_info->aFunc->pExpr->x.pList->nExpr > 0))
@@ -5562,6 +5560,15 @@ resetAccumulator(Parse * pParse, AggInfo * pAggInfo)
 	}
 }
 
+static inline void
+finalize_agg_function(struct Vdbe *vdbe, const struct AggInfo_func *agg_func)
+{
+	if (agg_func->func->def->language == FUNC_LANGUAGE_SQL_BUILTIN) {
+		sqlVdbeAddOp1(vdbe, OP_AggFinal, agg_func->iMem);
+		sqlVdbeAppendP4(vdbe, agg_func->func, P4_FUNC);
+	}
+}
+
 /*
  * Invoke the OP_AggFinalize opcode for every aggregate function
  * in the AggInfo structure.
@@ -5574,8 +5581,7 @@ finalizeAggFunctions(Parse * pParse, AggInfo * pAggInfo)
 	struct AggInfo_func *pF;
 	for (i = 0, pF = pAggInfo->aFunc; i < pAggInfo->nFunc; i++, pF++) {
 		assert(!ExprHasProperty(pF->pExpr, EP_xIsSelect));
-		sqlVdbeAddOp1(v, OP_AggFinal, pF->iMem);
-		sqlVdbeAppendP4(v, pF->func, P4_FUNC);
+		finalize_agg_function(v, pF);
 	}
 }
 
@@ -5602,7 +5608,7 @@ updateAccumulator(Parse * pParse, AggInfo * pAggInfo)
 		assert(!ExprHasProperty(pF->pExpr, EP_xIsSelect));
 		if (pList) {
 			nArg = pList->nExpr;
-			regAgg = sqlGetTempRange(pParse, nArg);
+			regAgg = pF->iMem - nArg;
 			sqlExprCodeExprList(pParse, pList, regAgg, 0,
 						SQL_ECEL_DUP);
 		} else {
@@ -5642,10 +5648,17 @@ updateAccumulator(Parse * pParse, AggInfo * pAggInfo)
 			pParse->is_aborted = true;
 			return;
 		}
-		sqlVdbeAddOp3(v, OP_AggStep, nArg, regAgg, pF->iMem);
-		sqlVdbeAppendP4(v, ctx, P4_FUNCCTX);
-		sql_expr_type_cache_change(pParse, regAgg, nArg);
-		sqlReleaseTempRange(pParse, regAgg, nArg);
+		if (pF->func->def->language == FUNC_LANGUAGE_SQL_BUILTIN) {
+			sqlVdbeAddOp3(v, OP_AggStep, nArg, regAgg, pF->iMem);
+			sqlVdbeAppendP4(v, ctx, P4_FUNCCTX);
+		} else {
+			const char *name = pF->func->def->name;
+			uint32_t len = pF->func->def->name_len;
+			const char *str = sqlDbStrNDup(pParse->db, name, len);
+			assert(regAgg == pF->iMem - nArg);
+			sqlVdbeAddOp4(v, OP_FunctionByName, nArg + 1, regAgg,
+				      pF->iMem, str, P4_DYNAMIC);
+		}
 		if (addrNext) {
 			sqlVdbeResolveLabel(v, addrNext);
 			sqlExprCacheClear(pParse);
diff --git a/test/sql-tap/CMakeLists.txt b/test/sql-tap/CMakeLists.txt
index c4ec1214a..136a517d4 100644
--- a/test/sql-tap/CMakeLists.txt
+++ b/test/sql-tap/CMakeLists.txt
@@ -3,6 +3,8 @@ build_module(gh-5938-wrong-string-length gh-5938-wrong-string-length.c)
 target_link_libraries(gh-5938-wrong-string-length msgpuck)
 build_module(gh-6024-funcs-return-bin gh-6024-funcs-return-bin.c)
 target_link_libraries(gh-6024-funcs-return-bin msgpuck)
+build_module(gh-2579-custom-aggregate gh-2579-custom-aggregate.c)
+target_link_libraries(gh-2579-custom-aggregate msgpuck)
 build_module(sql_uuid sql_uuid.c)
 target_link_libraries(sql_uuid msgpuck core)
 build_module(decimal decimal.c)
diff --git a/test/sql-tap/gh-2579-custom-aggregate.c b/test/sql-tap/gh-2579-custom-aggregate.c
new file mode 100644
index 000000000..f7d8a70a4
--- /dev/null
+++ b/test/sql-tap/gh-2579-custom-aggregate.c
@@ -0,0 +1,28 @@
+#include "msgpuck.h"
+#include "module.h"
+
+enum {
+	BUF_SIZE = 512,
+};
+
+int
+f3(box_function_ctx_t *ctx, const char *args, const char *args_end)
+{
+	(void)args_end;
+	uint32_t arg_count = mp_decode_array(&args);
+	if (arg_count != 2) {
+		return box_error_set(__FILE__, __LINE__, ER_PROC_C,
+				     "invalid argument count");
+	}
+	int num = mp_decode_uint(&args);
+	int sum = 0;
+	if (mp_typeof(*args) != MP_UINT)
+		mp_decode_nil(&args);
+	else
+		sum = mp_decode_uint(&args);
+	sum += num * num;
+	char res[BUF_SIZE];
+	char *end = mp_encode_uint(res, sum);
+	box_return_mp(ctx, res, end);
+	return 0;
+}
diff --git a/test/sql-tap/gh-2579-custom-aggregate.test.lua b/test/sql-tap/gh-2579-custom-aggregate.test.lua
new file mode 100755
index 000000000..afd08888e
--- /dev/null
+++ b/test/sql-tap/gh-2579-custom-aggregate.test.lua
@@ -0,0 +1,114 @@
+#!/usr/bin/env tarantool
+local build_path = os.getenv("BUILDDIR")
+package.cpath = build_path..'/test/sql-tap/?.so;'..build_path..
+                '/test/sql-tap/?.dylib;'..package.cpath
+
+local test = require("sqltester")
+test:plan(5)
+
+test:execsql([[
+    CREATE TABLE t (i INT PRIMARY KEY);
+    INSERT INTO t VALUES(1), (2), (3), (4), (5);
+]])
+
+-- Make sure that persistent aggregate functions work as intended.
+box.schema.func.create("F1", {
+    language = "Lua",
+    body = [[
+        function(x, state)
+            if state == nil then
+                state = {sum = 0, count = 0}
+            end
+            state.sum = state.sum + x
+            state.count = state.count + 1
+            return state
+        end
+    ]],
+    param_list = {"integer", "map"},
+    returns = "map",
+    aggregate = "group",
+    exports = {"SQL"},
+})
+
+test:do_execsql_test(
+    "gh-2579-1",
+    [[
+        SELECT f1(i) from t;
+    ]], {
+        {sum = 15, count = 5}
+    })
+
+-- Make sure that non-persistent aggregate functions work as intended.
+local f2 = function(x, state)
+    if state == nil then
+        state = {}
+    end
+    table.insert(state, x)
+    return state
+end
+
+rawset(_G, 'F2', f2)
+
+box.schema.func.create("F2", {
+    language = "Lua",
+    param_list = {"integer", "array"},
+    returns = "array",
+    aggregate = "group",
+    exports = {"SQL"},
+})
+
+test:do_execsql_test(
+    "gh-2579-2",
+    [[
+        SELECT f2(i) from t;
+    ]], {
+        {1, 2, 3, 4, 5}
+    })
+
+-- Make sure that C aggregate functions work as intended.
+box.schema.func.create("gh-2579-custom-aggregate.f3", {
+    language = "C",
+    param_list = {"integer", "integer"},
+    returns = "integer",
+    aggregate = "group",
+    exports = {"SQL"},
+})
+
+test:do_execsql_test(
+    "gh-2579-3",
+    [[
+        SELECT "gh-2579-custom-aggregate.f3"(i) from t;
+    ]], {
+        55
+    })
+
+-- Make sure aggregate functions can't be called in Lua.
+test:do_test(
+    "gh-2579-4",
+    function()
+        local def = {aggregate = 'group', exports = {'LUA', 'SQL'}}
+        local res = {pcall(box.schema.func.create, 'F4', def)}
+        return {tostring(res[2])}
+    end, {
+        "Failed to create function 'F4': aggregate function can only be "..
+        "accessed in SQL"
+    })
+
+-- Make sure aggregate functions can't have less that 1 argument.
+test:do_test(
+    "gh-2579-5",
+    function()
+        local def = {aggregate = 'group', exports = {'SQL'}}
+        local res = {pcall(box.schema.func.create, 'F4', def)}
+        return {tostring(res[2])}
+    end, {
+        "Failed to create function 'F4': aggregate function must have at "..
+        "least one argument"
+    })
+
+box.schema.func.drop('gh-2579-custom-aggregate.f3')
+box.schema.func.drop('F2')
+box.schema.func.drop('F1')
+test:execsql([[DROP TABLE t;]])
+
+test:finish_test()
-- 
2.25.1


  parent reply	other threads:[~2022-02-10  9:15 UTC|newest]

Thread overview: 8+ messages / expand[flat|nested]  mbox.gz  Atom feed  top
2022-02-10  9:14 [Tarantool-patches] [PATCH v2 0/4] Introduce " Mergen Imeev via Tarantool-patches
2022-02-10  9:14 ` [Tarantool-patches] [PATCH v2 1/4] sql: fix COUNT() optimization conditions Mergen Imeev via Tarantool-patches
2022-02-10  9:14 ` [Tarantool-patches] [PATCH v2 2/4] sql: drop unnecessary P2 register for OP_AggFinal Mergen Imeev via Tarantool-patches
2022-02-10  9:14 ` Mergen Imeev via Tarantool-patches [this message]
2022-02-10  9:14 ` [Tarantool-patches] [PATCH v2 4/4] sql: introduce FINALIZE for custom aggregate Mergen Imeev via Tarantool-patches
2022-02-14  8:26 ` [Tarantool-patches] [PATCH v2 0/4] Introduce custom aggregate functions Kirill Yukhin via Tarantool-patches
  -- strict thread matches above, loose matches on Subject: below --
2022-02-01 13:37 [Tarantool-patches] [PATCH v2 0/4] Introduce custom aggregate function Mergen Imeev via Tarantool-patches
2022-02-01 13:37 ` [Tarantool-patches] [PATCH v2 3/4] sql: introduce custom aggregate functions Mergen Imeev via Tarantool-patches
2022-02-03 23:29   ` Vladislav Shpilevoy via Tarantool-patches

Reply instructions:

You may reply publicly to this message via plain-text email
using any one of the following methods:

* Save the following mbox file, import it into your mail client,
  and reply-to-all from there: mbox

  Avoid top-posting and favor interleaved quoting:
  https://en.wikipedia.org/wiki/Posting_style#Interleaved_style

* Reply using the --to, --cc, and --in-reply-to
  switches of git-send-email(1):

  git send-email \
    --in-reply-to=d23ca922ff5dc5c493dcf40a0e74df3747337b58.1644484227.git.imeevma@gmail.com \
    --to=tarantool-patches@dev.tarantool.org \
    --cc=imeevma@tarantool.org \
    --cc=kyukhin@tarantool.org \
    --subject='Re: [Tarantool-patches] [PATCH v2 3/4] sql: introduce custom aggregate functions' \
    /path/to/YOUR_REPLY

  https://kernel.org/pub/software/scm/git/docs/git-send-email.html

* If your mail client supports setting the In-Reply-To header
  via mailto: links, try the mailto: link

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox