There is an existing question "Average of 3 long integers" that is specifically concerned with the efficient computation of the average of three signed integers.
The use of unsigned integers however allows for additional optimizations not applicable to the scenario covered in the previous question. This question is about the efficient computation of the average of three unsigned integers, where the average is rounded towards zero, i.e. in mathematical terms I want to compute ⌊ (a + b + c) / 3 ⌋.
A straightforward way to compute this average is
avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
To first order, modern optimizing compilers will transform the divisions into multiplications with a reciprocal plus a shift, and the modulo operations into a back-multiply and a subtraction, where the back-multiply may use a scale_add idiom available on many architectures, e.g. lea on x86_64, add with lsl #n on ARM, iscadd on NVIDIA GPUs.
In trying to optimize the above in a generic fashion suitable for many common platforms, I observe that typically the cost of integer operations is in the relationship logical ≤ (add | sub) ≤ shift ≤ scale_add ≤ mul. Cost here refers to all of latency, throughput limitations, and power consumption. Any such differences become more pronounced when the integer type processed is wider than the native register width, e.g. when processing uint64_t data on a 32-bit processor.
My optimization strategy was therefore to minimize instruction count and replace "expensive" with "cheap" operations where possible, while not increasing register pressure and retaining exploitable parallelism for wide out-of-order processors.
The first observation is that we can reduce a sum of three operands into a sum of two operands by first applying a CSA (carry save adder) that produces a sum value and a carry value, where the carry value has twice the weight of the sum value. The cost of a software-based CSA is five logicals on most processors. Some processors, like NVIDIA GPUs, have a LOP3 instruction that can compute an arbitrary logical expression of three operands in one fell swoop, in which case CSA condenses to two LOP3s (note: I have yet convince the CUDA compiler to emit those two LOP3s; it currently produces four LOP3s!).
The second observation is that because we are computing the modulo of division by 3, we don't need a back-multiply to compute it. We can instead use dividend % 3 = ((dividend / 3) + dividend) & 3, reducing the modulo to an add plus a logical since we already have the division result. This is an instance of the general algorithm: dividend % (2n-1) = ((dividend / (2n-1) + dividend) & (2n-1).
Finally for the division by 3 in the correction term (a % 3 + b % 3 + c % 3) / 3 we don't need the code for generic division by 3. Since the dividend is very small, in [0, 6], we can simplify x / 3 into (3 * x) / 8 which requires just a scale_add plus a shift.
The code below shows my current work-in-progress. Using Compiler Explorer to check the code generated for various platforms shows the tight code I would expect (when compiled with -O3).
However, in timing the code on my Ivy Bridge x86_64 machine using the Intel 13.x compiler, a flaw became apparent: while my code improves latency (from 18 cycles to 15 cycles for uint64_t data) compared to the simple version, throughput worsens (from one result every 6.8 cycles to one result every 8.5 cycles for uint64_t data). Looking at the assembly code more closely it is quite apparent why that is: I basically managed to take the code down from roughly three-way parallelism to roughly two-way parallelism.
Is there a generically applicable optimization technique, beneficial on common processors in particular all flavors of x86 and ARM as well as GPUs, that preserves more parallelism? Alternatively, is there an optimization technique that further reduces overall operation count to make up for reduced parallelism? The computation of the correction term (tail in the code below) seems like a good target. The simplification (carry_mod_3 + sum_mod_3) / 2 looked enticing but delivers an incorrect result for one of the nine possible combinations.
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#define BENCHMARK (1)
#define SIMPLE_COMPUTATION (0)
#if BENCHMARK
#define T uint64_t
#else // !BENCHMARK
#define T uint8_t
#endif // BENCHMARK
T average_of_3 (T a, T b, T c)
{
T avg;
#if SIMPLE_COMPUTATION
avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
#else // !SIMPLE_COMPUTATION
/* carry save adder */
T a_xor_b = a ^ b;
T sum = a_xor_b ^ c;
T carry = (a_xor_b & c) | (a & b);
/* here 2 * carry + sum = a + b + c */
T sum_div_3 = (sum / 3); // {MUL|MULHI}, SHR
T sum_mod_3 = (sum + sum_div_3) & 3; // ADD, AND
if (sizeof (size_t) == sizeof (T)) { // "native precision" (well, not always)
T two_carry_div_3 = (carry / 3) * 2; // MULHI, ANDN
T two_carry_mod_3 = (2 * carry + two_carry_div_3) & 6; // SCALE_ADD, AND
T head = two_carry_div_3 + sum_div_3; // ADD
T tail = (3 * (two_carry_mod_3 + sum_mod_3)) / 8; // ADD, SCALE_ADD, SHR
avg = head + tail; // ADD
} else {
T carry_div_3 = (carry / 3); // MUL, SHR
T carry_mod_3 = (carry + carry_div_3) & 3; // ADD, AND
T head = (2 * carry_div_3 + sum_div_3); // SCALE_ADD
T tail = (3 * (2 * carry_mod_3 + sum_mod_3)) / 8; // SCALE_ADD, SCALE_ADD, SHR
avg = head + tail; // ADD
}
#endif // SIMPLE_COMPUTATION
return avg;
}
#if !BENCHMARK
/* Test correctness on 8-bit data exhaustively. Should catch most errors */
int main (void)
{
T a, b, c, res, ref;
a = 0;
do {
b = 0;
do {
c = 0;
do {
res = average_of_3 (a, b, c);
ref = ((uint64_t)a + (uint64_t)b + (uint64_t)c) / 3;
if (res != ref) {
printf ("a=%08x b=%08x c=%08x res=%08x ref=%08x\n",
a, b, c, res, ref);
return EXIT_FAILURE;
}
c++;
} while (c);
b++;
} while (b);
a++;
} while (a);
return EXIT_SUCCESS;
}
#else // BENCHMARK
#include <math.h>
// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
LARGE_INTEGER t;
static double oofreq;
static int checkedForHighResTimer;
static BOOL hasHighResTimer;
if (!checkedForHighResTimer) {
hasHighResTimer = QueryPerformanceFrequency (&t);
oofreq = 1.0 / (double)t.QuadPart;
checkedForHighResTimer = 1;
}
if (hasHighResTimer) {
QueryPerformanceCounter (&t);
return (double)t.QuadPart * oofreq;
} else {
return (double)GetTickCount() * 1.0e-3;
}
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
struct timeval tv;
gettimeofday(&tv, NULL);
return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif
#define N (3000000)
int main (void)
{
double start, stop, elapsed = INFINITY;
int i, k;
T a, b;
T avg0 = 0xffffffff, avg1 = 0xfffffffe;
T avg2 = 0xfffffffd, avg3 = 0xfffffffc;
T avg4 = 0xfffffffb, avg5 = 0xfffffffa;
T avg6 = 0xfffffff9, avg7 = 0xfffffff8;
T avg8 = 0xfffffff7, avg9 = 0xfffffff6;
T avg10 = 0xfffffff5, avg11 = 0xfffffff4;
T avg12 = 0xfffffff2, avg13 = 0xfffffff2;
T avg14 = 0xfffffff1, avg15 = 0xfffffff0;
a = 0x31415926;
b = 0x27182818;
avg0 = average_of_3 (a, b, avg0);
for (k = 0; k < 5; k++) {
start = second();
for (i = 0; i < N; i++) {
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
b = (b + avg0) ^ a;
a = (a ^ b) + avg0;
}
stop = second();
elapsed = fmin (stop - start, elapsed);
}
printf ("a=%016llx b=%016llx avg=%016llx",
(uint64_t)a, (uint64_t)b, (uint64_t)avg0);
printf ("\rlatency: each average_of_3() took %.6e seconds\n",
elapsed / 16 / N);
a = 0x31415926;
b = 0x27182818;
avg0 = average_of_3 (a, b, avg0);
for (k = 0; k < 5; k++) {
start = second();
for (i = 0; i < N; i++) {
avg0 = average_of_3 (a, b, avg0);
avg1 = average_of_3 (a, b, avg1);
avg2 = average_of_3 (a, b, avg2);
avg3 = average_of_3 (a, b, avg3);
avg4 = average_of_3 (a, b, avg4);
avg5 = average_of_3 (a, b, avg5);
avg6 = average_of_3 (a, b, avg6);
avg7 = average_of_3 (a, b, avg7);
avg8 = average_of_3 (a, b, avg8);
avg9 = average_of_3 (a, b, avg9);
avg10 = average_of_3 (a, b, avg10);
avg11 = average_of_3 (a, b, avg11);
avg12 = average_of_3 (a, b, avg12);
avg13 = average_of_3 (a, b, avg13);
avg14 = average_of_3 (a, b, avg14);
avg15 = average_of_3 (a, b, avg15);
b = (b + avg0) ^ a;
a = (a ^ b) + avg0;
}
stop = second();
elapsed = fmin (stop - start, elapsed);
}
printf ("a=%016llx b=%016llx avg=%016llx", (uint64_t)a, (uint64_t)b,
(uint64_t)(avg0 + avg1 + avg2 + avg3 + avg4 + avg5 + avg6 + avg7 +
avg8 + avg9 +avg10 +avg11 +avg12 +avg13 +avg14 +avg15));
printf ("\rthroughput: each average_of_3() took %.6e seconds\n",
elapsed / 16 / N);
return EXIT_SUCCESS;
}
#endif // BENCHMARK
Let me throw my hat in the ring. Not doing anything too tricky here, I
think.
#include <stdint.h>
uint64_t average_of_three(uint64_t a, uint64_t b, uint64_t c) {
uint64_t hi = (a >> 32) + (b >> 32) + (c >> 32);
uint64_t lo = hi + (a & 0xffffffff) + (b & 0xffffffff) + (c & 0xffffffff);
return 0x55555555 * hi + lo / 3;
}
Following discussion below about different splits, here's a version that saves a multiply at the expense of three bitwise-ANDs:
T hi = (a >> 2) + (b >> 2) + (c >> 2);
T lo = (a & 3) + (b & 3) + (c & 3);
avg = hi + (hi + lo) / 3;
I'm not sure if it fits your requirements, but maybe it works to just calculate the result and then fixup the error from the overflow:
T average_of_3 (T a, T b, T c)
{
T r = ((T) (a + b + c)) / 3;
T o = (a > (T) ~b) + ((T) (a + b) > (T) (~c));
if (o) r += ((T) 0x5555555555555555) << (o - 1);
T rem = ((T) (a + b + c)) % 3;
if (rem >= (3 - o)) ++r;
return r;
}
[EDIT] Here is the best branch-and-compare-less version I can come up with. On my machine, this version actually has slightly higher throughput than njuffa's code. __builtin_add_overflow(x, y, r) is supported by gcc and clang and returns 1 if the sum x + y overflows the type of *r and 0 otherwise, so the calculation of o is equivalent to the portable code in the first version, but at least gcc produces better code with the builtin.
T average_of_3 (T a, T b, T c)
{
T r = ((T) (a + b + c)) / 3;
T rem = ((T) (a + b + c)) % 3;
T dummy;
T o = __builtin_add_overflow(a, b, &dummy) + __builtin_add_overflow((T) (a + b), c, &dummy);
r += -((o - 1) & 0xaaaaaaaaaaaaaaab) ^ 0x5555555555555555;
r += (rem + o + 1) >> 2;
return r;
}
I answered the question you linked to already, so I am only answering the part that is different about this one: performance.
If you really cared about performance, then the answer is:
( a + b + c ) / 3
Since you cared about performance, you should have an intuition about the size of data you are working with. You should not have worried about overflow on addition (multiplication is another matter) of only 3 values, because if your data is already big enough to use the high bits of your chosen data type, you are in danger of overflow anyway and should have used a larger integer type. If you are overflowing on uint64_t, then you should really ask yourself why exactly do you need to count accurately up to 18 quintillion, and perhaps consider using float or double.
Now, having said all that, I will give you my actual reply: It doesn't matter. The question doesn't come up in real life and when it does, perf doesn't matter.
It could be a real performance question if you are doing it a million times in SIMD, because there, you are really incentivized to use integers of smaller width and you may need that last bit of headroom, but that wasn't your question.
New answer, new idea. This one's based on the mathematical identity
floor((a+b+c)/3) = floor(x + (a+b+c - 3x)/3)
When does this work with machine integers and unsigned division?
When the difference doesn't wrap, i.e. 0 ≤ a+b+c - 3x ≤ T_MAX.
This definition of x is fast and gets the job done.
T avg3(T a, T b, T c) {
T x = (a >> 2) + (b >> 2) + (c >> 2);
return x + (a + b + c - 3 * x) / 3;
}
Weirdly, ICC inserts an extra neg unless I do this:
T avg3(T a, T b, T c) {
T x = (a >> 2) + (b >> 2) + (c >> 2);
return x + (a + b + c - (x + x * 2)) / 3;
}
Note that T must be at least five bits wide.
If T is two platform words long, then you can save some double word operations by omitting the low word of x.
Alternative version with worse latency but maybe slightly higher throughput?
T lo = a + b;
T hi = lo < b;
lo += c;
hi += lo < c;
T x = (hi << (sizeof(T) * CHAR_BIT - 2)) + (lo >> 2);
avg = x + (T)(lo - 3 * x) / 3;
I suspect SIMPLE is defeating the throughput benchmark by CSEing and hoisting a/3+b/3 and a%3+b%3 out of the loop, reusing those results for all 16 avg0..15 results.
(The SIMPLE version can hoist much more of the work than the tricky version; really just a ^ b and a & b in that version.)
Forcing the function to not inline introduces more front-end overhead, but does make your version win, as we expect it should on a CPU with deep out-of-order execution buffers to overlap independent work. There's lots of ILP to find across iterations, for the throughput benchmark. (I didn't look closely at the asm for the non-inline version.)
https://godbolt.org/z/j95qn3 (using __attribute__((noinline)) with clang -O3 -march=skylake on Godbolt's SKX CPUs) shows 2.58 nanosec throughput for the simple way, 2.48 nanosec throughput for your way. vs. 1.17 nanosec throughput with inlining for the simple version.
-march=skylake allows mulx for more flexible full-multiply, but otherwise no benefit from BMI2. andn isn't used; the line you commented with mulhi / andn is mulx into RCX / and rcx, -2 which only requires a sign-extended immediate.
Another way to do this without forcing call/ret overhead would be inline asm like in Preventing compiler optimizations while benchmarking (Chandler Carruth's CppCon talk has some example of how he uses a couple wrappers), or Google Benchmark's benchmark::DoNotOptimize.
Specifically, GNU C asm("" : "+r"(a), "+r"(b)) between each avgX = average_of_3 (a, b, avgX); statement will make the compiler forget everything it knows about the values of a and b, while keeping them in registers.
My answer on I don't understand the definition of DoNotOptimizeAway goes into more detail about using a read-only "r" register constraint to force the compiler to materialize a result in a register, vs. "+r" to make it assume the value has been modified.
If you understand GNU C inline asm well, it may be easier to roll your own in ways that you know exactly what they do.
[Falk Hüffner points out in comments that this answer has similarities to his answer . Looking at his code more closely belatedly, I do find some similarities. However what I posted here is product of an independent thought process, a continuation of my original idea "reduce three items to two prior to div-mod". I understood Hüffner's approach to be different: "naive computation followed by corrections".]
I have found a better way than the CSA-technique in my question to reduce the division and modulo work from three operands to two operands. First, form the full double-word sum, then apply the division and modulo by 3 to each of the halves separately, finally combine the results. Since the most significant half can only take the values 0, 1, or 2, computing the quotient and remainder of division by three is trivial. Also, the combination into the final result becomes simpler.
Compared to the non-simple code variant from the question this achieves speedup on all platforms I examined. The quality of the code generated by compilers for the simulated double-word addition varies but is satisfactory overall. Nonetheless it may be worthwhile to code this portion in a non-portable way, e.g. with inline assembly.
T average_of_3_hilo (T a, T b, T c)
{
const T fives = (((T)(~(T)0)) / 3); // 0x5555...
T avg, hi, lo, lo_div_3, lo_mod_3, hi_div_3, hi_mod_3;
/* compute the full sum a + b + c into the operand pair hi:lo */
lo = a + b;
hi = lo < a;
lo = c + lo;
hi = hi + (lo < c);
/* determine quotient and remainder of each half separately */
lo_div_3 = lo / 3;
lo_mod_3 = (lo + lo_div_3) & 3;
hi_div_3 = hi * fives;
hi_mod_3 = hi;
/* combine partial results into the division result for the full sum */
avg = lo_div_3 + hi_div_3 + ((lo_mod_3 + hi_mod_3 + 1) / 4);
return avg;
}
An experimental build of GCC-11 compiles the obvious naive function to something like:
uint32_t avg3t (uint32_t a, uint32_t b, uint32_t c) {
a += b;
b = a < b;
a += c;
b += a < c;
b = b + a;
b += b < a;
return (a - (b % 3)) * 0xaaaaaaab;
}
Which is similar to some of the other answers posted here.
Any explanation of how these solutions work would be welcome
(not sure of the netiquette here).
I've been asked to make a short program that gets 3 numbers, and print the middle one
i.e. 3 2 0 -> "middle number is 2"
another example : 4 4 5 -> middle number is 4 (the user may enter the same number twice
it would be pretty easy with the if function and so but
I'm not allowed to use conditionals or loops
edit: cant use arrays
I've tried to divide numbers by each other to get a 0 which indicates one is bigger or smaller
As mentioned in the almost duplicate question and in previous answer, one way is to sum the three values and then to substract the min and max values.
However, using min and max function of the library is cheating a little bit, as it is likely that internally conditionals are used.
A way forward is to use the relations:
min(a, b) = (a + b - abs(a-b))/2;
max(a, b) = (a + b + abs(a-b))/2;
etc.
It was mentioned in a comment that using abs is cheating also.
In practice, I don't know how it is implemented internally.
One possibility is to used the tricks mentioned in this post: get absolute values....
For example, abs(a) = sqrt (a*a);
Use these two functions:
int maximum(int a, int b, int c)
{
// initialize max with a
int max = a;
// set max to b if and only if max is less than b
(max < b) && (max = b); // these are not conditional statements.
// set max to c if and only if max is less than c
(max < c) && (max = c); // these are just Boolean expressions.
return max;
}
int minimum(int a, int b, int c)
{
// initialize min with a
int min = a;
// set min to b if and only if min is more than b
(min > b) && (min = b);
// set min to c if and only if min is more than c
(min > c) && (min = c);
return min;
}
Then the mid-value can be calculated easily using:
(a+b+c)-minimum(a, b, c)-maximum(a, b, c)
Edit:
To run the code without defining any function, we can use the following. It is exactly same as above, but does all the works inside main function.
int main() {
int a=3, b=2, c=0, min, max;
// initialize min with a
min = a;
// set min to b if and only if min is more than b
(min > b) && (min = b);
// set min to c if and only if min is more than c
(min > c) && (min = c);
// similarly find max
max = a;
(max < b) && (max = b);
(max < c) && (max = c);
printf("%d", a+b+c-min-max);
return 0;
}
Use wider than int math to handle range limitations.
Find absolute value by division trick.1
((3 * d) / (3 * d + 1)) --> 0 when d >= 0
((3 * d) / (3 * d + 1)) --> -1 when d < 0
No conditions used.
long long my_abs_diff(long long a, long long b) {
long long d = 0LL + a - b;
return d * (1 - 2 * ((3 * d) / (3 * d + 1)));
}
long long my_min(long long a, long long b) {
return ((a + b) - my_abs_diff(a, b)) / 2;
}
long long my_max(long long a, long long b) {
return ((a + b) + my_abs_diff(a, b)) / 2;
}
int median3(int a, int b, int c) {
long long mn = my_min(my_min(a, b), c);
long long mx = my_max(my_max(a, b), c);
return (int) (-mn - mx + a + b + c);
}
Simplification may exist.
1In C, since C99, integer division is defined as truncating towards zero. It is this directed-ness towards zero code is exploiting to effect an "if". A cost to this approach is that we need a few more bits of integer range. For pre-C99, code can use div() to insure the desired quotient truncation.
Tested successfully with many edge cases and 10,000,00,00 random combinations over entire int range: [INT_MIN ... INT_MAX].
Hint : first find the min and max values by using std::min and std::max
Then sum all the numbers and subtract the min and max values.
What you are left with is the middle value.
Is there any way to find nth root of the number without any external library in C? I'm working on a bare metal code so there is no OS. Also, no complete C is there.
You can write a program like this for nth root. This program is for square root.
int floorSqrt(int x)
{
// Base cases
if (x == 0 || x == 1)
return x;
// Staring from 1, try all numbers until
// i*i is greater than or equal to x.
int i = 1, result = 1;
while (result < x)
{
if (result == x)
return result;
i++;
result = i*i;
}
return i-1;
}
You can use the same approach for nth root.
Here there is a C implementation of the the nth root algorithm you can find in wikipedia. It needs an exponentiation algorithm, so I also include an implementation of a basic method for exponentiation by squaring that you can find also find in wikipedia.
double npower(double const base, int const n)
{
if (n < 0) return npower(1/base, -n)
else if (n == 0) return 1.0;
else if (n == 1) return base;
else if (n % 2) return base*npower(base*base, n/2);
else return npower(base*base, n/2);
}
double nroot(double const base, int const n)
{
if (n == 1) return base;
else if (n <= 0 || base < 0) return NAN;
else {
double delta, x = base/n;
do {
delta = (base/npower(x,n-1)-x)/n;
x += delta;
} while (fabs(delta) >= 1e-8);
return x;
}
}
Some comments on this:
The nth root algorithm in wikipedia leaves freedom for the initial guess. In this example I set it up to be base/n, but this was just a guess.
The macro NAN is usually defined in <math.h>, so you would need to define it to be suitable for your needs.
Both functions are implemented in a very rough and simple way, and their performance can be greatly improved with careful thought.
The tolerance in this example is set to 1e-8 and should be changed to something different. It should probably be proportional to the value of the base.
You can try the nth_root C function :
// return a number that, when multiplied by itself nth times, makes N.
unsigned nth_root(const unsigned n, const unsigned nth) {
unsigned a = n, b, c, r = nth ? n + (n > 1) : n == 1 ;
for (; a < r; b = a + (nth - 1) * r, a = b / nth)
for (r = a, a = n, c = nth - 1; c && (a /= r); --c);
return r;
}
Source
I'm implementing an algorithm in C that needs to do modular addition and subtraction quickly on unsigned integers and can handle overflow conditions correctly. Here's what I have now (which does work):
/* a and/or b may be greater than m */
uint32_t modadd_32(uint32_t a, uint32_t b, uint32_t m) {
uint32_t tmp;
if (b <= UINT32_MAX - a)
return (a + b) % m;
if (m <= (UINT32_MAX>>1))
return ((a % m) + (b % m)) % m;
tmp = a + b;
if (tmp > (uint32_t)(m * 2)) // m*2 must be truncated before compare
tmp -= m;
tmp -= m;
return tmp % m;
}
/* a and/or b may be greater than m */
uint32_t modsub_32(uint32_t a, uint32_t b, uint32_t m) {
uint32_t tmp;
if (a >= b)
return (a - b) % m;
tmp = (m - ((b - a) % m)); /* results in m when 0 is needed */
if (tmp == m)
return 0;
return tmp;
}
Anybody know of a better algorithm? The libraries I've found that do modular arithmetic all seem to be for large arbitrary precision numbers which is way overkill.
Edit: I want this to run well on a 32 bit machine. Also, my existing functions are trivially converted to work on other sizes of unsigned integers, a property which would be nice to retain.
Modular operations usually assume that a and b are less than m. This allows simpler algorithms:
umod_t sub_mod(umod_t a, umod_t b, umod_t m)
{
if ( a>=b )
return a - b;
else
return m - b + a;
}
umod_t add_mod(umod_t a, umod_t b, umod_t m)
{
if ( 0==b ) return a;
// return sub_mod(a, m-b, m);
b = m - b;
if ( a>=b )
return a - b;
else
return m - b + a;
}
Source: Matters Computational, chapter 39.1.
I'd just do the arithmetic in uint32_t if it fits and in uint64_t otherwise.
uint32_t modadd_32(uint32_t a, uint32_t b, uint32_t m) {
if (b <= UINT32_MAX - a)
return (a + b) % m;
else
return ((uint64_t)a + b) % m;
}
On an architecture with 64bit integer types, this should be almost no overhead, you could even think of just doing everything in uint64_t. On architectures where uint64_t is synthesized
let the compiler decide what he thinks is best, an then look into the generated assembler and mmeasure to see if this is satisfactory.
Overflow-safe modular addition
First establish that a<m and b<m with the usual % m.
Add updated a and b.
Should a (or b) exceed the uintN_t sum, then the mathematically sum was an uintN_t overflow and subtraction of m will "mod" the mathematically sum into the range of uintN_t.
If the sum exceeds m, then like the above step, a single subtraction of m will "mod" the sum.
uintN_t modadd_N(uintN_t a, uintN_t b, uintN_t m) {
// may omit these 2 steps if a < b and a < m are known before the call.
a %= m;
b %= m;
uintN_t sum = a + b;
if (sum >= m || sum < a) {
sum -= m;
}
return sum;
}
Quite simple in the end.
Overflow-safe modular subtraction
Variation on #Evgeny Kluev good answer.
uintN_t modsub_N(uintN_t a, uintN_t b, uintN_t m) {
// may omit these 2 steps if a < b and a < m are known before the call.
a %= m;
b %= m;
uintN_t diff = a - b;
if (a < b) {
diff += m;
}
return diff;
}
Note this approach works for various N such as 32, 64, 16 or unsigned, unsigned long, etc. without resorting to wider types. It also works for unsigned types narrower than int/unsigned.