Z3 quantified formula with implication giving unsat - arrays

I'm still new to Z3, and hence not sure why I'm getting unsat for the formula below; it should be sat at least for those ts_var arrays, of which each array element (bitvector) (of the 32 array elements) has 1 in a different position of the 32 bits and zeros in all other positions (so the bvxor result will be different). So any advice or hints about what I'm doing wrong?
UPDATE: When I did the implication in exp4 ((=> a!1 a!2)) in the opposite way of what it is in the code, Z3 produced SAT! But this is not what I want. I want to find an array of which its different combinations of 2 elements gives different result when they are XORed together. Which is the implication in the code that still gives unsat.
(assert (exists ((ts_var (Array (_ BitVec 5) (_ BitVec 32))))
(forall ((k (_ BitVec 5)) (l (_ BitVec 5)) (m (_ BitVec 5)) (n (_ BitVec 5)))
(let ((a!1 (and (not (= k l))
(not (= n m))
(=> (= k m) (not (= l n)))
(=> (= l n) (not (= k m)))))
(a!2 (not (= (bvxor (select ts_var k) (select ts_var l))
(bvxor (select ts_var m) (select ts_var n))))))
(=> a!1 a!2)
)
)
)
)
(check-sat)
I originally wrote the code which gave this result using the C-API:
Z3_ast mk_var(Z3_context ctx, const char * name, Z3_sort ty)
{
Z3_symbol s = Z3_mk_string_symbol(ctx, name);
return Z3_mk_const(ctx, s, ty);
}
bv_w_sort = Z3_mk_bv_sort (ctx, 32);
index_w_sort = Z3_mk_bv_sort (ctx, 5);
array_sort = Z3_mk_array_sort(ctx, index_w_sort, bv_w_sort);
ts_var = mk_var(ctx, "ts_var" , array_sort);
fp1 = mk_var(ctx, "fp1" , bv_w_sort);
fp2 = mk_var(ctx, "fp2" , bv_w_sort);
fp1 = Z3_mk_bvxor(ctx, Z3_mk_select(ctx, ts_var, k) , Z3_mk_select(ctx, ts_var, l) );
fp2 = Z3_mk_bvxor(ctx, Z3_mk_select(ctx, ts_var, m) , Z3_mk_select(ctx, ts_var, n) );
cond_uniq = Z3_mk_not (ctx,Z3_mk_eq (ctx, fp1, fp2) );
cond_k_neq_l = Z3_mk_not (ctx,Z3_mk_eq (ctx, k, l));
cond_n_neq_m = Z3_mk_not (ctx,Z3_mk_eq (ctx, n, m));
cond_l_neq_n = Z3_mk_not (ctx,Z3_mk_eq (ctx, l, n));
cond_k_neq_m = Z3_mk_not (ctx,Z3_mk_eq (ctx, k, m));
cond_k_eq_m = Z3_mk_eq (ctx, k, m);
cond_l_eq_n = Z3_mk_eq (ctx, l, n);
cond_imply1 = Z3_mk_implies (ctx, cond_k_eq_m, cond_l_neq_n);
cond_imply2 = Z3_mk_implies (ctx, cond_l_eq_n, cond_k_neq_m);
args[0]= cond_k_neq_l;
args[1]= cond_n_neq_m;
args[2]= cond_imply1;
args[3]= cond_imply2;
exp4 = Z3_mk_and(ctx, 4, args);
bound[0] = (Z3_app) k;
bound[1] = (Z3_app) l;
bound[2] = (Z3_app) m;
bound[3] = (Z3_app) n;
bound4[0]= (Z3_app)ts_var;
exp2 = Z3_mk_implies(ctx, exp4, cond_uniq);
exp1 = Z3_mk_forall_const(ctx, 0, 4, bound, 0, 0, exp2);
q = Z3_mk_exists_const(ctx, 0, 1, bound4, 0, 0, exp1);
Z3_solver_assert(ctx, s, q);
I'm also not sure if I have to use some patterns over variables like suggested here: Does Z3 support variable-only patterns in quantified formulas?
But according to what I read in this tutorial http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.225.8231&rep=rep1&type=pdf
It seems OK just not to use any patterns, right?

The way you pick k, l, m, and n allows symmetries. For instance:
k = 0
l = 1
m = 1
n = 0
satisfies your condition a!1, but it obviously fails to pick "distinct" elements for ts_var; which makes a!2 false. Hence your entire query becomes unsat.
You can replace the definition of your a!1 with the following:
(a!1 (distinct k l m n))
which would concisely state these four variables are all different. With that change, z3 does find a model indeed.

Related

Calculating multiples in Haskell (conversion from C)? [closed]

Closed. This question needs to be more focused. It is not currently accepting answers.
Want to improve this question? Update the question so it focuses on one problem only by editing this post.
Closed 4 years ago.
Improve this question
I'm trying to write a Haskell program that calculates multiples. Basically, when given two integers a and b, I want to find how many integers 1 ≤ bi ≤ b are multiple of any integer 2 ≤ ai ≤ a. For example, if a = 3 and b = 30, I want to know how many integers in the range of 1-30 are a multiple of 2 or 3; there are 20 such integers: 2, 3, 4, 6, 8, 9, 10, 12, 14, 15, 16, 18, 20, 21, 22, 24, 26, 27, 28, 30.
I have a C program that does this. I'm trying to get this translated into Haskell, but part of the difficulty is getting around the loops that I've used since Haskell doesn't use loops. I appreciate any and all help in translating this!
My C program for reference (sorry if formatting is off):
#define PRIME_RANGE 130
#define PRIME_CNT 32
#define UPPER_LIMIT (1000000000000000ull) //10^15
#define MAX_BASE_MULTIPLES_COUNT 25000000
typedef struct
{
char primeFactorFlag;
long long multiple;
}multipleInfo;
unsigned char primeFlag[PRIME_RANGE + 1];
int primes[PRIME_CNT];
int primeCnt = 0;
int maxPrimeStart[PRIME_CNT];
multipleInfo baseMultiples[MAX_BASE_MULTIPLES_COUNT];
multipleInfo mergedMultiples[MAX_BASE_MULTIPLES_COUNT];
int baseMultiplesCount, mergedMultiplesCount;
void findOddMultiples(int a, long long b, long long *count);
void generateBaseMultiples(void);
void mergeLists(multipleInfo listSource[], int countS, multipleInfo
listDest[], int *countD);
void sieve(void);
int main(void)
{
int i, j, a, n, startInd, endInd;
long long b, multiples;
//Generate primes
sieve();
primes[primeCnt] = PRIME_RANGE + 1;
generateBaseMultiples();
baseMultiples[baseMultiplesCount].multiple = UPPER_LIMIT + 1;
//Input and Output
scanf("%d", &n);
for(i = 1; i <= n; i++)
{
scanf("%d%lld", &a, &b);
//If b <= a, all are multiple except 1
if(b <= a)
printf("%lld\n",b-1);
else
{
//Add all even multiples
multiples = b / 2;
//Add all odd multiples
findOddMultiples(a, b, &multiples);-
printf("%lld\n", multiples);
}
}
return 0;
}
void findOddMultiples(int a, long long b, long long *count)
{
int i, k;
long long currentNum;
for(k = 1; k < primeCnt && primes[k] <= a; k++)
{
for(i = maxPrimeStart[k]; i < maxPrimeStart[k + 1] &&
baseMultiples[i].multiple <= b; i++)
{
currentNum = b/baseMultiples[i].multiple;
currentNum = (currentNum + 1) >> 1; // remove even multiples
if(baseMultiples[i].primeFactorFlag) //odd number of factors
(*count) += currentNum;
else
(*count) -= currentNum;
}
}
}
void addTheMultiple(long long value, int primeFactorFlag)
{
baseMultiples[baseMultiplesCount].multiple = value;
baseMultiples[baseMultiplesCount].primeFactorFlag = primeFactorFlag;
baseMultiplesCount++;
}
void generateBaseMultiples(void)
{
int i, j, t, prevCount;
long long curValue;
addTheMultiple(3, 1);
mergedMultiples[0] = baseMultiples[0];
mergedMultiplesCount = 1;
maxPrimeStart[1] = 0;
prevCount = mergedMultiplesCount;
for(i = 2; i < primeCnt; i++)
{
maxPrimeStart[i] = baseMultiplesCount;
addTheMultiple(primes[i], 1);
for(j = 0; j < prevCount; j++)
{
curValue = mergedMultiples[j].multiple * primes[i];
if(curValue > UPPER_LIMIT)
break;
addTheMultiple(curValue, 1 - mergedMultiples[j].primeFactorFlag);
}
if(i < primeCnt - 1)
mergeLists(&baseMultiples[prevCount], baseMultiplesCount - prevCount, mergedMultiples, &mergedMultiplesCount);
prevCount = mergedMultiplesCount;
}
maxPrimeStart[primeCnt] = baseMultiplesCount;
}
void mergeLists(multipleInfo listSource[], int countS, multipleInfo listDest[], int *countD)
{
int limit = countS + *countD;
int i1, i2, j, k;
//Copy one list in unused safe memory
for(j = limit - 1, k = *countD - 1; k >= 0; j--, k--)
listDest[j] = listDest[k];
//Merge the lists
for(i1 = 0, i2 = countS, k = 0; i1 < countS && i2 < limit; k++)
{
if(listSource[i1].multiple <= listDest[i2].multiple)
listDest[k] = listSource[i1++];
else
listDest[k] = listDest[i2++];
}
while(i1 < countS)
listDest[k++] = listSource[i1++];
while(i2 < limit)
listDest[k++] = listDest[i2++];
*countD = k;
}
void sieve(void)
{
int i, j, root = sqrt(PRIME_RANGE);
primes[primeCnt++] = 2;
for(i = 3; i <= PRIME_RANGE; i+= 2)
{
if(!primeFlag[i])
{
primes[primeCnt++] = i;
if(root >= i)
{
for(j = i * i; j <= PRIME_RANGE; j += i << 1)
primeFlag[j] = 1;
}
}
}
}
First, unless I'm grossly misunderstanding, the number of multiples you have there is wrong. The number of multiples of 2 between 1 and 30 is 15, and the number of multiples of 3 between 1 and 30 is 10, so there should be 25 numbers there.
EDIT: I did misunderstand; you want unique multiples.
To get unique multiples, you can use Data.Set, which has the invariant that the elements of the Set are unique and ordered ascendingly.
If you know you aren't going to exceed x = maxBound :: Int, you can get even better speedups using Data.IntSet. I've also included some test cases and annotated with comments what they run at on my machine.
{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -O2 #-}
module Main (main) where
import System.CPUTime (getCPUTime)
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
main :: IO ()
main = do
test 3 30 -- 0.12 ms
test 131 132 -- 0.14 ms
test 500 300000 -- 117.63 ms
test :: Int -> Int -> IO ()
test !a !b = do
start <- getCPUTime
print (numMultiples a b)
end <- getCPUTime
print $ "Needed " ++ show ((fromIntegral (end - start)) / 10^9) ++ " ms.\n"
numMultiples :: Int -> Int -> Int
numMultiples !a !b = IntSet.size (foldMap go [2..a])
where
go :: Int -> IntSet
go !x = IntSet.fromAscList [x, x+x .. b]
I'm not really into understanding your C, so I implemented a solution afresh using the algorithm discussed here. The N in the linked algorithm is the product of the primes up to a in your problem description.
So first we'll need a list of primes. There's a standardish trick for getting a list of primes that is at once very idiomatic and relatively efficient:
primes :: [Integer]
primes = 2:filter isPrime [3..]
-- Doesn't work right for n<2, but we never call it there, so who cares?
isPrime :: Integer -> Bool
isPrime n = go primes n where
go (p:ps) n | p*p>n = True
| otherwise = n `rem` p /= 0 && go ps n
Next up: we want a way to iterate over the positive square-free divisors of N. This can be achieved by iterating over the subsets of the primes less than a. There's a standard idiomatic way to get a powerset, namely:
-- import Control.Monad
-- powerSet :: [a] -> [[a]]
-- powerSet = filterM (const [False, True])
That would be a fine component to use, but since at the end of the day we only care about the product of each powerset element and the value of the Mobius function of that product, we would end up duplicating a lot of multiplications and counting problems. It's cheaper to compute those two things directly while producing the powerset. So:
-- Given the prime factorization of a square-free number, produce a list of
-- its divisors d together with mu(d).
divisorsWithMu :: Num a => [a] -> [(a, a)]
divisorsWithMu [] = [(1, 1)]
divisorsWithMu (p:ps) = rec ++ [(p*d, -mu) | (d, mu) <- rec] where
rec = divisorsWithMu ps
With that in hand, we can just iterate and do a little arithmetic.
f :: Integer -> Integer -> Integer
f a b = b - sum
[ mu * (b `div` d)
| (d, mu) <- divisorsWithMu (takeWhile (<=a) primes)
]
And that's all the code. Crunched 137 lines of C down to 15 lines of Haskell -- not bad! Try it out in ghci:
> f 3 30
20
As an additional optimization, one could consider modifying divisorsWithMu to short-circuit when its divisor is bigger than b, as we know such terms will not contribute to the final sum. This makes a noticeable difference for large a, as without it there are exponentially many elements in the powerset. Here's how that modification looks:
-- Given an upper bound and the prime factorization of a square-free number,
-- produce a list of its divisors d that are no larger than the upper bound
-- together with mu(d).
divisorsWithMuUnder :: (Ord a, Num a) => a -> [a] -> [(a, a)]
divisorsWithMuUnder n [] = [(1, 1)]
divisorsWithMuUnder n (p:ps) = rec ++ [(p*d, -mu) | (d, mu) <- rec, p*d<=n]
where rec = divisorsWithMuUnder n ps
f' :: Integer -> Integer -> Integer
f' a b = b - sum
[ mu * (b `div` d)
| (d, mu) <- divisorsWithMuUnder b (takeWhile (<=a) primes)
]
Not much more complicated; the only really interesting difference is that there's now a condition in the list comprehension. Here's an example of f' finishing quickly for inputs that would take infeasibly long with f:
> f' 100 100000
88169
With data-ordlist package mentioned by Daniel Wagner in the comments, it is just
f a b = length $ unionAll [ [p,p+p..b] | p <- takeWhile (<= a) primes]
That is all. Some timings, for non-compiled code run inside GHCi:
~> f 100 (10^5)
88169
(0.05 secs, 48855072 bytes)
~> f 131 (3*10^6)
2659571
(0.55 secs, 1493586480 bytes)
~> f 131 132
131
(0.00 secs, 0 bytes)
~> f 500 300000
274055
(0.11 secs, 192704760 bytes)
Compiling will surely make the memory consumption a non-issue, by converting the length to a counting loop.
You'll have to use recursion in place of loops.
In (most) procedural or object-orientated languages, you should hardly ever (never?) be using recursion. It is horribly inefficient, as a new stack frame must be created each time the recursive function is called.
However, in a functional language, like Haskell, the compiler is often able to optimize the recursion away into a loop, which makes it much faster then its procedural counterparts.
I've converted your sieve function into a set of recursive functions in C. I'll leave it to you to convert it into Haskell:
int main(void) {
//...
int root = sqrt(PRIME_RANGE);
primes[primeCnt++] = 2;
sieve(3, PRIME_RANGE, root);
//...
}
void sieve(int i, int end, int root) {
if(i > end) {
return;
}
if(!primeFlag[i]) {
primes[primeCnt++] = i;
if(root >= i) {
markMultiples(i * i, PRIME_RANGE, i);
}
}
i += 2;
sieve(i, end, root);
}
void markMultiples(int j, int end, int prime) {
if(j > end) {
return;
}
primeFlag[j] = 1;
j += i << 1;
markMultiples(j, end, prime);
}
The point of recursion is that the same function is called repeatedly, until a condition is met. The results of one recursive call are passed onto the next call, until the condition is met.
Also, why are you bit-fiddling instead of just multiplying or dividing by 2? Any half-decent compiler these days can convert most multiplications and divisions by 2 into a bit-shift.

Functional way to find a pair of integers, which sum to X, in a sorted array

This is a follow-up to my previous question.
Suppose I want to find a pair of integers, which sum to a given number x, in a given sorted array. The well-known "one pass" solution looks like that:
def pair(a: Array[Int], target: Int): Option[(Int, Int)] = {
var l = 0
var r = a.length - 1
var result: Option[(Int, Int)] = None
while (l < r && result.isEmpty) {
(a(l), a(r)) match {
case (x, y) if x + y == target => result = Some(x, y)
case (x, y) if x + y < target => l = l + 1
case (x, y) if x + y > target => r = r - 1
}
}
result
}
How would you suggest write functionally without any mutable state ?
I guess I can write a recursive version with Stream (lazy list in Scala)
Could you suggest a non-recursive version ?
Here's a fairly straightforward version. It creates a Stream of Vectors that removes the first or last element on each iteration. Then we limit the size of the otherwise infinite Stream (-1 so you can't add a number with itself), then map it into the output format and check for the target condition.
def findSum(a: Vector[Int], target: Int): Option[(Int, Int)] = {
def stream = Stream.iterate(a){
xs => if (xs.head + xs.last > target) xs.init else xs.tail
}
stream.take (a.size - 1)
.map {xs => (xs.head, xs.last)}
.find {case (x,y) => x + y == target}
}
There are a lot of gems hidden in the companion objects of Scala's collections, like Stream.iterate. I highly recommend checking them out. Knowing about them can greatly simplify a problem like this.
Here's a version that doesn't use indices (which I try to avoid unless there's an important computation with the value of them):
def findPair2(x: Int, a: Array[Int]): Option[(Int, Int)] = {
def findPairLoop(x: Int, l: Array[Int], r: Array[Int]): Option[(Int, Int)] = {
if (l.head >= r.last) None
else if (l.head + r.last == x) Some((l.head, r.last))
else if (l.head + r.last > x) findPairLoop(x, l, r.init)
else findPairLoop(x, l.tail, r)
}
findPairLoop(x, a, a)
}
It's recursive, but doesn't need Stream. tail and init are O(N) for Array but if we use Lists and reverse the r collection to avoid init and last an O(N) version can be done
def findPairInOrderN(x: Int, a: Array[Int]): Option[(Int, Int)] = {
def findPairLoop(x: Int, l: List[Int], r: List[Int]): Option[(Int, Int)] = {
if (l.head >= r.head) None
else if (l.head + r.head == x) Some((l.head, r.head))
else if (l.head + r.head > x) findPairLoop(x, l, r.tail)
else findPairLoop(x, l.tail, r)
}
val l = a.toList
findPairLoop(x, l, l.reverse)
}
If we don't care about one-pass (or efficiency generally :)) it's a one liner
(for (m <-a ; n <- a if m + n == x) yield (m,n)).headOption
unwrapping that into flatmap/map and then using collectFirst gives us this, which is fairly neat and more optimal (but still not O(n)) - it stops at the first correct pair but does more work than necessary to get there.
a.collectFirst{case m => a.collectFirst{case n if n+m == x => (m,n)}}.get
Without recursion and without a mutable state it can get pretty ugly. Here's my attempt:
def f(l: List[Int], x: Int): Option[(Int, Int)] = {
l.foldLeft(l.reverse) {
(list, first) =>
list.headOption.map {
last =>
first + last match {
case `x` => return Some(first, last)
case sum if sum < x => list
case sum if sum > x =>
val innerList = list.dropWhile(_ + first > x)
innerList.headOption.collect {
case r if r + first == x => return Some(first, r)
}.getOrElse {
innerList
}
}
}.getOrElse {
return None
}
}
None
}
Examples:
scala> f(List(1, 2, 3, 4, 5), 3)
res33: Option[(Int, Int)] = Some((1,2))
scala> f(List(1, 2, 3, 4, 5), 9)
res34: Option[(Int, Int)] = Some((4,5))
scala> f(List(1, 2, 3, 4, 5), 12)
res36: Option[(Int, Int)] = None
The .reverse in the beginning plus the foldLeft with a return when the result is found makes this O(2n).
This is a one liner
[ (x, y) | x <- array, y <- array, x+y == n ]
It even works on unsorted lists.
But if you want to take advantage of the sorting, just do a binary search for (n-x) for every x in the array instead of going through the array.

Optimizing mutable array state heavy manipulation code

I've been trying to complete this exercise on hackerrank in time.
But my following Haskell solution fails on test case 13 to 15 due to time out.
My Haskell solution
import Data.Vector(Vector(..),fromList,(!),(//),toList)
import Data.Vector.Mutable
import qualified Data.Vector as V
import Data.ByteString.Lazy.Char8 (ByteString(..))
import qualified Data.ByteString.Lazy.Char8 as L
import Data.ByteString.Lazy.Builder
import Data.Maybe
import Control.Applicative
import Data.Monoid
import Prelude hiding (length)
readInt' = fst . fromJust . L.readInt
toB [] = mempty
toB (x:xs) = string8 (show x) <> string8 " " <> toB xs
main = do
[firstLine, secondLine] <- L.lines <$> L.getContents
let [n,k] = map readInt' $ L.words firstLine
let xs = largestPermutation n k $ fromList $ map readInt' $ Prelude.take n $ L.words secondLine
L.putStrLn $ toLazyByteString $ toB $ toList xs
largestPermutation n k v
| i >= l || k == 0 = v
| n == x = largestPermutation (n-1) k v
| otherwise = largestPermutation (n-1) (k-1) (replaceOne n x (i+1) (V.modify (\v' -> write v' i n) v))
where l = V.length v
i = l - n
x = v!i
replaceOne n x i v
| n == h = V.modify (\v' -> write v' i x ) v
| otherwise = replaceOne n x (i+1) v
where h = v!i
Most optimal solution that I've found constantly updates 2 arrays. One array being the main target, and other array being for fast index look ups.
Better Java solution
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
int n = input.nextInt();
int k = input.nextInt();
int[] a = new int[n];
int[] index = new int[n + 1];
for (int i = 0; i < n; i++) {
a[i] = input.nextInt();
index[a[i]] = i;
}
for (int i = 0; i < n && k > 0; i++) {
if (a[i] == n - i) {
continue;
}
a[index[n - i]] = a[i];
index[a[i]] = index[n - i];
a[i] = n - i;
index[n - i] = i;
k--;
}
for (int i = 0; i < n; i++) {
System.out.print(a[i] + " ");
}
}
My question is
What's the elegant and fast implementation of this algorithm in Haskell?
Is there a faster way to do this problem than the Java solution?
How should I deal with heavy array update elegantly and yet efficiently in Haskell in general?
One optimization you can do to mutable arrays is not to use them at all. In particular, the problem you have linked to has a right fold solution.
The idea being that you fold the list and greedily swap the items with the largest value to the right and maintain swaps already made in a Data.Map:
import qualified Data.Map as M
import Data.Map (empty, insert)
solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = foldr go (\_ _ _ -> []) xs n empty k
where
go x run i m k
-- out of budget to do a swap or no swap necessary
| k == 0 || y == i = y : run (pred i) m k
-- make a swap and record the swap made in the map
| otherwise = i : run (pred i) (insert i y m) (k - 1)
where
-- find the value current position is swapped with
y = find x
find k = case M.lookup k m of
Just a -> find a
Nothing -> k
In above, run is a function which given the reverse index i, current mapping m and the remaining swap budget k, solves the rest of the list onwards. By reverse index I mean indices of the list in the reverse direction: n, n - 1, ..., 1.
The folding function go, builds the run function at each step by updating values of i, m and k which are passed to the next step. At the end we call this function with initial parameters i = n, m = empty and initial swap budget k.
The recursive search in find can be optimized out by maintaining a reverse map, but this already performs much faster than the java code you have posted.
Edit: Above solution, still pays a logarithmic cost for tree access. Here is an alternative solution using mutable STUArray and monadic fold foldM_, which in fact performs faster than above:
import Control.Monad.ST (ST)
import Control.Monad (foldM_)
import Data.Array.Unboxed (UArray, elems, listArray, array)
import Data.Array.ST (STUArray, readArray, writeArray, runSTUArray, thaw)
-- first 3 args are the scope, which will be curried
swap :: STUArray s Int Int -> STUArray s Int Int -> Int
-> Int -> Int -> ST s Int
swap _ _ _ 0 _ = return 0 -- out of budget to make a swap
swap arr rev n k i = do
xi <- readArray arr i
if xi + i == n + 1
then return k -- no swap necessary
else do -- make a swap, and reduce budget
j <- readArray rev (n + 1 - i)
writeArray rev xi j
writeArray arr j xi
writeArray arr i (n + 1 - i)
return $ pred k
solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = elems $ runSTUArray $ do
arr <- thaw (listArray (1, n) xs :: UArray Int Int)
rev <- thaw (array (1, n) (zip xs [1..]) :: UArray Int Int)
foldM_ (swap arr rev n) k [1..n]
return arr
Not exactly an answer to #2, but there is a left fold solution that requires loading at most ~K values in memory at a time.
Because the problem deals with permutations, we know that 1 through N will appear in the output. If K > 0, at least the first K terms are going to be N, N-1, ... N - K, because we can afford at least K swaps. In addition, we expect some (K/N) digits to be in their optimal position.
This suggests an algorithm:
Initialize a map / dictionary and scan input xs as zip xs [n, n-1..]. For every (x, i), if x \= i, we 'decrement' K and update out dictionary s.t. dct[i] = x. This procedure terminates when K == 0 (out of swaps) or we run out of input (can output {N, N-1, ... 1}).
Next, if we have any more x <- xs we look at each one and print x if x is not in our dictionary or dct[x] otherwise.
The above algorithm can fail to produce an optimal permutation only if our dictionary contains a cycle. In that case, we moved around elements with absolute value >= K using |cycle| swaps. But this means that we moved one element to its original position! So we can always save a swap on every cycle (i.e. increment K).
Finally, this gives the memory efficient algorithm.
Step 0: get N, K
Step 1: Read the input permutation and output {N, N-1, ... N-K-E}, N <- N - K - E, K <- 0, update dict as per above,
where E = number of elements X equal to N - (index of X)
Step 2: remove and count cycles from dict; let cycles = number of cycles; if cycles > 0, let K <- |cycles|, go to step 1,
else go to step 3. We can make this step more efficient by optimizing the dict.
Step 3: Output the rest of the input as is.
The following Python code implements the idea and can be made quite fast if better cycle detection is used. Of course, data better be read in chunks, unlike below.
from collections import deque
n, t = map(int, raw_input().split())
xs = deque(map(int, raw_input().split()))
dct = {}
cycles = True
while cycles:
while t > 0 and xs:
x = xs.popleft()
if x != n:
dct[n] = x
t -= 1
print n,
n -= 1
cycles = False
for k, v in dct.items():
visited = set()
cycle = False
while v in dct:
if v in visited:
cycle = True
break
visited.add(v)
v, buf = dct[v], v
dct[buf] = v
if cycle:
cycles = True
for i in visited:
del dct[i]
t += 1
else:
dct[k] = v
while xs:
x = xs.popleft()
print dct.get(x, x),

Why is iterating through an array faster than Seq.find

I have an array sums that gives all the possible sums of a function f. This function accepts integers (say between 1 and 200, but same applies for say 1 and 10000) and converts them to double. I want to store sums as an array as I still haven't figured out how to do the algorithm I need without a loop.
Here's the code for how I generate sums:
let f n k = exp (double(k)/double(n)) - 1.0
let n = 200
let maxLimit = int(Math.Round(float(n)*1.5))
let FunctionValues = [|1..maxLimit|] |> Array.map (fun k -> f n k)
let sums = FunctionValues |> Array.map (fun i -> Array.map (fun j -> j + i) FunctionValues) |> Array.concat |> Array.sort
I found certain elements of the array sums that I want to find some integers that when input into the function f and then added will equal the value in sums. I could store the integers in sums, but I found that this destroys my memory.
Now I have two algorithms. Algorithm 1 uses a simple loop and a mutable int to store the values I care about. It shouldn't be very efficient since there isn't a break statement when it finds all the possible integers. I tried implementing Algorithm 2 that is more functional style, but I found it slower (~10% slower or 4200ms vs 4600ms with n = 10000), despite Seq being lazy. Why is this?
Algorithm 1:
let mutable a = 0
let mutable b = 0
let mutable c = 0
let mutable d = 0
for i in 1..maxLimit do
for j in i..maxLimit do
if sums.[bestI] = f n i + f n j then
a <- i
b <- j
if sums.[bestMid] = f n i + f n j then
c <- i
d <- j
Algorithm 2:
let findNM x =
let seq = {1..maxLimit} |> Seq.map (fun k -> (f n k, k))
let get2nd3rd (a, b, c) = (b, c)
seq |> Seq.map (fun (i, n) -> Seq.map (fun (j, m) -> (j + i, n, m) ) seq)
|> Seq.concat |> Seq.find (fun (i, n, m) -> i = x)
|> get2nd3rd
let digitsBestI = findNM sums.[bestI]
let digitsBestMid = findNM sums.[bestMid]
let a = fst digitsBestI
let b = snd digitsBestI
let c = fst digitsBestMid
let d = snd digitsBestMid
Edit: Note that the array sums is length maxLimit*maxLimit not length n. bestI and bestMid are then indices between 0 and maxLimit*maxLimit. For the purposes of this question they can be any number in that range. Their specific values are not particularly relevant.
I extended OPs code a bit in order to profile it
open System
let f n k = exp (double(k)/double(n)) - 1.0
let outer = 200
let n = 200
let maxLimit= int(Math.Round(float(n)*1.5))
let FunctionValues = [|1..maxLimit|] |> Array.map (fun k -> f n k)
let random = System.Random 19740531
let sums = FunctionValues |> Array.map (fun i -> Array.map (fun j -> j + i) FunctionValues) |> Array.concat |> Array.sort
let bests =
[| for i in [1..outer] -> (random.Next (n, maxLimit*maxLimit), random.Next (n, maxLimit*maxLimit))|]
let stopWatch =
let sw = System.Diagnostics.Stopwatch ()
sw.Start ()
sw
let timeIt (name : string) (a : int*int -> 'T) : unit =
let t = stopWatch.ElapsedMilliseconds
let v = a (bests.[0])
for i = 1 to (outer - 1) do
a bests.[i] |> ignore
let d = stopWatch.ElapsedMilliseconds - t
printfn "%s, elapsed %d ms, result %A" name d v
let algo1 (bestI, bestMid) =
let mutable a = 0
let mutable b = 0
let mutable c = 0
let mutable d = 0
for i in 1..maxLimit do
for j in i..maxLimit do
if sums.[bestI] = f n i + f n j then
a <- i
b <- j
if sums.[bestMid] = f n i + f n j then
c <- i
d <- j
a,b,c,d
let algo2 (bestI, bestMid) =
let findNM x =
let seq = {1..maxLimit} |> Seq.map (fun k -> (f n k, k))
let get2nd3rd (a, b, c) = (b, c)
seq |> Seq.map (fun (i, n) -> Seq.map (fun (j, m) -> (j + i, n, m) ) seq)
|> Seq.concat |> Seq.find (fun (i, n, m) -> i = x)
|> get2nd3rd
let digitsBestI = findNM sums.[bestI]
let digitsBestMid = findNM sums.[bestMid]
let a = fst digitsBestI
let b = snd digitsBestI
let c = fst digitsBestMid
let d = snd digitsBestMid
a,b,c,d
let algo3 (bestI, bestMid) =
let rec find best i j =
if best = f n i + f n j then i, j
elif i = maxLimit && j = maxLimit then 0, 0
elif j = maxLimit then find best (i + 1) 1
else find best i (j + 1)
let a, b = find sums.[bestI] 1 1
let c, d = find sums.[bestMid] 1 1
a, b, c, d
let algo4 (bestI, bestMid) =
let rec findI bestI mid i j =
if bestI = f n i + f n j then
let x, y = mid
i, j, x, y
elif i = maxLimit && j = maxLimit then 0, 0, 0, 0
elif j = maxLimit then findI bestI mid (i + 1) 1
else findI bestI mid i (j + 1)
let rec findMid ii bestMid i j =
if bestMid = f n i + f n j then
let x, y = ii
x, y, i, j
elif i = maxLimit && j = maxLimit then 0, 0, 0, 0
elif j = maxLimit then findMid ii bestMid (i + 1) 1
else findMid ii bestMid i (j + 1)
let rec find bestI bestMid i j =
if bestI = f n i + f n j then findMid (i, j) bestMid i j
elif bestMid = f n i + f n j then findI bestI (i, j) i j
elif i = maxLimit && j = maxLimit then 0, 0, 0, 0
elif j = maxLimit then find bestI bestMid (i + 1) 1
else find bestI bestMid i (j + 1)
find sums.[bestI] sums.[bestMid] 1 1
[<EntryPoint>]
let main argv =
timeIt "algo1" algo1
timeIt "algo2" algo2
timeIt "algo3" algo3
timeIt "algo4" algo4
0
The test results on my machine:
algo1, elapsed 438 ms, result (162, 268, 13, 135)
algo2, elapsed 1012 ms, result (162, 268, 13, 135)
algo3, elapsed 348 ms, result (162, 268, 13, 135)
algo4, elapsed 322 ms, result (162, 268, 13, 135)
algo1 uses the naive for loop implementation. algo2 uses a more refined algorithm relying on Seq.find. I describe algo3 and algo4 later.
OP wondered why the naive algo1 performed better even it does more work than the algo2 that is based around lazy Seq (essentially an IEnumerable<>).
The answer is Seq abstraction introduces overhead and prevents useful optimizations from occuring.
I usually resort to looking at the generated IL code in order to understand what's going (There are many good decompilers for .NET like ILSpy).
Let's look at algo1 (decompiled to C#)
// Program
public static Tuple<int, int, int, int> algo1(int bestI, int bestMid)
{
int a = 0;
int b = 0;
int c = 0;
int d = 0;
int i = 1;
int maxLimit = Program.maxLimit;
if (maxLimit >= i)
{
do
{
int j = i;
int maxLimit2 = Program.maxLimit;
if (maxLimit2 >= j)
{
do
{
if (Program.sums[bestI] == Math.Exp((double)i / (double)200) - 1.0 + (Math.Exp((double)j / (double)200) - 1.0))
{
a = i;
b = j;
}
if (Program.sums[bestMid] == Math.Exp((double)i / (double)200) - 1.0 + (Math.Exp((double)j / (double)200) - 1.0))
{
c = i;
d = j;
}
j++;
}
while (j != maxLimit2 + 1);
}
i++;
}
while (i != maxLimit + 1);
}
return new Tuple<int, int, int, int>(a, b, c, d);
}
algo1 is then expanded to an efficient while loop. In addition f is inlined. The JITter is easily able to create efficient machine code from this.
When we look at algo2 unpacking the full structure is too much for this post so I focus on findNM
internal static Tuple<int, int> findNM#48(double x)
{
IEnumerable<Tuple<double, int>> seq = SeqModule.Map<int, Tuple<double, int>>(new Program.seq#49(), Operators.OperatorIntrinsics.RangeInt32(1, 1, Program.maxLimit));
FSharpTypeFunc get2nd3rd = new Program.get2nd3rd#50-1();
Tuple<double, int, int> tupledArg = SeqModule.Find<Tuple<double, int, int>>(new Program.findNM#52-1(x), SeqModule.Concat<IEnumerable<Tuple<double, int, int>>, Tuple<double, int, int>>(SeqModule.Map<Tuple<double, int>, IEnumerable<Tuple<double, int, int>>>(new Program.findNM#51-2(seq), seq)));
FSharpFunc<Tuple<double, int, int>, Tuple<int, int>> fSharpFunc = (FSharpFunc<Tuple<double, int, int>, Tuple<int, int>>)((FSharpTypeFunc)((FSharpTypeFunc)get2nd3rd.Specialize<double>()).Specialize<int>()).Specialize<int>();
return Program.get2nd3rd#50<double, int, int>(tupledArg);
}
We see that it requires creation of multiple objects implementing IEnumerable<> as well as functions objects that are passed to higher order functions like Seq.find. While it is in principle possible for the JITter to inline the loop it most likely won't because of time-constraints and memory reasons. This means each call to the function object is a virtual call, virtual calls are quite expensive (tip: check the machine code). Because the virtual call might do anything that in turn prevents optimizations such as using SIMD instructions.
The OP noted that F# loop expressions lacks break/continue constructs which are useful when writing efficient for loops. F# do however support it implicitly in that if you write a tail-recursive function F# unwinds this into an efficient loop that uses break/continue to exit early.
algo3 is an example of implementing algo2 using tail-recursion. The disassembled code is something like this:
internal static Tuple<int, int> find#66(double best, int i, int j)
{
while (best != Math.Exp((double)i / (double)200) - 1.0 + (Math.Exp((double)j / (double)200) - 1.0))
{
if (i == Program.maxLimit && j == Program.maxLimit)
{
return new Tuple<int, int>(0, 0);
}
if (j == Program.maxLimit)
{
double arg_6F_0 = best;
int arg_6D_0 = i + 1;
j = 1;
i = arg_6D_0;
best = arg_6F_0;
}
else
{
double arg_7F_0 = best;
int arg_7D_0 = i;
j++;
i = arg_7D_0;
best = arg_7F_0;
}
}
return new Tuple<int, int>(i, j);
}
This enables us to write idiomatic functional code and yet get very good performance while avoiding stack overflows.
Before I realized how good tail-recursion is implemented in F# I tried to write efficient while loops with mutable logic in the while test expression. For the sake of humanity that code is abolished from existence now.
algo4 is an optimized version in that it only iterates of sums once for both bestMid and bestI much like algo1 but algo4 exits early if it can.
Hope this helps

Recursion to Iteration - Scheme to C

Can somewone help me convert this scheme function:
#lang racket
(define (powmod2 x e m)
(define xsqrm (remainder (* x x) m))
(cond [(= e 0) 1]
[(even? e) (powmod2 xsqrm (/ e 2) m)]
[(odd? e) (remainder (* x (powmod2 xsqrm (/ (sub1 e) 2) m)) m)]))
Into a function in C, and don't use recursion i.e use iteration.
I'm out of ideas', the part that is bothering me is when e is odd and then the recursive call is in the remainder function. I dont know how to transfer that in a while loop? any tips or suggestions:
This is what i have so far:
int powmod2(int x, int e, int m) {
int i = x;
int xsqrm = ((x * x) % m);
while (e != 0){
if (e%2 == 0) {
x = xsqrm;
e = (e/2);
xsqrm = ((x * x) % m);
}
else {
i = x;
x = xsqrm;
e = (e - 1)/2;
xsqrm = ((x * x) % m);
}
}
e = 1;
return (i*e)%m;
}
The even version is straightforward because the code has been written tail recursively so the call to (powmod2 xsqrm (/ e 2) m) can be expressed iteratively by replacing e with half of e and x with its square modulo m:
int powmod2(int x, int e, int m) { /* version for when e is a power of 2 */
while ((e /= 2) != 0)
x = (x * x) % m;
return x;
}
However the odd version has not been written tail recursively. One approach is to create a helper method that uses an accumulator. This helper method can then be written tail recursively for both even and odd exponent. You can then transform that into an iteration.
You are having trouble doing the conversion because the original scheme code is not tail recursive. Try to add extra parameters to powmod2 so that you do not need to do the multiplication by remainder in the odd case after calling the recursive function.
To illustrate, its hard to loopify the following function
int fac(n){
if(n == 0) {
return 1;
}else{
return n * fac(n-1)
}
}
But it is easy to loopify the version with an accumulation parameter
int fac(n, res){
if(n == 0) {
return res;
}else{
return fac(n-1, res*n)
}
}
int real_fac(n){ return fac(n, 1); }
Perhaps if you were to run the algorithm with some values to see how the result is calculated, it can help you figure out how to convert it. Let's see a single run for x=5, e=5 and m=7:
1. x=5, e=5, m=7
xsqrm=4
e:odd => res = 5*...%7
2. x=4, e=2, m=7
xsqrm=2
e:even => res = ...%7
3. x=2, e=1, m=7
xsqrm=4
e:odd => res = 2*...%7
4. x=4, e=0, m=7
e==0 => res = 1
res = 5*2%7=3
At step 1, we get a partial calculation for the result: it is 5 times the result of next step mod 7. At step 2, since it is even the result is the same as the result of the next step. At step 3, we've got something similar to step 1. The result we'll feed upstairs is calculated by multiplying next result by 2 (mod 7 again). And at termination, we've got our result to feed upstairs: 1. Now, as we go up, we just know how to calculate res: 2*1%7 for step 3, 2 for step 2, and 2*5%7 for step 1.
One way to implement it is to use a stack. At every partial result, if the exponent is odd, we can push the multiplication factor to the stack, and once we terminate, we can just multiply them all. This is the naive/cheating method for conversion.
There is a more efficient way that you should be able to see when you look at the steps above. Also other answers about converting everything to tail recursive is a very good hint.
The easiest way is to reason what is the original function trying to compute? This is the value of x to the power e module m. If you express e in binary, you can get e = e0 * 1 + e1 * 2 + e2 * 4 + e3 * 8 + ..., where en is either 0 or 1. And x^n = x * e0 + x ^ 2 * e1 + x ^ 4 * e2 + x ^ 8 * e3 + ....
By using the mathematical properties of the modulo operator, ie. (a + b) % m = ((a % m) + (b % m)) % m and (a * b) % m = ((a % m) * (b % m)) % m, we can then rewrite the function as:
int powmod2(int x, int e, int m) {
// This correspond to (= e 0)
int r = 1;
while (e != 0) {
if (e % 2) {
// This correspond to (odd? e)
r = (r * x) % m;
}
// This correspond to the recursive call
// that is done whatever the parity of e.
x = (x * x) % m;
e /= 2;
}
return r;
}
The first step would be writing the original Scheme procedure as a tail recursion. Notice that this rewrite works because of the properties of modular arithmetic:
(define (powmod2 x e m)
(define (iter x e acc)
(let ((xsqrm (remainder (* x x) m)))
(cond ((zero? e) acc)
((even? e) (iter xsqrm (/ e 2) acc))
(else (iter xsqrm (/ (sub1 e) 2) (remainder (* x acc) m))))))
(iter x e 1))
The key element of the above procedure is that the answer is passed in the acc parameter. Now we have a tail recursion, after that the conversion to a fully iterative solution is pretty straightforward:
int powmod2(int x, int e, int m) {
int acc = 1;
int xsqrm = 0;
while (e != 0) {
xsqrm = (x * x) % m;
if (e % 2 == 0) {
x = xsqrm;
e = e / 2;
}
else {
acc = (x * acc) % m;
x = xsqrm;
e = (e - 1) / 2;
}
}
return acc;
}
It can be optimized further, like this:
int powmod2(int x, int e, int m) {
int acc = 1;
while (e) {
if (e & 1) {
e--;
acc = (x * acc) % m;
}
x = (x * x) % m;
e >>= 1;
}
return acc;
}

Resources