[Tarantool-patches] [PATCH luajit v7] Fix math.min()/math.max() inconsistencies.
Maxim Kokryashkin
max.kokryashkin at gmail.com
Fri Feb 3 00:06:09 MSK 2023
From: Mike Pall <mike>
(cherry-picked from commit 03208c8162af9cc01ca76ee1676ca79e5abe9b60)
`math.min()`/`math.max()` could produce different results.
Previously, dirty values on the Lua stack could be
treated as arguments to `math.min()`/`math.max()`.
This patch adds check for the number of arguments provided to
math.min/max, which fixes the issue.
Also, several fold optimizations were modified or
deleted due to inconsistencies between compiler
and interpreter on NaN-values.
Changes in `fold_kfold_numarith()` also required
to replace the `CC_LO/CC_HI` comparison modes with the `CC_LE/CC_PL`
on aarhc64 platforms. The issue is thoroughly described just before
the corresponding test.
Maxim Kokryashkin & Sergey Kaplun:
* added the description and tests for the problem
Resolves tarantool/tarantool#6163
---
src/lj_asm_arm.h | 6 +-
src/lj_asm_arm64.h | 6 +-
src/lj_opt_fold.c | 53 ++--
src/lj_vmmath.c | 4 +-
src/vm_arm.dasc | 4 +-
src/vm_arm64.dasc | 4 +-
src/vm_x64.dasc | 2 +-
src/vm_x86.dasc | 2 +-
test/tarantool-tests/gh-6163-min-max.test.lua | 245 ++++++++++++++++++
9 files changed, 278 insertions(+), 48 deletions(-)
create mode 100644 test/tarantool-tests/gh-6163-min-max.test.lua
diff --git a/src/lj_asm_arm.h b/src/lj_asm_arm.h
index 8af19eb9..6ae6e2f2 100644
--- a/src/lj_asm_arm.h
+++ b/src/lj_asm_arm.h
@@ -1663,8 +1663,8 @@ static void asm_min_max(ASMState *as, IRIns *ir, int cc, int fcc)
asm_intmin_max(as, ir, cc);
}
-#define asm_min(as, ir) asm_min_max(as, ir, CC_GT, CC_HI)
-#define asm_max(as, ir) asm_min_max(as, ir, CC_LT, CC_LO)
+#define asm_min(as, ir) asm_min_max(as, ir, CC_GT, CC_PL)
+#define asm_max(as, ir) asm_min_max(as, ir, CC_LT, CC_LE)
/* -- Comparisons --------------------------------------------------------- */
@@ -1856,7 +1856,7 @@ static void asm_hiop(ASMState *as, IRIns *ir)
} else if ((ir-1)->o == IR_MIN || (ir-1)->o == IR_MAX) {
as->curins--; /* Always skip the loword min/max. */
if (uselo || usehi)
- asm_sfpmin_max(as, ir-1, (ir-1)->o == IR_MIN ? CC_HI : CC_LO);
+ asm_sfpmin_max(as, ir-1, (ir-1)->o == IR_MIN ? CC_PL : CC_LE);
return;
#elif LJ_HASFFI
} else if ((ir-1)->o == IR_CONV) {
diff --git a/src/lj_asm_arm64.h b/src/lj_asm_arm64.h
index 4aeb51f3..fe197700 100644
--- a/src/lj_asm_arm64.h
+++ b/src/lj_asm_arm64.h
@@ -1592,7 +1592,7 @@ static void asm_fpmin_max(ASMState *as, IRIns *ir, A64CC fcc)
Reg dest = (ra_dest(as, ir, RSET_FPR) & 31);
Reg right, left = ra_alloc2(as, ir, RSET_FPR);
right = ((left >> 8) & 31); left &= 31;
- emit_dnm(as, A64I_FCSELd | A64F_CC(fcc), dest, left, right);
+ emit_dnm(as, A64I_FCSELd | A64F_CC(fcc), dest, right, left);
emit_nm(as, A64I_FCMPd, left, right);
}
@@ -1604,8 +1604,8 @@ static void asm_min_max(ASMState *as, IRIns *ir, A64CC cc, A64CC fcc)
asm_intmin_max(as, ir, cc);
}
-#define asm_max(as, ir) asm_min_max(as, ir, CC_GT, CC_HI)
-#define asm_min(as, ir) asm_min_max(as, ir, CC_LT, CC_LO)
+#define asm_min(as, ir) asm_min_max(as, ir, CC_LT, CC_PL)
+#define asm_max(as, ir) asm_min_max(as, ir, CC_GT, CC_LE)
/* -- Comparisons --------------------------------------------------------- */
diff --git a/src/lj_opt_fold.c b/src/lj_opt_fold.c
index 49f74996..27e489af 100644
--- a/src/lj_opt_fold.c
+++ b/src/lj_opt_fold.c
@@ -1797,8 +1797,6 @@ LJFOLDF(reassoc_intarith_k64)
#endif
}
-LJFOLD(MIN MIN any)
-LJFOLD(MAX MAX any)
LJFOLD(BAND BAND any)
LJFOLD(BOR BOR any)
LJFOLDF(reassoc_dup)
@@ -1808,6 +1806,15 @@ LJFOLDF(reassoc_dup)
return NEXTFOLD;
}
+LJFOLD(MIN MIN any)
+LJFOLD(MAX MAX any)
+LJFOLDF(reassoc_dup_minmax)
+{
+ if (fins->op2 == fleft->op2)
+ return LEFTFOLD; /* (a o b) o b ==> a o b */
+ return NEXTFOLD;
+}
+
LJFOLD(BXOR BXOR any)
LJFOLDF(reassoc_bxor)
{
@@ -1846,23 +1853,12 @@ LJFOLDF(reassoc_shift)
return NEXTFOLD;
}
-LJFOLD(MIN MIN KNUM)
-LJFOLD(MAX MAX KNUM)
LJFOLD(MIN MIN KINT)
LJFOLD(MAX MAX KINT)
LJFOLDF(reassoc_minmax_k)
{
IRIns *irk = IR(fleft->op2);
- if (irk->o == IR_KNUM) {
- lua_Number a = ir_knum(irk)->n;
- lua_Number y = lj_vm_foldarith(a, knumright, fins->o - IR_ADD);
- if (a == y) /* (x o k1) o k2 ==> x o k1, if (k1 o k2) == k1. */
- return LEFTFOLD;
- PHIBARRIER(fleft);
- fins->op1 = fleft->op1;
- fins->op2 = (IRRef1)lj_ir_knum(J, y);
- return RETRYFOLD; /* (x o k1) o k2 ==> x o (k1 o k2) */
- } else if (irk->o == IR_KINT) {
+ if (irk->o == IR_KINT) {
int32_t a = irk->i;
int32_t y = kfold_intop(a, fright->i, fins->o);
if (a == y) /* (x o k1) o k2 ==> x o k1, if (k1 o k2) == k1. */
@@ -1875,24 +1871,6 @@ LJFOLDF(reassoc_minmax_k)
return NEXTFOLD;
}
-LJFOLD(MIN MAX any)
-LJFOLD(MAX MIN any)
-LJFOLDF(reassoc_minmax_left)
-{
- if (fins->op2 == fleft->op1 || fins->op2 == fleft->op2)
- return RIGHTFOLD; /* (b o1 a) o2 b ==> b; (a o1 b) o2 b ==> b */
- return NEXTFOLD;
-}
-
-LJFOLD(MIN any MAX)
-LJFOLD(MAX any MIN)
-LJFOLDF(reassoc_minmax_right)
-{
- if (fins->op1 == fright->op1 || fins->op1 == fright->op2)
- return LEFTFOLD; /* a o2 (a o1 b) ==> a; a o2 (b o1 a) ==> a */
- return NEXTFOLD;
-}
-
/* -- Array bounds check elimination -------------------------------------- */
/* Eliminate ABC across PHIs to handle t[i-1] forwarding case.
@@ -2018,8 +1996,6 @@ LJFOLDF(comm_comp)
LJFOLD(BAND any any)
LJFOLD(BOR any any)
-LJFOLD(MIN any any)
-LJFOLD(MAX any any)
LJFOLDF(comm_dup)
{
if (fins->op1 == fins->op2) /* x o x ==> x */
@@ -2027,6 +2003,15 @@ LJFOLDF(comm_dup)
return fold_comm_swap(J);
}
+LJFOLD(MIN any any)
+LJFOLD(MAX any any)
+LJFOLDF(comm_dup_minmax)
+{
+ if (fins->op1 == fins->op2) /* x o x ==> x */
+ return LEFTFOLD;
+ return NEXTFOLD;
+}
+
LJFOLD(BXOR any any)
LJFOLDF(comm_bxor)
{
diff --git a/src/lj_vmmath.c b/src/lj_vmmath.c
index c04459bd..ae4e0f15 100644
--- a/src/lj_vmmath.c
+++ b/src/lj_vmmath.c
@@ -49,8 +49,8 @@ double lj_vm_foldarith(double x, double y, int op)
case IR_ABS - IR_ADD: return fabs(x); break;
#if LJ_HASJIT
case IR_LDEXP - IR_ADD: return ldexp(x, (int)y); break;
- case IR_MIN - IR_ADD: return x > y ? y : x; break;
- case IR_MAX - IR_ADD: return x < y ? y : x; break;
+ case IR_MIN - IR_ADD: return x < y ? x : y; break;
+ case IR_MAX - IR_ADD: return x > y ? x : y; break;
#endif
default: return x;
}
diff --git a/src/vm_arm.dasc b/src/vm_arm.dasc
index a29292f1..89faa03e 100644
--- a/src/vm_arm.dasc
+++ b/src/vm_arm.dasc
@@ -1718,8 +1718,8 @@ static void build_subroutines(BuildCtx *ctx)
|.endif
|.endmacro
|
- | math_minmax math_min, gt, hi
- | math_minmax math_max, lt, lo
+ | math_minmax math_min, gt, pl
+ | math_minmax math_max, lt, le
|
|//-- String library -----------------------------------------------------
|
diff --git a/src/vm_arm64.dasc b/src/vm_arm64.dasc
index f517a808..2c1bb4f8 100644
--- a/src/vm_arm64.dasc
+++ b/src/vm_arm64.dasc
@@ -1494,8 +1494,8 @@ static void build_subroutines(BuildCtx *ctx)
| b <6
|.endmacro
|
- | math_minmax math_min, gt, hi
- | math_minmax math_max, lt, lo
+ | math_minmax math_min, gt, pl
+ | math_minmax math_max, lt, le
|
|//-- String library -----------------------------------------------------
|
diff --git a/src/vm_x64.dasc b/src/vm_x64.dasc
index 59f117ba..faeb5181 100644
--- a/src/vm_x64.dasc
+++ b/src/vm_x64.dasc
@@ -1896,7 +1896,7 @@ static void build_subroutines(BuildCtx *ctx)
| jmp ->fff_res
|
|.macro math_minmax, name, cmovop, sseop
- | .ffunc name
+ | .ffunc_1 name
| mov RAd, 2
|.if DUALNUM
| mov RB, [BASE]
diff --git a/src/vm_x86.dasc b/src/vm_x86.dasc
index f7ffe5d2..1c995d16 100644
--- a/src/vm_x86.dasc
+++ b/src/vm_x86.dasc
@@ -2321,7 +2321,7 @@ static void build_subroutines(BuildCtx *ctx)
| xorps xmm4, xmm4; jmp <1 // Return +-Inf and +-0.
|
|.macro math_minmax, name, cmovop, sseop
- | .ffunc name
+ | .ffunc_1 name
| mov RA, 2
| cmp dword [BASE+4], LJ_TISNUM
|.if DUALNUM
diff --git a/test/tarantool-tests/gh-6163-min-max.test.lua b/test/tarantool-tests/gh-6163-min-max.test.lua
new file mode 100644
index 00000000..0c9378cc
--- /dev/null
+++ b/test/tarantool-tests/gh-6163-min-max.test.lua
@@ -0,0 +1,245 @@
+local tap = require('tap')
+local test = tap.test('gh-6163-jit-min-max')
+local x86_64 = jit.arch == 'x86' or jit.arch == 'x64'
+test:plan(18)
+--
+-- gh-6163: math.min/math.max inconsistencies.
+--
+
+local function isnan(x)
+ return x ~= x
+end
+
+local function array_is_consistent(res)
+ for i = 1, #res - 1 do
+ if res[i] ~= res[i + 1] and not (isnan(res[i]) and isnan(res[i + 1])) then
+ return false
+ end
+ end
+ return true
+end
+
+-- This function creates dirty values on the Lua stack.
+-- The latter of them is going to be treated as an
+-- argument by the `math.min/math.max`.
+-- The first two of them are going to be overwritten
+-- by the math function itself.
+local function filler()
+ return 1, 1, 1
+end
+
+local min = math.min
+local max = math.max
+
+-- It would be enough to test all cases for the
+-- `math.min()` or for the `math.max()` only, because the
+-- problem was in the common code. However, we shouldn't
+-- make such assumptions in the testing code.
+
+-- `math.min()/math.max()` should raise an error when
+-- called with no arguments.
+filler()
+local r, _ = pcall(function() min() end)
+test:ok(not r, 'math.min fails with no args')
+
+filler()
+r, _ = pcall(function() max() end)
+test:ok(false == r, 'math.max fails with no args')
+
+local nan = 0/0
+local x = 1
+
+jit.opt.start('hotloop=1')
+jit.on()
+
+-- Without the `(a o b) o a ==> a o b` fold optimization for
+-- `math.min()/math.max()` the following mcode is emitted on aarch64
+-- for the `math.min(math.min(x, nan), x)` expression:
+--
+-- | fcmp d2, d3 ; fcmp 1.0, nan
+-- | fcsel d1, d2, d3, cc ; d1 == nan after this instruction
+-- | ...
+-- | fcmp d1, d2 ; fcmp nan, 1.0
+-- | fcsel d0, d1, d2, cc ; d0 == 1.0 after this instruction
+--
+-- According to the `fcmp` docs[1], if either of the operands is NaN,
+-- then the operands are unordered. It results in the following state
+-- of the flags register: N=0, Z=0, C=1, V=1
+--
+-- According to the `fcsel` docs[2], if the condition is met, then
+-- the first register value is taken, otherwise -- the second.
+-- In our case, the condition is cc, which means that the `C` flag
+-- should be clear[3], which is false. Then, the second value is taken,
+-- which is `NaN` for the first `fcmp`-`fcsel` pair, and `1.0` for
+-- the second.
+--
+-- If that fold optimization is applied, then only the first `fcmp`-`fcsel`
+-- pair is emitted, and the result is `NaN`, which is inconsistent with
+-- the result of the non-optimized mcode.
+--
+-- [1]: https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FCMP
+-- [2]: https://developer.arm.com/documentation/100069/0608/A64-Floating-point-Instructions/FCSEL
+-- [3]: https://developer.arm.com/documentation/dui0068/b/ARM-Instruction-Reference/Conditional-execution
+
+local result = {}
+for k = 1, 4 do
+ result[k] = min(min(x, nan), x)
+end
+test:ok(array_is_consistent(result), 'math.min: reassoc_dup')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(max(x, nan), x)
+end
+test:ok(array_is_consistent(result), 'math.max: reassoc_dup')
+
+-- If one gets the expression like `math.min(x, math.min(x, nan))`,
+-- and the `comm_dup` optimization is applied, it results in the
+-- same situation as explained above. With the `comm_dup_minmax`
+-- there is no swap, hence, everything is consistent again:
+--
+-- | fcmp d2, d3 ; fcmp 1.0, nan
+-- | fcsel d1, d3, d2, pl ; d1 == nan after this instruction
+-- | ...
+-- | fcmp d2, d1 ; fcmp 1.0, nan
+-- | fcsel d0, d1, d2, pl ; d0 == nan after this instruction
+-- `pl` (aka `CC_PL`) condition means that N flag is 0 [2], that
+-- is true when we are comparing something with NaN. So, the value of the
+-- first source register is taken
+
+result = {}
+for k = 1, 4 do
+ result[k] = min(x, min(x, nan))
+end
+-- FIXME: results are still inconsistent for the x86/64 architecture.
+test:ok(array_is_consistent(result) or x86_64, 'math.min: comm_dup_minmax')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(x, max(x, nan))
+end
+-- FIXME: results are still inconsistent for the x86/64 architecture.
+test:ok(array_is_consistent(result) or x86_64, 'math.max: comm_dup_minmax')
+
+-- The following optimization should be disabled:
+-- (x o k1) o k2 ==> x o (k1 o k2)
+
+x = 1.2
+result = {}
+for k = 1, 4 do
+ result[k] = min(min(x, 0/0), 1.3)
+end
+test:ok(array_is_consistent(result), 'math.min: reassoc_minmax_k')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(max(x, 0/0), 1.1)
+end
+test:ok(array_is_consistent(result), 'math.max: reassoc_minmax_k')
+
+result = {}
+for k = 1, 4 do
+ result[k] = min(max(nan, 1), 1)
+end
+test:ok(array_is_consistent(result), 'min-max-case1: reassoc_minmax_left')
+
+result = {}
+for k = 1, 4 do
+ result[k] = min(max(1, nan), 1)
+end
+test:ok(array_is_consistent(result), 'min-max-case2: reassoc_minmax_left')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(min(nan, 1), 1)
+end
+test:ok(array_is_consistent(result), 'max-min-case1: reassoc_minmax_left')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(min(1, nan), 1)
+end
+test:ok(array_is_consistent(result), 'max-min-case2: reassoc_minmax_left')
+
+result = {}
+for k = 1, 4 do
+ result[k] = min(1, max(nan, 1))
+end
+test:ok(array_is_consistent(result), 'min-max-case1: reassoc_minmax_right')
+
+result = {}
+for k = 1, 4 do
+ result[k] = min(1, max(1, nan))
+end
+-- FIXME: results are still inconsistent for the x86/64 architecture.
+test:ok(array_is_consistent(result) or x86_64, 'min-max-case2: reassoc_minmax_right')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(1, min(nan, 1))
+end
+test:ok(array_is_consistent(result), 'max-min-case1: reassoc_minmax_right')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(1, min(1, nan))
+end
+-- FIXME: results are still inconsistent for the x86/64 architecture.
+test:ok(array_is_consistent(result) or x86_64, 'max-min-case2: reassoc_minmax_right')
+
+-- XXX: If we look into the disassembled code of `lj_vm_foldarith()`
+-- we can see the following:
+--
+-- | /* In our test x == 7.1, y == nan */
+-- | case IR_MIN - IR_ADD: return x > y ? y : x; break;
+--
+-- | ; case IR_MIN
+-- | <lj_vm_foldarith+337>: movsd xmm0,QWORD PTR [rsp+0x18] ; xmm0 <- 7.1
+-- | <lj_vm_foldarith+343>: comisd xmm0,QWORD PTR [rsp+0x10] ; comisd 7.1, nan
+-- | <lj_vm_foldarith+349>: jbe <lj_vm_foldarith+358> ; >= ?
+-- | <lj_vm_foldarith+351>: mov rax,QWORD PTR [rsp+0x10] ; return nan
+-- | <lj_vm_foldarith+356>: jmp <lj_vm_foldarith+398> ;
+-- | <lj_vm_foldarith+358>: mov rax,QWORD PTR [rsp+0x18] ; else return 7.1
+-- | <lj_vm_foldarith+363>: jmp <lj_vm_foldarith+398> ;
+--
+-- According to `comisd` documentation [4] in case when one of the operands
+-- is NaN, the result is unordered and ZF,PF,CF := 111. This means that `jbe`
+-- condition is true (CF=1 or ZF=1)[5], so we return 7.1 (the first
+-- operand) for case `IR_MIN`.
+--
+-- However, in `lj_ff_math_min()` in the VM we see the following:
+-- |7:
+-- | sseop xmm0, xmm1
+-- Where `sseop` is either `minsd` or `maxsd` instruction.
+-- If only one of their args is a NaN, the second source operand,
+-- either a NaN or a valid floating-point value, is
+-- written to the result.
+--
+-- So the patch changes the `lj_vm_foldairth()` assembly in the following way:
+-- | ; case IR_MIN
+-- | <lj_vm_foldarith+337>: movsd xmm0,QWORD PTR [rsp+0x10] ; xmm0 <- nan
+-- | <lj_vm_foldarith+343>: comisd xmm0,QWORD PTR [rsp+0x18] ; comisd nan, 7.1
+-- | <lj_vm_foldarith+349>: jbe <lj_vm_foldarith+358> ; >= ?
+-- | <lj_vm_foldarith+351>: mov rax,QWORD PTR [rsp+0x18] ; return 7.1
+-- | <lj_vm_foldarith+356>: jmp <lj_vm_foldarith+398> ;
+-- | <lj_vm_foldarith+358>: mov rax,QWORD PTR [rsp+0x10] ; else return nan
+-- | <lj_vm_foldarith+363>: jmp <lj_vm_foldarith+398> ;
+--
+-- So now we always return the second operand.
+--
+-- XXX: The two tests below use the `0/0` constant instead of `nan`
+-- variable is dictated by the `fold_kfold_numarith` semantics.
+result = {}
+for k = 1, 4 do
+ result[k] = min(min(7.1, 0/0), 1.1)
+end
+test:ok(array_is_consistent(result), 'min: fold_kfold_numarith')
+
+result = {}
+for k = 1, 4 do
+ result[k] = max(max(7.1, 0/0), 1.1)
+end
+test:ok(array_is_consistent(result), 'max: fold_kfold_numarith')
+
+
+os.exit(test:check() and 0 or 1)
--
2.39.0
More information about the Tarantool-patches
mailing list