[Tarantool-patches] [PATCH v1 5/8] sql: rework TRIM() function
Mergen Imeev
imeevma at tarantool.org
Mon Nov 1 13:35:09 MSK 2021
Thank you for the review! My answers, diff and new patch below. Also, I replaced
ucnv_getNextUChar() by U8_NEXT().
On Fri, Oct 29, 2021 at 12:12:25AM +0200, Vladislav Shpilevoy wrote:
> Thanks for the fixes!
>
> See 3 comments below.
>
> > diff --git a/src/box/sql/func.c b/src/box/sql/func.c
> > index d36c83501..1294ff5b3 100644
> > --- a/src/box/sql/func.c
> > +++ b/src/box/sql/func.c
> > @@ -344,6 +344,178 @@ func_nullif(struct sql_context *ctx, int argc, struct Mem *argv)
> > ctx->is_aborted = true;
> > }
> >
> > +static inline void
> > +return_empty_str(struct sql_context *ctx, bool is_str)
>
> 1. It is called in a single place with is_str = false. Could we maybe
> inline it?
>
There is couple more places in the next patches, but I agree that this
function is not needed. Dropped.
> > +{
> > + return is_str ? mem_set_str_static(ctx->pOut, "", 0) :
> > + mem_set_bin_static(ctx->pOut, "", 0);
> > +}
>
> ...
>
> > +
> > +static void
> > +func_trim_str(struct sql_context *ctx, int argc, struct Mem *argv)
> > +{
> > + if (mem_is_null(&argv[0]) || (argc == 3 && mem_is_null(&argv[2])))
> > + return;
> > + assert(argc == 2 || (argc == 3 && mem_is_str(&argv[2])));
> > + assert(mem_is_str(&argv[0]) && mem_is_uint(&argv[1]));
> > + const char *str = argv[0].z;
> > + int size = argv[0].n;
> > + const char *chars;
> > + int chars_size;
> > + if (argc == 3) {
> > + chars = argv[2].z;
> > + chars_size = argv[2].n;
> > + } else {
> > + chars = " ";
> > + chars_size = 1;
> > + }
> > +
> > + uint8_t *chars_len = sqlDbMallocRawNN(sql_get(),
> > + chars_size * sizeof(uint8_t));
>
> 2. Could use fiber region here. Up to you.
>
Fixed.
> > + if (chars_len == NULL) {
> > + ctx->is_aborted = true;
> > + return;
> > + }
> > + int chars_count = 0;
> > +
> > + UErrorCode err = U_ZERO_ERROR;
> > + const char *pos_start = chars;
> > + const char *pos_end = chars + chars_size;
> > + while (pos_start < pos_end) {
> > + const char *cur = pos_start;
> > + ucnv_getNextUChar(icu_utf8_conv, &pos_start, pos_end, &err);
> > + chars_len[chars_count++] = pos_start - cur;
> > + }
> > +
> > + uint64_t flags = argv[1].u.u;
> > + int end = trim_str_end(str, size, chars, chars_len, chars_count, flags);
> > + int start = trim_str_start(str, end, chars, chars_len, chars_count,
> > + flags);
>
> 3. The second line of the call is misaligned a bit.
Fixed.
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index 0ad6ac966..ba6b9246d 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -343,13 +343,6 @@ func_nullif(struct sql_context *ctx, int argc, struct Mem *argv)
ctx->is_aborted = true;
}
-static inline void
-return_empty_str(struct sql_context *ctx, bool is_str)
-{
- return is_str ? mem_set_str_static(ctx->pOut, "", 0) :
- mem_set_bin_static(ctx->pOut, "", 0);
-}
-
/** Implementation of the TRIM() function. */
static inline int
trim_bin_end(const char *str, int end, const char *octets, int octets_size,
@@ -412,7 +405,7 @@ func_trim_bin(struct sql_context *ctx, int argc, struct Mem *argv)
int start = trim_bin_start(str, end, octets, octets_size, flags);
if (start >= end)
- return return_empty_str(ctx, false);
+ return mem_set_bin_static(ctx->pOut, "", 0);
if (mem_copy_bin(ctx->pOut, &str[start], end - start) != 0)
ctx->is_aborted = true;
}
@@ -486,28 +479,29 @@ func_trim_str(struct sql_context *ctx, int argc, struct Mem *argv)
chars_size = 1;
}
- uint8_t *chars_len = sqlDbMallocRawNN(sql_get(),
- chars_size * sizeof(uint8_t));
+ struct region *region = &fiber()->gc;
+ size_t svp = region_used(region);
+ uint8_t *chars_len = region_alloc(region, chars_size);
if (chars_len == NULL) {
ctx->is_aborted = true;
+ diag_set(OutOfMemory, chars_size, "region_alloc", "chars_len");
return;
}
int chars_count = 0;
- UErrorCode err = U_ZERO_ERROR;
- const char *pos_start = chars;
- const char *pos_end = chars + chars_size;
- while (pos_start < pos_end) {
- const char *cur = pos_start;
- ucnv_getNextUChar(icu_utf8_conv, &pos_start, pos_end, &err);
- chars_len[chars_count++] = pos_start - cur;
+ int offset = 0;
+ while (offset < chars_size) {
+ UChar32 c;
+ int prev = offset;
+ U8_NEXT((uint8_t *)chars, offset, chars_size, c);
+ chars_len[chars_count++] = offset - prev;
}
uint64_t flags = argv[1].u.u;
int end = trim_str_end(str, size, chars, chars_len, chars_count, flags);
int start = trim_str_start(str, end, chars, chars_len, chars_count,
- flags);
- sqlDbFree(sql_get(), chars_len);
+ flags);
+ region_truncate(region, svp);
if (start >= end)
return mem_set_str0_static(ctx->pOut, "");
New patch:
commit 2e74fe281daee22a67d9ec61659ccee569f7fd65
Author: Mergen Imeev <imeevma at gmail.com>
Date: Tue Sep 21 19:45:36 2021 +0300
sql: rework TRIM() function
This patch refactoring TRIM() and fixes an issue with incorrect trimming
of some VARBINARY values. Also, TRIM() now use ICU functions instead of
self-created.
Part of #4415
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index 7d54a39cd..ba6b9246d 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -343,6 +343,172 @@ func_nullif(struct sql_context *ctx, int argc, struct Mem *argv)
ctx->is_aborted = true;
}
+/** Implementation of the TRIM() function. */
+static inline int
+trim_bin_end(const char *str, int end, const char *octets, int octets_size,
+ int flags)
+{
+ if ((flags & TRIM_TRAILING) == 0)
+ return end;
+ while (end > 0) {
+ bool is_trimmed = false;
+ char c = str[end - 1];
+ for (int i = 0; i < octets_size && !is_trimmed; ++i)
+ is_trimmed = c == octets[i];
+ if (!is_trimmed)
+ break;
+ --end;
+ }
+ return end;
+}
+
+static inline int
+trim_bin_start(const char *str, int end, const char *octets, int octets_size,
+ int flags)
+{
+ if ((flags & TRIM_LEADING) == 0)
+ return 0;
+ int start = 0;
+ while (start < end) {
+ bool is_trimmed = false;
+ char c = str[start];
+ for (int i = 0; i < octets_size && !is_trimmed; ++i)
+ is_trimmed = c == octets[i];
+ if (!is_trimmed)
+ break;
+ ++start;
+ }
+ return start;
+}
+
+static void
+func_trim_bin(struct sql_context *ctx, int argc, struct Mem *argv)
+{
+ if (mem_is_null(&argv[0]) || (argc == 3 && mem_is_null(&argv[2])))
+ return;
+ assert(argc == 2 || (argc == 3 && mem_is_bin(&argv[2])));
+ assert(mem_is_bin(&argv[0]) && mem_is_uint(&argv[1]));
+ const char *str = argv[0].z;
+ int size = argv[0].n;
+ const char *octets;
+ int octets_size;
+ if (argc == 3) {
+ octets = argv[2].z;
+ octets_size = argv[2].n;
+ } else {
+ octets = "\0";
+ octets_size = 1;
+ }
+
+ int flags = argv[1].u.u;
+ int end = trim_bin_end(str, size, octets, octets_size, flags);
+ int start = trim_bin_start(str, end, octets, octets_size, flags);
+
+ if (start >= end)
+ return mem_set_bin_static(ctx->pOut, "", 0);
+ if (mem_copy_bin(ctx->pOut, &str[start], end - start) != 0)
+ ctx->is_aborted = true;
+}
+
+static inline int
+trim_str_end(const char *str, int end, const char *chars, uint8_t *chars_len,
+ int chars_count, int flags)
+{
+ if ((flags & TRIM_TRAILING) == 0)
+ return end;
+ while (end > 0) {
+ bool is_trimmed = false;
+ const char *c = chars;
+ int len;
+ for (int i = 0; i < chars_count && !is_trimmed; ++i) {
+ len = chars_len[i];
+ const char *s = str + end - len;
+ is_trimmed = len <= end && memcmp(c, s, len) == 0;
+ c += len;
+ }
+ if (!is_trimmed)
+ break;
+ assert(len > 0);
+ end -= len;
+ }
+ return end;
+}
+
+static inline int
+trim_str_start(const char *str, int end, const char *chars, uint8_t *chars_len,
+ int chars_count, int flags)
+{
+ if ((flags & TRIM_LEADING) == 0)
+ return 0;
+ int start = 0;
+ while (start < end) {
+ bool is_trimmed = false;
+ const char *c = chars;
+ int len;
+ for (int i = 0; i < chars_count && !is_trimmed; ++i) {
+ len = chars_len[i];
+ const char *s = str + start;
+ is_trimmed = start + len <= end &&
+ memcmp(c, s, len) == 0;
+ c += len;
+ }
+ if (!is_trimmed)
+ break;
+ assert(len > 0);
+ start += len;
+ }
+ return start;
+}
+
+static void
+func_trim_str(struct sql_context *ctx, int argc, struct Mem *argv)
+{
+ if (mem_is_null(&argv[0]) || (argc == 3 && mem_is_null(&argv[2])))
+ return;
+ assert(argc == 2 || (argc == 3 && mem_is_str(&argv[2])));
+ assert(mem_is_str(&argv[0]) && mem_is_uint(&argv[1]));
+ const char *str = argv[0].z;
+ int size = argv[0].n;
+ const char *chars;
+ int chars_size;
+ if (argc == 3) {
+ chars = argv[2].z;
+ chars_size = argv[2].n;
+ } else {
+ chars = " ";
+ chars_size = 1;
+ }
+
+ struct region *region = &fiber()->gc;
+ size_t svp = region_used(region);
+ uint8_t *chars_len = region_alloc(region, chars_size);
+ if (chars_len == NULL) {
+ ctx->is_aborted = true;
+ diag_set(OutOfMemory, chars_size, "region_alloc", "chars_len");
+ return;
+ }
+ int chars_count = 0;
+
+ int offset = 0;
+ while (offset < chars_size) {
+ UChar32 c;
+ int prev = offset;
+ U8_NEXT((uint8_t *)chars, offset, chars_size, c);
+ chars_len[chars_count++] = offset - prev;
+ }
+
+ uint64_t flags = argv[1].u.u;
+ int end = trim_str_end(str, size, chars, chars_len, chars_count, flags);
+ int start = trim_str_start(str, end, chars, chars_len, chars_count,
+ flags);
+ region_truncate(region, svp);
+
+ if (start >= end)
+ return mem_set_str0_static(ctx->pOut, "");
+ if (mem_copy_str(ctx->pOut, &str[start], end - start) != 0)
+ ctx->is_aborted = true;
+}
+
static const unsigned char *
mem_as_ustr(struct Mem *mem)
{
@@ -1527,193 +1693,6 @@ replaceFunc(struct sql_context *context, int argc, struct Mem *argv)
mem_set_bin_dynamic(context->pOut, (char *)zOut, j);
}
-/**
- * Remove characters included in @a trim_set from @a input_str
- * until encounter a character that doesn't belong to @a trim_set.
- * Remove from the side specified by @a flags.
- * @param context SQL context.
- * @param flags Trim specification: left, right or both.
- * @param trim_set The set of characters for trimming.
- * @param char_len Lengths of each UTF-8 character in @a trim_set.
- * @param char_cnt A number of UTF-8 characters in @a trim_set.
- * @param input_str Input string for trimming.
- * @param input_str_sz Input string size in bytes.
- */
-static void
-trim_procedure(struct sql_context *context, enum trim_side_mask flags,
- const unsigned char *trim_set, const uint8_t *char_len,
- int char_cnt, const unsigned char *input_str, int input_str_sz)
-{
- if (char_cnt == 0)
- goto finish;
- int i, len;
- const unsigned char *z;
- if ((flags & TRIM_LEADING) != 0) {
- while (input_str_sz > 0) {
- z = trim_set;
- for (i = 0; i < char_cnt; ++i, z += len) {
- len = char_len[i];
- if (len <= input_str_sz
- && memcmp(input_str, z, len) == 0)
- break;
- }
- if (i >= char_cnt)
- break;
- input_str += len;
- input_str_sz -= len;
- }
- }
- if ((flags & TRIM_TRAILING) != 0) {
- while (input_str_sz > 0) {
- z = trim_set;
- for (i = 0; i < char_cnt; ++i, z += len) {
- len = char_len[i];
- if (len <= input_str_sz
- && memcmp(&input_str[input_str_sz - len],
- z, len) == 0)
- break;
- }
- if (i >= char_cnt)
- break;
- input_str_sz -= len;
- }
- }
-finish:
- if (context->func->def->returns == FIELD_TYPE_STRING)
- mem_copy_str(context->pOut, (char *)input_str, input_str_sz);
- else
- mem_copy_bin(context->pOut, (char *)input_str, input_str_sz);
-}
-
-/**
- * Prepare arguments for trimming procedure. Allocate memory for
- * @a char_len (array of lengths each character in @a trim_set)
- * and fill it.
- *
- * @param context SQL context.
- * @param trim_set The set of characters for trimming.
- * @param[out] char_len Lengths of each character in @ trim_set.
- * @retval >=0 A number of UTF-8 characters in @a trim_set.
- * @retval -1 Memory allocation error.
- */
-static int
-trim_prepare_char_len(struct sql_context *context,
- const unsigned char *trim_set, int trim_set_sz,
- uint8_t **char_len)
-{
- /*
- * Count the number of UTF-8 characters passing through
- * the entire char set, but not up to the '\0' or X'00'
- * character. This allows to handle trimming set
- * containing such characters.
- */
- int char_cnt = sql_utf8_char_count(trim_set, trim_set_sz);
- if (char_cnt == 0) {
- *char_len = NULL;
- return 0;
- }
-
- if ((*char_len = (uint8_t *)contextMalloc(context, char_cnt)) == NULL)
- return -1;
-
- int i = 0, j = 0;
- while(j < char_cnt) {
- int old_i = i;
- SQL_UTF8_FWD_1(trim_set, i, trim_set_sz);
- (*char_len)[j++] = i - old_i;
- }
-
- return char_cnt;
-}
-
-/**
- * Normalize args from @a argv input array when it has two args.
- *
- * Case: TRIM(<str>)
- * Call trimming procedure with TRIM_BOTH as the flags and " " as
- * the trimming set.
- *
- * Case: TRIM(LEADING/TRAILING/BOTH FROM <str>)
- * If user has specified side keyword only, then call trimming
- * procedure with the specified side and " " as the trimming set.
- */
-static void
-trim_func_two_args(struct sql_context *context, sql_value *arg1,
- sql_value *arg2)
-{
- const unsigned char *trim_set;
- if (mem_is_bin(arg1))
- trim_set = (const unsigned char *)"\0";
- else
- trim_set = (const unsigned char *)" ";
- const unsigned char *input_str;
- if ((input_str = mem_as_ustr(arg1)) == NULL)
- return;
-
- int input_str_sz = mem_len_unsafe(arg1);
- assert(arg2->type == MEM_TYPE_UINT);
- uint8_t len_one = 1;
- trim_procedure(context, arg2->u.u, trim_set,
- &len_one, 1, input_str, input_str_sz);
-}
-
-/**
- * Normalize args from @a argv input array when it has three args.
- *
- * Case: TRIM(<character_set> FROM <str>)
- * If user has specified <character_set> only, call trimming procedure with
- * TRIM_BOTH as the flags and that trimming set.
- *
- * Case: TRIM(LEADING/TRAILING/BOTH <character_set> FROM <str>)
- * If user has specified side keyword and <character_set>, then
- * call trimming procedure with that args.
- */
-static void
-trim_func_three_args(struct sql_context *context, sql_value *arg1,
- sql_value *arg2, sql_value *arg3)
-{
- assert(arg2->type == MEM_TYPE_UINT);
- const unsigned char *input_str, *trim_set;
- if ((input_str = mem_as_ustr(arg1)) == NULL ||
- (trim_set = mem_as_ustr(arg3)) == NULL)
- return;
-
- int trim_set_sz = mem_len_unsafe(arg3);
- int input_str_sz = mem_len_unsafe(arg1);
- uint8_t *char_len;
- int char_cnt = trim_prepare_char_len(context, trim_set, trim_set_sz,
- &char_len);
- if (char_cnt == -1)
- return;
- trim_procedure(context, arg2->u.u, trim_set, char_len,
- char_cnt, input_str, input_str_sz);
- sql_free(char_len);
-}
-
-/**
- * Normalize args from @a argv input array when it has one,
- * two or three args.
- *
- * This is a dispatcher function that calls corresponding
- * implementation depending on the number of arguments.
-*/
-static void
-trim_func(struct sql_context *context, int argc, struct Mem *argv)
-{
- switch (argc) {
- case 2:
- trim_func_two_args(context, &argv[0], &argv[1]);
- break;
- case 3:
- trim_func_three_args(context, &argv[0], &argv[1], &argv[2]);
- break;
- default:
- diag_set(ClientError, ER_FUNC_WRONG_ARG_COUNT, "TRIM",
- "2 or 3", argc);
- context->is_aborted = true;
- }
-}
-
/*
* Compute the soundex encoding of a word.
*
@@ -2040,14 +2019,14 @@ static struct sql_func_definition definitions[] = {
fin_total},
{"TRIM", 2, {FIELD_TYPE_STRING, FIELD_TYPE_INTEGER},
- FIELD_TYPE_STRING, trim_func, NULL},
+ FIELD_TYPE_STRING, func_trim_str, NULL},
{"TRIM", 3, {FIELD_TYPE_STRING, FIELD_TYPE_INTEGER, FIELD_TYPE_STRING},
- FIELD_TYPE_STRING, trim_func, NULL},
+ FIELD_TYPE_STRING, func_trim_str, NULL},
{"TRIM", 2, {FIELD_TYPE_VARBINARY, FIELD_TYPE_INTEGER},
- FIELD_TYPE_VARBINARY, trim_func, NULL},
+ FIELD_TYPE_VARBINARY, func_trim_bin, NULL},
{"TRIM", 3,
{FIELD_TYPE_VARBINARY, FIELD_TYPE_INTEGER, FIELD_TYPE_VARBINARY},
- FIELD_TYPE_VARBINARY, trim_func, NULL},
+ FIELD_TYPE_VARBINARY, func_trim_bin, NULL},
{"TYPEOF", 1, {FIELD_TYPE_ANY}, FIELD_TYPE_STRING, typeofFunc, NULL},
{"UNICODE", 1, {FIELD_TYPE_STRING}, FIELD_TYPE_INTEGER, unicodeFunc,
diff --git a/test/sql-tap/badutf1.test.lua b/test/sql-tap/badutf1.test.lua
index ce8354840..d1e17ca3e 100755
--- a/test/sql-tap/badutf1.test.lua
+++ b/test/sql-tap/badutf1.test.lua
@@ -1,6 +1,6 @@
#!/usr/bin/env tarantool
local test = require("sqltester")
-test:plan(19)
+test:plan(20)
--!./tcltestrunner.lua
-- 2007 May 15
@@ -296,47 +296,62 @@ test:do_test(
test:do_test(
"badutf-4.4",
function()
- return test:execsql2([[SELECT hex(CAST(TRIM(x'ff80' FROM ]]..
- [[x'808080f0808080ff') AS VARBINARY)) AS x]])
+ return test:execsql2([[
+ SELECT hex(TRIM(x'ff80' FROM x'808080f0808080ff')) AS x;
+ ]])
end, {
-- <badutf-4.4>
- "X", "808080F0808080FF"
+ "X", "F0"
-- </badutf-4.4>
})
test:do_test(
"badutf-4.5",
function()
- return test:execsql2([[SELECT hex(CAST(TRIM(x'ff80' FROM ]]..
- [[x'ff8080f0808080ff') AS VARBINARY)) AS x]])
+ return test:execsql2([[
+ SELECT hex(TRIM(x'ff80' FROM x'ff8080f0808080ff')) AS x;
+ ]])
end, {
-- <badutf-4.5>
- "X", "80F0808080FF"
+ "X", "F0"
-- </badutf-4.5>
})
test:do_test(
"badutf-4.6",
function()
- return test:execsql2([[SELECT hex(CAST(TRIM(x'ff80' FROM ]]..
- [[x'ff80f0808080ff') AS VARBINARY)) AS x]])
+ return test:execsql2([[
+ SELECT hex(TRIM(x'ff80' FROM x'ff80f0808080ff')) AS x;
+ ]])
end, {
-- <badutf-4.6>
- "X", "F0808080FF"
+ "X", "F0"
-- </badutf-4.6>
})
test:do_test(
"badutf-4.7",
function()
- return test:execsql2([[SELECT hex(CAST(TRIM(x'ff8080' FROM ]]..
- [[x'ff80f0808080ff') AS VARBINARY)) AS x]])
+ return test:execsql2([[
+ SELECT hex(TRIM(x'ff8080' FROM x'ff80f0808080ff')) AS x;
+ ]])
end, {
-- <badutf-4.7>
- "X", "FF80F0808080FF"
+ "X", "F0"
-- </badutf-4.7>
})
+-- gh-4145: Make sure that TRIM() properly work with VARBINARY.
+test:do_execsql_test(
+ "badutf-5",
+ [[
+ SELECT HEX(TRIM(x'ff1234' from x'1234125678123412'));
+ ]],
+ {
+ '5678'
+ }
+)
+
--db2("close")
More information about the Tarantool-patches
mailing list