[tarantool-patches] [PATCH v2 06/15] sql: arithmetic functions support big integers
Stanislav Zudin
szudin at tarantool.org
Mon Apr 1 23:44:44 MSK 2019
Makes arithmetic functions accept arguments with
values in the range [2^63, 2^64).
Part of #3810
---
src/box/sql/func.c | 2 +-
src/box/sql/sqlInt.h | 23 +++-
src/box/sql/util.c | 236 ++++++++++++++++++++++++++++++++----------
src/box/sql/vdbe.c | 36 ++++---
src/box/sql/vdbeInt.h | 2 +-
5 files changed, 223 insertions(+), 76 deletions(-)
diff --git a/src/box/sql/func.c b/src/box/sql/func.c
index cf65bf2a2..8a8acc216 100644
--- a/src/box/sql/func.c
+++ b/src/box/sql/func.c
@@ -1437,7 +1437,7 @@ sumStep(sql_context * context, int argc, sql_value ** argv)
i64 v = sql_value_int64(argv[0]);
p->rSum += v;
if ((p->approx | p->overflow) == 0
- && sqlAddInt64(&p->iSum, v)) {
+ && sqlAddInt64(&p->iSum, true, v, true) != ATHR_SIGNED) {
p->overflow = 1;
}
} else {
diff --git a/src/box/sql/sqlInt.h b/src/box/sql/sqlInt.h
index 9b1d7df9a..7f8e3f04e 100644
--- a/src/box/sql/sqlInt.h
+++ b/src/box/sql/sqlInt.h
@@ -4383,9 +4383,26 @@ Expr *sqlExprAddCollateString(Parse *, Expr *, const char *);
Expr *sqlExprSkipCollate(Expr *);
int sqlCheckIdentifierName(Parse *, char *);
void sqlVdbeSetChanges(sql *, int);
-int sqlAddInt64(i64 *, i64);
-int sqlSubInt64(i64 *, i64);
-int sqlMulInt64(i64 *, i64);
+
+enum arithmetic_result {
+ /* The result fits the signed 64-bit integer */
+ ATHR_SIGNED,
+ /* The result is positive and fits the
+ * unsigned 64-bit integer
+ */
+ ATHR_UNSIGNED,
+ /* The operation causes an overflow */
+ ATHR_OVERFLOW,
+ /* The operation causes division by zero */
+ ATHR_DIVBYZERO
+};
+
+enum arithmetic_result sqlAddInt64(i64 *, bool, i64, bool);
+enum arithmetic_result sqlSubInt64(i64 *, bool, i64, bool);
+enum arithmetic_result sqlMulInt64(i64 *, bool, i64, bool);
+enum arithmetic_result sqlDivInt64(i64 *, bool, i64, bool);
+enum arithmetic_result sqlRemInt64(i64 *, bool, i64, bool);
+
int sqlAbsInt32(int);
#ifdef SQL_ENABLE_8_3_NAMES
void sqlFileSuffix3(const char *, char *);
diff --git a/src/box/sql/util.c b/src/box/sql/util.c
index be77f72f8..3786c5083 100644
--- a/src/box/sql/util.c
+++ b/src/box/sql/util.c
@@ -1249,74 +1249,202 @@ sqlSafetyCheckSickOrOk(sql * db)
}
/*
- * Attempt to add, substract, or multiply the 64-bit signed value iB against
- * the other 64-bit signed integer at *pA and store the result in *pA.
- * Return 0 on success. Or if the operation would have resulted in an
- * overflow, leave *pA unchanged and return 1.
+ * Get modulo of 64-bit number
*/
-int
-sqlAddInt64(i64 * pA, i64 iB)
+static u64 mod64(i64 v, bool is_signed)
+{
+ bool is_neg = v < 0 && is_signed;
+ if (is_neg)
+ return (v == INT64_MIN) ? (u64)v : (u64)(-v);
+ else
+ return (u64)v;
+}
+
+/*
+ * Attempt to add, substract, or multiply the 64-bit value iB against
+ * the other 64-bit integer at *pA and store the result in *pA.
+ * Return ATHR_SIGNED or ATHR_UNSIGNED on success.
+ * Or if the operation would have resulted in an
+ * overflow, leave *pA unchanged and return ATHR_OVERFLOW.
+ */
+enum arithmetic_result
+sqlAddInt64(i64 * pA, bool is_signedA, i64 iB, bool is_signedB)
{
i64 iA = *pA;
- testcase(iA == 0);
- testcase(iA == 1);
- testcase(iB == -1);
- testcase(iB == 0);
- if (iB >= 0) {
- testcase(iA > 0 && LARGEST_INT64 - iA == iB);
- testcase(iA > 0 && LARGEST_INT64 - iA == iB - 1);
- if (iA > 0 && LARGEST_INT64 - iA < iB)
- return 1;
+
+ bool is_negA = iA < 0 && is_signedA;
+ bool is_negB = iB < 0 && is_signedB;
+
+ /* Make sure we've got only one combination of
+ * positive and negative operands
+ */
+ if (is_negA > is_negB) {
+ SWAP(is_negA, is_negB);
+ SWAP(iA, iB);
+ }
+
+ if (is_negA != is_negB) {
+
+ assert(iA >=0 && iB < 0);
+ u64 uB = mod64(iB, true);
+
+ if ((u64)iA >= uB) {
+ u64 sum = (u64)iA - uB;
+ *pA = (i64)sum;
+ return (sum <= INT64_MAX) ? ATHR_SIGNED
+ : ATHR_UNSIGNED;
+ } else {
+ u64 sum = uB - (u64)iA;
+ if (sum == INT64_MIN_MOD) {
+ *pA = INT64_MIN;
+ } else {
+ assert(sum < INT64_MAX);
+ *pA = -(i64)sum;
+ }
+ return ATHR_SIGNED;
+ }
+ }
+
+ if (is_negA) {
+ assert(is_signedA && is_signedB);
+ if (-(iA + LARGEST_INT64) > iB + 1)
+ return ATHR_OVERFLOW;
+ *pA = iA + iB;
+ return ATHR_SIGNED;
} else {
- testcase(iA < 0 && -(iA + LARGEST_INT64) == iB + 1);
- testcase(iA < 0 && -(iA + LARGEST_INT64) == iB + 2);
- if (iA < 0 && -(iA + LARGEST_INT64) > iB + 1)
- return 1;
+ if (UINT64_MAX - (u64)iA < (u64)iB)
+ return ATHR_OVERFLOW;
+
+ u64 sum = (u64)iA + (u64)iB;
+ *pA = (i64)sum;
+ return (sum <= INT64_MAX) ? ATHR_SIGNED
+ : ATHR_UNSIGNED;
}
- *pA += iB;
- return 0;
}
-int
-sqlSubInt64(i64 * pA, i64 iB)
+enum arithmetic_result
+sqlSubInt64(i64 * pA, bool is_signedA, i64 iB, bool is_signedB)
{
- testcase(iB == SMALLEST_INT64 + 1);
- if (iB == SMALLEST_INT64) {
- testcase((*pA) == (-1));
- testcase((*pA) == 0);
- if ((*pA) >= 0)
- return 1;
- *pA -= iB;
- return 0;
+ i64 iA = *pA;
+
+ bool is_negA = iA < 0 && is_signedA;
+ bool is_negB = iB < 0 && is_signedB;
+
+ if (is_negA) {
+ if (!is_signedB){
+ assert((u64)iB > INT64_MAX);
+ return ATHR_OVERFLOW;
+ }
+
+ if (iB == INT64_MIN)
+ return ATHR_OVERFLOW;
+ else
+ return sqlAddInt64(pA, true, -iB, true);
+ }
+
+ if (is_negB) {
+ /* iA - (-iB) => iA + iB */
+ u64 uB = mod64(iB, true);
+ if (iB == INT64_MIN)
+ is_signedB = false;
+
+ return sqlAddInt64(pA, is_signedA, uB, is_signedB);
} else {
- return sqlAddInt64(pA, -iB);
+ /* Both iA & iB are positive */
+ if ((u64)iA < (u64)iB)
+ return ATHR_OVERFLOW;
+ u64 val = (u64)iA - (u64)iB;
+ *pA = (i64)val;
+ return (val > INT64_MAX) ? ATHR_UNSIGNED
+ : ATHR_SIGNED;
}
}
-int
-sqlMulInt64(i64 * pA, i64 iB)
+static enum arithmetic_result
+apply_sign(i64* pOut, u64 value, bool is_neg)
{
- i64 iA = *pA;
- if (iB > 0) {
- if (iA > LARGEST_INT64 / iB)
- return 1;
- if (iA < SMALLEST_INT64 / iB)
- return 1;
- } else if (iB < 0) {
- if (iA > 0) {
- if (iB < SMALLEST_INT64 / iA)
- return 1;
- } else if (iA < 0) {
- if (iB == SMALLEST_INT64)
- return 1;
- if (iA == SMALLEST_INT64)
- return 1;
- if (-iA > LARGEST_INT64 / -iB)
- return 1;
- }
+ if (is_neg) {
+ if (value > INT64_MIN_MOD)
+ return ATHR_OVERFLOW;
+ else if (value == INT64_MIN_MOD)
+ *pOut = (i64)value;
+ else
+ *pOut = -(i64)value;
+
+ return ATHR_SIGNED;
}
- *pA = iA * iB;
- return 0;
+
+ *pOut = (i64) value;
+ return (value > INT64_MAX) ? ATHR_UNSIGNED
+ : ATHR_SIGNED;
+}
+
+enum arithmetic_result
+sqlMulInt64(i64 * pA, bool is_signedA, i64 iB, bool is_signedB)
+{
+ if (*pA == 0 || iB == 0) {
+ *pA = 0;
+ return ATHR_SIGNED;
+ }
+
+ bool is_negA = *pA < 0 && is_signedA;
+ bool is_negB = iB < 0 && is_signedB;
+
+ bool is_neg = is_negA != is_negB;
+
+ u64 uA = mod64(*pA, is_signedA);
+ u64 uB = mod64(iB, is_signedB);
+
+ if (is_neg) {
+ if (INT64_MIN_MOD / uA < uB)
+ return ATHR_OVERFLOW;
+ } else {
+ if (INT64_MAX / uA < uB)
+ return ATHR_OVERFLOW;
+ }
+
+ u64 mul = uA * uB;
+ return apply_sign(pA, mul, is_neg);
+}
+
+enum arithmetic_result
+sqlDivInt64(i64 * pA, bool is_signedA, i64 iB, bool is_signedB) {
+ if (*pA == 0)
+ return ATHR_SIGNED;
+ if (iB == 0)
+ return ATHR_DIVBYZERO;
+
+ bool is_negA = *pA < 0 && is_signedA;
+ bool is_negB = iB < 0 && is_signedB;
+
+ bool is_neg = is_negA != is_negB;
+
+ u64 uA = mod64(*pA, is_signedA);
+ u64 uB = mod64(iB, is_signedB);
+
+ u64 div = uA / uB;
+ return apply_sign(pA, div, is_neg);
+}
+
+enum arithmetic_result
+sqlRemInt64(i64 * pA, bool is_signedA, i64 iB, bool is_signedB) {
+
+ if (iB == 0)
+ return ATHR_DIVBYZERO;
+ /*
+ * The sign of the remainder is defined in such
+ * a way that if the quotient a/b is representable
+ * in the result type, then (a/b)*b + a%b == a.
+ *
+ * The 2nd operand doesn't affect the sign of result.
+ */
+
+ bool is_neg = *pA < 0 && is_signedA;
+ u64 uA = mod64(*pA, is_signedA);
+ u64 uB = mod64(iB, is_signedB);
+
+ u64 rem = uA % uB;
+ return apply_sign(pA, rem, is_neg);
}
/*
diff --git a/src/box/sql/vdbe.c b/src/box/sql/vdbe.c
index ea9d9d98f..d4bd845fb 100644
--- a/src/box/sql/vdbe.c
+++ b/src/box/sql/vdbe.c
@@ -1672,28 +1672,29 @@ case OP_Remainder: { /* same as TK_REM, in1, in2, out3 */
if ((type1 & type2 & MEM_Int)!=0) {
iA = pIn1->u.i;
iB = pIn2->u.i;
+ bool is_signedA = (type1 & MEM_Unsigned) == 0;
+ bool is_signedB = (type2 & MEM_Unsigned) == 0;
bIntint = 1;
+ enum arithmetic_result arr;
switch( pOp->opcode) {
- case OP_Add: if (sqlAddInt64(&iB,iA)) goto integer_overflow; break;
- case OP_Subtract: if (sqlSubInt64(&iB,iA)) goto integer_overflow; break;
- case OP_Multiply: if (sqlMulInt64(&iB,iA)) goto integer_overflow; break;
- case OP_Divide: {
- if (iA == 0)
- goto division_by_zero;
- if (iA==-1 && iB==SMALLEST_INT64) goto integer_overflow;
- iB /= iA;
- break;
+ case OP_Add: arr = sqlAddInt64(&iB, is_signedA, iA, is_signedB); break;
+ case OP_Subtract: arr = sqlSubInt64(&iB, is_signedA, iA, is_signedB); break;
+ case OP_Multiply: arr = sqlMulInt64(&iB, is_signedA, iA, is_signedB); break;
+ case OP_Divide: arr = sqlDivInt64(&iB, is_signedA, iA, is_signedB); break;
+ default: arr = sqlRemInt64(&iB, is_signedA, iA, is_signedB); break;
}
- default: {
- if (iA == 0)
- goto division_by_zero;
- if (iA==-1) iA = 1;
- iB %= iA;
+
+ switch(arr){
+ case ATHR_SIGNED:
+ MemSetTypeFlag(pOut, MEM_Int);
break;
- }
+ case ATHR_UNSIGNED:
+ MemSetTypeFlag(pOut, MEM_Int|MEM_Unsigned);
+ break;
+ case ATHR_OVERFLOW: goto integer_overflow;
+ case ATHR_DIVBYZERO: goto division_by_zero;
}
pOut->u.i = iB;
- MemSetTypeFlag(pOut, MEM_Int);
} else {
bIntint = 0;
if (sqlVdbeRealValue(pIn1, &rA) != 0) {
@@ -5177,7 +5178,8 @@ case OP_OffsetLimit: { /* in1, out2, in3 */
assert(pIn1->flags & MEM_Int);
assert(pIn3->flags & MEM_Int);
x = pIn1->u.i;
- if (x<=0 || sqlAddInt64(&x, pIn3->u.i>0?pIn3->u.i:0)) {
+ if (x<=0 ||
+ sqlAddInt64(&x, true, pIn3->u.i>0?pIn3->u.i:0, true) != ATHR_SIGNED) {
/* If the LIMIT is less than or equal to zero, loop forever. This
* is documented. But also, if the LIMIT+OFFSET exceeds 2^63 then
* also loop forever. This is undocumented. In fact, one could argue
diff --git a/src/box/sql/vdbeInt.h b/src/box/sql/vdbeInt.h
index 0375845d9..42f22df52 100644
--- a/src/box/sql/vdbeInt.h
+++ b/src/box/sql/vdbeInt.h
@@ -276,7 +276,7 @@ enum {
* Clear any existing type flags from a Mem and replace them with f
*/
#define MemSetTypeFlag(p, f) \
- ((p)->flags = ((p)->flags&~(MEM_TypeMask|MEM_Zero))|f)
+ ((p)->flags = ((p)->flags&~(MEM_TypeMask|MEM_Zero|MEM_Unsigned))|f)
/*
* Return true if a memory cell is not marked as invalid. This macro
--
2.17.1
More information about the Tarantool-patches
mailing list