|
1 | 1 | use pyo3::prelude::*; |
2 | 2 | use pyo3::exceptions::PyValueError; |
3 | 3 | use ndarray::linalg::Dot; |
4 | | -use ndarray::{Array2, ArrayBase, Axis, Ix2, LinalgScalar, OwnedRepr}; |
5 | | -use std::ops::AddAssign; |
| 4 | +use ndarray::{Array2, ArrayBase, Ix2, LinalgScalar, OwnedRepr}; |
6 | 5 |
|
7 | | -// Define a custom type just to make this a bit more readable |
8 | 6 | type Matrix<A> = ArrayBase<OwnedRepr<A>, Ix2>; |
9 | 7 |
|
10 | | -// Define a Rust function for exponentiating a matrix |
11 | | -// This function has never even heard of Python! |
12 | 8 | pub fn matrix_power_rust<A>( |
13 | | - matrix: &Matrix<A>, // <- type definitions like this are MANDATORY in Rust |
14 | | - mut exponent: usize // by default, all objects are immutable. We need to mark them `mut` |
15 | | -) -> Matrix<A> // to change them |
| 9 | + matrix: &Matrix<A>, |
| 10 | + mut exponent: usize |
| 11 | +) -> Matrix<A> |
16 | 12 | where |
17 | 13 | A: LinalgScalar, |
18 | 14 | Matrix<A>: Dot<Matrix<A>, Output = Matrix<A>> |
19 | 15 | { |
20 | | - let mut result = Array2::eye(matrix.nrows()); // initializing identity matrix |
| 16 | + let mut result = Array2::eye(matrix.nrows()); |
21 | 17 | let mut base = matrix.to_owned(); |
22 | | - |
23 | | - // implementing the binary exponentiation algorithm matrices |
24 | | - // works in O(log(N)) time |
25 | 18 | while exponent > 0 { |
26 | | - if exponent % 2 == 1 { |
27 | | - result = result.dot(&base); |
28 | | - } |
| 19 | + if exponent % 2 == 1 { result = result.dot(&base); } |
29 | 20 | base = base.dot(&base); |
30 | 21 | exponent /= 2; |
31 | 22 | } |
32 | 23 | result |
33 | 24 | } |
34 | 25 |
|
35 | | -// Define a pyfunction: this is a function written in Rust, which is designed to |
36 | | -// take inputs from and return outputs to Python |
37 | 26 | #[pyfunction] |
38 | 27 | pub fn matrix_power( |
39 | | - // 1. Accept any Python object. A list of lists is expected |
40 | | - matrix_obj: &Bound<'_, PyAny>, // <- Remember how type definitions are mandatory? |
41 | | - exponent: usize, // Well, dealing with a language which doesn't have |
42 | | -) -> PyResult<Vec<Vec<f64>>> { // strong types, we sometimes need a type that says 'I dunno :P' |
43 | | - |
44 | | - // 2. Try to extract the Python object into a Rust nested vector of floats: `Vec<Vec<f64>>` |
45 | | - // PyO3 will return a PyErr (raising a Python TypeError) if the object |
46 | | - // doesn't have the right structure (e.g., it's not a list of lists of floats) |
| 28 | + matrix_obj: &Bound<'_, PyAny>, |
| 29 | + exponent: usize, |
| 30 | +) -> PyResult<Vec<Vec<f64>>> { |
47 | 31 | let nested_vecs: Vec<Vec<f64>> = matrix_obj.extract()?; |
48 | | - // !^! |
49 | | - // this `?` notation means 'if this returns an error, |
50 | | - // return it from this function early' |
51 | | - // in this case, it would return a Python TypeError |
52 | | - |
53 | | - // if we succeed, we transform it into a two-dimensional matrix of floats: ndarray::Array2<f64> |
54 | 32 |
|
55 | | - // Get the dimensions from the nested Vecs |
56 | 33 | let n_rows = nested_vecs.len(); |
57 | | - if n_rows == 0 { |
58 | | - // Handle empty matrix case, returning an error early |
59 | | - return Err(PyValueError::new_err("Matrices cannot be empty".to_string())); |
60 | | - } |
| 34 | + if n_rows == 0 { return Err(PyValueError::new_err("Matrices cannot be empty".to_string())); } |
61 | 35 | let n_cols = nested_vecs[0].len(); |
62 | 36 |
|
63 | | - // Flatten the nested vectors into a single vector, just like np.flatten() |
64 | 37 | let flat_vec: Vec<f64> = nested_vecs.into_iter().flatten().collect(); |
65 | | - // !^! !^! |
66 | | - // What are these doing here? If you want to iterate over something, you need |
67 | | - // to be explicit about it! This means 'transform flat_vec into an iterator in place, |
68 | | - // flatten it, and then collect it into a new vector of floats' |
69 | | - // Because we used .into_iter() instead of .iter(), nested_vecs no longer exists beyond this |
70 | | - // line! |
71 | | - |
72 | | - // 3. Create the ndarray Array from the shape and the flat data |
| 38 | + |
73 | 39 | let rust_matrix = Array2::from_shape_vec((n_rows, n_cols), flat_vec) |
74 | 40 | .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; |
75 | | - // !^! |
76 | | - // This will return an error if the number |
77 | | - // of elements doesn't match the shape |
78 | 41 |
|
79 | | - // 4. Call our all-Rust matrix power function |
80 | 42 | let result_matrix = matrix_power_rust(&rust_matrix, exponent); |
81 | 43 |
|
82 | | - // 5. Convert the result back into a Python-friendly type, a Vec<Vec<f64>> |
83 | | - let result_vecs: Vec<Vec<f64>> = result_matrix |
84 | | - .rows() |
85 | | - .into_iter() |
86 | | - .map(|row| row.to_vec()) |
87 | | - .collect(); |
| 44 | + let result_vecs: Vec<Vec<f64>> = result_matrix.rows().into_iter() |
| 45 | + .map(|row| row.to_vec()) .collect(); |
88 | 46 |
|
89 | | - // 6. Return the result to Python |
90 | 47 | Ok(result_vecs) |
91 | 48 | } |
92 | 49 |
|
|
0 commit comments