-
Notifications
You must be signed in to change notification settings - Fork 11
feat(lattice): Make lattice geometries differentiable and backend-agn… #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
9e01be8
9d22384
d71d4a1
bb65592
0ad707c
92bc8e4
efaee05
7063c6f
0660abf
589763e
daa3ff2
9575be5
d372f72
0b38522
04aca93
283e1fd
494a99b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
""" | ||
Lennard-Jones Potential Optimization Example | ||
|
||
This script demonstrates how to use TensorCircuit's differentiable lattice geometries | ||
to optimize crystal structure. It finds the equilibrium lattice constant that minimizes | ||
the total Lennard-Jones potential energy of a 2D square lattice. | ||
|
||
The optimization showcases the key Task 3 capability: making lattice parameters | ||
differentiable for variational material design. | ||
""" | ||
|
||
import optax | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import tensorcircuit as tc | ||
|
||
|
||
tc.set_dtype("float64") # Use tc for universal control | ||
K = tc.set_backend("jax") | ||
|
||
|
||
def calculate_potential(log_a, epsilon=0.5, sigma=1.0): | ||
""" | ||
Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a). | ||
This version creates the lattice inside the function to demonstrate truly differentiable geometry. | ||
""" | ||
lattice_constant = K.exp(log_a) | ||
|
||
# Create lattice with the differentiable parameter | ||
size = (4, 4) # Smaller size for demonstration | ||
lattice = tc.templates.lattice.SquareLattice( | ||
size, lattice_constant=lattice_constant, pbc=True | ||
) | ||
d = lattice.distance_matrix | ||
|
||
d_safe = K.where(d > 1e-9, d, K.convert_to_tensor(1e-9)) | ||
|
||
term12 = K.power(sigma / d_safe, 12) | ||
term6 = K.power(sigma / d_safe, 6) | ||
potential_matrix = 4 * epsilon * (term12 - term6) | ||
|
||
num_sites = lattice.num_sites | ||
# Zero out self-interactions (diagonal elements) | ||
eye_mask = K.eye(num_sites, dtype=potential_matrix.dtype) | ||
potential_matrix = potential_matrix * (1 - eye_mask) | ||
|
||
potential_energy = K.sum(potential_matrix) / 2.0 | ||
|
||
return potential_energy | ||
|
||
|
||
# Create value and grad function for optimization | ||
value_and_grad_fun = K.jit(K.value_and_grad(calculate_potential)) | ||
|
||
optimizer = optax.adam(learning_rate=0.01) | ||
|
||
log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. staring from 2.0 |
||
|
||
opt_state = optimizer.init(log_a) | ||
|
||
history = {"a": [], "energy": []} | ||
|
||
print("Starting optimization of lattice constant...") | ||
for i in range(200): | ||
energy, grad = value_and_grad_fun(log_a) | ||
|
||
history["a"].append(K.exp(log_a)) | ||
history["energy"].append(energy) | ||
|
||
# (Removed previously added blanket NaN guard per reviewer request to keep example minimal.) | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
updates, opt_state = optimizer.update(grad, opt_state) | ||
log_a = optax.apply_updates(log_a, updates) | ||
|
||
if (i + 1) % 20 == 0: | ||
current_a = K.exp(log_a) | ||
print( | ||
f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}" | ||
) | ||
|
||
final_a = K.exp(log_a) | ||
final_energy = calculate_potential(log_a) | ||
|
||
if not np.isnan(K.numpy(final_energy)): | ||
print("\nOptimization finished!") | ||
print(f"Final optimized lattice constant: {final_a:.6f}") | ||
print(f"Corresponding minimum total energy: {final_energy:.6f}") | ||
|
||
# Vectorized calculation for the potential curve | ||
a_vals = np.linspace(0.8, 1.5, 200) | ||
log_a_vals = K.log(K.convert_to_tensor(a_vals)) | ||
|
||
# Use vmap to create a vectorized version of the potential function | ||
vmap_potential = K.vmap(lambda la: calculate_potential(la)) | ||
potential_curve = vmap_potential(log_a_vals) | ||
|
||
plt.figure(figsize=(10, 6)) | ||
plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") | ||
plt.scatter( | ||
history["a"], | ||
history["energy"], | ||
color="red", | ||
s=20, | ||
zorder=5, | ||
label="Optimization Steps", | ||
) | ||
plt.scatter( | ||
final_a, | ||
final_energy, | ||
color="green", | ||
s=100, | ||
zorder=6, | ||
marker="*", | ||
label="Final Optimized Point", | ||
) | ||
|
||
plt.title("Lennard-Jones Potential Optimization") | ||
plt.xlabel("Lattice Constant (a)") | ||
plt.ylabel("Total Potential Energy") | ||
plt.legend() | ||
plt.grid(True) | ||
plt.show() | ||
else: | ||
print("\nOptimization failed. Final energy is NaN.") | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -596,6 +596,82 @@ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor: | |
"Backend '{}' has not implemented `argsort`.".format(self.name) | ||
) | ||
|
||
def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor: | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Sort a tensor along the given axis. | ||
|
||
:param a: [description] | ||
:type a: Tensor | ||
:param axis: [description], defaults to -1 | ||
:type axis: int, optional | ||
:return: [description] | ||
:rtype: Tensor | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `sort`.".format(self.name) | ||
) | ||
|
||
def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: | ||
""" | ||
Test whether all array elements along a given axis evaluate to True. | ||
|
||
:param a: Input tensor | ||
:type a: Tensor | ||
:param axis: Axis or axes along which a logical AND reduction is performed, | ||
defaults to None | ||
:type axis: Optional[Sequence[int]], optional | ||
:return: A new boolean or tensor resulting from the AND reduction | ||
:rtype: Tensor | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `all`.".format(self.name) | ||
) | ||
|
||
def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: | ||
""" | ||
Return coordinate matrices from coordinate vectors. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one more space for the docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still one more space on the above line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still one more space on the above line! |
||
|
||
:param args: coordinate vectors | ||
:type args: Any | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing' | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing). | ||
- 'ij': matrix indexing, first dimension corresponds to rows (default) | ||
- 'xy': Cartesian indexing, first dimension corresponds to columns | ||
Example: | ||
>>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy') | ||
Shapes: | ||
- x.shape == (2, 2) # rows correspond to y vector length | ||
- y.shape == (2, 2) | ||
Values: | ||
x = [[0, 1], | ||
[0, 1]] | ||
y = [[0, 0], | ||
[2, 2]] | ||
:type kwargs: Any | ||
:return: list of coordinate matrices | ||
:rtype: Any | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `meshgrid`.".format(self.name) | ||
) | ||
|
||
def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor: | ||
""" | ||
Expand the shape of a tensor. | ||
Insert a new axis that will appear at the `axis` position in the expanded | ||
tensor shape. | ||
|
||
:param a: Input tensor | ||
:type a: Tensor | ||
:param axis: Position in the expanded axes where the new axis is placed | ||
:type axis: int | ||
:return: Output tensor with the number of dimensions increased by one. | ||
:rtype: Tensor | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `expand_dims`.".format(self.name) | ||
) | ||
|
||
def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Find the unique elements and their corresponding counts of the given tensor ``a``. | ||
|
@@ -733,6 +809,21 @@ def cast(self: Any, a: Tensor, dtype: str) -> Tensor: | |
"Backend '{}' has not implemented `cast`.".format(self.name) | ||
) | ||
|
||
def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor: | ||
""" | ||
Convert input to tensor. | ||
|
||
:param a: input data to be converted | ||
:type a: Tensor | ||
:param dtype: target dtype, optional | ||
:type dtype: Optional[str] | ||
:return: converted tensor | ||
:rtype: Tensor | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `convert_to_tensor`.".format(self.name) | ||
) | ||
|
||
def mod(self: Any, x: Tensor, y: Tensor) -> Tensor: | ||
""" | ||
Compute y-mod of x (negative number behavior is not guaranteed to be consistent) | ||
|
@@ -1404,6 +1495,28 @@ def cond( | |
"Backend '{}' has not implemented `cond`.".format(self.name) | ||
) | ||
|
||
def where( | ||
self: Any, | ||
condition: Tensor, | ||
x: Optional[Tensor] = None, | ||
y: Optional[Tensor] = None, | ||
) -> Tensor: | ||
""" | ||
Return a tensor of elements selected from either x or y, depending on condition. | ||
|
||
:param condition: Where True, yield x, otherwise yield y. | ||
:type condition: Tensor (bool) | ||
:param x: Values from which to choose when condition is True. | ||
:type x: Tensor | ||
:param y: Values from which to choose when condition is False. | ||
:type y: Tensor | ||
:return: A tensor with elements from x where condition is True, and y otherwise. | ||
:rtype: Tensor | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `where`.".format(self.name) | ||
) | ||
|
||
def switch( | ||
self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]] | ||
) -> Tensor: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
users have no idea what taks 3 is, please rephrase