Coverage for /builds/ericyuan00000/ase/ase/quaternions.py: 77.78%

144 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-06-18 01:20 +0000

1# fmt: off 

2 

3import numpy as np 

4 

5from ase.atoms import Atoms 

6 

7 

8class Quaternions(Atoms): 

9 

10 def __init__(self, *args, **kwargs): 

11 quaternions = None 

12 if 'quaternions' in kwargs: 

13 quaternions = np.array(kwargs['quaternions']) 

14 del kwargs['quaternions'] 

15 Atoms.__init__(self, *args, **kwargs) 

16 if quaternions is not None: 

17 self.set_array('quaternions', quaternions, shape=(4,)) 

18 # set default shapes 

19 self.set_shapes(np.array([[3, 2, 1]] * len(self))) 

20 

21 def set_shapes(self, shapes): 

22 self.set_array('shapes', shapes, shape=(3,)) 

23 

24 def set_quaternions(self, quaternions): 

25 self.set_array('quaternions', quaternions, quaternion=(4,)) 

26 

27 def get_shapes(self): 

28 return self.get_array('shapes') 

29 

30 def get_quaternions(self): 

31 return self.get_array('quaternions').copy() 

32 

33 

34class Quaternion: 

35 

36 def __init__(self, qin=[1, 0, 0, 0]): 

37 assert len(qin) == 4 

38 self.q = np.array(qin) 

39 

40 def __str__(self): 

41 return self.q.__str__() 

42 

43 def __mul__(self, other): 

44 sw, sx, sy, sz = self.q 

45 ow, ox, oy, oz = other.q 

46 return Quaternion([sw * ow - sx * ox - sy * oy - sz * oz, 

47 sw * ox + sx * ow + sy * oz - sz * oy, 

48 sw * oy + sy * ow + sz * ox - sx * oz, 

49 sw * oz + sz * ow + sx * oy - sy * ox]) 

50 

51 def conjugate(self): 

52 return Quaternion(-self.q * np.array([-1., 1., 1., 1.])) 

53 

54 def rotate(self, vector): 

55 """Apply the rotation matrix to a vector.""" 

56 qw, qx, qy, qz = self.q[0], self.q[1], self.q[2], self.q[3] 

57 x, y, z = vector[0], vector[1], vector[2] 

58 

59 ww = qw * qw 

60 xx = qx * qx 

61 yy = qy * qy 

62 zz = qz * qz 

63 wx = qw * qx 

64 wy = qw * qy 

65 wz = qw * qz 

66 xy = qx * qy 

67 xz = qx * qz 

68 yz = qy * qz 

69 

70 return np.array( 

71 [(ww + xx - yy - zz) * x + 2 * ((xy - wz) * y + (xz + wy) * z), 

72 (ww - xx + yy - zz) * y + 2 * ((xy + wz) * x + (yz - wx) * z), 

73 (ww - xx - yy + zz) * z + 2 * ((xz - wy) * x + (yz + wx) * y)]) 

74 

75 def rotation_matrix(self): 

76 

77 qw, qx, qy, qz = self.q[0], self.q[1], self.q[2], self.q[3] 

78 

79 ww = qw * qw 

80 xx = qx * qx 

81 yy = qy * qy 

82 zz = qz * qz 

83 wx = qw * qx 

84 wy = qw * qy 

85 wz = qw * qz 

86 xy = qx * qy 

87 xz = qx * qz 

88 yz = qy * qz 

89 

90 return np.array([[ww + xx - yy - zz, 2 * (xy - wz), 2 * (xz + wy)], 

91 [2 * (xy + wz), ww - xx + yy - zz, 2 * (yz - wx)], 

92 [2 * (xz - wy), 2 * (yz + wx), ww - xx - yy + zz]]) 

93 

94 def axis_angle(self): 

95 """Returns axis and angle (in radians) for the rotation described 

96 by this Quaternion""" 

97 

98 sinth_2 = np.linalg.norm(self.q[1:]) 

99 

100 if sinth_2 == 0: 

101 # The angle is zero 

102 theta = 0.0 

103 n = np.array([0, 0, 1]) 

104 else: 

105 theta = np.arctan2(sinth_2, self.q[0]) * 2 

106 n = self.q[1:] / sinth_2 

107 

108 return n, theta 

109 

110 def euler_angles(self, mode='zyz'): 

111 """Return three Euler angles describing the rotation, in radians. 

112 Mode can be zyz or zxz. Default is zyz.""" 

113 # https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0276302 

114 if mode == 'zyz': 

115 a, b, c, d = self.q[0], self.q[3], self.q[2], -self.q[1] 

116 elif mode == 'zxz': 

117 a, b, c, d = self.q[0], self.q[3], self.q[1], self.q[2] 

118 else: 

119 raise ValueError(f'Invalid Euler angles mode {mode}') 

120 

121 beta = 2 * np.arccos( 

122 np.sqrt((a**2 + b**2) / (a**2 + b**2 + c**2 + d**2)) 

123 ) 

124 gap = np.arctan2(b, a) # (gamma + alpha) / 2 

125 gam = np.arctan2(d, c) # (gamma - alpha) / 2 

126 if np.isclose(beta, 0): 

127 # gam is meaningless here 

128 alpha = 0 

129 gamma = 2 * gap - alpha 

130 elif np.isclose(beta, np.pi): 

131 # gap is meaningless here 

132 alpha = 0 

133 gamma = 2 * gam + alpha 

134 else: 

135 alpha = gap - gam 

136 gamma = gap + gam 

137 

138 return np.array([alpha, beta, gamma]) 

139 

140 def arc_distance(self, other): 

141 """Gives a metric of the distance between two quaternions, 

142 expressed as 1-|q1.q2|""" 

143 

144 return 1.0 - np.abs(np.dot(self.q, other.q)) 

145 

146 @staticmethod 

147 def rotate_byq(q, vector): 

148 """Apply the rotation matrix to a vector.""" 

149 qw, qx, qy, qz = q[0], q[1], q[2], q[3] 

150 x, y, z = vector[0], vector[1], vector[2] 

151 

152 ww = qw * qw 

153 xx = qx * qx 

154 yy = qy * qy 

155 zz = qz * qz 

156 wx = qw * qx 

157 wy = qw * qy 

158 wz = qw * qz 

159 xy = qx * qy 

160 xz = qx * qz 

161 yz = qy * qz 

162 

163 return np.array( 

164 [(ww + xx - yy - zz) * x + 2 * ((xy - wz) * y + (xz + wy) * z), 

165 (ww - xx + yy - zz) * y + 2 * ((xy + wz) * x + (yz - wx) * z), 

166 (ww - xx - yy + zz) * z + 2 * ((xz - wy) * x + (yz + wx) * y)]) 

167 

168 @staticmethod 

169 def from_matrix(matrix): 

170 """Build quaternion from rotation matrix.""" 

171 m = np.array(matrix) 

172 assert m.shape == (3, 3) 

173 

174 # Now we need to find out the whole quaternion 

175 # This method takes into account the possibility of qw being nearly 

176 # zero, so it picks the stablest solution 

177 

178 if m[2, 2] < 0: 

179 if (m[0, 0] > m[1, 1]): 

180 # Use x-form 

181 qx = np.sqrt(1 + m[0, 0] - m[1, 1] - m[2, 2]) / 2.0 

182 fac = 1.0 / (4 * qx) 

183 qw = (m[2, 1] - m[1, 2]) * fac 

184 qy = (m[0, 1] + m[1, 0]) * fac 

185 qz = (m[0, 2] + m[2, 0]) * fac 

186 else: 

187 # Use y-form 

188 qy = np.sqrt(1 - m[0, 0] + m[1, 1] - m[2, 2]) / 2.0 

189 fac = 1.0 / (4 * qy) 

190 qw = (m[0, 2] - m[2, 0]) * fac 

191 qx = (m[0, 1] + m[1, 0]) * fac 

192 qz = (m[1, 2] + m[2, 1]) * fac 

193 else: 

194 if (m[0, 0] < -m[1, 1]): 

195 # Use z-form 

196 qz = np.sqrt(1 - m[0, 0] - m[1, 1] + m[2, 2]) / 2.0 

197 fac = 1.0 / (4 * qz) 

198 qw = (m[1, 0] - m[0, 1]) * fac 

199 qx = (m[2, 0] + m[0, 2]) * fac 

200 qy = (m[1, 2] + m[2, 1]) * fac 

201 else: 

202 # Use w-form 

203 qw = np.sqrt(1 + m[0, 0] + m[1, 1] + m[2, 2]) / 2.0 

204 fac = 1.0 / (4 * qw) 

205 qx = (m[2, 1] - m[1, 2]) * fac 

206 qy = (m[0, 2] - m[2, 0]) * fac 

207 qz = (m[1, 0] - m[0, 1]) * fac 

208 

209 return Quaternion(np.array([qw, qx, qy, qz])) 

210 

211 @staticmethod 

212 def from_axis_angle(n, theta): 

213 """Build quaternion from axis (n, vector of 3 components) and angle 

214 (theta, in radianses).""" 

215 

216 n = np.array(n, float) / np.linalg.norm(n) 

217 return Quaternion(np.concatenate([[np.cos(theta / 2.0)], 

218 np.sin(theta / 2.0) * n])) 

219 

220 @staticmethod 

221 def from_euler_angles(a, b, c, mode='zyz'): 

222 """Build quaternion from Euler angles, given in radians. Default 

223 mode is ZYZ, but it can be set to ZXZ as well.""" 

224 

225 q_a = Quaternion.from_axis_angle([0, 0, 1], a) 

226 q_c = Quaternion.from_axis_angle([0, 0, 1], c) 

227 

228 if mode == 'zyz': 

229 q_b = Quaternion.from_axis_angle([0, 1, 0], b) 

230 elif mode == 'zxz': 

231 q_b = Quaternion.from_axis_angle([1, 0, 0], b) 

232 else: 

233 raise ValueError(f'Invalid Euler angles mode {mode}') 

234 

235 return q_c * q_b * q_a