Skip to content

Commit 3101a2b

Browse files
sunqmQiming Sun
andauthored
Analytical Gradients for DFT+U (#459)
* DFT+U gradients for molecular DFT * Refactor kpoint DFT+U * Analytical gradients for KDFT+U * Fix KUKS+U linear response U method * Fix bug introduced in the gradients base method * miss import * DFT+U gradients fixes * Update KRKSpU and KUKSpU tests * Update pcm tests --------- Co-authored-by: Qiming Sun <qiming.sun@bytedance.com>
1 parent 5965dc4 commit 3101a2b

23 files changed

+1922
-199
lines changed

gpu4pyscf/dft/rkspu.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#!/usr/bin/env python
2+
# Copyright 2025 The PySCF Developers. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
'''
17+
DFT+U for molecules
18+
19+
See also the pbc.dft.krkspu and pbc.dft.kukspu module
20+
21+
Refs:
22+
Heather J. Kulik, J. Chem. Phys. 142, 240901 (2015)
23+
'''
24+
25+
import itertools
26+
import numpy as np
27+
import cupy as cp
28+
from pyscf import gto
29+
from pyscf.data.nist import HARTREE2EV
30+
from pyscf.lo.iao import reference_mol
31+
from gpu4pyscf.dft import rks
32+
from gpu4pyscf.lib import logger
33+
from gpu4pyscf.lib.cupy_helper import asarray
34+
35+
def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
36+
"""
37+
Coulomb + XC functional + Hubbard U terms for RKS+U.
38+
39+
.. note::
40+
This function will change the ks object.
41+
42+
Args:
43+
ks : an instance of :class:`RKS`
44+
XC functional are controlled by ks.xc attribute. Attribute
45+
ks.grids might be initialized.
46+
dm : ndarray or list of ndarrays
47+
A density matrix or a list of density matrices
48+
49+
Returns:
50+
Veff : ``(nao, nao)`` or ``(*, nao, nao)`` ndarray
51+
Veff = J + Vxc + V_U.
52+
"""
53+
if mol is None: mol = ks.mol
54+
if dm is None: dm = ks.make_rdm1()
55+
56+
# J + V_xc
57+
vxc = rks.get_veff(ks, mol, dm, dm_last=dm_last, vhf_last=vhf_last,
58+
hermi=hermi)
59+
60+
# V_U
61+
ovlp = asarray(mol.intor('int1e_ovlp', hermi=1))
62+
pmol = reference_mol(mol, ks.minao_ref)
63+
U_idx, U_val, U_lab = _set_U(mol, pmol, ks.U_idx, ks.U_val)
64+
# Construct orthogonal minao local orbitals.
65+
assert ks.C_ao_lo is None
66+
C_ao_lo = _make_minao_lo(mol, pmol)
67+
68+
alphas = ks.alpha
69+
if not hasattr(alphas, '__len__'): # not a list or tuple
70+
alphas = [alphas] * len(U_idx)
71+
72+
E_U = 0.0
73+
logger.info(ks, "-" * 79)
74+
lab_string = " "
75+
with np.printoptions(precision=5, suppress=True, linewidth=1000):
76+
for idx, val, lab, alpha in zip(U_idx, U_val, U_lab, alphas):
77+
if ks.verbose >= logger.INFO:
78+
lab_string = " "
79+
for l in lab:
80+
lab_string += "%9s" %(l.split()[-1])
81+
lab_sp = lab[0].split()
82+
logger.info(ks, "local rdm1 of atom %s: ",
83+
" ".join(lab_sp[:2]) + " " + lab_sp[2][:2])
84+
C_loc = C_ao_lo[:,idx]
85+
SC = ovlp.dot(C_loc) # ~ C^{-1}
86+
P = SC.conj().T.dot(dm).dot(SC)
87+
loc_sites = P.shape[-1]
88+
vhub_loc = (cp.eye(loc_sites) - P) * (val * 0.5)
89+
if alpha is not None:
90+
# The alpha perturbation is only applied to the linear term of
91+
# the local density.
92+
E_U += alpha * P.trace()
93+
vhub_loc += cp.eye(loc_sites) * alpha
94+
# vxc is a tagged array. The inplace updating avoids loosing the
95+
# tagged attributes.
96+
vxc[:] += SC.dot(vhub_loc).dot(SC.conj().T)
97+
E_U += (val * 0.5) * (P.trace() - P.dot(P).trace() * 0.5)
98+
logger.info(ks, "%s\n%s", lab_string, P)
99+
logger.info(ks, "-" * 79)
100+
101+
E_U = E_U.real.get()[()]
102+
if E_U < 0.0 and all(np.asarray(U_val) > 0):
103+
logger.warn(ks, "E_U (%s) is negative...", E_U)
104+
vxc.E_U = E_U
105+
return vxc
106+
107+
def energy_elec(mf, dm=None, h1e=None, vhf=None):
108+
"""
109+
Electronic energy for RKSpU.
110+
"""
111+
if dm is None: dm = mf.make_rdm1()
112+
if h1e is None: h1e = mf.get_hcore()
113+
if vhf is None: vhf = mf.get_veff(mf.mol, dm)
114+
e1 = cp.einsum('ij,ji->', h1e, dm).get()[()].real
115+
ecoul = vhf.ecoul.real
116+
exc = vhf.exc.real
117+
E_U = vhf.E_U
118+
if isinstance(ecoul, cp.ndarray):
119+
ecoul = ecoul.get()[()]
120+
if isinstance(exc, cp.ndarray):
121+
exc = exc.get()[()]
122+
e2 = ecoul + exc + E_U
123+
mf.scf_summary['e1'] = e1
124+
mf.scf_summary['coul'] = ecoul
125+
mf.scf_summary['exc'] = exc
126+
mf.scf_summary['E_U'] = E_U
127+
logger.debug(mf, 'E1 = %s Ecoul = %s Exc = %s EU = %s', e1, ecoul, exc, E_U)
128+
return e1+e2, e2
129+
130+
def _groupby(inp, labels):
131+
_, where, counts = np.unique(labels, return_index=True, return_counts=True)
132+
return [inp[start:start+count] for start, count in zip(where, counts)]
133+
134+
def _set_U(mol, minao_mol, U_idx, U_val):
135+
"""
136+
Regularize the U_idx and U_val to each atom,
137+
"""
138+
assert len(U_idx) == len(U_val)
139+
140+
ao_loc = minao_mol.ao_loc_nr()
141+
dims = ao_loc[1:] - ao_loc[:-1]
142+
# atm_ids labels the atom Id for each function
143+
atm_ids = np.repeat(minao_mol._bas[:,gto.ATOM_OF], dims)
144+
145+
ao_labels = mol.ao_labels()
146+
minao_labels = minao_mol.ao_labels()
147+
148+
U_indices = []
149+
U_values = []
150+
for i, idx in enumerate(U_idx):
151+
if isinstance(idx, str):
152+
lab_idx = minao_mol.search_ao_label(idx)
153+
# Group basis functions centered on the same atom
154+
for idxj in _groupby(lab_idx, atm_ids[lab_idx]):
155+
U_indices.append(idxj)
156+
U_values.append(U_val[i])
157+
else:
158+
# Map to MINAO indices
159+
idx_minao = [minao_labels.index(ao_labels[i]) for i in idx]
160+
U_indices.append(idx_minao)
161+
U_values.append(U_val[i])
162+
163+
if len(U_indices) == 0:
164+
logger.warn(mol, "No sites specified for Hubbard U. "
165+
"Please check if 'U_idx' is correctly specified")
166+
167+
U_values = np.asarray(U_values) / HARTREE2EV
168+
169+
U_labels = [[minao_labels[i] for i in idx] for idx in U_indices]
170+
return U_indices, U_values, U_labels
171+
172+
def _make_minao_lo(mol, minao_ref='minao'):
173+
'''
174+
Construct orthogonal minao local orbitals.
175+
'''
176+
if isinstance(minao_ref, str):
177+
minao_mol = reference_mol(mol, minao_ref)
178+
else:
179+
minao_mol = minao_ref
180+
ovlp = asarray(mol.intor('int1e_ovlp', hermi=1))
181+
s12 = asarray(gto.intor_cross('int1e_ovlp', mol, minao_mol))
182+
C_minao = cp.linalg.solve(ovlp, s12)
183+
S0 = C_minao.T.dot(ovlp).dot(C_minao)
184+
w2, v = cp.linalg.eigh(S0)
185+
C_minao = C_minao.dot((v*cp.sqrt(1./w2)).dot(v.T))
186+
return C_minao
187+
188+
def _format_idx(idx_list):
189+
string = ''
190+
for k, g in itertools.groupby(enumerate(idx_list), lambda ix: ix[0] - ix[1]):
191+
g = list(g)
192+
if len(g) > 1:
193+
string += '%d-%d, '%(g[0][1], g[-1][1])
194+
else:
195+
string += '%d, '%(g[0][1])
196+
return string[:-2]
197+
198+
def _print_U_info(mf, log):
199+
mol = mf.mol
200+
pmol = reference_mol(mol, mf.minao_ref)
201+
U_idx, U_val, U_lab = _set_U(mol, pmol, mf.U_idx, mf.U_val)
202+
alphas = mf.alpha
203+
if not hasattr(alphas, '__len__'): # not a list or tuple
204+
alphas = [alphas] * len(U_idx)
205+
log.info("-" * 79)
206+
log.info('U indices and values: ')
207+
for idx, val, lab, alpha in zip(U_idx, U_val, U_lab, alphas):
208+
log.info('%6s [%.6g eV] ==> %-100s', _format_idx(idx),
209+
val * HARTREE2EV, "".join(lab))
210+
if alpha is not None:
211+
log.info(' alpha for LR-cDFT %s (eV)',
212+
alpha * HARTREE2EV)
213+
log.info("-" * 79)
214+
215+
class RKSpU(rks.RKS):
216+
"""
217+
DFT+U for RKS
218+
"""
219+
220+
_keys = {"U_idx", "U_val", "C_ao_lo", "U_lab", 'minao_ref', 'alpha'}
221+
222+
get_veff = get_veff
223+
energy_elec = energy_elec
224+
to_hf = NotImplemented
225+
226+
def __init__(self, mol, xc='LDA,VWN',
227+
U_idx=[], U_val=[], C_ao_lo=None, minao_ref='MINAO'):
228+
"""
229+
Args:
230+
U_idx: can be
231+
list of list: each sublist is a set indices for AO orbitals
232+
(indcies corresponding to the large-basis-set mol).
233+
list of string: each string is one kind of LO orbitals,
234+
e.g. ['Ni 3d', '1 O 2pz'].
235+
or a combination of these two.
236+
U_val: a list of effective U [in eV], i.e. U-J in Dudarev's DFT+U.
237+
each U corresponds to one kind of LO orbitals, should have
238+
the same length as U_idx.
239+
C_ao_lo: Customized LO coefficients, can be
240+
np.array, shape ((spin,), nao, nlo),
241+
minao_ref: reference for minao orbitals, default is 'MINAO'.
242+
243+
Attributes:
244+
U_idx: same as the input.
245+
U_val: effectiv U-J [in AU]
246+
C_ao_loc: np.array
247+
alpha: the perturbation [in AU] used to compute U in LR-cDFT.
248+
Refs: Cococcioni and de Gironcoli, PRB 71, 035105 (2005)
249+
"""
250+
super().__init__(mol, xc=xc)
251+
252+
self.U_idx = U_idx
253+
self.U_val = U_val
254+
if isinstance(C_ao_lo, str):
255+
assert C_ao_lo.upper() == 'MINAO'
256+
C_ao_lo = None # API backward compatibility
257+
self.C_ao_lo = C_ao_lo
258+
self.minao_ref = minao_ref
259+
# The perturbation (eV) used to compute U in LR-cDFT.
260+
self.alpha = None
261+
262+
def dump_flags(self, verbose=None):
263+
log = logger.new_logger(self, verbose)
264+
super().dump_flags(log)
265+
if log.verbose >= logger.INFO:
266+
_print_U_info(self, log)
267+
return self
268+
269+
def Gradients(self):
270+
from gpu4pyscf.grad.rkspu import Gradients
271+
return Gradients(self)
272+
273+
def nuc_grad_method(self):
274+
return self.Gradients()
275+
276+
def linear_response_u(mf_plus_u, alphalist=(0.02, 0.05, 0.08)):
277+
'''
278+
Refs:
279+
[1] M. Cococcioni and S. de Gironcoli, Phys. Rev. B 71, 035105 (2005)
280+
[2] H. J. Kulik, M. Cococcioni, D. A. Scherlis, and N. Marzari, Phys. Rev. Lett. 97, 103001 (2006)
281+
[3] Heather J. Kulik, J. Chem. Phys. 142, 240901 (2015)
282+
[4] https://hjkgrp.mit.edu/tutorials/2011-05-31-calculating-hubbard-u/
283+
[5] https://hjkgrp.mit.edu/tutorials/2011-06-28-hubbard-u-multiple-sites/
284+
285+
Args:
286+
alphalist :
287+
alpha parameters (in eV) are the displacements for the linear
288+
response calculations. For each alpha in this list, the DFT+U with
289+
U=u0+alpha, U=u0-alpha are evaluated. u0 is the U value from the
290+
reference mf_plus_u object, which will be treated as a standard DFT
291+
functional.
292+
'''
293+
assert isinstance(mf_plus_u, RKSpU)
294+
assert len(mf_plus_u.U_idx) > 0
295+
if not mf_plus_u.converged:
296+
mf_plus_u.run()
297+
assert mf_plus_u.converged
298+
# The bare density matrix without adding U
299+
bare_dm = mf_plus_u.make_rdm1()
300+
301+
mf = mf_plus_u.copy()
302+
log = logger.new_logger(mf)
303+
304+
alphalist = np.asarray(alphalist)
305+
alphalist = np.append(-alphalist[::-1], alphalist)
306+
307+
mol = mf.mol
308+
pmol = reference_mol(mol, mf.minao_ref)
309+
U_idx, U_val, U_lab = _set_U(mol, pmol, mf.U_idx, mf.U_val)
310+
# Construct orthogonal minao local orbitals.
311+
assert mf.C_ao_lo is None
312+
C_ao_lo = _make_minao_lo(mol, pmol)
313+
ovlp = asarray(mol.intor('int1e_ovlp', hermi=1))
314+
C_inv = []
315+
for idx in U_idx:
316+
c = C_ao_lo[:,idx]
317+
C_inv.append(c.conj().T.dot(ovlp))
318+
319+
bare_occupancies = []
320+
final_occupancies = []
321+
for alpha in alphalist:
322+
# All in atomic unit
323+
mf.alpha = alpha / HARTREE2EV
324+
mf.kernel(dm0=bare_dm)
325+
local_occ = 0
326+
for c in C_inv:
327+
C_on_site = c.dot(mf.mo_coeff)
328+
rdm1_lo = mf.make_rdm1(C_on_site, mf.mo_occ)
329+
local_occ += rdm1_lo.trace()
330+
final_occupancies.append(local_occ.get())
331+
332+
# The first iteration of SCF
333+
fock = mf.get_fock(dm=bare_dm)
334+
e, mo = mf.eig(fock, ovlp)
335+
local_occ = 0
336+
for c in C_inv:
337+
C_on_site = c.dot(mo)
338+
rdm1_lo = mf.make_rdm1(C_on_site, mf.mo_occ)
339+
local_occ += rdm1_lo.trace()
340+
bare_occupancies.append(local_occ.get())
341+
log.info('alpha=%f bare_occ=%g final_occ=%g',
342+
alpha, bare_occupancies[-1], final_occupancies[-1])
343+
344+
chi0, occ0 = np.polyfit(alphalist, bare_occupancies, deg=1)
345+
chif, occf = np.polyfit(alphalist, final_occupancies, deg=1)
346+
log.info('Line fitting chi0 = %f x + %f', chi0, occ0)
347+
log.info('Line fitting chif = %f x + %f', chif, occf)
348+
Uresp = 1./chi0 - 1./chif
349+
log.note('Uresp = %f, chi0 = %f, chif = %f', Uresp, chi0, chif)
350+
return Uresp

gpu4pyscf/dft/tests/test_dftu.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python
2+
# Copyright 2025 The PySCF Developers. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import unittest
18+
import numpy as np
19+
from pyscf import gto
20+
from gpu4pyscf.dft import rkspu, ukspu
21+
22+
class KnownValues(unittest.TestCase):
23+
def test_RKSpU_linear_response(self):
24+
mol = gto.M(atom='''
25+
O 0. 0. 0.
26+
H 0. -0.757 0.587
27+
H 0. 0.757 0.587''', basis='6-31g')
28+
mf = rkspu.RKSpU(mol, xc='pbe', U_idx=['O 2p'], U_val=[3.5])
29+
mf.run()
30+
u0 = rkspu.linear_response_u(mf)
31+
assert abs(u0 - 5.8926) < 1e-2
32+
33+
def test_UKSpU_linear_response(self):
34+
mol = gto.M(atom='''
35+
O 0. 0. 0.
36+
H 0. -0.757 0.587
37+
H 0. 0.757 0.587''', basis='6-31g')
38+
mf = ukspu.UKSpU(mol, xc='pbe', U_idx=['O 2p'], U_val=[3.5])
39+
mf.run()
40+
u0 = ukspu.linear_response_u(mf)
41+
assert abs(u0 - 5.8926) < 1e-2
42+
43+
if __name__ == '__main__':
44+
print("Full Tests for dft+U")
45+
unittest.main()

0 commit comments

Comments
 (0)