Score:2

Barret reduction to get 64-bit remainder of a 128-bit number

ru flag

On github there's this code part of Microsoft's SEAL:

SEAL_ITERATE(iter(operand1, operand2, result), coeff_count, [&](auto I) {
    // Reduces z using base 2^64 Barrett reduction
    unsigned long long z[2], tmp1, tmp2[2], tmp3, carry;
    multiply_uint64(get<0>(I), get<1>(I), z);

    // Multiply input and const_ratio
    // Round 1
    multiply_uint64_hw64(z[0], const_ratio_0, &carry);
    multiply_uint64(z[0], const_ratio_1, tmp2);
    tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, &tmp1);

    // Round 2
    multiply_uint64(z[1], const_ratio_0, tmp2);
    carry = tmp2[1] + add_uint64(tmp1, tmp2[0], &tmp1);

    // This is all we care about
    tmp1 = z[1] * const_ratio_1 + tmp3 + carry;

    // Barrett subtraction
    tmp3 = z[0] - tmp1 * modulus_value;

    // Claim: One more subtraction is enough
    get<2>(I) = SEAL_COND_SELECT(tmp3 >= modulus_value, tmp3 - modulus_value, tmp3);
});

that is supposed to do Barrett Reduction, a technique for calculating modulus without division.

It looks like multiply_uint64_hw64 multiplies two 64-bit numbers and get only the 64 most significant bits. multiply_uint64 gets a 128 bit number in two 64 bit numbers. However, I don't understand what's being done, and most importantly, where does

$$a-\left\lfloor a\,s\right\rfloor\,n$$

happen. There's not even a floor function on this code.

fgrieu avatar
ng flag
The code multiplies inputs `get<0>(I)` and `get<1>(I)` into `z`, then reduces `z` modulo constant `modulus_value`$=n$, with the outcome `get<2>(I)`. The $s=1/n$ in the linked wiki on Barrett reduction is scaled to $r=\lfloor2^{128}/n\rfloor$, precomputed externally as `const_ratio_0` and `const_ratio_1` (low and high 64-bit limbs). If something remains mysterious, please pinpoint it; and preferably transcribe what you understand (including my hints) in the question, giving variables nice and consistent names with e.g. $r_0+2^{64}r_1=r$ for `const_ratio`, same for `z` and `tmp2`.
ru flag
@fgrieu I basically didn't understand the separation into most and least significant bits. How can multiplication work with upper and bottom parts? I understood how `const_ratio` is broken into 2 pieces but not how things are done after that
Maarten Bodewes avatar
in flag
Maybe I'm too simple here, but with integer operations you don't need a floor, as rounding down is the same as forgetting everything behind the comma.
Score:3
ng flag

The question's code computes the 64-bit get<2>(I)=$h:=f\,g\bmod n$ from inputs:

  • 64-bit modulus_value=$n$ with $n\in[2,\,2^{63}]$
  • 64-bit get<0>(I)=$f$ with $f\in[0,\,2^{64}-1]$
  • 64-bit get<1>(I)=$g$ with $g\in[0,\,2^{64}-1]$ and $f\,g<2^{64}\,n$, a condition that's met if $f,g\in[0,\,n-1]$ (which I guess is always the case in the application).

The result $h$ is the remainder of the Euclidean division of $f\,g$ by $n$. It is mathematically defined by $0\le h<m$ and $\exists q\in\mathbb Z,\ f\,g=q\cdot n+h$.

The code first computes the 128-bit z$=z:=f\,g$, then the result get<2>(I)$=h:=z\bmod n$ by Barrett reduction:

  • It was precomputed (externally to the question's code) const_ratio$=r=\left\lfloor2^{128}/n\right\rfloor$
  • $\hat q:=\left\lfloor z\,r/2^{128}\right\rfloor$, which is the correct $q$ within one by default (note: $\hat q$ is the final value of variable tmp1).
  • $\hat h:=z-q\cdot n$, which is the correct $h$ within possibly an excess of exactly $n$ (note: $\hat h$ is the final value of variable tmp3).
  • $h:=\hat h-n$ when $\hat h\ge n$, or $h:=\hat h$ otherwise.

The code uses primary school algorithms to perform multiple-digit arithmetic, transposed from base $10$ to base $2^{64}$ (with optimizations and a variant detailed in the next section). The equivalent of digits are so-called limbs, here 64-bit.

The product z$=z$ is expressed as two limbs z[0]$=z_0$ (low-order), z[1]$=z_1$ (high-order), thus with $z=z_0+2^{64}\,z_1$, and $z_0, z_1\in[0,\,2^{64}-1]$.

The Barrett multiplier const_ratio$=r=\left\lfloor2^{128}/n\right\rfloor$ is similarly expressed as two limbs const_ratio_0$=r_0$ and const_ratio_1$=r_1$, thanks to the low end of the interval in the precondition $n\in[2,\,2^{63}]$.

The intermediary product $z\,r$ always fits three limbs (even though in general the product of two quantities expressed as two limbs would require four limbs).

The tentative quotient by default $\hat q$ fits a single limb, since $\hat q\le q$ and $q<2^{64}$, with the latest insured by the input precondition $f\,g<2^{64}\,n$.

The tentative remainder $\hat h$ fits a single limb, since $\hat h<2n$ and $2n\le2^{64}$, with the later thanks to the high end of the interval in the precondition $n\in[2,\,2^{63}]$.


Detailing the code's algorithm (as asked in comment and bounty):

I'll use an illustration in decimal. In that base, since $2\le n\le10/2$, const_ratio$=r$ can only be $\left\lfloor100/2\right\rfloor=50$, $\left\lfloor100/3\right\rfloor=33$, $\left\lfloor100/4\right\rfloor=25$, $\left\lfloor100/5\right\rfloor=20$, but I'll pretend const_ratio$=r=29$ because that makes a more interesting example. For the same reason I'll use z$=z=34$, even though that can't be obtained as the product of two digits.

The product z is obtained in the code by multiply_uint64(get<0>(I), get<1>(I), z) as two limbs z[0] and z[1].

The meat of the computation is $\hat q:=\left\lfloor z\,r/2^{128}\right\rfloor$. That's the analog in base $2^{64}$ of $9:=\left\lfloor29\cdot34/100\right\rfloor$ in base 10. Both arguments $29$ and $34$ to the multiplication are two-digit, but small enough that their product $986$ is three-digit (rather than four), and we are only interested in the third digit from the right. The primary school algorithm to compute $986:=29\cdot34$ would be presented as

      2 9   const_ratio
   x  3 4   z
    -----
    1 1 6
+   8 7
  -------
    9 8 6

In the primary school algorithm there are four single-digit multiplications (which the code performs) and a few extra operations (that the code reorganizes slightly):

  • 4 times 9, 36; write 6, keep 3;
  • 4 times 2, 8; plus 3 (kept), 11; write that.
  • 3 times 9, 27; write 7, keep 2;
  • 3 times 2, 6; plus 2 (kept), 8; write that.

The first of these four multiplications occurs in the code fragment multiply_uint64_hw64(z[0], const_ratio_0, &carry), which multiplies the low-order limb of $r$ with the low-order limb of $z$, like we multiply the low-order digit 4 of 34 with the low-order digit 9 of 29. Notice that "write 6" is pointless in the circumstance, since whatever digit it writes will stay segregated to the right column of the computation without any opportunity to influence a leftmost digit, and ignored when we divide by 100 and round down (equivalently, keep only the third digit from the right). That's why the low-order 64-bit of the 128-bit product is not even computed, as noted in the question. The equivalent of 3 in 36 is kept in carry.

The second multiplication occurs in multiply_uint64(z[0], const_ratio_1, tmp2), which multiplies the high-order limb of $r$ with the low-order limb of $z$, with result in the two limbs of tmp2; the 64-bit tmp[0] receives the equivalent of 8 in 8, and tmp[1] receives the equivalent of 0 for (notice a leading 0 is suppressed in the conventional writing of decimal integers). The equivalent of 8 plus 3 occurs in add_uint64(tmp2[0], carry, &tmp1), with the low-order digit 1 of the result 11 in tmp1, and the new carry 1 in the output of that function. That's used as right operand in tmp3 = tmp2[1] + … (which happens to be skipped in the primary school algorithm with the particular example I took since the 0 was suppressed), yielding the equivalent of the left 1 in 116. [Note on the output of add_uint64: it's generated by static_cast<unsigned char>(*result < operand1), which compares *result and operand, then turns true to 1, false to 0. Made after *result = operand1 + operand2, that tells if this addition generated a carry. Some compilers recognize this idiom, use the C bit of the status word, and reuse C in the forthcoming addition].

The third multiplication occurs in multiply_uint64(z[1], const_ratio_0, tmp2), which multiplies the low-order limb of $r$ with the high-order limb of $z$, with result on to limbs in tmp2, like we do 3 x 9 = 27. This time we need both limbs/digits: the equivalent of 7 goes to tmp2[0] and the equivalent of 2 goes to tmp2[1]. Here it's made a variant of the primary school algorithm: it's immediately added tmp1 (the equivalent of the middle 1 in 116) to the low-order limb with add_uint64(tmp1, tmp2[0], &tmp1), performing the equivalent of 1 + 7 = 8, no carry. The result 8 is stored in tmp1 because the semantic of add_uint64 needs a destination, but it's really ignored, because we don't care for the middle digit in 986. The carry output by add_uint64 is used as right operand in carry = tmp2[1] + …, performing the equivalent of 1 + 0 = 1 in our example. Despite the name carry, that holds a full-blown 64-bit limb/digit.

The fourth multiplication occurs in z[1] * const_ratio_1, which multiplies the high-order limb of $r$ with the high-order limb of $z$, like we do 3 x 2 = 6. Here the context insures the result fits a single limb, thus the native C operator for multiplication can be used. The outcome is then used as the left operator of … + tmp3 + carry, performing the equivalent of 6 + 1 + 1 = 8. Again the context insures this values $\hat q$, stored in tmp1, fits a single limb/digit.

Then tmp3 = z[0] - tmp1 * modulus_value performs $\hat h:=z-q\cdot n$. The context insures the mathematically exact result fits a single limb/digit (stored in tmp3) even though $q\cdot n$ does not. This allows the use of the native C operators, which skip computing the high-order limb entirely.

Then SEAL_COND_SELECT(tmp3 >= modulus_value, tmp3 - modulus_value, tmp3) computes $h$ from $\hat h$ by conditionally subtracting $n$ when $\hat h\ge n$. The selection operator is hidden in a macro.

Two examples for base $2^{64}$ (values in big-endian hexadecimal with space inserted between limbs):

modulus_value                      000076513ae0b1cd
const_ratio       00000000000229e6 7f4ca82ba3a115f1
get<0>(I)                          00005f0fd669f2c7
get<1>(I)                          000041a1f91ef16f
z                 00000000185f2ae8 a455846cb7cf9b49
tmp1                               000034bb854f9a8d
tmp3                               00000fcebfd55b60
get<2>(I)                          00000fcebfd55b60

modulus_value                      686f4b7702a9c775
const_ratio       0000000000000002 7387d66ffd685b82
get<0>(I)                          536094611fa2b19b
get<1>(I)                          675ef5187093ff63
z                 21aac8fcf31d6421 62e675ba16d513f1
tmp1                               5287278703394bb1
tmp3                               72b1d3d2b9f5e50c
get<2>(I)                          0a42885bb74c1d97

Note: for $n\in[2^{63}+1,\,2^{64}-1]$, the quantity $\hat h$ can overflow one limb and the code as it stands fails. E.g. for input $f=g=2^{32}$, we get $z=2^{64}$ thus $\hat q=0$ (for any $n>2^{63}$), thus $\hat h=z=2^{64}$ and an output of $0$ rather than the true $h=2^{64}-n$. The full source has a comment "the Modulus class represents a non-negative integer modulus up to 61 bits", thus such issues for large $n$ occurs only when the calling code errs. Plus, if I understand correctly, $n=2^{60}-2^{14}+1$ is the main target.


Alternative: For something performing the same function as the question's code, in 4 short lines of code instead of 11, for all $n\in[1,\,2^{64}-1]$, needing no precomputation, possibly faster, but compatible only with recent x64 compilers+CPUs, see the first of these code snippets (the second is a small variant without the restriction $f\,g<2^{64}\,n$ ). I make no statement about constant-timeness of either code.

ru flag
Why does `add_uint64` always return $0$ or $1$? Isn't there bigger carry possibility?
fgrieu avatar
ng flag
@Guerlando OCs: `add_uint64` adds two limbs passed as first and second arguments, thus two quantities in $[0,\ 2^{64}-1]$. The mathematically exact result thus is in $[0,\ 2^{65}-2]$, thus fits one 64-bit limb and one bit. That's similar to the sum of two decimal digits fitting one digit and a carry 0 (e.g. 4+5=9) or 1 (e.g. 9+9=18).
mangohost

Post an answer

Most people don’t grasp that asking a lot of questions unlocks learning and improves interpersonal bonding. In Alison’s studies, for example, though people could accurately recall how many questions had been asked in their conversations, they didn’t intuit the link between questions and liking. Across four studies, in which participants were engaged in conversations themselves or read transcripts of others’ conversations, people tended not to realize that question asking would influence—or had influenced—the level of amity between the conversationalists.