Skip to content

Conversation

@smsharma
Copy link
Owner

@smsharma smsharma commented Mar 1, 2025

Warning

Untested, generated by Claude Code.

Summary

  • Implements the ERWIN (Efficient Hierarchical Transformer) model based on the paper
  • Port of the original PyTorch implementation to JAX/Flax
  • Includes ball tree construction for local attention operations
  • Implements a hierarchical transformer with pooling/unpooling operations
  • Adds tests for the core functionality

Implementation Details

  • Supports both node-level and graph-level prediction tasks
  • Uses JAX-native operations for the ball tree construction
  • Includes Ball Multi-Head Self-Attention (BMSA) for local attention

Note

The radius graph construction is not fully JIT-compatible when computing the graph structure dynamically. Users should pre-compute edges or provide them directly when using within JIT-compiled functions.

🤖 Generated with Claude Code
Co-Authored-By: Claude noreply@anthropic.com

This PR took ~10 minutes and ~$7 to generate.

Implement ERWIN (Efficient Hierarchical Transformer) model for point cloud data:
- Added hierarchical ball tree construction
- Implemented Ball Multi-Head Self-Attention (BMSA)
- Created UNet-style encoder-decoder architecture
- Added comprehensive tests for all components

This implementation offers O(n) complexity instead of O(n²) through local
ball attention and hierarchical representation learning.
@maxxxzdn
Copy link

maxxxzdn commented Mar 2, 2025

haha Claude rocks!

However, I don't think the code will work. Particularly,

  • build_balltree_from_jraph does something weird; it doesn't even start recursion, so I assume it will not produce a meaningful ball tree.
  • on the same note, I spent quite some time implementing the whole procedure in C++, so it is as fast as possible. It is designed to work with ragged arrays, but it will be straightforward to make a wrapper to handle batches of same-sized point clouds/graphs. Although not sure how much XLA will like it.
  • Claude doesn't seem to get that the tree we use is always complete, meaning the number of leaves is always a power of 2. This makes handling balls, pooling, and unpooling so trivial - simply reshape and mean. Right now, there is unnecessary complexity in pooling/attention with padding.
  • Overall, Claude overcomplicated tree handling a lot. For example, with leaf extraction at the end of the decoder stage. Although it might work, Claude should have just used argsort(idx) imo.

I used Claude a lot when developing Erwin, but it required extensive parenting, especially when it came to anything related to ball tree handling. I cannot blame him for it since I use ball trees rather unorthodoxly, but it is something to keep in mind :)

I am going to implement Erwin in JAX once I am done polishing it in PyTorch (the current repo is not the final one, e.g. there is no cross-ball interaction yet). If Claude does it before, I am afraid we have AGI on our hands haha

@adam-hartshorne
Copy link

adam-hartshorne commented Aug 11, 2025

I don't know if you intend to reimplement Erwin in JAX, but the Foreign Function Interface is one option to address the many of the issues discussed above,

https://docs.jax.dev/en/latest/ffi.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants