Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions examples/lennard_jones_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Lennard-Jones Potential Optimization Example

This script demonstrates how to use TensorCircuit's differentiable lattice geometries
to optimize crystal structure. It finds the equilibrium lattice constant that minimizes
the total Lennard-Jones potential energy of a 2D square lattice.

The optimization showcases the key Task 3 capability: making lattice parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users have no idea what taks 3 is, please rephrase

differentiable for variational material design.
"""

import optax
import numpy as np
import matplotlib.pyplot as plt

# Try to enable JAX 64-bit precision if available (safe fallback)

try: # pragma: no cover - optional optimization
from jax import config as jax_config # type: ignore

jax_config.update("jax_enable_x64", True)
except Exception: # broad: environment may not have config attribute
pass
import tensorcircuit as tc # noqa: E402


tc.set_dtype("float64") # Use tc for universal control
K = tc.set_backend("jax")


def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
"""
Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a).
This version creates the lattice inside the function to demonstrate truly differentiable geometry.
"""
lattice_constant = K.exp(log_a)

# Create lattice with the differentiable parameter
size = (4, 4) # Smaller size for demonstration
lattice = tc.templates.lattice.SquareLattice(
size, lattice_constant=lattice_constant, pbc=True
)
d = lattice.distance_matrix

d_safe = K.where(d > 1e-9, d, K.convert_to_tensor(1e-9))

term12 = K.power(sigma / d_safe, 12)
term6 = K.power(sigma / d_safe, 6)
potential_matrix = 4 * epsilon * (term12 - term6)

num_sites = lattice.num_sites
# Zero out self-interactions (diagonal elements)
eye_mask = K.eye(num_sites, dtype=potential_matrix.dtype)
potential_matrix = potential_matrix * (1 - eye_mask)

potential_energy = K.sum(potential_matrix) / 2.0

return potential_energy


# Create a lambda function for optimization
potential_fun_for_grad = lambda log_a: calculate_potential(log_a)
value_and_grad_fun = K.jit(K.value_and_grad(potential_fun_for_grad))

optimizer = optax.adam(learning_rate=0.01)

log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

staring from 2.0


opt_state = optimizer.init(log_a)

history = {"a": [], "energy": []}

print("Starting optimization of lattice constant...")
for i in range(200):
energy, grad = value_and_grad_fun(log_a)

history["a"].append(K.exp(log_a))
history["energy"].append(energy)

# Check for NaN gradients using TensorCircuit's backend-agnostic approach
if K.sum(tc.num_to_tensor(np.isnan(K.numpy(grad)))) > 0:
print(f"Gradient became NaN at iteration {i+1}. Stopping optimization.")
print(f"Current energy: {energy}, Current log_a: {log_a}")
break

updates, opt_state = optimizer.update(grad, opt_state)
log_a = optax.apply_updates(log_a, updates)

if (i + 1) % 20 == 0:
current_a = K.exp(log_a)
print(
f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}"
)

final_a = K.exp(log_a)
final_energy = calculate_potential(log_a)

if not np.isnan(K.numpy(final_energy)):
print("\nOptimization finished!")
print(f"Final optimized lattice constant: {final_a:.6f}")
print(f"Corresponding minimum total energy: {final_energy:.6f}")

# Vectorized calculation for the potential curve
a_vals = np.linspace(0.8, 1.5, 200)
log_a_vals = K.log(K.convert_to_tensor(a_vals))

# Use vmap to create a vectorized version of the potential function
vmap_potential = K.vmap(lambda la: calculate_potential(la))
potential_curve = vmap_potential(log_a_vals)

plt.figure(figsize=(10, 6))
plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue")
plt.scatter(
history["a"],
history["energy"],
color="red",
s=20,
zorder=5,
label="Optimization Steps",
)
plt.scatter(
final_a,
final_energy,
color="green",
s=100,
zorder=6,
marker="*",
label="Final Optimized Point",
)

plt.title("Lennard-Jones Potential Optimization")
plt.xlabel("Lattice Constant (a)")
plt.ylabel("Total Potential Energy")
plt.legend()
plt.grid(True)
plt.show()
else:
print("\nOptimization failed. Final energy is NaN.")
86 changes: 86 additions & 0 deletions tensorcircuit/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,70 @@ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
"Backend '{}' has not implemented `argsort`.".format(self.name)
)

def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
"""
Sort a tensor along the given axis.

:param a: [description]
:type a: Tensor
:param axis: [description], defaults to -1
:type axis: int, optional
:return: [description]
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `sort`.".format(self.name)
)

def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
"""
Test whether all array elements along a given axis evaluate to True.

:param a: Input tensor
:type a: Tensor
:param axis: Axis or axes along which a logical AND reduction is performed,
defaults to None
:type axis: Optional[Sequence[int]], optional
:return: A new boolean or tensor resulting from the AND reduction
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `all`.".format(self.name)
)

def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
"""
Return coordinate matrices from coordinate vectors.

:param args: coordinate vectors
:type args: Any
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing)
:type kwargs: Any
:return: list of coordinate matrices
:rtype: Any
"""
raise NotImplementedError(
"Backend '{}' has not implemented `meshgrid`.".format(self.name)
)

def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
"""
Expand the shape of a tensor.
Insert a new axis that will appear at the `axis` position in the expanded
tensor shape.

:param a: Input tensor
:type a: Tensor
:param axis: Position in the expanded axes where the new axis is placed
:type axis: int
:return: Output tensor with the number of dimensions increased by one.
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `expand_dims`.".format(self.name)
)

def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
"""
Find the unique elements and their corresponding counts of the given tensor ``a``.
Expand Down Expand Up @@ -1404,6 +1468,28 @@ def cond(
"Backend '{}' has not implemented `cond`.".format(self.name)
)

def where(
self: Any,
condition: Tensor,
x: Optional[Tensor] = None,
y: Optional[Tensor] = None,
) -> Tensor:
"""
Return a tensor of elements selected from either x or y, depending on condition.

:param condition: Where True, yield x, otherwise yield y.
:type condition: Tensor (bool)
:param x: Values from which to choose when condition is True.
:type x: Tensor
:param y: Values from which to choose when condition is False.
:type y: Tensor
:return: A tensor with elements from x where condition is True, and y otherwise.
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `where`.".format(self.name)
)

def switch(
self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
) -> Tensor:
Expand Down
4 changes: 3 additions & 1 deletion tensorcircuit/backends/cupy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def __init__(self) -> None:
cpx = cupyx
self.name = "cupy"

def convert_to_tensor(self, a: Tensor) -> Tensor:
def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
if not isinstance(a, cp.ndarray) and not cp.isscalar(a):
a = cp.array(a)
a = cp.asarray(a)
if dtype is not None:
a = self.cast(a, dtype)
return a

def sum(
Expand Down
36 changes: 34 additions & 2 deletions tensorcircuit/backends/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,17 @@ def update(self, grads: pytree, params: pytree) -> pytree:
return params


def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor:
def _convert_to_tensor_jax(
self: Any, tensor: Tensor, dtype: Optional[str] = None
) -> Tensor:
if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor):
raise TypeError(
("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}")
)
result = jnp.asarray(tensor)
if dtype is not None:
# Use the backend's cast method to handle dtype conversion
result = self.cast(result, dtype)
return result


Expand Down Expand Up @@ -243,8 +248,10 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
def copy(self, tensor: Tensor) -> Tensor:
return jnp.array(tensor, copy=True)

def convert_to_tensor(self, tensor: Tensor) -> Tensor:
def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
result = jnp.asarray(tensor)
if dtype is not None:
result = self.cast(result, dtype)
return result

def abs(self, a: Tensor) -> Tensor:
Expand Down Expand Up @@ -390,6 +397,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
return jnp.argsort(a, axis=axis)

def sort(self, a: Tensor, axis: int = -1) -> Tensor:
return jnp.sort(a, axis=axis)

def unique_with_counts( # type: ignore
self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
) -> Tuple[Tensor, Tensor]:
Expand All @@ -410,6 +420,9 @@ def onehot(self, a: Tensor, num: int) -> Tensor:
def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
return jnp.cumsum(a, axis)

def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
return jnp.all(a, axis=axis)

def is_tensor(self, a: Any) -> bool:
if not isinstance(a, jnp.ndarray):
return False
Expand Down Expand Up @@ -812,4 +825,23 @@ def wrapper(

vvag = vectorized_value_and_grad

def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
"""
Backend-agnostic meshgrid function.
"""
return jnp.meshgrid(*args, **kwargs)

optimizer = optax_optimizer

def expand_dims(self, a: Tensor, axis: int) -> Tensor:
return jnp.expand_dims(a, axis)

def where(
self,
condition: Tensor,
x: Optional[Tensor] = None,
y: Optional[Tensor] = None,
) -> Tensor:
if x is None and y is None:
return jnp.where(condition)
return jnp.where(condition, x, y)
Loading