I'm working through the problems in Project Euler as a way of learning Haskell, and I find that my programs are a lot slower than a comparable C version, even when compiled. What can I do to speed up my Haskell programs?
For example, my brute-force solution to Problem 14 is:
import Data.Int
import Data.Ord
import Data.List
searchTo = 1000000
nextNumber :: Int64 -> Int64
nextNumber n
| even n = n `div` 2
| otherwise = 3 * n + 1
sequenceLength :: Int64 -> Int
sequenceLength 1 = 1
sequenceLength n = 1 + (sequenceLength next)
where next = nextNumber n
longestSequence = maximumBy (comparing sequenceLength) [1..searchTo]
main = putStrLn $ show $ longestSequence
Which takes around 220 seconds, while an "equivalent" brute-force C version only takes 1.2 seconds.
#include <stdio.h>
int main(int argc, char **argv)
{
int longest = 0;
int terms = 0;
int i;
unsigned long j;
for (i = 1; i <= 1000000; i++)
{
j = i;
int this_terms = 1;
while (j != 1)
{
this_terms++;
if (this_terms > terms)
{
terms = this_terms;
longest = i;
}
if (j % 2 == 0)
j = j / 2;
else
j = 3 * j + 1;
}
}
printf("%d\n", longest);
return 0;
}
What am I doing wrong? Or am I naive to think that Haskell could even approach C's speed?
(I'm compiling the C version with gcc -O2, and the Haskell version with ghc --make -O).
For testing purpose I have just set searchTo = 100000. The time taken is 7.34s. A few modification leads to some big improvement:
Use an Integer instead of Int64. This improves the time to 1.75s.
Use an accumulator (you don't need sequenceLength to be lazy right?) 1.54s.
seqLen2 :: Int -> Integer -> Int
seqLen2 a 1 = a
seqLen2 a n = seqLen2 (a+1) (nextNumber n)
sequenceLength :: Integer -> Int
sequenceLength = seqLen2 1
Rewrite the nextNumber using quotRem, thus avoiding computing the division twice (once in even and once in div). 1.27s.
nextNumber :: Integer -> Integer
nextNumber n
| r == 0 = q
| otherwise = 6*q + 4
where (q,r) = quotRem n 2
Use Schwartzian transform instead of maximumBy. The problem of maximumBy . comparing is that the sequenceLength function is called more than once for each value. 0.32s.
longestSequence = snd $ maximum [(sequenceLength a, a) | a <- [1..searchTo]]
Note:
I check the time by compiling with ghc -O and run with +RTS -s)
My machine is running on Mac OS X 10.6. The GHC version is 6.12.2. The compiled file is in i386 architecture.)
The C problem runs at 0.078s with the corresponding parameter. It is compiled with gcc -O3 -m32.
Although this is already rather old, let me chime in, there's one crucial point that hasn't been addressed before.
First, the timings of the different programmes on my box. Since I'm on a 64-bit linux system, they show somewhat different characteristics: using Integer instead of Int64 does not improve performance as it would with a 32-bit GHC, where each Int64 operation would incur the cost of a C-call while the computations with Integers fitting in signed 32-bit integers don't need a foreign call (since only few operations exceed that range here, Integer is the better choice on a 32-bit GHC).
C: 0.3 seconds
Original Haskell: 14.24 seconds, using Integer instead of Int64: 33.96 seconds
KennyTM's improved version: 5.55 seconds, using Int: 1.85 seconds
Chris Kuklewicz's version: 5.73 seconds, using Int: 1.90 seconds
FUZxxl's version: 3.56 seconds, using quotRem instead of divMod: 1.79 seconds
So what have we?
Calculate the length with an accumulator so the compiler can transform it (basically) into a loop
Don't recalculate the sequence lengths for the comparisons
Don't use div resp. divMod when it's not necessary, quot resp. quotRem are much faster
What is still missing?
if (j % 2 == 0)
j = j / 2;
else
j = 3 * j + 1;
Any C compiler I have used transforms the test j % 2 == 0 into a bit-masking and doesn't use a division instruction. GHC does not (yet) do that. So testing even n or computing n `quotRem` 2 is quite an expensive operation. Replacing nextNumber in KennyTM's Integer version with
nextNumber :: Integer -> Integer
nextNumber n
| fromInteger n .&. 1 == (0 :: Int) = n `quot` 2
| otherwise = 3*n+1
reduces its running time to 3.25 seconds (Note: for Integer, n `quot` 2 is faster than n `shiftR` 1, that takes 12.69 seconds!).
Doing the same in the Int version reduces its running time to 0.41 seconds. For Ints, the bit-shift for division by 2 is a bit faster than the quot operation, reducing its running time to 0.39 seconds.
Eliminating the construction of the list (that doesn't appear in the C version either),
module Main (main) where
import Data.Bits
result :: Int
result = findMax 0 0 1
findMax :: Int -> Int -> Int -> Int
findMax start len can
| can > 1000000 = start
| canlen > len = findMax can canlen (can+1)
| otherwise = findMax start len (can+1)
where
canlen = findLen 1 can
findLen :: Int -> Int -> Int
findLen l 1 = l
findLen l n
| n .&. 1 == 0 = findLen (l+1) (n `shiftR` 1)
| otherwise = findLen (l+1) (3*n+1)
main :: IO ()
main = print result
yields a further small speedup, resulting in a running time of 0.37 seconds.
So the Haskell version that's in close correspondence to the C version doesn't take that much longer, it's a factor of ~1.3.
Well, let's be fair, there's an inefficiency in the C version that's not present in the Haskell versions,
if (this_terms > terms)
{
terms = this_terms;
longest = i;
}
appearing in the inner loop. Lifting that out of the inner loop in the C version reduces its running time to 0.27 seconds, making the factor ~1.4.
The comparing may be recomputing sequenceLength too much. This is my best version:
type I = Integer
data P = P {-# UNPACK #-} !Int {-# UNPACK #-} !I deriving (Eq,Ord,Show)
searchTo = 1000000
nextNumber :: I -> I
nextNumber n = case quotRem n 2 of
(n2,0) -> n2
_ -> 3*n+1
sequenceLength :: I -> Int
sequenceLength x = count x 1 where
count 1 acc = acc
count n acc = count (nextNumber n) (succ acc)
longestSequence = maximum . map (\i -> P (sequenceLength i) i) $ [1..searchTo]
main = putStrLn $ show $ longestSequence
The answer and timing are slower than C, but it does use arbitrary precision integers (through the Integer type):
ghc -O2 --make euler14-fgij.hs
time ./euler14-fgij
P 525 837799
real 0m3.235s
user 0m3.184s
sys 0m0.015s
Haskell's lists are heap-based, whereas your C code is exceedingly tight and makes no heap use at all. You need to refactor to remove the dependency on lists.
Even if I'm a bit late, here is mine, I removed the dependency on lists and this solution uses no heap at all too.
{-# LANGUAGE BangPatterns #-}
-- Compiled with ghc -O2 -fvia-C -optc-O3 -Wall euler.hs
module Main (main) where
searchTo :: Int
searchTo = 1000000
nextNumber :: Int -> Int
nextNumber n = case n `divMod` 2 of
(k,0) -> k
_ -> 3*n + 1
sequenceLength :: Int -> Int
sequenceLength n = sl 1 n where
sl k 1 = k
sl k x = sl (k + 1) (nextNumber x)
longestSequence :: Int
longestSequence = testValues 1 0 0 where
testValues number !longest !longestNum
| number > searchTo = longestNum
| otherwise = testValues (number + 1) longest' longestNum' where
nlength = sequenceLength number
(longest',longestNum') = if nlength > longest
then (nlength,number)
else (longest,longestNum)
main :: IO ()
main = print longestSequence
I compiled this piece with ghc -O2 -fvia-C -optc-O3 -Wall euler.hs and it runs in 5 secs, compared to 80 of the beginning implementation. It doesn't uses Integer, but because I'm on a 64-bit machine, the results may be cheated.
The compiler can unbox all Ints in this case, resulting in really fast code. It runs faster than all other solutions I've seen so far, but C is still faster.
Related
Let's say I have 3 double-precision arrays,
real*8, dimension(n) :: x, y, z
which are initialized as
x = 1.
y = (/ (1., i=1,n) /)
z = (/ (1. +0*i, i=1,n) /)
They should initialize all elements of all arrays to 1. In ifort (16.0.0 20150815), this works as intended for any n within the range of the declared precision. That is, if we initialize n as
integer*4, parameter :: n
then as long as n < 2147483647, the initialization works as intended for all declarations.
In gfortran (4.8.5 20150623 Red Hat 4.8.5-16), the initialization fails for y (array comprehension with constant argument) as long as n>65535, independent of its precision. AFAIK, 65535 is the maximum of a unsigned short int, aka unsigned int*2 which is well within the range of integer*4.
Below is an MWE:
program test
implicit none
integer*4, parameter :: n = 65536
integer*4, parameter :: m = 65535
real*8, dimension(n) :: x, y, z
real*8, dimension(m) :: a, b, c
integer*4 :: i
print *, huge(n)
x = 1.
y = (/ (1., i=1,n) /)
z = (/ (1.+0*i, i=1,n) /)
print *, x(n), y(n), z(n)
a = 1.
b = (/ (1., i=1,m) /)
c = (/ (1.+0*i, i=1,m) /)
print *, a(m), c(m), c(m)
end program test
Compiling with gfortran (gfortran test.f90 -o gfortran_test), it outputs:
2147483647
1.0000000000000000 0.0000000000000000 1.0000000000000000
1.0000000000000000 1.0000000000000000 1.0000000000000000
Compiling with ifort (ifort test.f90 -o ifort_test), it outputs:
2147483647
1.00000000000000 1.00000000000000 1.00000000000000
1.00000000000000 1.00000000000000 1.00000000000000
What gives?
There is indeed a big difference in how the compiler treats the array constructors. For n<=65535 there is the actual array of [1., 1., 1.,...] stored in the object file (or in some of the intermediate representations).
For a larger array the compiler generates a loop:
(*(real(kind=8)[65536] * restrict) atmp.0.data)[offset.1] = 1.0e+0;
offset.1 = offset.1 + 1;
{
integer(kind=8) S.2;
S.2 = 0;
while (1)
{
if (S.2 > 65535) goto L.1;
y[S.2] = (*(real(kind=8)[65536] * restrict) atmp.0.data)[S.2];
S.2 = S.2 + 1;
}
L.1:;
}
it appears to me, that first it sets only one element of a temporary array and then it copies the (mostly undefined) temporary array to y. And that is wrong. Valgrind also reports usage of uninitialized memory.
For a default real we have
while (1)
{
if (shadow_loopvar.2 > 65536) goto L.1;
(*(real(kind=4)[65536] * restrict) atmp.0.data)[offset.1] = 1.0e+0;
offset.1 = offset.1 + 1;
shadow_loopvar.2 = shadow_loopvar.2 + 1;
}
L.1:;
{
integer(kind=8) S.3;
S.3 = 0;
while (1)
{
if (S.3 > 65535) goto L.2;
y[S.3] = (*(real(kind=4)[65536] * restrict) atmp.0.data)[S.3];
S.3 = S.3 + 1;
}
L.2:;
}
We have two loops now, one sets the whole temporary array and the second one copies that to y and everything is fine.
Conclusion: a compiler bug.
The issue was fixed by GCC developers who read this question. The bug is tracked at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=84931
They also identified that the problem is connected to type conversion. The constructor has default precision 1. and with single precision array there is no type conversion, but for a double precision array there is some type conversion. That caused the difference for these two cases.
How to create a c code that receive int parameter n and return the value of this mathematical equation
f(n) = 3 * f(n - 1) + 4, where f(0) = 1
each time the program receive n , the program should start from the 0 to n which means in code (for loop) .
the problem here that i can't translate this into code , I'm stuck at the f(n-1) part , how can i make this work in c ?
Note. this code should be build only in basic C (no more the loops , no functions , in the void main etc) .
It's called recursion, and you have a base case where f(0) == 1, so just check if (n == 0) and return 1 or recurse
int f(int n)
{
if (n == 0)
return 1;
return 3 * f(n - 1) + 4;
}
An iterative solution is quite simple too, for example if f(5)
#include <stdio.h>
int
main(void)
{
int f;
int n;
f = 1;
for (n = 1 ; n <= 5 ; ++n)
f = 3 * f + 4;
printf("%d\n", f);
return 0;
}
A LRE (linear recurrence equation) can be converted into a matrix multiply. In this case:
F(0) = | 1 | (the current LRE value)
| 1 | (this is just copied, used for the + 4)
M = | 3 4 | (calculates LRE to new 1st number)
| 0 1 | (copies previous 2nd number to new 2nd number (the 1))
F(n) = M F(n-1) = matrixpower(M, n) F(0)
You can raise a matrix to the power n by using repeated squaring, sometimes called binary exponentiation. Example code for integer:
r = 1; /* result */
s = m; /* s = squares of integer m */
while(n){ /* while exponent != 0 */
if(n&1) /* if bit of exponent set */
r *= s; /* multiply by s */
s *= s; /* s = s squared */
n >>= 1; /* test next exponent bit */
}
For an unsigned 64 bit integer, the max value for n is 40, so the maximum number of loops would be 6, since 2^6 > 40.
If this expression was calculating f(n) = 3 f(n-1) + 4 modulo some prime number (like 1,000,000,007) for very large n, then the matrix method would be useful, but in this case, with a max value of n = 40, recursion or iteration is good enough and simpler.
Best will be to use recursion . Learn it online .
Its is very powerful method for solving problems. Classical one is to calculate factorials. Its is used widely in many algorithms like tree/graph traversal etc.
Recursion in computer science is a method where the solution to a problem depends on solutions to smaller instances of the same problem.
Here you break you problem of size n into 3 instance of sub problem of size n-1 + a problem of constant size at each such step.
Recursion will stop at base case i.e. the trivial case here for n=0 the function or the smallest sub problem has value 1.
I have written this code in C where each of a,b,cc,ma,mb,mcc,N,k are int . But as per specification of the problem , N and k could be as big as 10^9 . 10^9 can be stored within a int variable in my machine. But internal and final value of of a,b,cc,ma,mb,mcc will be much bigger for bigger values of N and k which can not be stored even in a unsigned long long int variable.
Now, I want to print value of mcc % 1000000007 as you can see in the code. I know, some clever modulo arithmetic tricks in the operations of the body of the for loop can create correct output without any overflow and also can make the program time efficient. Being new in modulo arithmetic, I failed to solve this. Can someone point me out those steps?
ma=1;mb=0;mcc=0;
for(i=1; i<=N; ++i){
a=ma;b=mb;cc=mcc;
ma = k*a + 1;
mb = k*b + k*(k-1)*a*a;
mcc = k*cc + k*(k-1)*a*(3*b+(k-2)*a*a);
}
printf("%d\n",mcc%1000000007);
My attempt:
I used a,b,cc,ma,mb,mcc as long long and done this. Could it be optimized more ??
ma=1;mb=0;cc=0;
ok = k*(k-1);
for(i=1; i<=N; ++i){
a=ma;b=mb;
as = (a*a)%MOD;
ma = (k*a + 1)%MOD;
temp1 = (k*b)%MOD;
temp2 = (as*ok)%MOD;
mb = (temp1+temp2)%MOD;
temp1 = (k*cc)%MOD;
temp2 = (as*(k-2))%MOD;
temp3 = (3*b)%MOD;
temp2 = (temp2+temp3)%MOD;
temp2 = (temp2*a)%MOD;
temp2 = (ok*temp2)%MOD;
cc = (temp1 + temp2)%MOD;
}
printf("%lld\n",cc);
Let's look at a small example:
mb = (k*b + k*(k-1)*a*a)%MOD;
Here, k*b, k*(k-1)*a*a can overflow, so can the sum, taking into account
(x + y) mod m = (x mod m + y mod m) mod m
we can rewrite this (x= k*b, y=k*(k-1)*a*a and m=MOD)
mb = ((k*b) % MOD + (k*(k-1)*a*a) %MOD) % MOD
now, we could go one step futher. Since
x * y mod m = (x mod m * y mod m) mod m
we can also rewrite the multiplication k*(k-1)*a*a % MOD with with x=k*(k-1) and y=a*a to
((k*(k-1)) %MOD) * ((a*a) %MOD)) % MOD
I'm sure you can do the rest. While you can sprinkle % MOD all over the place, you should careful consider whether you need it or not, taking John's hint into account:
Adding two n-digit numbers produces a number of up to n+1 digits, and
multiplying an n-digit number by an m-digit number produces a result
with up to n + m digits.
As such, there are places where you will need use modulus properties, and there are some, where you surely don't need it, but this is your part of the work ;).
That's a good exercise to build a template class along these lines:
template <int N>
class modulo_int_t
{
public:
modulo_int_t(int value) : value_(value % N) {}
modulo_int_t<N> operator+(const modulo_int_t<N> &rhs)
{
return modulo_int_t<N>(value_ + rhs.value) ;
}
// fill in the other operations
private:
int value_ ;
} ;
Then write the operations using modulo_int_t<1000000007> objects instead of int.
Disclaimer: make use of long long where appropriate and take care of negative differencies...
Forgive me if I am being a bit silly, but I have only very recently started programming, and am maybe a little out of my depth doing Problem 160 on Project Euler. I have made some attempts at solving it but it seems that going through 1tn numbers will take too long on any personal computer, so I guess I should be looking into the mathematics to find some short-cuts.
Project Euler Problem 160:
For any N, let f(N) be the last five digits before the trailing zeroes
in N!. For example,
9! = 362880 so f(9)=36288 10! = 3628800 so f(10)=36288 20! =
2432902008176640000 so f(20)=17664
Find f(1,000,000,000,000)
New attempt:
#include <stdio.h>
main()
{
//I have used long long ints everywhere to avoid possible multiplication errors
long long f; //f is f(1,000,000,000,000)
f = 1;
for (long long i = 1; i <= 1000000000000; i = ++i){
long long p;
for (p = i; (p % 10) == 0; p = p / 10) //p is i without proceeding zeros
;
p = (p % 1000000); //p is last six nontrivial digits of i
for (f = f * p; (f % 10) == 0; f = f / 10)
;
f = (f % 1000000);
}
f = (f % 100000);
printf("f(1,000,000,000,000) = %d\n", f);
}
Old attempt:
#include <stdio.h>
main()
{
//This part of the programme removes the zeros in factorials by dividing by 10 for each factor of 5, and finds f(1,000,000,000,000) inductively
long long int f, m; //f is f(n), m is 10^k for each multiple of 5
short k; //Stores multiplicity of 5 for each multiple of 5
f = 1;
for (long long i = 1; i <= 100000000000; ++i){
if ((i % 5) == 0){
k = 1;
for ((m = i / 5); (m % 5) == 0; m = m / 5) //Computes multiplicity of 5 in factorisation of i
++k;
m = 1;
for (short j = 1; j <= k; ++j) //Computes 10^k
m = 10 * m;
f = (((f * i) / m) % 100000);
}
else f = ((f * i) % 100000);
}
printf("f(1,000,000,000,000) = %d\n", f);
}
The problem is:
For any N, let f(N) be the last five digits before the trailing zeroes in N!. Find f(1,000,000,000,000)
Let's rephrase the question:
For any N, let g(N) be the last five digits before the trailing zeroes in N. For any N, let f(N) be g(N!). Find f(1,000,000,000,000).
Now, before you write the code, prove this assertion mathematically:
For any N > 1, f(N) is equal to g(f(N-1) * g(N))
Note that I have not proved this myself; I might be making a mistake here. (UPDATE: It appears to be wrong! We'll have to give this more thought.) Prove it to your satisfaction. You might want to start by proving some intermediate results, like:
g(x * y) = g(g(x) * g(y))
And so on.
Once you have obtained a proof of this result, now you have a recurrence relation that you can use to find any f(N), and the numbers you have to deal with don't ever get much larger than N.
Prod(n->k)(k*a+c) mod a <=> c^k mod a
For example
prod[ 3, 1000003, 2000003,... , 999999000003 ] mod 1000000
equals
3^(1,000,000,000,000/1,000,000) mod 1000000
And number of trailing 0 in N! equals to number of 5 in factorisation of N!
I would compute the whole thing and then separate first nonzero digits from LSB ...
but for you I think is better this:
1.use bigger base
any number can be rewrite as sum of multiplies of powers of the same number (base)
like 1234560004587786542 can be rewrite to base b=1000 000 000 like this:
1*b^2 + 234560004*b^1 + 587786542*b^0
2.when you multiply then lower digit is dependent only on lowest digits of multiplied numbers
A*B = (a0*b^0+a1*b^1+...)*(b0*b^0+b1*b^1+...)
= (a0*b0*b^0)+ (...*b^1) + (...*b^2)+ ...
3.put it together
for (f=1,i=1;i<=N;i++)
{
j=i%base;
// here remove ending zeroes from j
f*=j;
// here remove ending zeroes from f
f%=base;
}
do not forget that variable f has to be big enough for base^2
and base has to be at least 2 digits bigger then 100000 to cover 5 digits and overflows to zero
base must be power of 10 to preserve decimal digits
[edit1] implementation
uint<2> f,i,j,n,base; // mine 64bit unsigned ints (i use 32bit compiler/app)
base="10000000000"; // base >= 100000^2 ... must be as string to avoid 32bit trunc
n="20"; // f(n) ... must be as string to avoid 32bit trunc
for (f=1,i=1;i<=n;i++)
{
j=i%base;
for (;(j)&&((j%10).iszero());j/=10);
f*=j;
for (;(f)&&((f%10).iszero());f/=10);
f%=base;
}
f%=100000;
int s=f.a[1]; // export low 32bit part of 64bit uint (s is the result)
It is too slow :(
f(1000000)=12544 [17769.414 ms]
f( 20)=17664 [ 0.122 ms]
f( 10)=36288 [ 0.045 ms]
for more speed or use any fast factorial implementation
[edit2] just few more 32bit n! factorials for testing
this statement is not valid :(
//You could attempt to exploit that
//f(n) = ( f(n%base) * (f(base)^floor(n/base)) )%base
//do not forget that this is true only if base fulfill the conditions above
luckily this one seems to be true :) but only if (a is much much bigger then b and a%base=0)
g((a+b)!)=g(g(a!)*g(b!))
// g mod base without last zeroes...
// this can speed up things a lot
f( 1)=00001
f( 10)=36288
f( 100)=16864
f( 1,000)=53472
f( 10,000)=79008
f( 100,000)=56096
f( 1,000,000)=12544
f( 10,000,000)=28125
f( 1,000,100)=42016
f( 1,000,100)=g(??????12544*??????16864)=g(??????42016)->42016
the more is a closer to b the less valid digits there are!!!
that is why f(1001000) will not work ...
I'm not an expert project Euler solver, but some general advice for all Euler problems.
1 - Start by solving the problem in the most obvious way first. This may lead to insights for later attempts
2 - Work the problem for a smaller range. Euler usually give an answer for the smaller range that you can use to check your algorithm
3 - Scale up the problem and work out how the problem will scale, time-wise, as the problem gets bigger
4 - If the solution is going to take longer than a few minutes, it's time to check the algorithm and come up with a better way
5 - Remember that Euler problems always have an answer and rely on a combination of clever programming and clever mathematics
6 - A problem that has been solved by many people cannot be wrong, it's you that's wrong!
I recently solved the phidigital number problem (Euler's site is down, can't look up the number, it's quite recent at time of posting) using exactly these steps. My initial brute-force algorithm was going to take 60 hours, I took a look at the patterns solving to 1,000,000 showed and got the insight to find a solution that took 1.25s.
It might be an idea to deal with numbers ending 2,4,5,6,8,0 separately. Numbers ending 1,3,7,9 can not contribute to a trailing zeros. Let
A(n) = 1 * 3 * 7 * 9 * 11 * 13 * 17 * 19 * ... * (n-1).
B(n) = 2 * 4 * 5 * 6 * 8 * 10 * 12 * 14 * 15 * 16 * 18 * 20 * ... * n.
The factorial of n is A(n)*B(n). We can find the last five digits of A(n) quite easily. First find A(100,000) MOD 100,000 we can make this easier by just doing multiplications mod 100,000. Note that A(200,000) MOD 100,000 is just A(100,000)*A(100,000) MOD 100,000 as 100,001 = 1 MOD 100,000 etc. So A(1,000,000,000,000) is just A(100,000)^10,000,000 MOD 100,000.
More care is needed with 2,4,5,6,8,0 you'll need to track when these add a trailing zero. Obviously whenever we multiply by numbers ending 2 or 5 we will end up with a zero. However there are cases when you can get two zeros 25*4 = 100.
So, i would like to convert a part of C code to Haskell. I wrote this part (it's a simplified example of what I want to do) in C, but being the newbie I am in Haskell, I can't really make it work.
float g(int n, float a, float p, float s)
{
int c;
while (n>0)
{
c = n % 2;
if (!c) s += p;
else s -= p;
p *= a;
n--;
}
return s;
}
Anyone got any ideas/solutions?
Lee's translation is already pretty good (well, he confused the odd and even cases(1)), but he fell into a couple of performance traps.
g n a p s =
if n > 0
then
let c = n `mod` 2
s' = (if c == 0 then (-) else (+)) s p
p' = p * a
in g (n-1) a p' s'
else s
He used mod instead of rem. The latter maps to machine division, the former performs additional checks to ensure a non-negative result. Thus mod is a bit slower than rem, and if either satisfies the needs - because they yield identical results in the case where both arguments are non-negative; or because the result is only compared to 0 (both conditions are satisfied here) - rem is preferable. Even better, and a bit more idiomatic is to use even (which uses rem for the reasons mentioned above). The difference is not huge, though.
No type signature. That means that the code is (type-class) polymorphic, and thus no strictness analysis is possible, nor any specialisations. If the code is used in the same module at a specific type, GHC can (and usually will, if optimisations are enabled) create a specialised version for that specific type that allows strictness analysis and some other optimisations (inlining of class methods like (+) etc.), in that case, one does not pay the polymorhism penalty. But if the use site is in a different module, that cannot happen. If (type-class) polymorphic code is desired, one should mark it INLINABLE or INLINE (for GHC < 7), so that its unfolding is exposed in the .hi file and the function can be specialised and optimised at the use site.
Since g is recursive, it cannot be inlined [meaning, GHC cannot inline it; in principle it is possible] at use sites, which often would enable more optimisations than a mere specialisation.
One technique that often allows better optimisation for recursive functions is the worker/wrapper transformation. One creates a wrapper that calls a recursive (local) worker, then the non-recursive wrapper can be inlined, and when the worker is called with known arguments, that can enable further optimisations like constant folding or, in the case of function arguments, inlining. In particular the latter often has an enormous impact, when combined with a static-argument-transformation (arguments that never change in the recursive calls are not passed as arguments to the recursive worker).
In this case, we only have one static argument of type Float, so a worker/wrapper transformation with a SAT typically makes no difference (as a rule of thumb, a SAT pays off when
the static argument is a function
several non-function arguments are static
so by this rule, we shouldn't expect any benefit from w/w + SAT, and in general, there is none). Here we have one special case where w/w + SAT can make a big difference, and that is when the factor a is 1. GHC has {-# RULES #-} that eliminate multiplication by 1 for various types, and with such a short loop body, a multiplication more or less per iteration makes a difference, the running time is reduced by about 40% after points 3 and 4 have been applied. (There are no RULES for multiplication by 0 or by -1 for floating point types because 0*x = 0 resp. (-1)*x = -x don't hold for NaNs.) For all other a, the w/w + SATed
{-# INLINABLE g #-}
g n a p s = worker n p s
where
worker n p s
| n <= 0 = s
| otherwise = let s' = if even n then s + p else s - p
in worker (n-1) a (p*a) s'
does not perform measurably different from the top-level recursive version with the same optimisations done.
Strictness. GHC's strictness analyser is good, but not perfect. It cannot see far enough through the algorithm to determine that the function is
strict in p if n >= 1 (assuming addition - (+) - is strict in both arguments)
also strict in a if n >= 2 (assuming strictness of (*) in both arguments)
and then produce a worker that is strict in both. Instead you get a worker that uses an unboxed Int# for n and an unboxed Float# for s (I'm using the type Int -> Float -> Float -> Float -> Float here, corresponding to the C), and boxed Floats for a and p. Thus in each iteration you get two unboxings and a re-boxing. That costs (relatively) a lot of time, since besides that it's just a bit of simple arithmetic and tests.
Help GHC along a bit, and make the worker (or g itself, if you don't do the worker/wrapper transform) strict in p (bang pattern for example). That is enough to allow GHC producing a worker using unboxed values throughout.
Using division to test parity (not applicable if the type is Int and the LLVM backend is used).
GHC's optimiser hasn't got down to the low-level bits very much yet, so the native code generator emits a division instruction for
x `rem` 2 == 0
and, when the rest of the loop body is as cheap as it is here, that costs a lot of time. LLVM's optimiser has already been taught to replace that with a bitmasking at type Int, so with ghc -O2 -fllvm you don't need to do that manually. With the native code generator, substituting that with
x .&. 1 == 0
(needs import Data.Bits of course) produces a significant speedup (on normal platforms where a bitwise and is much faster than a division).
The final result
{-# INLINABLE g #-}
g n a p s = worker n p s
where
worker k !ap acc
| k > 0 = worker (k-1) (ap*a) (if k .&. (1 :: Int) == 0 then acc + ap else acc - ap)
| otherwise = acc
performs not measurably different (for the tested values) from the result of gcc -O3 -msse2 loop.c, except for a = -1, where gcc replaces the multiplication with a negation (assuming all NaNs equivalent).
(1) He's not alone in that,
c = n % 2;
if (!c) s += p;
else s -= p;
seems to be really tricky, as far as I can see everybody(2) got that wrong.
(2) With one exception ;)
As a first step, let's simplify your code:
float g(int n, float a, float p, float s) {
if (n <= 0) return s;
float s2 = n % 2 == 0 ? s + p : s - p;
return g(n - 1, a, a*p, s2)
}
We have turned your original function into a recursive one that exhibits a certain structure. It's a sequence! We can turn this into Haskell conveniently:
gs :: Bool -> Float -> Float -> Float -> [Float]
gs nb a p s = s : gs (not nb) a (a*p) (if nb then s - p else s + p)
Finally we just need to index this list:
g :: Integer -> Float -> Float -> Float -> Float
g n a p s = gs (even n) a p s !! (n - 1)
The code is not tested, but it should work. If not, it's probably just an off-by-one error.
Here is how I would tackle this problem in Haskell. First, I observe that there are several loops merged into one here: we are
forming a geometric sequence (whose factor is a suitably negative version of p)
taking a prefix of the sequence
summing the result
So my solution follows this structure as well, with a tiny bit of s and p thrown in for good measure because that's what your code does. In a from-scratch version, I'd probably drop those two parameters entirely.
g n a p s = sum (s : take n (iterate (*(-a)) start)) where
start | odd n = -p
| otherwise = p
A fairly direct translation would be:
g n a p s =
if n > 0
then
let c = n `mod` 2
s' = (if c == 0 then (-) else (+)) s p
p' = p * a
in g (n-1) a p' s'
else s
Look at the signature of the g function (i.e., float g (int n, float a, float p, float s)) you know that your Haskell function will receive 4 elements and return a float, thus:
g :: Integer -> Float -> Float -> Float -> Float
let us now look into the loop, we see that n > 0 is the stop case, and n--; will be the decreasing step used on the recursive call. Therefore:
g :: Integer -> Float -> Float -> Float -> Float
g n a p s | n <= 0 = s
to n > 0, you have another conditional if (!(n % 2)) s += p; else s -= p; inside the loop. If n is odd than you will do s += p, p *= a and n--. In Haskell it will be:
g :: Integer -> Float -> Float -> Float -> Float
g n a p s | n <= 0 = s
| odd n = g (n-1) a (p*a) (s+p)
If n is even than you will do s-=p, p*=a; and n--. Thus:
g :: Integer -> Float -> Float -> Float -> Float
g n a p s | n <= 0 = s
| odd n = g (n-1) a (p*a) (s+p)
| otherwise = g (n-1) a (p*a) (s-p)
To expand on #Landei and #MathematicalOrchid 's comments below the question: The algorithm proposed to solve the problem at hand is always O(n). However, if you realize that what you're actually doing is computing a partial sum of the geometric series, you can use the well-known summation formula:
g n a p s = s + (-1)**n * p * ((-a)**n-1) / (-a-1)
This will be faster as the exponentiation can be done faster than O(n) by repeated squaring or other clever methods, which are likely automatically employed for integer powers by modern compilers.
You can encode loops almost-naturally with the Haskell Prelude function until :: (a -> Bool) -> (a -> a) -> a -> a:
g :: Int -> Float -> Float -> Float -> Float
g n a p s =
fst.snd $
until ((<= 0).fst)
(\(n,(!s,!p)) -> (n-1, (if even n then s+p else s-p, p*a)))
(n,(s,p))
The bang-patterns !s and !p mark strictly-calculated intermediate variables, to prevent excessive laziness which would otherwise harm efficiency.
until pred step start repeatedly applies the step function until pred called with the last generated value will hold, starting with initial value start. It can be represented by the pseudocode:
def until (pred, step, start): // well, actually,
while( true ): def until (pred, step, start):
if pred(start): return(start) if pred(start): return(start)
start := step(start) call until(pred, step, step(start))
The first pseudocode is equivalent to the second (which is how until is actually implemented) in the presence of tail call optimization, which is why in many functional languages where TCO is present loops are encoded via recursion.
So in Haskell, until is coded as
until p f x | p x = x
| otherwise = until p f (f x)
But it could have been coded differently, making explicit the interim results:
until p f x = last $ go x -- or, last (go x)
where go x | p x = [x]
| otherwise = x : go (f x)
Using the Haskell standard higher-order functions break and iterate this could be written as a stream-processing code,
until p f x = let (_,(r:_)) = break p (iterate f x) in r
-- or: span (not.p) ....
or just
until p f x = head $ dropWhile (not.p) $ iterate f x -- or, equivalently,
-- head . dropWhile (not.p) . iterate f $ x
If TCO weren't present in a given Haskell implementation, the last version would be the one to use.
Hopefully this makes clearer how the stream-processing code from Daniel Wagner's answer comes about,
g n a p s = s + (sum . take n . iterate (*(-a)) $ if odd n then (-p) else p)
because the predicate involved is about counting down from n, and
fst . snd . head . dropWhile ((> 0).fst) $
iterate (\(n,(!s,!p)) -> (n-1, (if even n then s+p else s-p, p*a)))
(n,(s,p))
===
fst . snd . head . dropWhile ((> 0).fst) $
iterate (\(n,(!s,!p)) -> (n-1, (s+p, p*(-a))))
(n,(s, if odd n then (-p) else p)) -- 0 is even
===
fst . (!! n) $
iterate (\(!s,!p) -> (s+p, p*(-a)))
(s, if odd n then (-p) else p)
===
foldl' (+) s . take n . iterate (*(-a)) $ if odd n then (-p) else p
In pure FP, the stream-processing paradigm makes all history of a computation available, as a stream (list) of values.