[Tarantool-patches] [PATCH v1 5/8] sql: rework TRIM() function

imeevma at tarantool.org imeevma at tarantool.org
Thu Nov 11 13:45:32 MSK 2021


This patch refactoring TRIM() and fixes an issue with incorrect trimming
of some VARBINARY values. Also, TRIM() now use ICU functions instead of
self-created.

Part of #4415
---
 src/box/sql/func.c            | 361 ++++++++++++++++------------------
 test/sql-tap/badutf1.test.lua |  41 ++--
 2 files changed, 198 insertions(+), 204 deletions(-)

diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index 7d54a39cd..ba6b9246d 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -343,6 +343,172 @@ func_nullif(struct sql_context *ctx, int argc, struct Mem *argv)
 		ctx->is_aborted = true;
 }
 
+/** Implementation of the TRIM() function. */
+static inline int
+trim_bin_end(const char *str, int end, const char *octets, int octets_size,
+	     int flags)
+{
+	if ((flags & TRIM_TRAILING) == 0)
+		return end;
+	while (end > 0) {
+		bool is_trimmed = false;
+		char c = str[end - 1];
+		for (int i = 0; i < octets_size && !is_trimmed; ++i)
+			is_trimmed = c == octets[i];
+		if (!is_trimmed)
+			break;
+		--end;
+	}
+	return end;
+}
+
+static inline int
+trim_bin_start(const char *str, int end, const char *octets, int octets_size,
+	       int flags)
+{
+	if ((flags & TRIM_LEADING) == 0)
+		return 0;
+	int start = 0;
+	while (start < end) {
+		bool is_trimmed = false;
+		char c = str[start];
+		for (int i = 0; i < octets_size && !is_trimmed; ++i)
+			is_trimmed = c == octets[i];
+		if (!is_trimmed)
+			break;
+		++start;
+	}
+	return start;
+}
+
+static void
+func_trim_bin(struct sql_context *ctx, int argc, struct Mem *argv)
+{
+	if (mem_is_null(&argv[0]) || (argc == 3 && mem_is_null(&argv[2])))
+		return;
+	assert(argc == 2 || (argc == 3 && mem_is_bin(&argv[2])));
+	assert(mem_is_bin(&argv[0]) && mem_is_uint(&argv[1]));
+	const char *str = argv[0].z;
+	int size = argv[0].n;
+	const char *octets;
+	int octets_size;
+	if (argc == 3) {
+		octets = argv[2].z;
+		octets_size = argv[2].n;
+	} else {
+		octets = "\0";
+		octets_size = 1;
+	}
+
+	int flags = argv[1].u.u;
+	int end = trim_bin_end(str, size, octets, octets_size, flags);
+	int start = trim_bin_start(str, end, octets, octets_size, flags);
+
+	if (start >= end)
+		return mem_set_bin_static(ctx->pOut, "", 0);
+	if (mem_copy_bin(ctx->pOut, &str[start], end - start) != 0)
+		ctx->is_aborted = true;
+}
+
+static inline int
+trim_str_end(const char *str, int end, const char *chars, uint8_t *chars_len,
+	     int chars_count, int flags)
+{
+	if ((flags & TRIM_TRAILING) == 0)
+		return end;
+	while (end > 0) {
+		bool is_trimmed = false;
+		const char *c = chars;
+		int len;
+		for (int i = 0; i < chars_count && !is_trimmed; ++i) {
+			len = chars_len[i];
+			const char *s = str + end - len;
+			is_trimmed = len <= end && memcmp(c, s, len) == 0;
+			c += len;
+		}
+		if (!is_trimmed)
+			break;
+		assert(len > 0);
+		end -= len;
+	}
+	return end;
+}
+
+static inline int
+trim_str_start(const char *str, int end, const char *chars, uint8_t *chars_len,
+	       int chars_count, int flags)
+{
+	if ((flags & TRIM_LEADING) == 0)
+		return 0;
+	int start = 0;
+	while (start < end) {
+		bool is_trimmed = false;
+		const char *c = chars;
+		int len;
+		for (int i = 0; i < chars_count && !is_trimmed; ++i) {
+			len = chars_len[i];
+			const char *s = str + start;
+			is_trimmed = start + len <= end &&
+				     memcmp(c, s, len) == 0;
+			c += len;
+		}
+		if (!is_trimmed)
+			break;
+		assert(len > 0);
+		start += len;
+	}
+	return start;
+}
+
+static void
+func_trim_str(struct sql_context *ctx, int argc, struct Mem *argv)
+{
+	if (mem_is_null(&argv[0]) || (argc == 3 && mem_is_null(&argv[2])))
+		return;
+	assert(argc == 2 || (argc == 3 && mem_is_str(&argv[2])));
+	assert(mem_is_str(&argv[0]) && mem_is_uint(&argv[1]));
+	const char *str = argv[0].z;
+	int size = argv[0].n;
+	const char *chars;
+	int chars_size;
+	if (argc == 3) {
+		chars = argv[2].z;
+		chars_size = argv[2].n;
+	} else {
+		chars = " ";
+		chars_size = 1;
+	}
+
+	struct region *region = &fiber()->gc;
+	size_t svp = region_used(region);
+	uint8_t *chars_len = region_alloc(region, chars_size);
+	if (chars_len == NULL) {
+		ctx->is_aborted = true;
+		diag_set(OutOfMemory, chars_size, "region_alloc", "chars_len");
+		return;
+	}
+	int chars_count = 0;
+
+	int offset = 0;
+	while (offset < chars_size) {
+		UChar32 c;
+		int prev = offset;
+		U8_NEXT((uint8_t *)chars, offset, chars_size, c);
+		chars_len[chars_count++] = offset - prev;
+	}
+
+	uint64_t flags = argv[1].u.u;
+	int end = trim_str_end(str, size, chars, chars_len, chars_count, flags);
+	int start = trim_str_start(str, end, chars, chars_len, chars_count,
+				   flags);
+	region_truncate(region, svp);
+
+	if (start >= end)
+		return mem_set_str0_static(ctx->pOut, "");
+	if (mem_copy_str(ctx->pOut, &str[start], end - start) != 0)
+		ctx->is_aborted = true;
+}
+
 static const unsigned char *
 mem_as_ustr(struct Mem *mem)
 {
@@ -1527,193 +1693,6 @@ replaceFunc(struct sql_context *context, int argc, struct Mem *argv)
 		mem_set_bin_dynamic(context->pOut, (char *)zOut, j);
 }
 
-/**
- * Remove characters included in @a trim_set from @a input_str
- * until encounter a character that doesn't belong to @a trim_set.
- * Remove from the side specified by @a flags.
- * @param context SQL context.
- * @param flags Trim specification: left, right or both.
- * @param trim_set The set of characters for trimming.
- * @param char_len Lengths of each UTF-8 character in @a trim_set.
- * @param char_cnt A number of UTF-8 characters in @a trim_set.
- * @param input_str Input string for trimming.
- * @param input_str_sz Input string size in bytes.
- */
-static void
-trim_procedure(struct sql_context *context, enum trim_side_mask flags,
-	       const unsigned char *trim_set, const uint8_t *char_len,
-	       int char_cnt, const unsigned char *input_str, int input_str_sz)
-{
-	if (char_cnt == 0)
-		goto finish;
-	int i, len;
-	const unsigned char *z;
-	if ((flags & TRIM_LEADING) != 0) {
-		while (input_str_sz > 0) {
-			z = trim_set;
-			for (i = 0; i < char_cnt; ++i, z += len) {
-				len = char_len[i];
-				if (len <= input_str_sz
-				    && memcmp(input_str, z, len) == 0)
-					break;
-			}
-			if (i >= char_cnt)
-				break;
-			input_str += len;
-			input_str_sz -= len;
-		}
-	}
-	if ((flags & TRIM_TRAILING) != 0) {
-		while (input_str_sz > 0) {
-			z = trim_set;
-			for (i = 0; i < char_cnt; ++i, z += len) {
-				len = char_len[i];
-				if (len <= input_str_sz
-				    && memcmp(&input_str[input_str_sz - len],
-					      z, len) == 0)
-					break;
-			}
-			if (i >= char_cnt)
-				break;
-			input_str_sz -= len;
-		}
-	}
-finish:
-	if (context->func->def->returns == FIELD_TYPE_STRING)
-		mem_copy_str(context->pOut, (char *)input_str, input_str_sz);
-	else
-		mem_copy_bin(context->pOut, (char *)input_str, input_str_sz);
-}
-
-/**
- * Prepare arguments for trimming procedure. Allocate memory for
- * @a char_len (array of lengths each character in @a trim_set)
- * and fill it.
- *
- * @param context SQL context.
- * @param trim_set The set of characters for trimming.
- * @param[out] char_len Lengths of each character in @ trim_set.
- * @retval >=0 A number of UTF-8 characters in @a trim_set.
- * @retval -1 Memory allocation error.
- */
-static int
-trim_prepare_char_len(struct sql_context *context,
-		      const unsigned char *trim_set, int trim_set_sz,
-		      uint8_t **char_len)
-{
-	/*
-	 * Count the number of UTF-8 characters passing through
-	 * the entire char set, but not up to the '\0' or X'00'
-	 * character. This allows to handle trimming set
-	 * containing such characters.
-	 */
-	int char_cnt = sql_utf8_char_count(trim_set, trim_set_sz);
-	if (char_cnt == 0) {
-		*char_len = NULL;
-		return 0;
-	}
-
-	if ((*char_len = (uint8_t *)contextMalloc(context, char_cnt)) == NULL)
-		return -1;
-
-	int i = 0, j = 0;
-	while(j < char_cnt) {
-		int old_i = i;
-		SQL_UTF8_FWD_1(trim_set, i, trim_set_sz);
-		(*char_len)[j++] = i - old_i;
-	}
-
-	return char_cnt;
-}
-
-/**
- * Normalize args from @a argv input array when it has two args.
- *
- * Case: TRIM(<str>)
- * Call trimming procedure with TRIM_BOTH as the flags and " " as
- * the trimming set.
- *
- * Case: TRIM(LEADING/TRAILING/BOTH FROM <str>)
- * If user has specified side keyword only, then call trimming
- * procedure with the specified side and " " as the trimming set.
- */
-static void
-trim_func_two_args(struct sql_context *context, sql_value *arg1,
-		   sql_value *arg2)
-{
-	const unsigned char *trim_set;
-	if (mem_is_bin(arg1))
-		trim_set = (const unsigned char *)"\0";
-	else
-		trim_set = (const unsigned char *)" ";
-	const unsigned char *input_str;
-	if ((input_str = mem_as_ustr(arg1)) == NULL)
-		return;
-
-	int input_str_sz = mem_len_unsafe(arg1);
-	assert(arg2->type == MEM_TYPE_UINT);
-	uint8_t len_one = 1;
-	trim_procedure(context, arg2->u.u, trim_set,
-		       &len_one, 1, input_str, input_str_sz);
-}
-
-/**
- * Normalize args from @a argv input array when it has three args.
- *
- * Case: TRIM(<character_set> FROM <str>)
- * If user has specified <character_set> only, call trimming procedure with
- * TRIM_BOTH as the flags and that trimming set.
- *
- * Case: TRIM(LEADING/TRAILING/BOTH <character_set> FROM <str>)
- * If user has specified side keyword and <character_set>, then
- * call trimming procedure with that args.
- */
-static void
-trim_func_three_args(struct sql_context *context, sql_value *arg1,
-		     sql_value *arg2, sql_value *arg3)
-{
-	assert(arg2->type == MEM_TYPE_UINT);
-	const unsigned char *input_str, *trim_set;
-	if ((input_str = mem_as_ustr(arg1)) == NULL ||
-	    (trim_set = mem_as_ustr(arg3)) == NULL)
-		return;
-
-	int trim_set_sz = mem_len_unsafe(arg3);
-	int input_str_sz = mem_len_unsafe(arg1);
-	uint8_t *char_len;
-	int char_cnt = trim_prepare_char_len(context, trim_set, trim_set_sz,
-					     &char_len);
-	if (char_cnt == -1)
-		return;
-	trim_procedure(context, arg2->u.u, trim_set, char_len,
-		       char_cnt, input_str, input_str_sz);
-	sql_free(char_len);
-}
-
-/**
- * Normalize args from @a argv input array when it has one,
- * two or three args.
- *
- * This is a dispatcher function that calls corresponding
- * implementation depending on the number of arguments.
-*/
-static void
-trim_func(struct sql_context *context, int argc, struct Mem *argv)
-{
-	switch (argc) {
-	case 2:
-		trim_func_two_args(context, &argv[0], &argv[1]);
-		break;
-	case 3:
-		trim_func_three_args(context, &argv[0], &argv[1], &argv[2]);
-		break;
-	default:
-		diag_set(ClientError, ER_FUNC_WRONG_ARG_COUNT, "TRIM",
-			"2 or 3", argc);
-		context->is_aborted = true;
-	}
-}
-
 /*
  * Compute the soundex encoding of a word.
  *
@@ -2040,14 +2019,14 @@ static struct sql_func_definition definitions[] = {
 	 fin_total},
 
 	{"TRIM", 2, {FIELD_TYPE_STRING, FIELD_TYPE_INTEGER},
-	 FIELD_TYPE_STRING, trim_func, NULL},
+	 FIELD_TYPE_STRING, func_trim_str, NULL},
 	{"TRIM", 3, {FIELD_TYPE_STRING, FIELD_TYPE_INTEGER, FIELD_TYPE_STRING},
-	 FIELD_TYPE_STRING, trim_func, NULL},
+	 FIELD_TYPE_STRING, func_trim_str, NULL},
 	{"TRIM", 2, {FIELD_TYPE_VARBINARY, FIELD_TYPE_INTEGER},
-	 FIELD_TYPE_VARBINARY, trim_func, NULL},
+	 FIELD_TYPE_VARBINARY, func_trim_bin, NULL},
 	{"TRIM", 3,
 	 {FIELD_TYPE_VARBINARY, FIELD_TYPE_INTEGER, FIELD_TYPE_VARBINARY},
-	 FIELD_TYPE_VARBINARY, trim_func, NULL},
+	 FIELD_TYPE_VARBINARY, func_trim_bin, NULL},
 
 	{"TYPEOF", 1, {FIELD_TYPE_ANY}, FIELD_TYPE_STRING, typeofFunc, NULL},
 	{"UNICODE", 1, {FIELD_TYPE_STRING}, FIELD_TYPE_INTEGER, unicodeFunc,
diff --git a/test/sql-tap/badutf1.test.lua b/test/sql-tap/badutf1.test.lua
index ce8354840..d1e17ca3e 100755
--- a/test/sql-tap/badutf1.test.lua
+++ b/test/sql-tap/badutf1.test.lua
@@ -1,6 +1,6 @@
 #!/usr/bin/env tarantool
 local test = require("sqltester")
-test:plan(19)
+test:plan(20)
 
 --!./tcltestrunner.lua
 -- 2007 May 15
@@ -296,47 +296,62 @@ test:do_test(
 test:do_test(
     "badutf-4.4",
     function()
-        return test:execsql2([[SELECT hex(CAST(TRIM(x'ff80' FROM ]]..
-                             [[x'808080f0808080ff') AS VARBINARY)) AS x]])
+        return test:execsql2([[
+            SELECT hex(TRIM(x'ff80' FROM x'808080f0808080ff')) AS x;
+        ]])
     end, {
         -- <badutf-4.4>
-        "X", "808080F0808080FF"
+        "X", "F0"
         -- </badutf-4.4>
     })
 
 test:do_test(
     "badutf-4.5",
     function()
-        return test:execsql2([[SELECT hex(CAST(TRIM(x'ff80' FROM ]]..
-                             [[x'ff8080f0808080ff') AS VARBINARY)) AS x]])
+        return test:execsql2([[
+            SELECT hex(TRIM(x'ff80' FROM x'ff8080f0808080ff')) AS x;
+        ]])
     end, {
         -- <badutf-4.5>
-        "X", "80F0808080FF"
+        "X", "F0"
         -- </badutf-4.5>
     })
 
 test:do_test(
     "badutf-4.6",
     function()
-        return test:execsql2([[SELECT hex(CAST(TRIM(x'ff80' FROM ]]..
-                             [[x'ff80f0808080ff') AS VARBINARY)) AS x]])
+        return test:execsql2([[
+            SELECT hex(TRIM(x'ff80' FROM x'ff80f0808080ff')) AS x;
+        ]])
     end, {
         -- <badutf-4.6>
-        "X", "F0808080FF"
+        "X", "F0"
         -- </badutf-4.6>
     })
 
 test:do_test(
     "badutf-4.7",
     function()
-        return test:execsql2([[SELECT hex(CAST(TRIM(x'ff8080' FROM ]]..
-                             [[x'ff80f0808080ff') AS VARBINARY)) AS x]])
+        return test:execsql2([[
+            SELECT hex(TRIM(x'ff8080' FROM x'ff80f0808080ff')) AS x;
+        ]])
     end, {
         -- <badutf-4.7>
-        "X", "FF80F0808080FF"
+        "X", "F0"
         -- </badutf-4.7>
     })
 
+-- gh-4145: Make sure that TRIM() properly work with VARBINARY.
+test:do_execsql_test(
+    "badutf-5",
+    [[
+        SELECT HEX(TRIM(x'ff1234' from x'1234125678123412'));
+    ]],
+    {
+        '5678'
+    }
+)
+
 --db2("close")
 
 
-- 
2.25.1



More information about the Tarantool-patches mailing list