[Tarantool-patches] [PATCH v1 6/8] sql: rework POSITION() function
Mergen Imeev
imeevma at tarantool.org
Wed Oct 20 20:08:24 MSK 2021
Thank you for the review! My answers, diff and new patch below. Also, I replaced
self-created functions by ucnv_getNextUChar().
On Fri, Oct 08, 2021 at 11:58:32PM +0200, Vladislav Shpilevoy wrote:
> Thanks for the patch!
>
> See 3 comments below.
>
> > diff --git a/src/box/sql/func.c b/src/box/sql/func.c
> > index 1d1a8b0cd..415a92738 100644
> > --- a/src/box/sql/func.c
> > +++ b/src/box/sql/func.c
> > @@ -530,6 +530,68 @@ func_trim_str(struct sql_context *ctx, int argc, struct Mem *argv)
> > ctx->is_aborted = true;
> > }
> >
> > +/** Implementation of the POSITION() function. */
> > +static void
> > +func_position_octets(struct sql_context *ctx, int argc, struct Mem *argv)
> > +{
> > + assert(argc == 2);
> > + (void)argc;
> > + if (mem_is_null(&argv[0]) || mem_is_null(&argv[1]))
>
> 1. There is mem_is_any_null(). The same in the next function.
>
Thanks, fixed.
> > + return;
> > + assert(mem_is_bytes(&argv[0]) && mem_is_bytes(&argv[1]));
> > +
> > + const char *key = argv[0].z;
> > + const char *str = argv[1].z;
> > + int key_size = argv[0].n;
> > + int str_size = argv[1].n;
> > + if (key_size <= 0)
> > + return mem_set_uint(ctx->pOut, 1);
> > + /* Matching time O(n * m). */
> > + for (int i = 0; i <= str_size - key_size; ++i) {
> > + if (memcmp(&str[i], key, key_size) == 0)
> > + return mem_set_uint(ctx->pOut, i + 1);
> > + }
>
> 2. There is memmem().
>
Thanks, fixed.
> > + return mem_set_uint(ctx->pOut, 0);
> > +}
> > diff --git a/test/sql-tap/position.test.lua b/test/sql-tap/position.test.lua
> > index 6a96ed9bc..5f62c7f54 100755
> > --- a/test/sql-tap/position.test.lua
> > +++ b/test/sql-tap/position.test.lua
> > @@ -858,4 +858,14 @@ test:do_catchsql_test(
> > }
> > )
> >
> > +-- gh-4145: Make sure that POSITION() can wirk with VARBINARY.
>
> 3. wirk -> work.
>
Fixed.
> > +test:do_execsql_test(
> > + "position-2",
> > + [[
> > + SELECT POSITION(x'313233', x'30313231323334353132333435');
> > + ]], {
> > + 4
> > + }
> > +)
> > +
> > test:finish_test()
> >
Diff:
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index d145e9cc0..80b075dcf 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -522,7 +522,7 @@ func_position_octets(struct sql_context *ctx, int argc, struct Mem *argv)
{
assert(argc == 2);
(void)argc;
- if (mem_is_null(&argv[0]) || mem_is_null(&argv[1]))
+ if (mem_is_any_null(&argv[0], &argv[1]))
return;
assert(mem_is_bytes(&argv[0]) && mem_is_bytes(&argv[1]));
@@ -532,12 +532,8 @@ func_position_octets(struct sql_context *ctx, int argc, struct Mem *argv)
int str_size = argv[1].n;
if (key_size <= 0)
return mem_set_uint(ctx->pOut, 1);
- /* Matching time O(n * m). */
- for (int i = 0; i <= str_size - key_size; ++i) {
- if (memcmp(&str[i], key, key_size) == 0)
- return mem_set_uint(ctx->pOut, i + 1);
- }
- return mem_set_uint(ctx->pOut, 0);
+ const char *pos = memmem(str, str_size, key, key_size);
+ return mem_set_uint(ctx->pOut, pos == NULL ? 0 : pos - str + 1);
}
static void
@@ -545,7 +541,7 @@ func_position_characters(struct sql_context *ctx, int argc, struct Mem *argv)
{
assert(argc == 2);
(void)argc;
- if (mem_is_null(&argv[0]) || mem_is_null(&argv[1]))
+ if (mem_is_any_null(&argv[0], &argv[1]))
return;
assert(mem_is_str(&argv[0]) && mem_is_str(&argv[1]));
@@ -555,26 +551,29 @@ func_position_characters(struct sql_context *ctx, int argc, struct Mem *argv)
int str_size = argv[1].n;
if (key_size <= 0)
return mem_set_uint(ctx->pOut, 1);
- int key_len = utf8_len_str(key, key_size);
- int start = 0;
- int end = 0;
- for (int i = 0; i < key_len && end <= str_size; ++i)
- end += utf8_len_char(str[end]);
- if (end > str_size)
- return mem_set_uint(ctx->pOut, 0);
+ UErrorCode err = U_ZERO_ERROR;
+ const char *pos = str;
+ const char *cur = str;
+ const char *end = str + str_size;
+ const char *tmp_pos = key;
+ const char *tmp_end = key + key_size;
+ assert(icu_utf8_conv != NULL);
+ while (tmp_pos < tmp_end && err == U_ZERO_ERROR) {
+ ucnv_getNextUChar(icu_utf8_conv, &tmp_pos, tmp_end, &err);
+ ucnv_getNextUChar(icu_utf8_conv, &cur, end, &err);
+ }
+
int i = 0;
- while (end <= str_size) {
+ while (err == U_ZERO_ERROR) {
struct coll *coll = ctx->coll;
- const char *s = &str[start];
- if (coll->cmp(key, key_size, s, end - start, coll) == 0)
+ if (coll->cmp(key, key_size, pos, cur - pos, coll) == 0)
return mem_set_uint(ctx->pOut, i + 1);
- start += utf8_len_char(str[start]);
- if (end == str_size)
- break;
- end += utf8_len_char(str[end]);
+ ucnv_getNextUChar(icu_utf8_conv, &pos, end, &err);
+ ucnv_getNextUChar(icu_utf8_conv, &cur, end, &err);
++i;
}
+ assert(err == U_INDEX_OUTOFBOUNDS_ERROR && cur == end);
return mem_set_uint(ctx->pOut, 0);
}
diff --git a/test/sql-tap/position.test.lua b/test/sql-tap/position.test.lua
index 5f62c7f54..e49f4665a 100755
--- a/test/sql-tap/position.test.lua
+++ b/test/sql-tap/position.test.lua
@@ -858,7 +858,7 @@ test:do_catchsql_test(
}
)
--- gh-4145: Make sure that POSITION() can wirk with VARBINARY.
+-- gh-4145: Make sure POSITION() can work with VARBINARY.
test:do_execsql_test(
"position-2",
[[
New patch:
commit cdc02ef02866bdc603f8389e09d3ac0078c1e782
Author: Mergen Imeev <imeevma at gmail.com>
Date: Wed Sep 22 14:36:40 2021 +0300
sql: rework POSITION() function
This patch is a refactoring of POSITION(). In addition, VARBINARY
arguments can now be used in this function. In addition, POSITION() now
uses ICU functions instead of self-created.
Part of #4145
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index 1294ff5b3..80b075dcf 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -516,6 +516,67 @@ func_trim_str(struct sql_context *ctx, int argc, struct Mem *argv)
ctx->is_aborted = true;
}
+/** Implementation of the POSITION() function. */
+static void
+func_position_octets(struct sql_context *ctx, int argc, struct Mem *argv)
+{
+ assert(argc == 2);
+ (void)argc;
+ if (mem_is_any_null(&argv[0], &argv[1]))
+ return;
+ assert(mem_is_bytes(&argv[0]) && mem_is_bytes(&argv[1]));
+
+ const char *key = argv[0].z;
+ const char *str = argv[1].z;
+ int key_size = argv[0].n;
+ int str_size = argv[1].n;
+ if (key_size <= 0)
+ return mem_set_uint(ctx->pOut, 1);
+ const char *pos = memmem(str, str_size, key, key_size);
+ return mem_set_uint(ctx->pOut, pos == NULL ? 0 : pos - str + 1);
+}
+
+static void
+func_position_characters(struct sql_context *ctx, int argc, struct Mem *argv)
+{
+ assert(argc == 2);
+ (void)argc;
+ if (mem_is_any_null(&argv[0], &argv[1]))
+ return;
+ assert(mem_is_str(&argv[0]) && mem_is_str(&argv[1]));
+
+ const char *key = argv[0].z;
+ const char *str = argv[1].z;
+ int key_size = argv[0].n;
+ int str_size = argv[1].n;
+ if (key_size <= 0)
+ return mem_set_uint(ctx->pOut, 1);
+
+ UErrorCode err = U_ZERO_ERROR;
+ const char *pos = str;
+ const char *cur = str;
+ const char *end = str + str_size;
+ const char *tmp_pos = key;
+ const char *tmp_end = key + key_size;
+ assert(icu_utf8_conv != NULL);
+ while (tmp_pos < tmp_end && err == U_ZERO_ERROR) {
+ ucnv_getNextUChar(icu_utf8_conv, &tmp_pos, tmp_end, &err);
+ ucnv_getNextUChar(icu_utf8_conv, &cur, end, &err);
+ }
+
+ int i = 0;
+ while (err == U_ZERO_ERROR) {
+ struct coll *coll = ctx->coll;
+ if (coll->cmp(key, key_size, pos, cur - pos, coll) == 0)
+ return mem_set_uint(ctx->pOut, i + 1);
+ ucnv_getNextUChar(icu_utf8_conv, &pos, end, &err);
+ ucnv_getNextUChar(icu_utf8_conv, &cur, end, &err);
+ ++i;
+ }
+ assert(err == U_INDEX_OUTOFBOUNDS_ERROR && cur == end);
+ return mem_set_uint(ctx->pOut, 0);
+}
+
static const unsigned char *
mem_as_ustr(struct Mem *mem)
{
@@ -679,141 +740,6 @@ lengthFunc(struct sql_context *context, int argc, struct Mem *argv)
}
}
-/**
- * Implementation of the position() function.
- *
- * position(needle, haystack) finds the first occurrence of needle
- * in haystack and returns the number of previous characters
- * plus 1, or 0 if needle does not occur within haystack.
- *
- * If both haystack and needle are BLOBs, then the result is one
- * more than the number of bytes in haystack prior to the first
- * occurrence of needle, or 0 if needle never occurs in haystack.
- */
-static void
-position_func(struct sql_context *context, int argc, struct Mem *argv)
-{
- UNUSED_PARAMETER(argc);
- struct Mem *needle = &argv[0];
- struct Mem *haystack = &argv[1];
- enum mp_type needle_type = sql_value_type(needle);
- enum mp_type haystack_type = sql_value_type(haystack);
-
- if (haystack_type == MP_NIL || needle_type == MP_NIL)
- return;
- /*
- * Position function can be called only with string
- * or blob params.
- */
- struct Mem *inconsistent_type_arg = NULL;
- if (needle_type != MP_STR && needle_type != MP_BIN)
- inconsistent_type_arg = needle;
- if (haystack_type != MP_STR && haystack_type != MP_BIN)
- inconsistent_type_arg = haystack;
- if (inconsistent_type_arg != NULL) {
- diag_set(ClientError, ER_INCONSISTENT_TYPES,
- "string or varbinary", mem_str(inconsistent_type_arg));
- context->is_aborted = true;
- return;
- }
- /*
- * Both params of Position function must be of the same
- * type.
- */
- if (haystack_type != needle_type) {
- diag_set(ClientError, ER_INCONSISTENT_TYPES,
- mem_type_to_str(needle), mem_str(haystack));
- context->is_aborted = true;
- return;
- }
-
- int n_needle_bytes = mem_len_unsafe(needle);
- int n_haystack_bytes = mem_len_unsafe(haystack);
- int position = 1;
- if (n_needle_bytes > 0) {
- const unsigned char *haystack_str;
- const unsigned char *needle_str;
- if (haystack_type == MP_BIN) {
- needle_str = mem_as_bin(needle);
- haystack_str = mem_as_bin(haystack);
- assert(needle_str != NULL);
- assert(haystack_str != NULL || n_haystack_bytes == 0);
- /*
- * Naive implementation of substring
- * searching: matching time O(n * m).
- * Can be improved.
- */
- while (n_needle_bytes <= n_haystack_bytes &&
- memcmp(haystack_str, needle_str, n_needle_bytes) != 0) {
- position++;
- n_haystack_bytes--;
- haystack_str++;
- }
- if (n_needle_bytes > n_haystack_bytes)
- position = 0;
- } else {
- /*
- * Code below handles not only simple
- * cases like position('a', 'bca'), but
- * also more complex ones:
- * position('a', 'bcá' COLLATE "unicode_ci")
- * To do so, we need to use comparison
- * window, which has constant character
- * size, but variable byte size.
- * Character size is equal to
- * needle char size.
- */
- haystack_str = mem_as_ustr(haystack);
- needle_str = mem_as_ustr(needle);
-
- int n_needle_chars =
- sql_utf8_char_count(needle_str, n_needle_bytes);
- int n_haystack_chars =
- sql_utf8_char_count(haystack_str,
- n_haystack_bytes);
-
- if (n_haystack_chars < n_needle_chars) {
- position = 0;
- goto finish;
- }
- /*
- * Comparison window is determined by
- * beg_offset and end_offset. beg_offset
- * is offset in bytes from haystack
- * beginning to window beginning.
- * end_offset is offset in bytes from
- * haystack beginning to window end.
- */
- int end_offset = 0;
- for (int c = 0; c < n_needle_chars; c++) {
- SQL_UTF8_FWD_1(haystack_str, end_offset,
- n_haystack_bytes);
- }
- int beg_offset = 0;
- struct coll *coll = context->coll;
- int c;
- for (c = 0; c + n_needle_chars <= n_haystack_chars; c++) {
- if (coll->cmp((const char *) haystack_str + beg_offset,
- end_offset - beg_offset,
- (const char *) needle_str,
- n_needle_bytes, coll) == 0)
- goto finish;
- position++;
- /* Update offsets. */
- SQL_UTF8_FWD_1(haystack_str, beg_offset,
- n_haystack_bytes);
- SQL_UTF8_FWD_1(haystack_str, end_offset,
- n_haystack_bytes);
- }
- /* Needle was not found in the haystack. */
- position = 0;
- }
- }
-finish:
- assert(position >= 0);
- sql_result_uint(context, position);
-}
-
/*
* Implementation of the printf() function.
*/
@@ -1989,7 +1915,9 @@ static struct sql_func_definition definitions[] = {
{"NULLIF", 2, {FIELD_TYPE_ANY, FIELD_TYPE_ANY}, FIELD_TYPE_SCALAR,
func_nullif, NULL},
{"POSITION", 2, {FIELD_TYPE_STRING, FIELD_TYPE_STRING},
- FIELD_TYPE_INTEGER, position_func, NULL},
+ FIELD_TYPE_INTEGER, func_position_characters, NULL},
+ {"POSITION", 2, {FIELD_TYPE_VARBINARY, FIELD_TYPE_VARBINARY},
+ FIELD_TYPE_INTEGER, func_position_octets, NULL},
{"PRINTF", -1, {FIELD_TYPE_ANY}, FIELD_TYPE_STRING, printfFunc,
NULL},
{"QUOTE", 1, {FIELD_TYPE_ANY}, FIELD_TYPE_STRING, quoteFunc, NULL},
diff --git a/test/sql-tap/position.test.lua b/test/sql-tap/position.test.lua
index 6a96ed9bc..e49f4665a 100755
--- a/test/sql-tap/position.test.lua
+++ b/test/sql-tap/position.test.lua
@@ -1,6 +1,6 @@
#!/usr/bin/env tarantool
local test = require("sqltester")
-test:plan(80)
+test:plan(81)
test:do_test(
"position-1.1",
@@ -305,130 +305,130 @@ test:do_test(
test:do_test(
"position-1.31",
function()
- return test:catchsql "SELECT position(x'01', x'0102030405');"
+ return test:execsql "SELECT position(x'01', x'0102030405');"
end, {
-- <position-1.31>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 1
-- </position-1.31>
})
test:do_test(
"position-1.32",
function()
- return test:catchsql "SELECT position(x'02', x'0102030405');"
+ return test:execsql "SELECT position(x'02', x'0102030405');"
end, {
-- <position-1.32>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 2
-- </position-1.32>
})
test:do_test(
"position-1.33",
function()
- return test:catchsql "SELECT position(x'03', x'0102030405');"
+ return test:execsql "SELECT position(x'03', x'0102030405');"
end, {
-- <position-1.33>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 3
-- </position-1.33>
})
test:do_test(
"position-1.34",
function()
- return test:catchsql "SELECT position(x'04', x'0102030405');"
+ return test:execsql "SELECT position(x'04', x'0102030405');"
end, {
-- <position-1.34>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 4
-- </position-1.34>
})
test:do_test(
"position-1.35",
function()
- return test:catchsql "SELECT position(x'05', x'0102030405');"
+ return test:execsql "SELECT position(x'05', x'0102030405');"
end, {
-- <position-1.35>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 5
-- </position-1.35>
})
test:do_test(
"position-1.36",
function()
- return test:catchsql "SELECT position(x'06', x'0102030405');"
+ return test:execsql "SELECT position(x'06', x'0102030405');"
end, {
-- <position-1.36>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 0
-- </position-1.36>
})
test:do_test(
"position-1.37",
function()
- return test:catchsql "SELECT position(x'0102030405', x'0102030405');"
+ return test:execsql "SELECT position(x'0102030405', x'0102030405');"
end, {
-- <position-1.37>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 1
-- </position-1.37>
})
test:do_test(
"position-1.38",
function()
- return test:catchsql "SELECT position(x'02030405', x'0102030405');"
+ return test:execsql "SELECT position(x'02030405', x'0102030405');"
end, {
-- <position-1.38>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 2
-- </position-1.38>
})
test:do_test(
"position-1.39",
function()
- return test:catchsql "SELECT position(x'030405', x'0102030405');"
+ return test:execsql "SELECT position(x'030405', x'0102030405');"
end, {
-- <position-1.39>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 3
-- </position-1.39>
})
test:do_test(
"position-1.40",
function()
- return test:catchsql "SELECT position(x'0405', x'0102030405');"
+ return test:execsql "SELECT position(x'0405', x'0102030405');"
end, {
-- <position-1.40>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 4
-- </position-1.40>
})
test:do_test(
"position-1.41",
function()
- return test:catchsql "SELECT position(x'0506', x'0102030405');"
+ return test:execsql "SELECT position(x'0506', x'0102030405');"
end, {
-- <position-1.41>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 0
-- </position-1.41>
})
test:do_test(
"position-1.42",
function()
- return test:catchsql "SELECT position(x'', x'0102030405');"
+ return test:execsql "SELECT position(x'', x'0102030405');"
end, {
-- <position-1.42>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 1
-- </position-1.42>
})
test:do_test(
"position-1.43",
function()
- return test:catchsql "SELECT position(x'', x'');"
+ return test:execsql "SELECT position(x'', x'');"
end, {
-- <position-1.43>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 1
-- </position-1.43>
})
@@ -571,40 +571,40 @@ test:do_test(
test:do_test(
"position-1.56.1",
function()
- return test:catchsql "SELECT position(x'79', x'78c3a4e282ac79');"
+ return test:execsql "SELECT position(x'79', x'78c3a4e282ac79');"
end, {
-- <position-1.56.1>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 7
-- </position-1.56.1>
})
test:do_test(
"position-1.56.2",
function()
- return test:catchsql "SELECT position(x'7a', x'78c3a4e282ac79');"
+ return test:execsql "SELECT position(x'7a', x'78c3a4e282ac79');"
end, {
-- <position-1.56.2>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 0
-- </position-1.56.2>
})
test:do_test(
"position-1.56.3",
function()
- return test:catchsql "SELECT position(x'78', x'78c3a4e282ac79');"
+ return test:execsql "SELECT position(x'78', x'78c3a4e282ac79');"
end, {
-- <position-1.56.3>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 1
-- </position-1.56.3>
})
test:do_test(
"position-1.56.3",
function()
- return test:catchsql "SELECT position(x'a4', x'78c3a4e282ac79');"
+ return test:execsql "SELECT position(x'a4', x'78c3a4e282ac79');"
end, {
-- <position-1.56.3>
- 1, "Failed to execute SQL statement: wrong arguments for function POSITION()"
+ 3
-- </position-1.56.3>
})
@@ -858,4 +858,14 @@ test:do_catchsql_test(
}
)
+-- gh-4145: Make sure POSITION() can work with VARBINARY.
+test:do_execsql_test(
+ "position-2",
+ [[
+ SELECT POSITION(x'313233', x'30313231323334353132333435');
+ ]], {
+ 4
+ }
+)
+
test:finish_test()
More information about the Tarantool-patches
mailing list