Skip to content

Commit a5565c1

Browse files
author
Han Wang
committed
support virial in qe/traj
1 parent 5aade0d commit a5565c1

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

dpdata/qe/traj.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
if TYPE_CHECKING:
1212
from dpdata.utils import FileType
1313

14+
import os
15+
1416
from ..unit import (
1517
EnergyConversion,
1618
ForceConversion,
@@ -20,6 +22,7 @@
2022

2123
ry2ev = EnergyConversion("rydberg", "eV").value()
2224
kbar2evperang3 = PressureConversion("kbar", "eV/angstrom^3").value()
25+
gpa2evperbohr = PressureConversion("GPa", "eV/bohr^3").value()
2326

2427
length_convert = LengthConversion("bohr", "angstrom").value()
2528
energy_convert = EnergyConversion("hartree", "eV").value()
@@ -228,6 +231,32 @@ def to_system_data(input_name, prefix, begin=0, step=1):
228231
)
229232
except FileNotFoundError:
230233
data["cells"] = np.tile(cell, (data["coords"].shape[0], 1, 1))
234+
235+
# handle virial
236+
stress_fname = prefix + ".str"
237+
if os.path.exists(stress_fname):
238+
# 1. Read stress tensor (in GPa) for each structure
239+
stress, vsteps = load_data(stress_fname, 3, begin=begin, step=step, convert=1.0)
240+
if csteps != vsteps:
241+
csteps.append(None)
242+
vsteps.append(None)
243+
for int_id in range(len(csteps)):
244+
if csteps[int_id] != vsteps[int_id]:
245+
break
246+
step_id = begin + int_id * step
247+
raise RuntimeError(
248+
f"the step key between files are not consistent. "
249+
f"The difference locates at step: {step_id}, "
250+
f".pos is {csteps[int_id]}, .str is {vsteps[int_id]}"
251+
)
252+
# 2. Calculate volume from cell. revert unit to bohr before taking det
253+
volumes = np.linalg.det(data["cells"] / length_convert).reshape(-1)
254+
# 3. Calculate virials for each structure
255+
virials = []
256+
for i in range(stress.shape[0]):
257+
virials.append(stress[i] * gpa2evperbohr * volumes[i])
258+
data["virials"] = np.array(virials) # shape: [nframe, 3, 3]
259+
231260
return data, csteps
232261

233262

tests/test_qe_cp_traj.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,40 @@ def test_case_null(self):
6868
self.assertAlmostEqual(cell[ii][jj], ref[ii][jj])
6969

7070

71+
class TestVirial(unittest.TestCase):
72+
def test(self):
73+
self.system = dpdata.LabeledSystem("qe.traj/si/si", fmt="qe/cp/traj")
74+
self.assertEqual(self.system["virials"].shape, (2, 3, 3))
75+
np.testing.assert_almost_equal(
76+
self.system["virials"][0],
77+
np.array(
78+
[
79+
[0.31120718, -0.03261485, -0.02537362],
80+
[-0.03261485, 0.3100397, 0.04211053],
81+
[-0.02537362, 0.04211057, 0.30571264],
82+
]
83+
),
84+
)
85+
np.testing.assert_almost_equal(
86+
self.system["virials"][1],
87+
np.array(
88+
[
89+
[0.31072979, -0.03151186, -0.02302297],
90+
[-0.03151186, 0.30951293, 0.04078447],
91+
[-0.02302297, 0.04078451, 0.30544987],
92+
]
93+
),
94+
)
95+
96+
def test_raise(self):
97+
with self.assertRaises(RuntimeError) as c:
98+
self.system = dpdata.LabeledSystem(
99+
"qe.traj/si.wrongstr/si", fmt="qe/cp/traj"
100+
)
101+
self.assertTrue(
102+
"the step key between files are not consistent." in str(c.exception)
103+
)
104+
105+
71106
if __name__ == "__main__":
72107
unittest.main()

0 commit comments

Comments
 (0)