Skip to content

Commit c20458a

Browse files
authored
Merge pull request #754 from nilgoyette/zip_mut_with
zip_mut_with 'f' order
2 parents 2f06327 + f5bd781 commit c20458a

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

benches/iter.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,23 @@ fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) {
370370
.sum::<f32>()
371371
});
372372
}
373+
374+
pub fn zip_mut_with(data: &Array3<f32>, out: &mut Array3<f32>) {
375+
out.zip_mut_with(&data, |o, &i| {
376+
*o = i;
377+
});
378+
}
379+
380+
#[bench]
381+
fn zip_mut_with_cc(b: &mut Bencher) {
382+
let data: Array3<f32> = Array3::zeros((ISZ, ISZ, ISZ));
383+
let mut out = Array3::zeros(data.dim());
384+
b.iter(|| black_box(zip_mut_with(&data, &mut out)));
385+
}
386+
387+
#[bench]
388+
fn zip_mut_with_ff(b: &mut Bencher) {
389+
let data: Array3<f32> = Array3::zeros((ISZ, ISZ, ISZ).f());
390+
let mut out = Array3::zeros(data.dim().f());
391+
b.iter(|| black_box(zip_mut_with(&data, &mut out)));
392+
}

src/dimension/dimension_trait.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,25 @@ pub trait Dimension:
229229
!end_iteration
230230
}
231231

232+
/// Returns `true` iff `strides1` and `strides2` are equivalent for the
233+
/// shape `self`.
234+
///
235+
/// The strides are equivalent if, for each axis with length > 1, the
236+
/// strides are equal.
237+
///
238+
/// Note: Returns `false` if any of the ndims don't match.
239+
#[doc(hidden)]
240+
fn strides_equivalent<D>(&self, strides1: &Self, strides2: &D) -> bool
241+
where
242+
D: Dimension,
243+
{
244+
let shape_ndim = self.ndim();
245+
shape_ndim == strides1.ndim()
246+
&& shape_ndim == strides2.ndim()
247+
&& izip!(self.slice(), strides1.slice(), strides2.slice())
248+
.all(|(&d, &s1, &s2)| d <= 1 || s1 as isize == s2 as isize)
249+
}
250+
232251
#[doc(hidden)]
233252
/// Return stride offset for index.
234253
fn stride_offset(index: &Self, strides: &Self) -> isize {

src/impl_methods.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::cmp;
109
use std::ptr as std_ptr;
1110
use std::slice;
1211

@@ -1937,18 +1936,19 @@ where
19371936
F: FnMut(&mut A, &B),
19381937
{
19391938
debug_assert_eq!(self.shape(), rhs.shape());
1940-
if let Some(self_s) = self.as_slice_mut() {
1941-
if let Some(rhs_s) = rhs.as_slice() {
1942-
let len = cmp::min(self_s.len(), rhs_s.len());
1943-
let s = &mut self_s[..len];
1944-
let r = &rhs_s[..len];
1945-
for i in 0..len {
1946-
f(&mut s[i], &r[i]);
1939+
1940+
if self.dim.strides_equivalent(&self.strides, &rhs.strides) {
1941+
if let Some(self_s) = self.as_slice_memory_order_mut() {
1942+
if let Some(rhs_s) = rhs.as_slice_memory_order() {
1943+
for (s, r) in self_s.iter_mut().zip(rhs_s) {
1944+
f(s, &r);
1945+
}
1946+
return;
19471947
}
1948-
return;
19491948
}
19501949
}
1951-
// otherwise, fall back to the outer iter
1950+
1951+
// Otherwise, fall back to the outer iter
19521952
self.zip_mut_with_by_rows(rhs, f);
19531953
}
19541954

0 commit comments

Comments
 (0)