Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
### Fixed
- Raised an error when an expression is used when a variable is required
- Fixed some compile warnings
- Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr.
### Changed
- MatrixExpr.sum() now supports axis arguments and can return either a scalar or MatrixExpr, depending on the result dimensions.
- AddMatrixCons() also accepts ExprCons.
Expand Down
5 changes: 4 additions & 1 deletion src/pyscipopt/matrix.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ class MatrixExpr(np.ndarray):

def __rsub__(self, other):
return super().__rsub__(other).view(MatrixExpr)


def __matmul__(self, other):
return super().__matmul__(other).view(MatrixExpr)

class MatrixGenExpr(MatrixExpr):
pass

Expand Down
17 changes: 17 additions & 0 deletions tests/test_matrix_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,20 @@ def test_matrix_cons_indicator():
assert m.getVal(is_equal).sum() == 2
assert (m.getVal(x) == m.getVal(y)).all().all()
assert (m.getVal(x) == np.array([[5, 5, 5], [5, 5, 5]])).all().all()


def test_matrix_matmul_return_type():
# test #1058, require returning type is MatrixExpr not MatrixVariable
m = Model()

# test 1D @ 1D → 0D
x = m.addMatrixVar(3)
assert type(x @ x) is MatrixExpr

# test 1D @ 1D → 2D
assert type(x[:, None] @ x[None, :]) is MatrixExpr

# test 2D @ 2D → 2D
y = m.addMatrixVar((2, 3))
z = m.addMatrixVar((3, 4))
assert type(y @ z) is MatrixExpr
Loading