index `jax` array with variable dimension - arrays

I am trying to write a general utility to update indices in a jax array that may have a different number of dimensions depending on the instance.
I know that I have to use the .at[].set() methods, and this is what I have so far:
b = np.arange(16).reshape([4,4])
print(b)
update_indices = np.array([[1,1], [3,2], [0,3]])
update_indices = np.moveaxis(update_indices, -1, 0)
b = b.at[update_indices[0], update_indices[1]].set([333, 444, 555])
print(b)
This transforms:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
into
[[ 0 1 2 555]
[ 4 333 6 7]
[ 8 9 10 11]
[ 12 13 444 15]]
My problem is that I have had to hard code the argument to at as update_indices[0], update_indices[1]. However, in general b could have an arbitrary number of dimensions so this will not work. (e.g. for a 3D array I would have to replace it with update_indices[0], update_indices[1], update_indices[2]).
It would be nice if I could write something like b.at[*update_indices] but this does not work.

This should work:
b.at[tuple(update_indices)]

Related

Is there a matlab function for splitting an array for several array?

I want to split an array into several arrays automatically. For example:
a=[1 2 3 4 5 6 7 8 9]
b=[2 5]
Thus, I want to split it to:
c1=[1 2]
c2=[3 4 5]
c3=[6 7 8 9]
How to do it?
A simple way is to use mat2cell:
a = [1 2 3 4 5 6 7 8 9];
b = [2 5];
c = mat2cell(a, 1, diff([0 b numel(a)]));
This gives a cell array c containing the subarrays of a:
>> celldisp(c)
c{1} =
1 2
c{2} =
3 4 5
c{3} =
6 7 8 9

How to print maximum value of a specific row/column in numpy array?

#input
import numpy as np
arr = np.array([(1,2,3,4),(5,6,7,8),(9,10,11,12)])
print(arr)
print(np.max(arr))
print(np.max(arr,0))
print(np.max(arr,1))
#output
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
12
[ 9 10 11 12]
[ 4 8 12]
I am getting the max value of the whole matrix, max value of all rows, and columns but how can I get the maximum value of a certain row to suppose I want to print the max value of only row 1 (i.e 8) not of each and every row.
How can I do that?
if
np.max(arr, 1)
gives you the maximum of each row:
[ 4 8 12]
you can simply use:
np.max(arr, 1)[i]
to get the max value of row i
or alternatively:
np.max(arr[i])

Given an index of choices for each column, construct a 1D array from a 2D array

I have a 2D array such as:
julia> m = [1 2 3 4 5
6 7 8 9 10
11 12 13 14 15]
3×5 Array{Int64,2}:
1 2 3 4 5
6 7 8 9 10
11 12 13 14 15
I want to pick one value from each column and construct a 1D array.
So for instance, if my choices are
julia> choices = [1, 2, 3, 2, 1]
5-element Array{Int64,1}:
1
2
3
2
1
Then the desired output is [1, 7, 13, 9, 5]. What's the best way to do that? In my particular application, I am randomly generating these values, e.g.
choices = rand(1:size(m)[1], size(m)[2])
Thank you!
This is probably the simplest approach:
[m[c, i] for (i, c) in enumerate(choices)]
EDIT:
If best means fastest for you such a function should be approximately 2x faster than the comprehension for large m:
function selector(m, choices)
v = similar(m, size(m, 2))
for i in eachindex(choices)
#inbounds v[i] = m[choices[i], i]
end
v
end

MATLAB: Delete elements in cell array with certain length

how can I delete all elements of a cell array that have less then for example 5 elements inside.
result{1}= 1
result{2}= [2 3 4 5 6 7 8]
result{3}= [9 10 11 12 13 14 16 17 18]
result{4}= [19 20 21]
In this example I want to delete result{1} and result{4}, because they have less than 5 elements inside.
With this topic ( matlab length of each element in cell array) I know how to get the length of each element, but how is possible to delete elements of a specific length?
Just choose the ones that have more than 4 elements by logical indexing:
result = result(cellfun('length', result) >= 5);
This code will do what you need. But the above answer from Mohsen is very compact and nice.
result{1}= 1;
result{2}= {2 3 4 5 6 7 8};
result{3}= {9 10 11 12 13 14 16 17 18};
result{4}= {19 20 21};
i = 1;
while i<=size(result,2)
if size(result{i},2)<5
result(i)=[];
end
i = i+1;
end

Array filtering based on cell content

I have a cell of length n where each number is a numeric array of varying length.
eg
C = { [ 1 2 3] ; [ 4 1 ] ; [ 28 5 15] }
And a 4xn numeric array
eg
A = [[ 1 2 3 4] ; [ 5 6 7 8 ] ; [ 9 10 11 12]]
I'd like to filter the numeric array A based on the content in cell C.
The filter may be to return all rows in A which have a 28 in the corresponding element in C.
ans = [ 9 10 11 12 ]
Or, the filter may be to return all rows in A which have a 1 in the first column of C or a 5 in the second column of C.
ans = [[ 1 2 3 4] ; [ 9 10 11 12]]
Hope this makes sense! It's the correlation the vectors in the cell to the main array which I'm struggling with
Cellfun makes this relatively straightforward - design the function that returns a logical vector matching your filter requirements (i.e., it maps each vector in C to a single logical scalar depending on the conditions), and make this the first input to cellfun. Your cell array is the second input. The output of this will be your nx1 "filter" vector. Then apply this along the dimension of A that has length n, and use a colon operator in the other dimension.
First one:
A(cellfun(#(x) ismember(28, x), C), :);
ans =
9 10 11 12
Second one:
A(cellfun(#(x) (x(1)==1) || (x(2)==5), C), :)
ans =
1 2 3 4
9 10 11 12

Resources