Skip to content

Commit 8dc28d7

Browse files
committed
Optimise numeric multiplication using base-NBASE^2 arithmetic.
Currently mul_var() uses the schoolbook multiplication algorithm, which is O(n^2) in the number of NBASE digits. To improve performance for large inputs, convert the inputs to base NBASE^2 before multiplying, which effectively halves the number of digits in each input, theoretically speeding up the computation by a factor of 4. In practice, the actual speedup for large inputs varies between around 3 and 6 times, depending on the system and compiler used. In turn, this significantly reduces the runtime of the numeric_big regression test. For this to work, 64-bit integers are required for the products of base-NBASE^2 digits, so this works best on 64-bit machines, on which it is faster whenever the shorter input has more than 4 or 5 NBASE digits. On 32-bit machines, the additional overheads, especially during carry propagation and the final conversion back to base-NBASE, are significantly higher, and it is only faster when the shorter input has more than around 50 NBASE digits. When the shorter input has more than 6 NBASE digits (so that mul_var_short() cannot be used), but fewer than around 50 NBASE digits, there may be a noticeable slowdown on 32-bit machines. That seems to be an acceptable tradeoff, given the performance gains for other inputs, and the effort that would be required to maintain code specifically targeting 32-bit machines. Joel Jacobson and Dean Rasheed. Discussion: https://postgr.es/m/9d8a4a42-c354-41f3-bbf3-199e1957db97%40app.fastmail.com
1 parent c4e4422 commit 8dc28d7

File tree

1 file changed

+150
-74
lines changed

1 file changed

+150
-74
lines changed

src/backend/utils/adt/numeric.c

Lines changed: 150 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
101101
typedef int16 NumericDigit;
102102
#endif
103103

104+
#define NBASE_SQR (NBASE * NBASE)
105+
104106
/*
105107
* The Numeric type as stored on disk.
106108
*
@@ -8668,21 +8670,30 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
86688670
int rscale)
86698671
{
86708672
int res_ndigits;
8673+
int res_ndigitpairs;
86718674
int res_sign;
86728675
int res_weight;
8676+
int pair_offset;
86738677
int maxdigits;
8674-
int *dig;
8675-
int carry;
8676-
int maxdig;
8677-
int newdig;
8678+
int maxdigitpairs;
8679+
uint64 *dig,
8680+
*dig_i1_off;
8681+
uint64 maxdig;
8682+
uint64 carry;
8683+
uint64 newdig;
86788684
int var1ndigits;
86798685
int var2ndigits;
8686+
int var1ndigitpairs;
8687+
int var2ndigitpairs;
86808688
NumericDigit *var1digits;
86818689
NumericDigit *var2digits;
8690+
uint32 var1digitpair;
8691+
uint32 *var2digitpairs;
86828692
NumericDigit *res_digits;
86838693
int i,
86848694
i1,
8685-
i2;
8695+
i2,
8696+
i2limit;
86868697

86878698
/*
86888699
* Arrange for var1 to be the shorter of the two numbers. This improves
@@ -8723,137 +8734,202 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
87238734
return;
87248735
}
87258736

8726-
/* Determine result sign and (maximum possible) weight */
8737+
/* Determine result sign */
87278738
if (var1->sign == var2->sign)
87288739
res_sign = NUMERIC_POS;
87298740
else
87308741
res_sign = NUMERIC_NEG;
8731-
res_weight = var1->weight + var2->weight + 2;
87328742

87338743
/*
8734-
* Determine the number of result digits to compute. If the exact result
8735-
* would have more than rscale fractional digits, truncate the computation
8736-
* with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
8737-
* would only contribute to the right of that. (This will give the exact
8744+
* Determine the number of result digits to compute and the (maximum
8745+
* possible) result weight. If the exact result would have more than
8746+
* rscale fractional digits, truncate the computation with
8747+
* MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that would
8748+
* only contribute to the right of that. (This will give the exact
87388749
* rounded-to-rscale answer unless carries out of the ignored positions
87398750
* would have propagated through more than MUL_GUARD_DIGITS digits.)
87408751
*
87418752
* Note: an exact computation could not produce more than var1ndigits +
8742-
* var2ndigits digits, but we allocate one extra output digit in case
8743-
* rscale-driven rounding produces a carry out of the highest exact digit.
8753+
* var2ndigits digits, but we allocate at least one extra output digit in
8754+
* case rscale-driven rounding produces a carry out of the highest exact
8755+
* digit.
8756+
*
8757+
* The computation itself is done using base-NBASE^2 arithmetic, so we
8758+
* actually process the input digits in pairs, producing a base-NBASE^2
8759+
* intermediate result. This significantly improves performance, since
8760+
* schoolbook multiplication is O(N^2) in the number of input digits, and
8761+
* working in base NBASE^2 effectively halves "N".
8762+
*
8763+
* Note: in a truncated computation, we must compute at least one extra
8764+
* output digit to ensure that all the guard digits are fully computed.
87448765
*/
8745-
res_ndigits = var1ndigits + var2ndigits + 1;
8766+
/* digit pairs in each input */
8767+
var1ndigitpairs = (var1ndigits + 1) / 2;
8768+
var2ndigitpairs = (var2ndigits + 1) / 2;
8769+
8770+
/* digits in exact result */
8771+
res_ndigits = var1ndigits + var2ndigits;
8772+
8773+
/* digit pairs in exact result with at least one extra output digit */
8774+
res_ndigitpairs = res_ndigits / 2 + 1;
8775+
8776+
/* pair offset to align result to end of dig[] */
8777+
pair_offset = res_ndigitpairs - var1ndigitpairs - var2ndigitpairs + 1;
8778+
8779+
/* maximum possible result weight (odd-length inputs shifted up below) */
8780+
res_weight = var1->weight + var2->weight + 1 + 2 * res_ndigitpairs -
8781+
res_ndigits - (var1ndigits & 1) - (var2ndigits & 1);
8782+
8783+
/* rscale-based truncation with at least one extra output digit */
87468784
maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
87478785
MUL_GUARD_DIGITS;
8748-
res_ndigits = Min(res_ndigits, maxdigits);
8786+
maxdigitpairs = maxdigits / 2 + 1;
8787+
8788+
res_ndigitpairs = Min(res_ndigitpairs, maxdigitpairs);
8789+
res_ndigits = 2 * res_ndigitpairs;
87498790

8750-
if (res_ndigits < 3)
8791+
/*
8792+
* In the computation below, digit pair i1 of var1 and digit pair i2 of
8793+
* var2 are multiplied and added to digit i1+i2+pair_offset of dig[]. Thus
8794+
* input digit pairs with index >= res_ndigitpairs - pair_offset don't
8795+
* contribute to the result, and can be ignored.
8796+
*/
8797+
if (res_ndigitpairs <= pair_offset)
87518798
{
87528799
/* All input digits will be ignored; so result is zero */
87538800
zero_var(result);
87548801
result->dscale = rscale;
87558802
return;
87568803
}
8804+
var1ndigitpairs = Min(var1ndigitpairs, res_ndigitpairs - pair_offset);
8805+
var2ndigitpairs = Min(var2ndigitpairs, res_ndigitpairs - pair_offset);
87578806

87588807
/*
8759-
* We do the arithmetic in an array "dig[]" of signed int's. Since
8760-
* INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
8761-
* to avoid normalizing carries immediately.
8808+
* We do the arithmetic in an array "dig[]" of unsigned 64-bit integers.
8809+
* Since PG_UINT64_MAX is much larger than NBASE^4, this gives us a lot of
8810+
* headroom to avoid normalizing carries immediately.
87628811
*
87638812
* maxdig tracks the maximum possible value of any dig[] entry; when this
8764-
* threatens to exceed INT_MAX, we take the time to propagate carries.
8765-
* Furthermore, we need to ensure that overflow doesn't occur during the
8766-
* carry propagation passes either. The carry values could be as much as
8767-
* INT_MAX/NBASE, so really we must normalize when digits threaten to
8768-
* exceed INT_MAX - INT_MAX/NBASE.
8813+
* threatens to exceed PG_UINT64_MAX, we take the time to propagate
8814+
* carries. Furthermore, we need to ensure that overflow doesn't occur
8815+
* during the carry propagation passes either. The carry values could be
8816+
* as much as PG_UINT64_MAX / NBASE^2, so really we must normalize when
8817+
* digits threaten to exceed PG_UINT64_MAX - PG_UINT64_MAX / NBASE^2.
87698818
*
8770-
* To avoid overflow in maxdig itself, it actually represents the max
8771-
* possible value divided by NBASE-1, ie, at the top of the loop it is
8772-
* known that no dig[] entry exceeds maxdig * (NBASE-1).
8819+
* To avoid overflow in maxdig itself, it actually represents the maximum
8820+
* possible value divided by NBASE^2-1, i.e., at the top of the loop it is
8821+
* known that no dig[] entry exceeds maxdig * (NBASE^2-1).
8822+
*
8823+
* The conversion of var1 to base NBASE^2 is done on the fly, as each new
8824+
* digit is required. The digits of var2 are converted upfront, and
8825+
* stored at the end of dig[]. To avoid loss of precision, the input
8826+
* digits are aligned with the start of digit pair array, effectively
8827+
* shifting them up (multiplying by NBASE) if the inputs have an odd
8828+
* number of NBASE digits.
87738829
*/
8774-
dig = (int *) palloc0(res_ndigits * sizeof(int));
8775-
maxdig = 0;
8830+
dig = (uint64 *) palloc(res_ndigitpairs * sizeof(uint64) +
8831+
var2ndigitpairs * sizeof(uint32));
8832+
8833+
/* convert var2 to base NBASE^2, shifting up if its length is odd */
8834+
var2digitpairs = (uint32 *) (dig + res_ndigitpairs);
8835+
8836+
for (i2 = 0; i2 < var2ndigitpairs - 1; i2++)
8837+
var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
8838+
8839+
if (2 * i2 + 1 < var2ndigits)
8840+
var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
8841+
else
8842+
var2digitpairs[i2] = var2digits[2 * i2] * NBASE;
87768843

87778844
/*
8778-
* The least significant digits of var1 should be ignored if they don't
8779-
* contribute directly to the first res_ndigits digits of the result that
8780-
* we are computing.
8845+
* Start by multiplying var2 by the least significant contributing digit
8846+
* pair from var1, storing the results at the end of dig[], and filling
8847+
* the leading digits with zeros.
87818848
*
8782-
* Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
8783-
* i1+i2+2 of the accumulator array, so we need only consider digits of
8784-
* var1 for which i1 <= res_ndigits - 3.
8849+
* The loop here is the same as the inner loop below, except that we set
8850+
* the results in dig[], rather than adding to them. This is the
8851+
* performance bottleneck for multiplication, so we want to keep it simple
8852+
* enough so that it can be auto-vectorized. Accordingly, process the
8853+
* digits left-to-right even though schoolbook multiplication would
8854+
* suggest right-to-left. Since we aren't propagating carries in this
8855+
* loop, the order does not matter.
8856+
*/
8857+
i1 = var1ndigitpairs - 1;
8858+
if (2 * i1 + 1 < var1ndigits)
8859+
var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
8860+
else
8861+
var1digitpair = var1digits[2 * i1] * NBASE;
8862+
maxdig = var1digitpair;
8863+
8864+
i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
8865+
dig_i1_off = &dig[i1 + pair_offset];
8866+
8867+
memset(dig, 0, (i1 + pair_offset) * sizeof(uint64));
8868+
for (i2 = 0; i2 < i2limit; i2++)
8869+
dig_i1_off[i2] = (uint64) var1digitpair * var2digitpairs[i2];
8870+
8871+
/*
8872+
* Next, multiply var2 by the remaining digit pairs from var1, adding the
8873+
* results to dig[] at the appropriate offsets, and normalizing whenever
8874+
* there is a risk of any dig[] entry overflowing.
87858875
*/
8786-
for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
8876+
for (i1 = i1 - 1; i1 >= 0; i1--)
87878877
{
8788-
NumericDigit var1digit = var1digits[i1];
8789-
8790-
if (var1digit == 0)
8878+
var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
8879+
if (var1digitpair == 0)
87918880
continue;
87928881

87938882
/* Time to normalize? */
8794-
maxdig += var1digit;
8795-
if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
8883+
maxdig += var1digitpair;
8884+
if (maxdig > (PG_UINT64_MAX - PG_UINT64_MAX / NBASE_SQR) / (NBASE_SQR - 1))
87968885
{
8797-
/* Yes, do it */
8886+
/* Yes, do it (to base NBASE^2) */
87988887
carry = 0;
8799-
for (i = res_ndigits - 1; i >= 0; i--)
8888+
for (i = res_ndigitpairs - 1; i >= 0; i--)
88008889
{
88018890
newdig = dig[i] + carry;
8802-
if (newdig >= NBASE)
8891+
if (newdig >= NBASE_SQR)
88038892
{
8804-
carry = newdig / NBASE;
8805-
newdig -= carry * NBASE;
8893+
carry = newdig / NBASE_SQR;
8894+
newdig -= carry * NBASE_SQR;
88068895
}
88078896
else
88088897
carry = 0;
88098898
dig[i] = newdig;
88108899
}
88118900
Assert(carry == 0);
88128901
/* Reset maxdig to indicate new worst-case */
8813-
maxdig = 1 + var1digit;
8902+
maxdig = 1 + var1digitpair;
88148903
}
88158904

8816-
/*
8817-
* Add the appropriate multiple of var2 into the accumulator.
8818-
*
8819-
* As above, digits of var2 can be ignored if they don't contribute,
8820-
* so we only include digits for which i1+i2+2 < res_ndigits.
8821-
*
8822-
* This inner loop is the performance bottleneck for multiplication,
8823-
* so we want to keep it simple enough so that it can be
8824-
* auto-vectorized. Accordingly, process the digits left-to-right
8825-
* even though schoolbook multiplication would suggest right-to-left.
8826-
* Since we aren't propagating carries in this loop, the order does
8827-
* not matter.
8828-
*/
8829-
{
8830-
int i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
8831-
int *dig_i1_2 = &dig[i1 + 2];
8905+
/* Multiply and add */
8906+
i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
8907+
dig_i1_off = &dig[i1 + pair_offset];
88328908

8833-
for (i2 = 0; i2 < i2limit; i2++)
8834-
dig_i1_2[i2] += var1digit * var2digits[i2];
8835-
}
8909+
for (i2 = 0; i2 < i2limit; i2++)
8910+
dig_i1_off[i2] += (uint64) var1digitpair * var2digitpairs[i2];
88368911
}
88378912

88388913
/*
8839-
* Now we do a final carry propagation pass to normalize the result, which
8840-
* we combine with storing the result digits into the output. Note that
8841-
* this is still done at full precision w/guard digits.
8914+
* Now we do a final carry propagation pass to normalize back to base
8915+
* NBASE^2, and construct the base-NBASE result digits. Note that this is
8916+
* still done at full precision w/guard digits.
88428917
*/
88438918
alloc_var(result, res_ndigits);
88448919
res_digits = result->digits;
88458920
carry = 0;
8846-
for (i = res_ndigits - 1; i >= 0; i--)
8921+
for (i = res_ndigitpairs - 1; i >= 0; i--)
88478922
{
88488923
newdig = dig[i] + carry;
8849-
if (newdig >= NBASE)
8924+
if (newdig >= NBASE_SQR)
88508925
{
8851-
carry = newdig / NBASE;
8852-
newdig -= carry * NBASE;
8926+
carry = newdig / NBASE_SQR;
8927+
newdig -= carry * NBASE_SQR;
88538928
}
88548929
else
88558930
carry = 0;
8856-
res_digits[i] = newdig;
8931+
res_digits[2 * i + 1] = (NumericDigit) ((uint32) newdig % NBASE);
8932+
res_digits[2 * i] = (NumericDigit) ((uint32) newdig / NBASE);
88578933
}
88588934
Assert(carry == 0);
88598935

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy