[Tarantool-patches] [PATCH v4 51/53] sql: introduce mem_get_string0()

imeevma at tarantool.org imeevma at tarantool.org
Tue Mar 23 12:36:50 MSK 2021


This patch introduces mem_get_string0() which is used to receive
NULL-terminated string from MEM.

Part of #5818
---
 src/box/sql/func.c      |  77 ++++++++++++++++++------------
 src/box/sql/mem.c       | 101 ++++++----------------------------------
 src/box/sql/mem.h       |   9 ++--
 src/box/sql/printf.c    |   6 ++-
 src/box/sql/sqlInt.h    |   4 --
 src/box/sql/vdbeapi.c   |   6 ---
 src/box/sql/whereexpr.c |   7 ++-
 7 files changed, 75 insertions(+), 135 deletions(-)

diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index d033dae86..78f4ec3b5 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -54,6 +54,24 @@
 #include "lua/utils.h"
 #include "mpstream/mpstream.h"
 
+static const char *
+mem_get_str(struct Mem *mem)
+{
+	const char *str;
+	if (mem_convert_to_string0(mem) != 0 || mem_get_string0(mem, &str) != 0)
+		return NULL;
+	return str;
+}
+
+static const unsigned char *
+mem_get_ustr(struct Mem *mem)
+{
+	const char *str;
+	if (mem_convert_to_string0(mem) != 0 || mem_get_string0(mem, &str) != 0)
+		return NULL;
+	return (const unsigned char *)str;
+}
+
 /*
  * Return the collating function associated with a function.
  */
@@ -131,7 +149,7 @@ port_vdbemem_dump_lua(struct port *base, struct lua_State *L, bool is_flat)
 			break;
 		}
 		case MP_STR:
-			lua_pushstring(L, (const char *) sql_value_text(param));
+			lua_pushstring(L, mem_get_str(param));
 			break;
 		case MP_BIN:
 		case MP_ARRAY:
@@ -192,7 +210,7 @@ port_vdbemem_get_msgpack(struct port *base, uint32_t *size)
 			break;
 		}
 		case MP_STR: {
-			const char *str = (const char *) sql_value_text(param);
+			const char *str = mem_get_str(param);
 			mpstream_encode_str(&stream, str);
 			break;
 		}
@@ -497,7 +515,7 @@ lengthFunc(sql_context * context, int argc, sql_value ** argv)
 			break;
 		}
 	case MP_STR:{
-			const unsigned char *z = sql_value_text(argv[0]);
+			const unsigned char *z = mem_get_ustr(argv[0]);
 			if (z == 0)
 				return;
 			len = sql_utf8_char_count(z, sql_value_bytes(argv[0]));
@@ -652,8 +670,8 @@ position_func(struct sql_context *context, int argc, struct Mem **argv)
 			 * Character size is equal to
 			 * needle char size.
 			 */
-			haystack_str = sql_value_text(haystack);
-			needle_str = sql_value_text(needle);
+			haystack_str = mem_get_ustr(haystack);
+			needle_str = mem_get_ustr(needle);
 
 			int n_needle_chars =
 				sql_utf8_char_count(needle_str, n_needle_bytes);
@@ -715,8 +733,7 @@ printfFunc(sql_context * context, int argc, sql_value ** argv)
 	int n;
 	sql *db = sql_context_db_handle(context);
 
-	if (argc >= 1
-	    && (zFormat = (const char *)sql_value_text(argv[0])) != 0) {
+	if (argc >= 1 && (zFormat = mem_get_str(argv[0])) != NULL) {
 		x.nArg = argc - 1;
 		x.nUsed = 0;
 		x.apArg = argv + 1;
@@ -773,7 +790,7 @@ substrFunc(sql_context * context, int argc, sql_value ** argv)
 			return;
 		assert(len == sql_value_bytes(argv[0]));
 	} else {
-		z = sql_value_text(argv[0]);
+		z = mem_get_ustr(argv[0]);
 		if (z == 0)
 			return;
 		len = 0;
@@ -936,13 +953,13 @@ case_type##ICUFunc(sql_context *context, int argc, sql_value **argv)   \
 		context->is_aborted = true;                                    \
 		return;                                                        \
 	}                                                                      \
-	z2 = (char *)sql_value_text(argv[0]);                              \
+	z2 = mem_get_str(argv[0]);                                            \
 	n = sql_value_bytes(argv[0]);                                      \
 	/*                                                                     \
 	 * Verify that the call to _bytes()                                    \
 	 * does not invalidate the _text() pointer.                            \
 	 */                                                                    \
-	assert(z2 == (char *)sql_value_text(argv[0]));                     \
+	assert(z2 == mem_get_str(argv[0]));                                   \
 	if (!z2)                                                               \
 		return;                                                        \
 	z1 = contextMalloc(context, ((i64) n) + 1);                            \
@@ -1273,8 +1290,8 @@ likeFunc(sql_context *context, int argc, sql_value **argv)
 		context->is_aborted = true;
 		return;
 	}
-	const char *zB = (const char *) sql_value_text(argv[0]);
-	const char *zA = (const char *) sql_value_text(argv[1]);
+	const char *zB = mem_get_str(argv[0]);
+	const char *zA = mem_get_str(argv[1]);
 	const char *zB_end = zB + sql_value_bytes(argv[0]);
 	const char *zA_end = zA + sql_value_bytes(argv[1]);
 
@@ -1293,7 +1310,7 @@ likeFunc(sql_context *context, int argc, sql_value **argv)
 		return;
 	}
 	/* Encoding did not change */
-	assert(zB == (const char *) sql_value_text(argv[0]));
+	assert(zB == mem_get_str(argv[0]));
 
 	if (argc == 3) {
 		/*
@@ -1301,7 +1318,7 @@ likeFunc(sql_context *context, int argc, sql_value **argv)
 		 * single UTF-8 character. Otherwise, return an
 		 * error.
 		 */
-		const unsigned char *zEsc = sql_value_text(argv[2]);
+		const unsigned char *zEsc = mem_get_ustr(argv[2]);
 		if (zEsc == 0)
 			return;
 		if (sql_utf8_char_count(zEsc, sql_value_bytes(argv[2])) != 1) {
@@ -1430,7 +1447,7 @@ quoteFunc(sql_context * context, int argc, sql_value ** argv)
 	case MP_STR:{
 			int i, j;
 			u64 n;
-			const unsigned char *zArg = sql_value_text(argv[0]);
+			const unsigned char *zArg = mem_get_ustr(argv[0]);
 			char *z;
 
 			if (zArg == 0)
@@ -1477,7 +1494,7 @@ quoteFunc(sql_context * context, int argc, sql_value ** argv)
 static void
 unicodeFunc(sql_context * context, int argc, sql_value ** argv)
 {
-	const unsigned char *z = sql_value_text(argv[0]);
+	const unsigned char *z = mem_get_ustr(argv[0]);
 	(void)argc;
 	if (z && z[0])
 		sql_result_uint(context, sqlUtf8Read(&z));
@@ -1596,12 +1613,12 @@ replaceFunc(sql_context * context, int argc, sql_value ** argv)
 
 	assert(argc == 3);
 	UNUSED_PARAMETER(argc);
-	zStr = sql_value_text(argv[0]);
+	zStr = mem_get_ustr(argv[0]);
 	if (zStr == 0)
 		return;
 	nStr = sql_value_bytes(argv[0]);
-	assert(zStr == sql_value_text(argv[0]));	/* No encoding change */
-	zPattern = sql_value_text(argv[1]);
+	assert(zStr == mem_get_ustr(argv[0]));	/* No encoding change */
+	zPattern = mem_get_ustr(argv[1]);
 	if (zPattern == 0) {
 		assert(mem_is_null(argv[1])
 		       || sql_context_db_handle(context)->mallocFailed);
@@ -1613,12 +1630,12 @@ replaceFunc(sql_context * context, int argc, sql_value ** argv)
 		sql_result_value(context, argv[0]);
 		return;
 	}
-	assert(zPattern == sql_value_text(argv[1]));	/* No encoding change */
-	zRep = sql_value_text(argv[2]);
+	assert(zPattern == mem_get_ustr(argv[1]));	/* No encoding change */
+	zRep = mem_get_ustr(argv[2]);
 	if (zRep == 0)
 		return;
 	nRep = sql_value_bytes(argv[2]);
-	assert(zRep == sql_value_text(argv[2]));
+	assert(zRep == mem_get_ustr(argv[2]));
 	nOut = nStr + 1;
 	assert(nOut < SQL_MAX_LENGTH);
 	zOut = contextMalloc(context, (i64) nOut);
@@ -1780,7 +1797,7 @@ trim_func_one_arg(struct sql_context *context, sql_value *arg)
 	else
 		default_trim = (const unsigned char *) " ";
 	int input_str_sz = sql_value_bytes(arg);
-	const unsigned char *input_str = sql_value_text(arg);
+	const unsigned char *input_str = mem_get_ustr(arg);
 	uint8_t trim_char_len[1] = { 1 };
 	trim_procedure(context, TRIM_BOTH, default_trim, trim_char_len, 1,
 		       input_str, input_str_sz);
@@ -1802,7 +1819,7 @@ trim_func_two_args(struct sql_context *context, sql_value *arg1,
 		   sql_value *arg2)
 {
 	const unsigned char *input_str, *trim_set;
-	if ((input_str = sql_value_text(arg2)) == NULL)
+	if ((input_str = mem_get_ustr(arg2)) == NULL)
 		return;
 
 	int input_str_sz = sql_value_bytes(arg2);
@@ -1813,7 +1830,7 @@ trim_func_two_args(struct sql_context *context, sql_value *arg1,
 		mem_get_integer(arg1, &n, &unused);
 		trim_procedure(context, n, (const unsigned char *) " ",
 			       &len_one, 1, input_str, input_str_sz);
-	} else if ((trim_set = sql_value_text(arg1)) != NULL) {
+	} else if ((trim_set = mem_get_ustr(arg1)) != NULL) {
 		int trim_set_sz = sql_value_bytes(arg1);
 		uint8_t *char_len;
 		int char_cnt = trim_prepare_char_len(context, trim_set,
@@ -1839,8 +1856,8 @@ trim_func_three_args(struct sql_context *context, sql_value *arg1,
 {
 	assert(sql_value_type(arg1) == MP_INT || sql_value_type(arg1) == MP_UINT);
 	const unsigned char *input_str, *trim_set;
-	if ((input_str = sql_value_text(arg3)) == NULL ||
-	    (trim_set = sql_value_text(arg2)) == NULL)
+	if ((input_str = mem_get_ustr(arg3)) == NULL ||
+	    (trim_set = mem_get_ustr(arg2)) == NULL)
 		return;
 
 	int trim_set_sz = sql_value_bytes(arg2);
@@ -1915,7 +1932,7 @@ soundexFunc(sql_context * context, int argc, sql_value ** argv)
 		context->is_aborted = true;
 		return;
 	}
-	zIn = (u8 *) sql_value_text(argv[0]);
+	zIn = (u8 *) mem_get_ustr(argv[0]);
 	if (zIn == 0)
 		zIn = (u8 *) "";
 	for (i = 0; zIn[i] && !sqlIsalpha(zIn[i]); i++) {
@@ -2165,7 +2182,7 @@ groupConcatStep(sql_context * context, int argc, sql_value ** argv)
 		pAccum->mxAlloc = db->aLimit[SQL_LIMIT_LENGTH];
 		if (!firstTerm) {
 			if (argc == 2) {
-				zSep = (char *)sql_value_text(argv[1]);
+				zSep = mem_get_str(argv[1]);
 				nSep = sql_value_bytes(argv[1]);
 			} else {
 				zSep = ",";
@@ -2174,7 +2191,7 @@ groupConcatStep(sql_context * context, int argc, sql_value ** argv)
 			if (zSep)
 				sqlStrAccumAppend(pAccum, zSep, nSep);
 		}
-		zVal = (char *)sql_value_text(argv[0]);
+		zVal = mem_get_str(argv[0]);
 		nVal = sql_value_bytes(argv[0]);
 		if (zVal)
 			sqlStrAccumAppend(pAccum, zVal, nVal);
diff --git a/src/box/sql/mem.c b/src/box/sql/mem.c
index 90c6ee526..efd3e4677 100644
--- a/src/box/sql/mem.c
+++ b/src/box/sql/mem.c
@@ -1319,6 +1319,15 @@ mem_get_boolean(const struct Mem *mem, bool *b)
 	return -1;
 }
 
+int
+mem_get_string0(const struct Mem *mem, const char **s)
+{
+	if ((mem->flags & MEM_Str) == 0 || (mem->flags & MEM_Term) == 0)
+		return -1;
+	*s = mem->z;
+	return 0;
+}
+
 int
 mem_copy(struct Mem *to, const struct Mem *from)
 {
@@ -1936,45 +1945,6 @@ mem_has_msgpack_subtype(struct Mem *mem)
 	       mem->subtype == SQL_SUBTYPE_MSGPACK;
 }
 
-/*
- * The pVal argument is known to be a value other than NULL.
- * Convert it into a string with encoding enc and return a pointer
- * to a zero-terminated version of that string.
- */
-static SQL_NOINLINE const void *
-valueToText(sql_value * pVal)
-{
-	assert(pVal != 0);
-	assert((pVal->flags & (MEM_Null)) == 0);
-	if ((pVal->flags & (MEM_Blob | MEM_Str)) &&
-	    !mem_has_msgpack_subtype(pVal)) {
-		if (ExpandBlob(pVal))
-			return 0;
-		pVal->flags |= MEM_Str;
-		sqlVdbeMemNulTerminate(pVal);	/* IMP: R-31275-44060 */
-	} else {
-		mem_convert_to_string(pVal);
-		assert(0 == (1 & SQL_PTR_TO_INT(pVal->z)));
-	}
-	return pVal->z;
-}
-
-/*
- * It is already known that pMem contains an unterminated string.
- * Add the zero terminator.
- */
-static SQL_NOINLINE int
-vdbeMemAddTerminator(Mem * pMem)
-{
-	if (sqlVdbeMemGrow(pMem, pMem->n + 2, 1)) {
-		return -1;
-	}
-	pMem->z[pMem->n] = 0;
-	pMem->z[pMem->n + 1] = 0;
-	pMem->flags |= MEM_Term;
-	return 0;
-}
-
 /*
  * Both *pMem1 and *pMem2 contain string values. Compare the two values
  * using the collation sequence pColl. As usual, return a negative , zero
@@ -2076,7 +2046,9 @@ sql_value_type(sql_value *pVal)
 static SQL_NOINLINE int
 valueBytes(sql_value * pVal)
 {
-	return valueToText(pVal) != 0 ? pVal->n : 0;
+	if (mem_convert_to_string(pVal) != 0)
+		return 0;
+	return pVal->n;
 }
 
 int
@@ -2311,21 +2283,6 @@ registerTrace(int iReg, Mem *p) {
 }
 #endif
 
-/*
- * Make sure the given Mem is \u0000 terminated.
- */
-int
-sqlVdbeMemNulTerminate(Mem * pMem)
-{
-	testcase((pMem->flags & (MEM_Term | MEM_Str)) == (MEM_Term | MEM_Str));
-	testcase((pMem->flags & (MEM_Term | MEM_Str)) == 0);
-	if ((pMem->flags & (MEM_Term | MEM_Str)) != MEM_Str) {
-		return 0;	/* Nothing to do */
-	} else {
-		return vdbeMemAddTerminator(pMem);
-	}
-}
-
 /*
  * If the given Mem* has a zero-filled tail, turn it into an ordinary
  * blob stored in dynamically allocated space.
@@ -2493,7 +2450,9 @@ sql_value_blob(sql_value * pVal)
 		p->flags |= MEM_Blob;
 		return p->n ? p->z : 0;
 	} else {
-		return sql_value_text(pVal);
+		if (mem_convert_to_string(pVal) != 0)
+			return NULL;
+		return pVal->z;
 	}
 }
 
@@ -2503,36 +2462,6 @@ sql_value_bytes(sql_value * pVal)
 	return sqlValueBytes(pVal);
 }
 
-const unsigned char *
-sql_value_text(sql_value * pVal)
-{
-	return (const unsigned char *)sqlValueText(pVal);
-}
-
-/* This function is only available internally, it is not part of the
- * external API. It works in a similar way to sql_value_text(),
- * except the data returned is in the encoding specified by the second
- * parameter, which must be one of SQL_UTF16BE, SQL_UTF16LE or
- * SQL_UTF8.
- *
- * (2006-02-16:)  The enc value can be or-ed with SQL_UTF16_ALIGNED.
- * If that is the case, then the result must be aligned on an even byte
- * boundary.
- */
-const void *
-sqlValueText(sql_value * pVal)
-{
-	if (!pVal)
-		return 0;
-	if ((pVal->flags & (MEM_Str | MEM_Term)) == (MEM_Str | MEM_Term)) {
-		return pVal->z;
-	}
-	if (pVal->flags & MEM_Null) {
-		return 0;
-	}
-	return valueToText(pVal);
-}
-
 /*
  * Return a pointer to static memory containing an SQL NULL value.
  */
diff --git a/src/box/sql/mem.h b/src/box/sql/mem.h
index 8015608fa..fb2385541 100644
--- a/src/box/sql/mem.h
+++ b/src/box/sql/mem.h
@@ -357,6 +357,9 @@ mem_get_double(const struct Mem *mem, double *d);
 int
 mem_get_boolean(const struct Mem *mem, bool *b);
 
+int
+mem_get_string0(const struct Mem *mem, const char **s);
+
 /**
  * Simple type to str convertor. It is used to simplify
  * error reporting.
@@ -407,7 +410,6 @@ registerTrace(int iReg, Mem *p);
 # define memAboutToChange(P,M)
 #endif
 
-int sqlVdbeMemNulTerminate(struct Mem *);
 int sqlVdbeMemExpandBlob(struct Mem *);
 #define ExpandBlob(P) (mem_is_zeroblob(P)? sqlVdbeMemExpandBlob(P) : 0)
 
@@ -439,11 +441,6 @@ sql_value_blob(struct Mem *);
 int
 sql_value_bytes(struct Mem *);
 
-const unsigned char *
-sql_value_text(struct Mem *);
-
-const void *sqlValueText(struct Mem *);
-
 #define VdbeFrameMem(p) ((Mem *)&((u8 *)p)[ROUND8(sizeof(VdbeFrame))])
 
 const Mem *
diff --git a/src/box/sql/printf.c b/src/box/sql/printf.c
index 0790a52a7..85cfe8802 100644
--- a/src/box/sql/printf.c
+++ b/src/box/sql/printf.c
@@ -165,7 +165,11 @@ getTextArg(PrintfArguments * p)
 {
 	if (p->nArg <= p->nUsed)
 		return 0;
-	return (char *)sql_value_text(p->apArg[p->nUsed++]);
+	const char *str;
+	struct Mem *mem = p->apArg[p->nUsed++];
+	if (mem_convert_to_string0(mem) != 0 || mem_get_string0(mem, &str) != 0)
+		return NULL;
+	return (char *)str;
 }
 
 /*
diff --git a/src/box/sql/sqlInt.h b/src/box/sql/sqlInt.h
index 9b4afc3f2..2f995ffc6 100644
--- a/src/box/sql/sqlInt.h
+++ b/src/box/sql/sqlInt.h
@@ -442,10 +442,6 @@ sql_column_bytes(sql_stmt *, int iCol);
 int
 sql_column_bytes16(sql_stmt *, int iCol);
 
-const unsigned char *
-sql_column_text(sql_stmt *,
-		    int iCol);
-
 char *
 sql_result_to_msgpack(struct sql_stmt *stmt, uint32_t *tuple_size,
 		      struct region *region);
diff --git a/src/box/sql/vdbeapi.c b/src/box/sql/vdbeapi.c
index 532677b05..465284567 100644
--- a/src/box/sql/vdbeapi.c
+++ b/src/box/sql/vdbeapi.c
@@ -485,12 +485,6 @@ sql_column_bytes(sql_stmt * pStmt, int i)
 	return sql_value_bytes(columnMem(pStmt, i));
 }
 
-const unsigned char *
-sql_column_text(sql_stmt * pStmt, int i)
-{
-	return sql_value_text(columnMem(pStmt, i));
-}
-
 char *
 sql_result_to_msgpack(struct sql_stmt *stmt, uint32_t *tuple_size,
 		      struct region *region)
diff --git a/src/box/sql/whereexpr.c b/src/box/sql/whereexpr.c
index 782e44734..51681b008 100644
--- a/src/box/sql/whereexpr.c
+++ b/src/box/sql/whereexpr.c
@@ -312,8 +312,11 @@ like_optimization_is_valid(Parse *pParse, Expr *pExpr, Expr **ppPrefix,
 		pVal =
 		    sqlVdbeGetBoundValue(pReprepare, iCol,
 					     FIELD_TYPE_SCALAR);
-		if (pVal != NULL && mem_is_string(pVal))
-			z = (char *)sql_value_text(pVal);
+		if (pVal != NULL && mem_is_string(pVal)) {
+			if (mem_convert_to_string0(pVal) != 0 ||
+			    mem_get_string0(pVal, &z) != 0)
+				return -1;
+		}
 		assert(pRight->op == TK_VARIABLE || pRight->op == TK_REGISTER);
 	} else if (op == TK_STRING) {
 		z = pRight->u.zToken;
-- 
2.25.1



More information about the Tarantool-patches mailing list