Skip to content

Commit 021fb8f

Browse files
committed
removing comments
1 parent 6a3f6a9 commit 021fb8f

File tree

1 file changed

+13
-56
lines changed

1 file changed

+13
-56
lines changed

src/lib.rs

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,49 @@
11
use pyo3::prelude::*;
22
use pyo3::exceptions::PyValueError;
33
use ndarray::linalg::Dot;
4-
use ndarray::{Array2, ArrayBase, Axis, Ix2, LinalgScalar, OwnedRepr};
5-
use std::ops::AddAssign;
4+
use ndarray::{Array2, ArrayBase, Ix2, LinalgScalar, OwnedRepr};
65

7-
// Define a custom type just to make this a bit more readable
86
type Matrix<A> = ArrayBase<OwnedRepr<A>, Ix2>;
97

10-
// Define a Rust function for exponentiating a matrix
11-
// This function has never even heard of Python!
128
pub fn matrix_power_rust<A>(
13-
matrix: &Matrix<A>, // <- type definitions like this are MANDATORY in Rust
14-
mut exponent: usize // by default, all objects are immutable. We need to mark them `mut`
15-
) -> Matrix<A> // to change them
9+
matrix: &Matrix<A>,
10+
mut exponent: usize
11+
) -> Matrix<A>
1612
where
1713
A: LinalgScalar,
1814
Matrix<A>: Dot<Matrix<A>, Output = Matrix<A>>
1915
{
20-
let mut result = Array2::eye(matrix.nrows()); // initializing identity matrix
16+
let mut result = Array2::eye(matrix.nrows());
2117
let mut base = matrix.to_owned();
22-
23-
// implementing the binary exponentiation algorithm matrices
24-
// works in O(log(N)) time
2518
while exponent > 0 {
26-
if exponent % 2 == 1 {
27-
result = result.dot(&base);
28-
}
19+
if exponent % 2 == 1 { result = result.dot(&base); }
2920
base = base.dot(&base);
3021
exponent /= 2;
3122
}
3223
result
3324
}
3425

35-
// Define a pyfunction: this is a function written in Rust, which is designed to
36-
// take inputs from and return outputs to Python
3726
#[pyfunction]
3827
pub fn matrix_power(
39-
// 1. Accept any Python object. A list of lists is expected
40-
matrix_obj: &Bound<'_, PyAny>, // <- Remember how type definitions are mandatory?
41-
exponent: usize, // Well, dealing with a language which doesn't have
42-
) -> PyResult<Vec<Vec<f64>>> { // strong types, we sometimes need a type that says 'I dunno :P'
43-
44-
// 2. Try to extract the Python object into a Rust nested vector of floats: `Vec<Vec<f64>>`
45-
// PyO3 will return a PyErr (raising a Python TypeError) if the object
46-
// doesn't have the right structure (e.g., it's not a list of lists of floats)
28+
matrix_obj: &Bound<'_, PyAny>,
29+
exponent: usize,
30+
) -> PyResult<Vec<Vec<f64>>> {
4731
let nested_vecs: Vec<Vec<f64>> = matrix_obj.extract()?;
48-
// !^!
49-
// this `?` notation means 'if this returns an error,
50-
// return it from this function early'
51-
// in this case, it would return a Python TypeError
52-
53-
// if we succeed, we transform it into a two-dimensional matrix of floats: ndarray::Array2<f64>
5432

55-
// Get the dimensions from the nested Vecs
5633
let n_rows = nested_vecs.len();
57-
if n_rows == 0 {
58-
// Handle empty matrix case, returning an error early
59-
return Err(PyValueError::new_err("Matrices cannot be empty".to_string()));
60-
}
34+
if n_rows == 0 { return Err(PyValueError::new_err("Matrices cannot be empty".to_string())); }
6135
let n_cols = nested_vecs[0].len();
6236

63-
// Flatten the nested vectors into a single vector, just like np.flatten()
6437
let flat_vec: Vec<f64> = nested_vecs.into_iter().flatten().collect();
65-
// !^! !^!
66-
// What are these doing here? If you want to iterate over something, you need
67-
// to be explicit about it! This means 'transform flat_vec into an iterator in place,
68-
// flatten it, and then collect it into a new vector of floats'
69-
// Because we used .into_iter() instead of .iter(), nested_vecs no longer exists beyond this
70-
// line!
71-
72-
// 3. Create the ndarray Array from the shape and the flat data
38+
7339
let rust_matrix = Array2::from_shape_vec((n_rows, n_cols), flat_vec)
7440
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
75-
// !^!
76-
// This will return an error if the number
77-
// of elements doesn't match the shape
7841

79-
// 4. Call our all-Rust matrix power function
8042
let result_matrix = matrix_power_rust(&rust_matrix, exponent);
8143

82-
// 5. Convert the result back into a Python-friendly type, a Vec<Vec<f64>>
83-
let result_vecs: Vec<Vec<f64>> = result_matrix
84-
.rows()
85-
.into_iter()
86-
.map(|row| row.to_vec())
87-
.collect();
44+
let result_vecs: Vec<Vec<f64>> = result_matrix.rows().into_iter()
45+
.map(|row| row.to_vec()) .collect();
8846

89-
// 6. Return the result to Python
9047
Ok(result_vecs)
9148
}
9249

0 commit comments

Comments
 (0)