Coverage for /builds/ericyuan00000/ase/ase/quaternions.py: 77.78%
144 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
1# fmt: off
3import numpy as np
5from ase.atoms import Atoms
8class Quaternions(Atoms):
10 def __init__(self, *args, **kwargs):
11 quaternions = None
12 if 'quaternions' in kwargs:
13 quaternions = np.array(kwargs['quaternions'])
14 del kwargs['quaternions']
15 Atoms.__init__(self, *args, **kwargs)
16 if quaternions is not None:
17 self.set_array('quaternions', quaternions, shape=(4,))
18 # set default shapes
19 self.set_shapes(np.array([[3, 2, 1]] * len(self)))
21 def set_shapes(self, shapes):
22 self.set_array('shapes', shapes, shape=(3,))
24 def set_quaternions(self, quaternions):
25 self.set_array('quaternions', quaternions, quaternion=(4,))
27 def get_shapes(self):
28 return self.get_array('shapes')
30 def get_quaternions(self):
31 return self.get_array('quaternions').copy()
34class Quaternion:
36 def __init__(self, qin=[1, 0, 0, 0]):
37 assert len(qin) == 4
38 self.q = np.array(qin)
40 def __str__(self):
41 return self.q.__str__()
43 def __mul__(self, other):
44 sw, sx, sy, sz = self.q
45 ow, ox, oy, oz = other.q
46 return Quaternion([sw * ow - sx * ox - sy * oy - sz * oz,
47 sw * ox + sx * ow + sy * oz - sz * oy,
48 sw * oy + sy * ow + sz * ox - sx * oz,
49 sw * oz + sz * ow + sx * oy - sy * ox])
51 def conjugate(self):
52 return Quaternion(-self.q * np.array([-1., 1., 1., 1.]))
54 def rotate(self, vector):
55 """Apply the rotation matrix to a vector."""
56 qw, qx, qy, qz = self.q[0], self.q[1], self.q[2], self.q[3]
57 x, y, z = vector[0], vector[1], vector[2]
59 ww = qw * qw
60 xx = qx * qx
61 yy = qy * qy
62 zz = qz * qz
63 wx = qw * qx
64 wy = qw * qy
65 wz = qw * qz
66 xy = qx * qy
67 xz = qx * qz
68 yz = qy * qz
70 return np.array(
71 [(ww + xx - yy - zz) * x + 2 * ((xy - wz) * y + (xz + wy) * z),
72 (ww - xx + yy - zz) * y + 2 * ((xy + wz) * x + (yz - wx) * z),
73 (ww - xx - yy + zz) * z + 2 * ((xz - wy) * x + (yz + wx) * y)])
75 def rotation_matrix(self):
77 qw, qx, qy, qz = self.q[0], self.q[1], self.q[2], self.q[3]
79 ww = qw * qw
80 xx = qx * qx
81 yy = qy * qy
82 zz = qz * qz
83 wx = qw * qx
84 wy = qw * qy
85 wz = qw * qz
86 xy = qx * qy
87 xz = qx * qz
88 yz = qy * qz
90 return np.array([[ww + xx - yy - zz, 2 * (xy - wz), 2 * (xz + wy)],
91 [2 * (xy + wz), ww - xx + yy - zz, 2 * (yz - wx)],
92 [2 * (xz - wy), 2 * (yz + wx), ww - xx - yy + zz]])
94 def axis_angle(self):
95 """Returns axis and angle (in radians) for the rotation described
96 by this Quaternion"""
98 sinth_2 = np.linalg.norm(self.q[1:])
100 if sinth_2 == 0:
101 # The angle is zero
102 theta = 0.0
103 n = np.array([0, 0, 1])
104 else:
105 theta = np.arctan2(sinth_2, self.q[0]) * 2
106 n = self.q[1:] / sinth_2
108 return n, theta
110 def euler_angles(self, mode='zyz'):
111 """Return three Euler angles describing the rotation, in radians.
112 Mode can be zyz or zxz. Default is zyz."""
113 # https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0276302
114 if mode == 'zyz':
115 a, b, c, d = self.q[0], self.q[3], self.q[2], -self.q[1]
116 elif mode == 'zxz':
117 a, b, c, d = self.q[0], self.q[3], self.q[1], self.q[2]
118 else:
119 raise ValueError(f'Invalid Euler angles mode {mode}')
121 beta = 2 * np.arccos(
122 np.sqrt((a**2 + b**2) / (a**2 + b**2 + c**2 + d**2))
123 )
124 gap = np.arctan2(b, a) # (gamma + alpha) / 2
125 gam = np.arctan2(d, c) # (gamma - alpha) / 2
126 if np.isclose(beta, 0):
127 # gam is meaningless here
128 alpha = 0
129 gamma = 2 * gap - alpha
130 elif np.isclose(beta, np.pi):
131 # gap is meaningless here
132 alpha = 0
133 gamma = 2 * gam + alpha
134 else:
135 alpha = gap - gam
136 gamma = gap + gam
138 return np.array([alpha, beta, gamma])
140 def arc_distance(self, other):
141 """Gives a metric of the distance between two quaternions,
142 expressed as 1-|q1.q2|"""
144 return 1.0 - np.abs(np.dot(self.q, other.q))
146 @staticmethod
147 def rotate_byq(q, vector):
148 """Apply the rotation matrix to a vector."""
149 qw, qx, qy, qz = q[0], q[1], q[2], q[3]
150 x, y, z = vector[0], vector[1], vector[2]
152 ww = qw * qw
153 xx = qx * qx
154 yy = qy * qy
155 zz = qz * qz
156 wx = qw * qx
157 wy = qw * qy
158 wz = qw * qz
159 xy = qx * qy
160 xz = qx * qz
161 yz = qy * qz
163 return np.array(
164 [(ww + xx - yy - zz) * x + 2 * ((xy - wz) * y + (xz + wy) * z),
165 (ww - xx + yy - zz) * y + 2 * ((xy + wz) * x + (yz - wx) * z),
166 (ww - xx - yy + zz) * z + 2 * ((xz - wy) * x + (yz + wx) * y)])
168 @staticmethod
169 def from_matrix(matrix):
170 """Build quaternion from rotation matrix."""
171 m = np.array(matrix)
172 assert m.shape == (3, 3)
174 # Now we need to find out the whole quaternion
175 # This method takes into account the possibility of qw being nearly
176 # zero, so it picks the stablest solution
178 if m[2, 2] < 0:
179 if (m[0, 0] > m[1, 1]):
180 # Use x-form
181 qx = np.sqrt(1 + m[0, 0] - m[1, 1] - m[2, 2]) / 2.0
182 fac = 1.0 / (4 * qx)
183 qw = (m[2, 1] - m[1, 2]) * fac
184 qy = (m[0, 1] + m[1, 0]) * fac
185 qz = (m[0, 2] + m[2, 0]) * fac
186 else:
187 # Use y-form
188 qy = np.sqrt(1 - m[0, 0] + m[1, 1] - m[2, 2]) / 2.0
189 fac = 1.0 / (4 * qy)
190 qw = (m[0, 2] - m[2, 0]) * fac
191 qx = (m[0, 1] + m[1, 0]) * fac
192 qz = (m[1, 2] + m[2, 1]) * fac
193 else:
194 if (m[0, 0] < -m[1, 1]):
195 # Use z-form
196 qz = np.sqrt(1 - m[0, 0] - m[1, 1] + m[2, 2]) / 2.0
197 fac = 1.0 / (4 * qz)
198 qw = (m[1, 0] - m[0, 1]) * fac
199 qx = (m[2, 0] + m[0, 2]) * fac
200 qy = (m[1, 2] + m[2, 1]) * fac
201 else:
202 # Use w-form
203 qw = np.sqrt(1 + m[0, 0] + m[1, 1] + m[2, 2]) / 2.0
204 fac = 1.0 / (4 * qw)
205 qx = (m[2, 1] - m[1, 2]) * fac
206 qy = (m[0, 2] - m[2, 0]) * fac
207 qz = (m[1, 0] - m[0, 1]) * fac
209 return Quaternion(np.array([qw, qx, qy, qz]))
211 @staticmethod
212 def from_axis_angle(n, theta):
213 """Build quaternion from axis (n, vector of 3 components) and angle
214 (theta, in radianses)."""
216 n = np.array(n, float) / np.linalg.norm(n)
217 return Quaternion(np.concatenate([[np.cos(theta / 2.0)],
218 np.sin(theta / 2.0) * n]))
220 @staticmethod
221 def from_euler_angles(a, b, c, mode='zyz'):
222 """Build quaternion from Euler angles, given in radians. Default
223 mode is ZYZ, but it can be set to ZXZ as well."""
225 q_a = Quaternion.from_axis_angle([0, 0, 1], a)
226 q_c = Quaternion.from_axis_angle([0, 0, 1], c)
228 if mode == 'zyz':
229 q_b = Quaternion.from_axis_angle([0, 1, 0], b)
230 elif mode == 'zxz':
231 q_b = Quaternion.from_axis_angle([1, 0, 0], b)
232 else:
233 raise ValueError(f'Invalid Euler angles mode {mode}')
235 return q_c * q_b * q_a