Function accepting ndarray and float at the same time - arrays

I would like to write a function computing the refractive index of some material as a function of wavelength. I use the ndarray crate for arrays, similar to numpy in Python.
At the moment the function is implemented the following way:
fn calc_refractive_index(l: &Array1<f64>) -> Array1<f64> {
fn refractive_index(&self, l: &Array1<f64>) -> Array1<f64> {
let V =
self.A[0] + self.A[1] / (1. + self.A[2] * (self.A[3] * l / self.pitch).mapv(f64::exp));
let W =
self.B[0] + self.B[1] / (1. + self.B[2] * (self.B[3] * l / self.pitch).mapv(f64::exp));
(self.material.refractive_index(&l).powi(2)
- 3. * l.powi(2) / 4. / (consts::PI * self.pitch).powi(2) * (V.powi(2) - W.powi(2)))
.mapv(f64::sqrt)
}
}
With the following trait implemented, to make potentiation shorter:
trait Squared<T> {
fn powi(&self, e: i32) -> T;
}
impl Squared<Array1<f64>> for Array1<f64> {
fn powi(&self, e: i32) -> Array1<f64> {
self.mapv(|a| a.powi(e))
}
}
However, I may also want to compute the refractive index at only one specific wavelength, so I would like to also accept float64 values. What is the best way to implement this, without implementing two separate functions?
Edit:
The function is using ndarray's syntax for squaring: l.mapv(|x| x.powi(2)), which is unfortunately different than for float64
Edit 2:
As asked I included the function body.

It is definitely possible, although it might take more work than simply adding another function. You will need to first define two helper traits that specify the mapv and powi (much like Squared) operations for both f64 and Array1<f64>:
trait PowI {
fn powi(&self, e: i32) -> Self;
}
impl PowI for f64 {
fn powi(&self, e: i32) -> Self {
f64::powi(*self, e)
}
}
impl PowI for Array1<f64> {
fn powi(&self, e: i32) -> Self {
self.mapv(|a| a.powi(e))
}
}
trait MapV {
fn mapv(&self, f: impl FnMut(f64) -> f64) -> Self;
}
impl MapV for f64 {
fn mapv(&self, mut f: impl FnMut(f64) -> f64) -> Self {
f(*self)
}
}
impl MapV for Array1<f64> {
fn mapv(&self, f: impl FnMut(f64) -> f64) -> Self {
Array1::mapv(self, f)
}
}
You can then make your refractive index calculation function be generic over a type T that implements both these traits (i.e. T: PowI + MapV).
Note that you will also need additional bounds on T that specify that it can be added, divided and multiplied with Ts and f64s (which I assume is the type of self.pitch and the elements in the self.A and self.B arrays). You can do this by requiring that T, &'a T or even f64 implement the appropriate Add, Mul and Div traits. For example, your implementation may require that f64: Mul<&'a T, Output=T> and T: Div<f64, Output=T> in addition to many other bounds. The compiler errors will help you figure out exactly which ones you need.

Related

Flatten boxed slice of arrays

Can I somehow flatten a boxed slice Box<[[T;SIZE]]> into Box<[T]> (unsafe code is ok if necessary) in Rust? I own the box and I want to do it in-place without making unnecessary copies.
It is more complex than it looks. It is necessary to modify PtrComponents::metadata (this heap variable holds length of array) as is shown in https://doc.rust-lang.org/src/core/slice/mod.rs.html#137
If you don't take this into account
pub fn flatten_box<T,const DIM:usize>(v:Box<[[T;DIM]]>)->Box<[T]>{
// SAFETY: raw pointer is created from Box and
// *mut [[T; N]] has the same aligment as *mut [T]
unsafe { Box::from_raw(Box::into_raw(v) as *mut [T]) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test6() {
let v:Box<[[u32;13]]> = unsafe{Box::<[[u32;13]]>::new_uninit_slice(7).assume_init()};
let l = v.len();
let v = flatten_box(v);
assert_eq!(v.len(),l*13);
}
}
you will get errors
Left: 7
Right: 91
<Click to see difference>
thread 'transmute::tests::test6' panicked at 'assertion failed: `(left == right)`
left: `7`,
right: `91`', vf\src\transmute.rs:76:9
You could try this:
fn transmute<T, const N: usize>(b: Box<[[T; N]]>) -> Box<[T]> {
// SAFETY: raw pointer is created from Box and
// *mut [[T; N]] has the same aligment as *mut [T]
unsafe { Box::from_raw(Box::into_raw(b) as *mut [T]) }
}
EDIT. This doesn't preserve the length of a slice. See comment below by #eggyal for solution.

Rust How do you define a default method for similar types?

A very common pattern I have to deal with is, I am given some raw byte data. This data can represent an array of floats, 2D vectors, Matrices...
I know the data is compact and properly aligned. In C usually you would just do:
vec3 * ptr = (vec3*)data;
And start reading from it.
I am trying to create a view to this kind of data in rust to be able to read and write to the buffer as follows:
pub trait AccessView<T>
{
fn access_view<'a>(
offset : usize,
length : usize,
buffer : &'a Vec<u8>) -> &'a mut [T]
{
let bytes = &buffer[offset..(offset + length)];
let ptr = bytes.as_ptr() as *mut T;
return unsafe { std::slice::from_raw_parts_mut(ptr, length / size_of::<T>()) };
}
}
And then calling it:
let data: &[f32] =
AccessView::<f32>::access_view(0, 32, &buffers[0]);
The idea is, I should be able to replace f32 with vec3 or mat4 and get a slice view into the underlying data.
This is crashing with:
--> src/main.rs:341:9
|
341 | AccessView::<f32>::access_view(&accessors[0], &buffer_views, &buffers);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot infer type
|
= note: cannot satisfy `_: AccessView<f32>`
How could I use rust to achieve my goal? i.e. have a generic "template" for turning a set of raw bytes into a range checked slice view casted to some type.
There are two important problems I can identify:
You are using a trait incorrectly. You have to connect a trait to an actual type. If you want to call it the way you do, it needs to be a struct instead.
Soundness. You are creating a mutable reference from an immutable one through unsafe code. This is unsound and dangerous. By using unsafe, you tell the compiler that you manually verified that your code is sound, and the borrow checker should blindly believe you. Your code, however, is not sound.
To part 1, #BlackBeans gave you a good answer already. I would still do it a little differently, though. I would directly imlement the trait for &[u8], so you can write data.access_view::<T>().
To part 2, you at least need to make the input data &mut. Further, make sure they have the same lifetime, otherwise the compiler might not realize that they are actually connected.
Also, don't use &Vec<u8> as an argument; in general, use slices (&[u8]) instead.
Be aware that with all that said, there still is the problem of ENDIANESS. The behavior you will get will not be consistent between platforms. Use other means of conversion instead if that is something you require. Do not put this code in a generic library, at max use it for your own personal project.
That all said, here is what I came up with:
pub trait AccessView {
fn access_view<'a, T>(&'a mut self, offset: usize, length: usize) -> &'a mut [T];
}
impl AccessView for [u8] {
fn access_view<T>(&mut self, offset: usize, length: usize) -> &mut [T] {
let bytes = &mut self[offset..(offset + length)];
let ptr = bytes.as_ptr() as *mut T;
return unsafe { std::slice::from_raw_parts_mut(ptr, length / ::std::mem::size_of::<T>()) };
}
}
impl AccessView for Vec<u8> {
fn access_view<T>(&mut self, offset: usize, length: usize) -> &mut [T] {
self.as_mut_slice().access_view(offset, length)
}
}
fn main() {
let mut data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
println!("{:?}", data);
let float_view: &mut [f32] = data.access_view(2, 4);
float_view[0] = 42.0;
println!("{:?}", float_view);
println!("{:?}", data);
// println!("{:?}", float_view); // Adding this would cause a compiler error, which shows that we implemented lifetimes correctly
}
[1, 2, 3, 4, 5, 6, 7, 8]
[42.0]
[1, 2, 0, 0, 40, 66, 7, 8]
I think you didn't understood exactly what traits are. Traits represent a characteristic of a type, for instance, since I know the size at compile-time of u32 (32 bits), u32 implements the marker trait Sized, noted u32: Sized. A more feature-complete trait could be the Default one: if there is a "default" way of building of type T, then we can implement Default for it, so that now there is a standard default way of building it.
In your example, you are using a trait as a namespace for functions, ie you could simply have
fn access_view<'a, T>(
offset: usize,
length: usize,
buffer: &'a [u8]
) -> &'a mut T
{
let bytes = &buffer[offset..offset+length];
let ptr = bytes.as_ptr() as *mut T;
unsafe {
std::slice::from_raw_parts_mut(ptr, length / size_of::<T>()
}
}
Or, if you want to put it as a trait:
trait Viewable {
fn access_view<'a>(
offset: usize,
length: usize,
buffer: &'a [u8],
) -> &'a mut [Self]
{
let bytes = &buffer[offset..offset+length];
let ptr = bytes.as_ptr() as *mut T;
unsafe {
std::slice::from_raw_parts_mut(ptr, length / size_of::<T>()
}
}
}
Then implement it:
impl<T> Viewable for T {}
Or, again, differently
trait Viewable {
fn access_view<'a>(
offset: usize,
length: usize,
buffer: &'a [u8],
) -> &'a mut [Self];
}
impl<T> Viewable for T {
fn access_view<'a>(
offset: usize,
length: usize,
buffer: &'a [u8],
) -> &'a mut [Self]
{
let bytes = &buffer[offset..offset+length];
let ptr = bytes.as_ptr() as *mut T;
unsafe {
std::slice::from_raw_parts_mut(ptr, length / size_of::<T>()
}
}
}
Although all this way to structure the code will somehow produce the same result, it doesn't mean they're equivalent. Maybe you should learn a little bit more about traits before using them.
Also, your code, as is, really seems unsound, in the sense that you make a call to an unsafe function without any checking (ie. what if I call it with random nonsense in buffer?). It doesn't mean it is (we don't have access to the rest of your code), but you should be careful about that: Rust is not C.
Finally, your error simply comes from the fact that it's impossible for Rust to find out which type T you are calling the associated method access_view of.

How can I get the largest element from array, using generics and traits?

I just started learning rust and I'm creatively trying somethings as I read "the Rust Book".
I know it is possible to create a generic method to get the largest element from an array, like the following:
fn largest<T: PartialOrd + Copy> (nums: &[T]) -> T {
let mut largest = nums[0];
for &el in nums {
if el > largest {
largest = el;
}
}
return largest;
}
And calling the main function like this:
fn main() {
let list: Vec<u32> = vec![1,7,4];
println!("{}", largest(&list)); // 7
}
How would I go doing the same thing but "extending" the array, like this:
fn main() {
let list: Vec<u32> = vec![1,7,4];
println!("{}", list.largest()); // 7
}
I guess the final question is: if it is possible, would it be a bad practice? Why?
I tried something like this, but didn't manage to figure out how to implement the "Largeble" trait returning the T:
pub trait Largeble {
fn largest(&self);
}
impl<T: Copy + PartialOrd + Display> Largeble for Vec<T> {
fn largest(&self) {
let mut largest = match self.get(0) {
Some(&el) => el,
None => panic!("Non Empty array expected")
};
for &el in self {
if el > largest {
largest = el;
}
}
println!("{}", largest);
// return largest;
}
}
You need to make the Largeable trait return a T from the Vec<T>, which you can do with an associated type:
use std::fmt;
pub trait Largeble {
type Output;
fn largest(&self) -> Self::Output;
}
impl<T: Copy + PartialOrd + fmt::Display> Largeble for Vec<T> {
type Output = T;
fn largest(&self) -> T {
let mut largest = match self.get(0) {
Some(&el) => el,
None => panic!("Non Empty array expected")
};
for &el in self {
if el > largest {
largest = el;
}
}
largest
}
}
println!("{}", vec![1, 2, 3, 2].largest()); // prints "3"
Traits like Largeable are usually called extension traits, since they extend existing items with new features. Using extension traits to extend items in existing libraries is common in the Rust ecosystem. It's common to suffix the names of extensions with Ext (so a collection of additional methods for Vec would be called VecExt). A popular use of extension traits is the itertools library, which provides a trait that adds additional useful methods to Iterator in the standard library.
How would I go doing the same thing but "extending" the array
Sure, your code snippet was close, you can create a trait and implement it on types.
pub trait Largeble<T>
where
T: Ord,
{
fn largest(&self) -> Option<&T>;
}
impl<T> Largeble<T> for Vec<T>
where
T: Ord,
{
fn largest(&self) -> Option<&T> {
// Iterator already has a method for getting max which simplifies things
self.iter().max()
}
}
Alternatively, you can make T an associated type which may be better suited to this example.
You can run this code in the Rust Playground.
Would it be a bad practice? Why?
Nope, it's definitely not bad practise. It is a very common way to developing in Rust and can be very powerful. Your trait needs to be in scope for you to be able to call .largest(), so it does not pollute anything.
Additionally, if you have multiple methods with the same name from different traits, you can provide a longer syntax to specify the exact trait you want to use: Largest::<u32>::largest(&list).
I tried something like this, but didn't manage to figure out how to implement the "Largeble" trait returning the T.
Your code was mostly correct, but your largest method didn't return anything. That's why the trait needs a generic T, to specify that you will return T.

How can I return a Chain iterator with data added from a temporary array?

I'm writing a MQTT5 library. To send a packet, I need to know the size of the payload before writing the payload. My solution for determining the size has the following constraints order by importance:
be easy to maintain
should not create copies of the data
should be fairly performant (avoid double calculations)
To determine the size I can do any of the following solutions:
do the calculations by hand, which is fairly annoying
hold a copy of the data to send in memory, which I want to avoid
Build an std::iter::ExactSizeIterator for the payload which consists of std::iter::Chains itself, which leads to ugly typings fast, if you don't create wrapper types
I decided to go with version 3.
The example below shows my try on writing a MQTT String iterator. A MQTT String consists of two bytes which are the length of the string followed by the data as utf8.
use std::iter::*;
use std::slice::Iter;
pub struct MQTTString<'a> {
chain: Chain<Iter<'a, u8>, Iter<'a, u8>>,
}
impl<'a> MQTTString<'a> {
pub fn new(s: &'a str) -> Self {
let u16_len = s.len() as u16;
let len_bytes = u16_len.to_be_bytes();
let len_iter = len_bytes.iter(); // len_bytes is borrowed here
let s_bytes = s.as_bytes();
let s_iter = s_bytes.iter();
let chain = len_iter.chain(s_iter);
MQTTString { chain }
}
}
impl<'a> Iterator for MQTTString<'a> {
type Item = &'a u8;
fn next(&mut self) -> Option<&'a u8> {
self.chain.next()
}
}
impl<'a> ExactSizeIterator for MQTTString<'a> {}
pub struct MQTTStringPait<'a> {
chain: Chain<std::slice::Iter<'a, u8>, std::slice::Iter<'a, u8>>,
}
This implementation doesn't compile because I borrow len_bytes instead of moving it, so it'd get dropped before the Chain can consume it:
error[E0515]: cannot return value referencing local variable `len_bytes`
--> src/lib.rs:19:9
|
12 | let len_iter = len_bytes.iter(); // len_bytes is borrowed here
| --------- `len_bytes` is borrowed here
...
19 | MQTTString { chain }
| ^^^^^^^^^^^^^^^^^^^^ returns a value referencing data owned by the current function
Is there a nice way to do this? Adding len_bytes to the MQTTString struct doesn't help. Is there a better fourth option of solving the problem?
The root problem is that iter borrows the array. In nightly Rust, you can use array::IntoIter, but it does require that you change your iterator to return u8 instead of &u8:
#![feature(array_value_iter)]
use std::array::IntoIter;
use std::iter::*;
use std::slice::Iter;
pub struct MQTTString<'a> {
chain: Chain<IntoIter<u8, 2_usize>, Copied<Iter<'a, u8>>>,
}
impl<'a> MQTTString<'a> {
pub fn new(s: &'a str) -> Self {
let u16_len = s.len() as u16;
let len_bytes = u16_len.to_be_bytes();
let len_iter = std::array::IntoIter::new(len_bytes);
let s_bytes = s.as_bytes();
let s_iter = s_bytes.iter().copied();
let chain = len_iter.chain(s_iter);
MQTTString { chain }
}
}
impl<'a> Iterator for MQTTString<'a> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
self.chain.next()
}
}
impl<'a> ExactSizeIterator for MQTTString<'a> {}
You could do the same thing in stable Rust by using a Vec, but that'd be a bit of overkill. Instead, since you know the exact size of the array, you could get the values and chain more:
use std::iter::{self, *};
use std::slice;
pub struct MQTTString<'a> {
chain: Chain<Chain<Once<u8>, Once<u8>>, Copied<slice::Iter<'a, u8>>>,
}
impl<'a> MQTTString<'a> {
pub fn new(s: &'a str) -> Self {
let u16_len = s.len() as u16;
let [a, b] = u16_len.to_be_bytes();
let s_bytes = s.as_bytes();
let s_iter = s_bytes.iter().copied();
let chain = iter::once(a).chain(iter::once(b)).chain(s_iter);
MQTTString { chain }
}
}
impl<'a> Iterator for MQTTString<'a> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
self.chain.next()
}
}
impl<'a> ExactSizeIterator for MQTTString<'a> {}
See also:
How to implement Iterator and IntoIterator for a simple struct?
An iterator of &u8 is not a good idea from the point of view of pure efficiency. On a 64-bit system, &u8 takes up 64 bits, as opposed to the 8 bits that the u8 itself would take. Additionally, dealing with this data on a byte-by-byte basis will likely impede common optimizations around copying memory around.
Instead, I'd recommend creating something that can write itself to something implementing Write. One possible implementation:
use std::{
convert::TryFrom,
io::{self, Write},
};
pub struct MQTTString<'a>(&'a str);
impl MQTTString<'_> {
pub fn write_to(&self, mut w: impl Write) -> io::Result<()> {
let len = u16::try_from(self.0.len()).expect("length exceeded 16-bit");
let len = len.to_be_bytes();
w.write_all(&len)?;
w.write_all(self.0.as_bytes())?;
Ok(())
}
}
See also:
How do I convert between numeric types safely and idiomatically?
Converting number primitives (i32, f64, etc) to byte representations

How to implement a Rust function that modifies multiple elements of a vector? [duplicate]

fn change(a: &mut i32, b: &mut i32) {
let c = *a;
*a = *b;
*b = c;
}
fn main() {
let mut v = vec![1, 2, 3];
change(&mut v[0], &mut v[1]);
}
When I compile the code above, it has the error:
error[E0499]: cannot borrow `v` as mutable more than once at a time
--> src/main.rs:9:32
|
9 | change(&mut v[0], &mut v[1]);
| - ^ - first borrow ends here
| | |
| | second mutable borrow occurs here
| first mutable borrow occurs here
Why does the compiler prohibit it? v[0] and v[1] occupy different memory positions, so it's not dangerous to use these together. And what should I do if I come across this problem?
You can solve this with split_at_mut():
let mut v = vec![1, 2, 3];
let (a, b) = v.split_at_mut(1); // Returns (&mut [1], &mut [2, 3])
change(&mut a[0], &mut b[0]);
There are uncountably many safe things to do that the compiler unfortunately does not recognize yet. split_at_mut() is just like that, a safe abstraction implemented with an unsafe block internally.
We can do that too, for this problem. The following is something I use in code where I need to separate all three cases anyway (I: Index out of bounds, II: Indices equal, III: Separate indices).
enum Pair<T> {
Both(T, T),
One(T),
None,
}
fn index_twice<T>(slc: &mut [T], a: usize, b: usize) -> Pair<&mut T> {
if a == b {
slc.get_mut(a).map_or(Pair::None, Pair::One)
} else {
if a >= slc.len() || b >= slc.len() {
Pair::None
} else {
// safe because a, b are in bounds and distinct
unsafe {
let ar = &mut *(slc.get_unchecked_mut(a) as *mut _);
let br = &mut *(slc.get_unchecked_mut(b) as *mut _);
Pair::Both(ar, br)
}
}
}
}
Since Rust 1.26, pattern matching can be done on slices. You can use that as long as you don't have huge indices and your indices are known at compile-time.
fn change(a: &mut i32, b: &mut i32) {
let c = *a;
*a = *b;
*b = c;
}
fn main() {
let mut arr = [5, 6, 7, 8];
{
let [ref mut a, _, ref mut b, ..] = arr;
change(a, b);
}
assert_eq!(arr, [7, 6, 5, 8]);
}
The borrow rules of Rust need to be checked at compilation time, that is why something like mutably borrowing a part of a Vec is a very hard problem to solve (if not impossible), and why it is not possible with Rust.
Thus, when you do something like &mut v[i], it will mutably borrow the entire vector.
Imagine I did something like
let guard = something(&mut v[i]);
do_something_else(&mut v[j]);
guard.do_job();
Here, I create an object guard that internally stores a mutable reference to v[i], and will do something with it when I call do_job().
In the meantime, I did something that changed v[j]. guard holds a mutable reference that is supposed to guarantee nothing else can modify v[i]. In this case, all is good, as long as i is different from j; if the two values are equal it is a huge violation of the borrow rules.
As the compiler cannot guarantee that i != j, it is thus forbidden.
This was a simple example, but similar cases are legions, and are why such access mutably borrows the whole container. Plus the fact that the compiler actually does not know enough about the internals of Vec to ensure that this operation is safe even if i != j.
In your precise case, you can have a look at the swap(..) method available on Vec that does the swap you are manually implementing.
On a more generic case, you'll probably need an other container. Possibilities are wrapping all the values of your Vec into a type with interior mutability, such as Cell or RefCell, or even using a completely different container, as #llogiq suggested in his answer with par-vec.
The method [T]::iter_mut() returns an iterator that can yield a mutable reference for each element in the slice. Other collections have an iter_mut method too. These methods often encapsulate unsafe code, but their interface is totally safe.
Here's a general purpose extension trait that adds a method on slices that returns mutable references to two distinct items by index:
pub trait SliceExt {
type Item;
fn get_two_mut(&mut self, index0: usize, index1: usize) -> (&mut Self::Item, &mut Self::Item);
}
impl<T> SliceExt for [T] {
type Item = T;
fn get_two_mut(&mut self, index0: usize, index1: usize) -> (&mut Self::Item, &mut Self::Item) {
match index0.cmp(&index1) {
Ordering::Less => {
let mut iter = self.iter_mut();
let item0 = iter.nth(index0).unwrap();
let item1 = iter.nth(index1 - index0 - 1).unwrap();
(item0, item1)
}
Ordering::Equal => panic!("[T]::get_two_mut(): received same index twice ({})", index0),
Ordering::Greater => {
let mut iter = self.iter_mut();
let item1 = iter.nth(index1).unwrap();
let item0 = iter.nth(index0 - index1 - 1).unwrap();
(item0, item1)
}
}
}
}
On recent nightlies, there is get_many_mut():
#![feature(get_many_mut)]
fn main() {
let mut v = vec![1, 2, 3];
let [a, b] = v
.get_many_mut([0, 1])
.expect("out of bounds or overlapping indices");
change(a, b);
}
Building up on the answer by #bluss you can use split_at_mut() to create a function that can turn mutable borrow of a vector into a vector of mutable borrows of vector elements:
fn borrow_mut_elementwise<'a, T>(v:&'a mut Vec<T>) -> Vec<&'a mut T> {
let mut result:Vec<&mut T> = Vec::new();
let mut current: &mut [T];
let mut rest = &mut v[..];
while rest.len() > 0 {
(current, rest) = rest.split_at_mut(1);
result.push(&mut current[0]);
}
result
}
Then you can use it to get a binding that lets you mutate many items of original Vec at once, even while you are iterating over them (if you access them by index in your loop, not through any iterator):
let mut items = vec![1,2,3];
let mut items_mut = borrow_mut_elementwise(&mut items);
for i in 1..items_mut.len() {
*items_mut[i-1] = *items_mut[i];
}
println!("{:?}", items); // [2, 3, 3]
The problem is that &mut v[…] first mutably borrows v and then gives the mutable reference to the element to the change-function.
This reddit comment has a solution to your problem.
Edit: Thanks for the heads-up, Shepmaster. par-vec is a library that allows to mutably borrow disjunct partitions of a vec.
I publish my daily utils for this to crate.io. Link to the doc.
You may use it like
use arref::array_mut_ref;
let mut arr = vec![1, 2, 3, 4];
let (a, b) = array_mut_ref!(&mut arr, [1, 2]);
assert_eq!(*a, 2);
assert_eq!(*b, 3);
let (a, b, c) = array_mut_ref!(&mut arr, [1, 2, 0]);
assert_eq!(*c, 1);
// ⚠️ The following code will panic. Because we borrow the same element twice.
// let (a, b) = array_mut_ref!(&mut arr, [1, 1]);
It's a simple wrapper around the following code, which is sound. But it requires that the two indexes are different at runtime.
pub fn array_mut_ref<T>(arr: &mut [T], a0: usize, a1: usize) -> (&mut T, &mut T) {
assert!(a0 != a1);
// SAFETY: this is safe because we know a0 != a1
unsafe {
(
&mut *(&mut arr[a0] as *mut _),
&mut *(&mut arr[a1] as *mut _),
)
}
}
Alternatively, you may use a method that won't panic with mut_twice
#[inline]
pub fn mut_twice<T>(arr: &mut [T], a0: usize, a1: usize) -> Result<(&mut T, &mut T), &mut T> {
if a0 == a1 {
Err(&mut arr[a0])
} else {
unsafe {
Ok((
&mut *(&mut arr[a0] as *mut _),
&mut *(&mut arr[a1] as *mut _),
))
}
}
}

Resources