Skip to content

Commit f156004

Browse files
pankit-engfacebook-github-bot
authored andcommitted
Implement __str__ for Point struct py binding (#977)
Summary: Pull Request resolved: #977 Implement `__str__` method for the `Point` struct in the `monarch_hyperactor` library. The `__str__` method returns a string representation of the `Point` struct, which includes the labels and coordinates of the point. Also, added a unit test for the same. Related github issue #935 TODO: Plan to audit other structs with missing __str__. If there are missing ones, then those shall be sent via separate diff. Reviewed By: highker Differential Revision: D80862068 fbshipit-source-id: 11af7638fe49feeba5102e8af66d87370396d5af
1 parent 0eb5e6c commit f156004

File tree

4 files changed

+53
-4
lines changed

4 files changed

+53
-4
lines changed

monarch_hyperactor/src/shape.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,31 @@ impl PyPoint {
210210
}
211211
}
212212

213+
fn __str__(&self, py: Python) -> PyResult<String> {
214+
let shape = self.shape.bind(py).get();
215+
let inner_shape = &shape.inner;
216+
let slice = inner_shape.slice();
217+
let total_size = slice.len();
218+
let current_rank = self.rank;
219+
220+
let coords = slice
221+
.coordinates(current_rank)
222+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
223+
224+
// Create the underlying Point struct from ndslice::view
225+
let extent =
226+
ndslice::view::Extent::new(inner_shape.labels().to_vec(), slice.sizes().to_vec())
227+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
228+
let point = extent
229+
.point(coords)
230+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
231+
232+
Ok(format!(
233+
"rank={}/{} coords={{{}}}",
234+
current_rank, total_size, point
235+
))
236+
}
237+
213238
fn __repr__(&self, py: Python) -> PyResult<String> {
214239
let shape = self.shape.bind(py).get();
215240
let inner_shape = &shape.inner;
@@ -234,6 +259,7 @@ impl PyPoint {
234259

235260
let coords_str = coords_parts.join(",");
236261

262+
// TODO: Should we call the Display implementation of the Point struct using extent here as well?
237263
Ok(format!(
238264
"rank={}/{} coords={{{}}}",
239265
current_rank, total_size, coords_str

ndslice/src/view.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,13 @@ impl std::fmt::Display for Point {
370370
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371371
let n = self.coords.len();
372372
for i in 0..n {
373-
write!(f, "{}={}", self.extent.labels()[i], self.coords[i])?;
373+
write!(
374+
f,
375+
"{}={}/{}",
376+
self.extent.labels()[i],
377+
self.coords[i],
378+
self.extent.sizes()[i]
379+
)?;
374380
if i != n - 1 {
375381
write!(f, ",")?;
376382
}
@@ -851,7 +857,7 @@ mod test {
851857
fn test_point_display() {
852858
let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
853859
let point = extent.point(vec![1, 2, 3]).unwrap();
854-
assert_eq!(format!("{}", point), "x=1,y=2,z=3");
860+
assert_eq!(format!("{}", point), "x=1/4,y=2/5,z=3/6");
855861

856862
assert!(extent.point(vec![]).is_err());
857863

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class Shape:
106106
107107
Arguments:
108108
- `labels`: A list of strings representing the labels for each dimension.
109-
- `slice`: An Slice object representing the shape.
109+
- `slice`: A Slice object representing the shape.
110110
"""
111111
def __new__(cls, labels: Sequence[str], slice: Slice) -> "Shape": ...
112112
@property

python/tests/_monarch/test_ndslice.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
1414

15-
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
15+
from monarch._rust_bindings.monarch_hyperactor.shape import Point, Shape, Slice
1616

1717

1818
class TestNdslice(TestCase):
@@ -191,6 +191,23 @@ def test_shape_repr(self) -> None:
191191
)
192192

193193

194+
class TestPoint(TestCase):
195+
def test_point_str_simple(self) -> None:
196+
"""Test __str__ method for Point with simple 2D shape."""
197+
s = Slice(offset=0, sizes=[3, 4], strides=[4, 1])
198+
shape = Shape(["label0", "label1"], s)
199+
200+
# Test different ranks and their string representations
201+
point_0 = Point(0, shape)
202+
self.assertEqual(str(point_0), "rank=0/12 coords={label0=0/3,label1=0/4}")
203+
204+
point_3 = Point(3, shape)
205+
self.assertEqual(str(point_3), "rank=3/12 coords={label0=0/3,label1=3/4}")
206+
207+
point_11 = Point(11, shape)
208+
self.assertEqual(str(point_11), "rank=11/12 coords={label0=2/3,label1=3/4}")
209+
210+
194211
class TestSelection(TestCase):
195212
def test_constants(self) -> None:
196213
self.assertEqual(repr(Selection.any()), "Any(True)")

0 commit comments

Comments
 (0)