Hi, Sergey!
Thanks for the patch!
From: Mike Pall <mike> See the discussion in the corresponding ticket for the rationale. (cherry picked from commit de2e1ca9d3d87e74c0c20c1e4ad3c32b31a5875b) For the modulo operation, the arm64 VM uses `fmsub` [1] instruction, which is the fused multiply-add (FMA [2]) operation (more precisely, multiply-sub). Hence, it may produce different results compared to the unfused one. This patch fixes the behaviour by using the unfused instructions by default. However, the new JIT optimization flag (fma) is introduced to make it possible to take advantage of the FMA optimizations. Sergey Kaplun: * added the description and the test for the problem [1]: https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB [2]: https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation Part of tarantool/tarantool#10709 --- I intentionally avoid mentioning the ticket in the commit message to avoid excess mentioning in the LuaJIT issue tracker. You can see the LuaJIT/LuaJIT#918 link in the cover letter. doc/running.html | 8 +++++ src/lj_asm_arm.h | 6 +++- src/lj_asm_arm64.h | 3 +- src/lj_asm_ppc.h | 3 +- src/lj_jit.h | 4 ++- src/lj_vmmath.c | 13 ++++++- src/vm_arm64.dasc | 4 ++- ...lj-918-fma-numerical-accuracy-jit.test.lua | 36 +++++++++++++++++++ .../lj-918-fma-numerical-accuracy.test.lua | 31 ++++++++++++++++ .../lj-918-fma-optimization.test.lua | 25 +++++++++++++ 10 files changed, 127 insertions(+), 6 deletions(-) create mode 100644 test/tarantool-tests/lj-918-fma-numerical-accuracy-jit.test.lua create mode 100644 test/tarantool-tests/lj-918-fma-numerical-accuracy.test.lua create mode 100644 test/tarantool-tests/lj-918-fma-optimization.test.lua diff --git a/doc/running.html b/doc/running.html index 7868efab..1cf41f1b 100644 --- a/doc/running.html +++ b/doc/running.html @@ -226,6 +226,12 @@ mix the three forms, but note that setting an optimization level overrides all earlier flags. </p> <p> +Note that <tt>-Ofma</tt> is not enabled by default at any level, +because it affects floating-point result accuracy. Only enable this, +if you fully understand the trade-offs of FMA for performance (higher), +determinism (lower) and numerical accuracy (higher). +</p> +<p> Here are the available flags and at what optimization levels they are enabled: </p> @@ -257,6 +263,8 @@ are enabled: <td class="flag_name">sink</td><td class="flag_level"> </td><td class="flag_level"> </td><td class="flag_level">•</td><td class="flag_desc">Allocation/Store Sinking</td></tr> <tr class="even"> <td class="flag_name">fuse</td><td class="flag_level"> </td><td class="flag_level"> </td><td class="flag_level">•</td><td class="flag_desc">Fusion of operands into instructions</td></tr> +<tr class="odd"> +<td class="flag_name">fma </td><td class="flag_level"> </td><td class="flag_level"> </td><td class="flag_level"> </td><td class="flag_desc">Fused multiply-add</td></tr> </table> <p> Here are the parameters and their default settings: diff --git a/src/lj_asm_arm.h b/src/lj_asm_arm.h index 5a0f925f..041cd794 100644 --- a/src/lj_asm_arm.h +++ b/src/lj_asm_arm.h @@ -310,7 +310,11 @@ static void asm_fusexref(ASMState *as, ARMIns ai, Reg rd, IRRef ref, } #if !LJ_SOFTFP -/* Fuse to multiply-add/sub instruction. */ +/* +** Fuse to multiply-add/sub instruction. +** VMLA rounds twice (UMA, not FMA) -- no need to check for JIT_F_OPT_FMA. +** VFMA needs VFPv4, which is uncommon on the remaining ARM32 targets. +*/ static int asm_fusemadd(ASMState *as, IRIns *ir, ARMIns ai, ARMIns air) { IRRef lref = ir->op1, rref = ir->op2; diff --git a/src/lj_asm_arm64.h b/src/lj_asm_arm64.h index 88b47ceb..554bb60a 100644 --- a/src/lj_asm_arm64.h +++ b/src/lj_asm_arm64.h @@ -334,7 +334,8 @@ static int asm_fusemadd(ASMState *as, IRIns *ir, A64Ins ai, A64Ins air) { IRRef lref = ir->op1, rref = ir->op2; IRIns *irm; - if (lref != rref && + if ((as->flags & JIT_F_OPT_FMA) && + lref != rref && ((mayfuse(as, lref) && (irm = IR(lref), irm->o == IR_MUL) && ra_noreg(irm->r)) || (mayfuse(as, rref) && (irm = IR(rref), irm->o == IR_MUL) && diff --git a/src/lj_asm_ppc.h b/src/lj_asm_ppc.h index 7bba71b3..52db2926 100644 --- a/src/lj_asm_ppc.h +++ b/src/lj_asm_ppc.h @@ -232,7 +232,8 @@ static int asm_fusemadd(ASMState *as, IRIns *ir, PPCIns pi, PPCIns pir) { IRRef lref = ir->op1, rref = ir->op2; IRIns *irm; - if (lref != rref && + if ((as->flags & JIT_F_OPT_FMA) && + lref != rref && ((mayfuse(as, lref) && (irm = IR(lref), irm->o == IR_MUL) && ra_noreg(irm->r)) || (mayfuse(as, rref) && (irm = IR(rref), irm->o == IR_MUL) && diff --git a/src/lj_jit.h b/src/lj_jit.h index 47df85c6..73c355b9 100644 --- a/src/lj_jit.h +++ b/src/lj_jit.h @@ -86,10 +86,11 @@ #define JIT_F_OPT_ABC (JIT_F_OPT << 7) #define JIT_F_OPT_SINK (JIT_F_OPT << 8) #define JIT_F_OPT_FUSE (JIT_F_OPT << 9) +#define JIT_F_OPT_FMA (JIT_F_OPT << 10) /* Optimizations names for -O. Must match the order above. */ #define JIT_F_OPTSTRING \ - "\4fold\3cse\3dce\3fwd\3dse\6narrow\4loop\3abc\4sink\4fuse" + "\4fold\3cse\3dce\3fwd\3dse\6narrow\4loop\3abc\4sink\4fuse\3fma" /* Optimization levels set a fixed combination of flags. */ #define JIT_F_OPT_0 0 @@ -98,6 +99,7 @@ #define JIT_F_OPT_3 (JIT_F_OPT_2|\ JIT_F_OPT_FWD|JIT_F_OPT_DSE|JIT_F_OPT_ABC|JIT_F_OPT_SINK|JIT_F_OPT_FUSE) #define JIT_F_OPT_DEFAULT JIT_F_OPT_3 +/* Note: FMA is not set by default. */ /* -- JIT engine parameters ----------------------------------------------- */ diff --git a/src/lj_vmmath.c b/src/lj_vmmath.c index faebe719..29b72e0c 100644 --- a/src/lj_vmmath.c +++ b/src/lj_vmmath.c @@ -36,6 +36,17 @@ LJ_FUNCA double lj_wrap_fmod(double x, double y) { return fmod(x, y); } /* -- Helper functions ---------------------------------------------------- */ +/* Required to prevent the C compiler from applying FMA optimizations. +** +** Yes, there's -ffp-contract and the FP_CONTRACT pragma ... in theory. +** But the current state of C compilers is a mess in this regard. +** Also, this function is not performance sensitive at all. +*/ +LJ_NOINLINE static double lj_vm_floormul(double x, double y) +{ + return lj_vm_floor(x / y) * y; +} + double lj_vm_foldarith(double x, double y, int op) { switch (op) { @@ -43,7 +54,7 @@ double lj_vm_foldarith(double x, double y, int op) case IR_SUB - IR_ADD: return x-y; break; case IR_MUL - IR_ADD: return x*y; break; case IR_DIV - IR_ADD: return x/y; break; - case IR_MOD - IR_ADD: return x-lj_vm_floor(x/y)*y; break; + case IR_MOD - IR_ADD: return x-lj_vm_floormul(x, y); break; case IR_POW - IR_ADD: return pow(x, y); break; case IR_NEG - IR_ADD: return -x; break; case IR_ABS - IR_ADD: return fabs(x); break; diff --git a/src/vm_arm64.dasc b/src/vm_arm64.dasc index 1cf1ea51..c5f0a7a7 100644 --- a/src/vm_arm64.dasc +++ b/src/vm_arm64.dasc @@ -2581,7 +2581,9 @@ static void build_ins(BuildCtx *ctx, BCOp op, int defop) |.macro ins_arithmod, res, reg1, reg2 | fdiv d2, reg1, reg2 | frintm d2, d2 - | fmsub res, d2, reg2, reg1 + | // Cannot use fmsub, because FMA is not enabled by default. + | fmul d2, d2, reg2 + | fsub res, reg1, d2 |.endmacro | |.macro ins_arithdn, intins, fpins diff --git a/test/tarantool-tests/lj-918-fma-numerical-accuracy-jit.test.lua b/test/tarantool-tests/lj-918-fma-numerical-accuracy-jit.test.lua new file mode 100644 index 00000000..55ec7b98 --- /dev/null +++ b/test/tarantool-tests/lj-918-fma-numerical-accuracy-jit.test.lua @@ -0,0 +1,36 @@ +local tap = require('tap') + +-- Test file to demonstrate consistent behaviour for JIT and the +-- VM regarding FMA optimization (disabled by default). +-- XXX: The VM behaviour is checked in the +-- <lj-918-fma-numerical-accuracy.test.lua>. +-- See also: https://github.com/LuaJIT/LuaJIT/issues/918. +local test = tap.test('lj-918-fma-numerical-accuracy-jit'):skipcond({ + ['Test requires JIT enabled'] = not jit.status(), +}) + +test:plan(1) + +local _2pow52 = 2 ^ 52 + +-- IEEE754 components to double: +-- sign * (2 ^ (exp - 1023)) * (mantissa / _2pow52 + normal). +local a = 1 * (2 ^ (1083 - 1023)) * (4080546448249347 / _2pow52 + 1) +assert(a == 2197541395358679800) + +local b = -1 * (2 ^ (1052 - 1023)) * (3927497732209973 / _2pow52 + 1) +assert(b == -1005065126.3690554) +
Please add a comment with explanation why exactly these testcases
are used.
As I got it right, the idea is to calculate negative and positive number, right?
Why do you think two examples are enough for testing that behavior for JIT and the VM
is consistent?
Should we check more corner cases?
Please add a commit hash and it's short description.+local results = {} + +jit.opt.start('hotloop=1') +for i = 1, 4 do + results[i] = a % b +end + +-- XXX: The test doesn't fail before the commit. But it is
+-- required to be sure that there are no inconsistencies after the +-- commit. +test:samevalues(results, 'consistent behaviour between the JIT and the VM') + +test:done(true) diff --git a/test/tarantool-tests/lj-918-fma-numerical-accuracy.test.lua b/test/tarantool-tests/lj-918-fma-numerical-accuracy.test.lua new file mode 100644 index 00000000..a3775d6d --- /dev/null +++ b/test/tarantool-tests/lj-918-fma-numerical-accuracy.test.lua @@ -0,0 +1,31 @@ +local tap = require('tap') + +-- Test file to demonstrate possible numerical inaccuracy if FMA +-- optimization takes place.
I suppose we don't need to test FMA itself, but we should
check that FMA is actually enabled when it's option
is enabled. Right? if yes I would merge test lj-918-fma-numerical-accuracy.test.lua
and test lj-918-fma-optimization.test.lua.
The same questions as above.+-- XXX: The JIT consistency is checked in the +-- <lj-918-fma-numerical-accuracy-jit.test.lua>. +-- See also: https://github.com/LuaJIT/LuaJIT/issues/918. +local test = tap.test('lj-918-fma-numerical-accuracy') + +test:plan(2) + +local _2pow52 = 2 ^ 52 + +-- IEEE754 components to double: +-- sign * (2 ^ (exp - 1023)) * (mantissa / _2pow52 + normal). +local a = 1 * (2 ^ (1083 - 1023)) * (4080546448249347 / _2pow52 + 1) +assert(a == 2197541395358679800) + +local b = -1 * (2 ^ (1052 - 1023)) * (3927497732209973 / _2pow52 + 1) +assert(b == -1005065126.3690554)
why `needed` and not something like "flag"?+ +-- These tests fail on ARM64 before the patch or with FMA +-- optimization enabled. +-- The first test may not fail if the compiler doesn't generate +-- an ARM64 FMA operation in `lj_vm_foldarith()`. +test:is(2197541395358679800 % -1005065126.3690554, -606337536, + 'FMA in the lj_vm_foldarith() during parsing') + +test:is(a % b, -606337536, 'FMA in the VM') + +test:done(true) diff --git a/test/tarantool-tests/lj-918-fma-optimization.test.lua b/test/tarantool-tests/lj-918-fma-optimization.test.lua new file mode 100644 index 00000000..af749eb5 --- /dev/null +++ b/test/tarantool-tests/lj-918-fma-optimization.test.lua @@ -0,0 +1,25 @@ +local tap = require('tap') +local test = tap.test('lj-918-fma-optimization'):skipcond({ + ['Test requires JIT enabled'] = not jit.status(), +}) + +test:plan(3) + +local function jit_opt_is_on(needed)
+ for _, opt in ipairs({jit.status()}) do + if opt == needed then + return true + end + end + return false +end + +test:ok(not jit_opt_is_on('fma'), 'FMA is disabled by default') + +local ok, _ = pcall(jit.opt.start, '+fma') + +test:ok(ok, 'fma flag is recognized') + +test:ok(jit_opt_is_on('fma'), 'FMA is enabled after jit.opt.start()') + +test:done(true)