[Tarantool-patches] [PATCH v2 3/4] sql: introduce custom aggregate functions
imeevma at tarantool.org
imeevma at tarantool.org
Tue Feb 1 16:37:25 MSK 2022
This patch introduces user-defined aggregate functions in SQL.
Part of #2579
---
src/box/alter.cc | 13 ++
src/box/lua/schema.lua | 2 +-
src/box/sql/expr.c | 17 ++-
src/box/sql/func.c | 14 ++-
src/box/sql/select.c | 31 +++--
test/sql-tap/CMakeLists.txt | 2 +
test/sql-tap/gh-2579-custom-aggregate.c | 28 +++++
.../sql-tap/gh-2579-custom-aggregate.test.lua | 113 ++++++++++++++++++
8 files changed, 200 insertions(+), 20 deletions(-)
create mode 100644 test/sql-tap/gh-2579-custom-aggregate.c
create mode 100755 test/sql-tap/gh-2579-custom-aggregate.test.lua
diff --git a/src/box/alter.cc b/src/box/alter.cc
index 65c1cb952..b85d279e3 100644
--- a/src/box/alter.cc
+++ b/src/box/alter.cc
@@ -3460,6 +3460,13 @@ func_def_new_from_tuple(struct tuple *tuple)
def->name, "invalid aggregate value");
return NULL;
}
+ if (def->aggregate == FUNC_AGGREGATE_GROUP &&
+ def->exports.lua) {
+ diag_set(ClientError, ER_CREATE_FUNCTION, def->name,
+ "aggregate function can only be accessed in "
+ "SQL");
+ return NULL;
+ }
const char *param_list = tuple_field_with_type(tuple,
BOX_FUNC_FIELD_PARAM_LIST, MP_ARRAY);
if (param_list == NULL)
@@ -3482,6 +3489,12 @@ func_def_new_from_tuple(struct tuple *tuple)
return NULL;
}
}
+ if (def->aggregate == FUNC_AGGREGATE_GROUP && argc == 0) {
+ diag_set(ClientError, ER_CREATE_FUNCTION, def->name,
+ "aggregate function must have at least one "
+ "argument");
+ return NULL;
+ }
def->param_count = argc;
const char *opts = tuple_field(tuple, BOX_FUNC_FIELD_OPTS);
if (opts_decode(&def->opts, func_opts_reg, &opts,
diff --git a/src/box/lua/schema.lua b/src/box/lua/schema.lua
index 2c02949c5..23d6d0f64 100644
--- a/src/box/lua/schema.lua
+++ b/src/box/lua/schema.lua
@@ -2603,7 +2603,7 @@ box.schema.func.create = function(name, opts)
language = 'string', body = 'string',
is_deterministic = 'boolean',
is_sandboxed = 'boolean',
- is_multikey = 'boolean',
+ is_multikey = 'boolean', aggregate = 'string',
takes_raw_args = 'boolean',
comment = 'string',
param_list = 'table', returns = 'string',
diff --git a/src/box/sql/expr.c b/src/box/sql/expr.c
index eb169aeb8..920ad9d08 100644
--- a/src/box/sql/expr.c
+++ b/src/box/sql/expr.c
@@ -5469,6 +5469,17 @@ analyzeAggregate(Walker * pWalker, Expr * pExpr)
(pExpr, EP_xIsSelect));
pItem = &pAggInfo->aFunc[i];
pItem->pExpr = pExpr;
+ int n = pExpr->x.pList == NULL ?
+ 0 : pExpr->x.pList->nExpr;
+ /*
+ * Allocate n MEMs for arguments
+ * and one more MEM for
+ * accumulator. This makes it
+ * easier to pass these n + 1
+ * MEMs to the user-defined
+ * aggregate function.
+ */
+ pParse->nMem += n;
pItem->iMem = ++pParse->nMem;
assert(!ExprHasProperty
(pExpr, EP_IntValue));
@@ -5479,12 +5490,6 @@ analyzeAggregate(Walker * pWalker, Expr * pExpr)
true;
return WRC_Abort;
}
- assert(pItem->func->def->
- language ==
- FUNC_LANGUAGE_SQL_BUILTIN &&
- pItem->func->def->
- aggregate ==
- FUNC_AGGREGATE_GROUP);
if (pExpr->flags & EP_Distinct) {
pItem->iDistinct =
pParse->nTab++;
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index b69bf7fd6..cda872194 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -2060,9 +2060,12 @@ sql_func_find(struct Expr *expr)
return NULL;
}
int n = expr->x.pList != NULL ? expr->x.pList->nExpr : 0;
- if (func->def->param_count != n) {
+ int argc = func->def->aggregate == FUNC_AGGREGATE_GROUP ?
+ func->def->param_count - 1 : func->def->param_count;
+ assert(argc >= 0);
+ if (argc != n) {
diag_set(ClientError, ER_FUNC_WRONG_ARG_COUNT, name,
- tt_sprintf("%d", func->def->param_count), n);
+ tt_sprintf("%d", argc), n);
return NULL;
}
return func;
@@ -2072,9 +2075,12 @@ uint32_t
sql_func_flags(const char *name)
{
struct sql_func_dictionary *dict = built_in_func_get(name);
- if (dict == NULL)
+ if (dict != NULL)
+ return dict->flags;
+ struct func *func = func_by_name(name, strlen(name));
+ if (func == NULL || func->def->aggregate != FUNC_AGGREGATE_GROUP)
return 0;
- return dict->flags;
+ return SQL_FUNC_AGG;
}
static struct func_vtab func_sql_builtin_vtab;
diff --git a/src/box/sql/select.c b/src/box/sql/select.c
index 6159a9670..dcea48c6e 100644
--- a/src/box/sql/select.c
+++ b/src/box/sql/select.c
@@ -4648,8 +4648,6 @@ is_simple_count(struct Select *select, struct AggInfo *agg_info)
return NULL;
if (NEVER(agg_info->nFunc == 0))
return NULL;
- assert(agg_info->aFunc->func->def->language ==
- FUNC_LANGUAGE_SQL_BUILTIN);
if (strcmp(agg_info->aFunc->func->def->name, "COUNT") != 0 ||
(agg_info->aFunc->pExpr->x.pList != NULL &&
agg_info->aFunc->pExpr->x.pList->nExpr > 0))
@@ -5562,6 +5560,15 @@ resetAccumulator(Parse * pParse, AggInfo * pAggInfo)
}
}
+static inline void
+finalize_agg_function(struct Vdbe *vdbe, const struct AggInfo_func *agg_func)
+{
+ if (agg_func->func->def->language == FUNC_LANGUAGE_SQL_BUILTIN) {
+ sqlVdbeAddOp1(vdbe, OP_AggFinal, agg_func->iMem);
+ sqlVdbeAppendP4(vdbe, agg_func->func, P4_FUNC);
+ }
+}
+
/*
* Invoke the OP_AggFinalize opcode for every aggregate function
* in the AggInfo structure.
@@ -5574,8 +5581,7 @@ finalizeAggFunctions(Parse * pParse, AggInfo * pAggInfo)
struct AggInfo_func *pF;
for (i = 0, pF = pAggInfo->aFunc; i < pAggInfo->nFunc; i++, pF++) {
assert(!ExprHasProperty(pF->pExpr, EP_xIsSelect));
- sqlVdbeAddOp1(v, OP_AggFinal, pF->iMem);
- sqlVdbeAppendP4(v, pF->func, P4_FUNC);
+ finalize_agg_function(v, pF);
}
}
@@ -5602,7 +5608,7 @@ updateAccumulator(Parse * pParse, AggInfo * pAggInfo)
assert(!ExprHasProperty(pF->pExpr, EP_xIsSelect));
if (pList) {
nArg = pList->nExpr;
- regAgg = sqlGetTempRange(pParse, nArg);
+ regAgg = pF->iMem - nArg;
sqlExprCodeExprList(pParse, pList, regAgg, 0,
SQL_ECEL_DUP);
} else {
@@ -5642,10 +5648,17 @@ updateAccumulator(Parse * pParse, AggInfo * pAggInfo)
pParse->is_aborted = true;
return;
}
- sqlVdbeAddOp3(v, OP_AggStep, nArg, regAgg, pF->iMem);
- sqlVdbeAppendP4(v, ctx, P4_FUNCCTX);
- sql_expr_type_cache_change(pParse, regAgg, nArg);
- sqlReleaseTempRange(pParse, regAgg, nArg);
+ if (pF->func->def->language == FUNC_LANGUAGE_SQL_BUILTIN) {
+ sqlVdbeAddOp3(v, OP_AggStep, nArg, regAgg, pF->iMem);
+ sqlVdbeAppendP4(v, ctx, P4_FUNCCTX);
+ } else {
+ const char *name = pF->func->def->name;
+ uint32_t len = pF->func->def->name_len;
+ const char *str = sqlDbStrNDup(pParse->db, name, len);
+ assert(regAgg == pF->iMem - nArg);
+ sqlVdbeAddOp4(v, OP_FunctionByName, nArg + 1, regAgg,
+ pF->iMem, str, P4_DYNAMIC);
+ }
if (addrNext) {
sqlVdbeResolveLabel(v, addrNext);
sqlExprCacheClear(pParse);
diff --git a/test/sql-tap/CMakeLists.txt b/test/sql-tap/CMakeLists.txt
index c4ec1214a..136a517d4 100644
--- a/test/sql-tap/CMakeLists.txt
+++ b/test/sql-tap/CMakeLists.txt
@@ -3,6 +3,8 @@ build_module(gh-5938-wrong-string-length gh-5938-wrong-string-length.c)
target_link_libraries(gh-5938-wrong-string-length msgpuck)
build_module(gh-6024-funcs-return-bin gh-6024-funcs-return-bin.c)
target_link_libraries(gh-6024-funcs-return-bin msgpuck)
+build_module(gh-2579-custom-aggregate gh-2579-custom-aggregate.c)
+target_link_libraries(gh-2579-custom-aggregate msgpuck)
build_module(sql_uuid sql_uuid.c)
target_link_libraries(sql_uuid msgpuck core)
build_module(decimal decimal.c)
diff --git a/test/sql-tap/gh-2579-custom-aggregate.c b/test/sql-tap/gh-2579-custom-aggregate.c
new file mode 100644
index 000000000..f7d8a70a4
--- /dev/null
+++ b/test/sql-tap/gh-2579-custom-aggregate.c
@@ -0,0 +1,28 @@
+#include "msgpuck.h"
+#include "module.h"
+
+enum {
+ BUF_SIZE = 512,
+};
+
+int
+f3(box_function_ctx_t *ctx, const char *args, const char *args_end)
+{
+ (void)args_end;
+ uint32_t arg_count = mp_decode_array(&args);
+ if (arg_count != 2) {
+ return box_error_set(__FILE__, __LINE__, ER_PROC_C,
+ "invalid argument count");
+ }
+ int num = mp_decode_uint(&args);
+ int sum = 0;
+ if (mp_typeof(*args) != MP_UINT)
+ mp_decode_nil(&args);
+ else
+ sum = mp_decode_uint(&args);
+ sum += num * num;
+ char res[BUF_SIZE];
+ char *end = mp_encode_uint(res, sum);
+ box_return_mp(ctx, res, end);
+ return 0;
+}
diff --git a/test/sql-tap/gh-2579-custom-aggregate.test.lua b/test/sql-tap/gh-2579-custom-aggregate.test.lua
new file mode 100755
index 000000000..213e2e870
--- /dev/null
+++ b/test/sql-tap/gh-2579-custom-aggregate.test.lua
@@ -0,0 +1,113 @@
+#!/usr/bin/env tarantool
+local build_path = os.getenv("BUILDDIR")
+package.cpath = build_path..'/test/sql-tap/?.so;'..build_path..'/test/sql-tap/?.dylib;'..package.cpath
+
+local test = require("sqltester")
+test:plan(5)
+
+test:execsql([[
+ CREATE TABLE t (i INT PRIMARY KEY);
+ INSERT INTO t VALUES(1), (2), (3), (4), (5);
+ ]])
+
+-- Make sure that persistent aggregate functions work as intended.
+box.schema.func.create("F1", {
+ language = "Lua",
+ body = [[
+ function(x, state)
+ if state == nil then
+ state = {sum = 0, count = 0}
+ end
+ state.sum = state.sum + x
+ state.count = state.count + 1
+ return state
+ end
+ ]],
+ param_list = {"integer", "map"},
+ returns = "map",
+ aggregate = "group",
+ exports = {"SQL"},
+})
+
+test:do_execsql_test(
+ "gh-2579-1",
+ [[
+ SELECT f1(i) from t;
+ ]], {
+ {sum = 15, count = 5}
+ })
+
+-- Make sure that non-persistent aggregate functions work as intended.
+local f2 = function(x, state)
+ if state == nil then
+ state = {}
+ end
+ table.insert(state, x)
+ return state
+end
+
+rawset(_G, 'F2', f2)
+
+box.schema.func.create("F2", {
+ language = "Lua",
+ param_list = {"integer", "array"},
+ returns = "array",
+ aggregate = "group",
+ exports = {"SQL"},
+})
+
+test:do_execsql_test(
+ "gh-2579-2",
+ [[
+ SELECT f2(i) from t;
+ ]], {
+ {1, 2, 3, 4, 5}
+ })
+
+-- Make sure that C aggregate functions work as intended.
+box.schema.func.create("gh-2579-custom-aggregate.f3", {
+ language = "C",
+ param_list = {"integer", "integer"},
+ returns = "integer",
+ aggregate = "group",
+ exports = {"SQL"},
+})
+
+test:do_execsql_test(
+ "gh-2579-3",
+ [[
+ SELECT "gh-2579-custom-aggregate.f3"(i) from t;
+ ]], {
+ 55
+ })
+
+-- Make sure aggregate functions can't be called in Lua.
+test:do_test(
+ "gh-2579-4",
+ function()
+ local def = {aggregate = 'group', exports = {'LUA', 'SQL'}}
+ local res = {pcall(box.schema.func.create, 'F4', def)}
+ return {tostring(res[2])}
+ end, {
+ "Failed to create function 'F4': aggregate function can only be "..
+ "accessed in SQL"
+ })
+
+-- Make sure aggregate functions can't have less that 1 argument.
+test:do_test(
+ "gh-2579-5",
+ function()
+ local def = {aggregate = 'group', exports = {'SQL'}}
+ local res = {pcall(box.schema.func.create, 'F4', def)}
+ return {tostring(res[2])}
+ end, {
+ "Failed to create function 'F4': aggregate function must have at "..
+ "least one argument"
+ })
+
+box.schema.func.drop('gh-2579-custom-aggregate.f3')
+box.schema.func.drop('F2')
+box.schema.func.drop('F1')
+test:execsql([[DROP TABLE t;]])
+
+test:finish_test()
--
2.25.1
More information about the Tarantool-patches
mailing list