Skip to content

Commit f5bd781

Browse files
committed
Dimension::strides_equivalent
1 parent 59eadb5 commit f5bd781

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/dimension/dimension_trait.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,25 @@ pub trait Dimension:
229229
!end_iteration
230230
}
231231

232+
/// Returns `true` iff `strides1` and `strides2` are equivalent for the
233+
/// shape `self`.
234+
///
235+
/// The strides are equivalent if, for each axis with length > 1, the
236+
/// strides are equal.
237+
///
238+
/// Note: Returns `false` if any of the ndims don't match.
239+
#[doc(hidden)]
240+
fn strides_equivalent<D>(&self, strides1: &Self, strides2: &D) -> bool
241+
where
242+
D: Dimension,
243+
{
244+
let shape_ndim = self.ndim();
245+
shape_ndim == strides1.ndim()
246+
&& shape_ndim == strides2.ndim()
247+
&& izip!(self.slice(), strides1.slice(), strides2.slice())
248+
.all(|(&d, &s1, &s2)| d <= 1 || s1 as isize == s2 as isize)
249+
}
250+
232251
#[doc(hidden)]
233252
/// Return stride offset for index.
234253
fn stride_offset(index: &Self, strides: &Self) -> isize {

src/impl_methods.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,8 +1917,7 @@ where
19171917
{
19181918
debug_assert_eq!(self.shape(), rhs.shape());
19191919

1920-
// Same shape and order should have same strides
1921-
if self.strides() == rhs.strides() {
1920+
if self.dim.strides_equivalent(&self.strides, &rhs.strides) {
19221921
if let Some(self_s) = self.as_slice_memory_order_mut() {
19231922
if let Some(rhs_s) = rhs.as_slice_memory_order() {
19241923
for (s, r) in self_s.iter_mut().zip(rhs_s) {

0 commit comments

Comments
 (0)