File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed
proteinworkshop/models/graph_encoders/layers Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change 1616* Add support for handling backward OOMs gracefully [ #83 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/83 )
1717* Update GCPNet paper link [ #85 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/85 )
1818* Add ability for ` BenchmarkModel ` to have its decoder disabled [ #101 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/101 )
19+ * Fix dtype mismatch in ` gcp.py ` that broke Automatic Mixed Precision (AMP) training [ #102 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/102 )
1920
2021### Framework
2122
Original file line number Diff line number Diff line change @@ -265,14 +265,20 @@ def scalarize(
265265
266266 if node_mask is not None :
267267 edge_mask = node_mask [row ] & node_mask [col ]
268+ # Initialize destination tensor
268269 local_scalar_rep_i = torch .zeros (
269270 (edge_index .shape [1 ], 3 , 3 ), device = edge_index .device
270271 )
271- local_scalar_rep_i [edge_mask ] = torch .matmul (
272+ # Calculate the source value (result of matmul, likely Half under AMP)
273+ matmul_result = torch .matmul (
272274 frames [edge_mask ], vector_rep_i [edge_mask ]
273275 )
276+ # Explicitly cast the source value to the destination's dtype before assignment
277+ local_scalar_rep_i [edge_mask ] = matmul_result .to (local_scalar_rep_i .dtype )
278+
274279 local_scalar_rep_i = local_scalar_rep_i .transpose (- 1 , - 2 )
275280 else :
281+ # This path might need similar treatment if it causes issues
276282 local_scalar_rep_i = torch .matmul (frames , vector_rep_i ).transpose (- 1 , - 2 )
277283
278284 # potentially enable E(3)-equivariance and, thereby, chirality-invariance
You can’t perform that action at this time.
0 commit comments