Coverage for /builds/ericyuan00000/ase/ase/utils/structure_comparator.py: 94.88%
293 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
1# fmt: off
3"""Determine symmetry equivalence of two structures.
4Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012)."""
5from collections import Counter
6from itertools import combinations, filterfalse, product
8import numpy as np
9from scipy.spatial import cKDTree as KDTree
11from ase import Atom, Atoms
12from ase.build.tools import niggli_reduce
15def normalize(cell):
16 for i in range(3):
17 cell[i] /= np.linalg.norm(cell[i])
20class SpgLibNotFoundError(Exception):
21 """Raised if SPG lib is not found when needed."""
23 def __init__(self, msg):
24 super().__init__(msg)
27class SymmetryEquivalenceCheck:
28 """Compare two structures to determine if they are symmetry equivalent.
30 Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012).
32 Parameters:
34 angle_tol: float
35 angle tolerance for the lattice vectors in degrees
37 ltol: float
38 relative tolerance for the length of the lattice vectors (per atom)
40 stol: float
41 position tolerance for the site comparison in units of
42 (V/N)^(1/3) (average length between atoms)
44 vol_tol: float
45 volume tolerance in angstrom cubed to compare the volumes of
46 the two structures
48 scale_volume: bool
49 if True the volumes of the two structures are scaled to be equal
51 to_primitive: bool
52 if True the structures are reduced to their primitive cells
53 note that this feature requires spglib to installed
55 Examples:
57 >>> from ase.build import bulk
58 >>> from ase.utils.structure_comparator import SymmetryEquivalenceCheck
59 >>> comp = SymmetryEquivalenceCheck()
61 Compare a cell with a rotated version
63 >>> a = bulk('Al', orthorhombic=True)
64 >>> b = a.copy()
65 >>> b.rotate(60, 'x', rotate_cell=True)
66 >>> comp.compare(a, b)
67 True
69 Transform to the primitive cell and then compare
71 >>> pa = bulk('Al')
72 >>> comp.compare(a, pa)
73 False
74 >>> comp = SymmetryEquivalenceCheck(to_primitive=True)
75 >>> comp.compare(a, pa)
76 True
78 Compare one structure with a list of other structures
80 >>> import numpy as np
81 >>> from ase import Atoms
82 >>> s1 = Atoms('H3', positions=[[0.5, 0.5, 0],
83 ... [0.5, 1.5, 0],
84 ... [1.5, 1.5, 0]],
85 ... cell=[2, 2, 2], pbc=True)
86 >>> comp = SymmetryEquivalenceCheck(stol=0.068)
87 >>> s2_list = []
88 >>> for d in np.linspace(0.1, 1.0, 5):
89 ... s2 = s1.copy()
90 ... s2.positions[0] += [d, 0, 0]
91 ... s2_list.append(s2)
92 >>> comp.compare(s1, s2_list[:-1])
93 False
94 >>> comp.compare(s1, s2_list)
95 True
97 """
99 def __init__(self, angle_tol=1.0, ltol=0.05, stol=0.05, vol_tol=0.1,
100 scale_volume=False, to_primitive=False):
101 self.angle_tol = angle_tol * np.pi / 180.0 # convert to radians
102 self.scale_volume = scale_volume
103 self.stol = stol
104 self.ltol = ltol
105 self.vol_tol = vol_tol
106 self.position_tolerance = 0.0
107 self.to_primitive = to_primitive
109 # Variables to be used in the compare function
110 self.s1 = None
111 self.s2 = None
112 self.expanded_s1 = None
113 self.expanded_s2 = None
114 self.least_freq_element = None
116 def _niggli_reduce(self, atoms):
117 """Reduce to niggli cells.
119 Reduce the atoms to niggli cells, then rotates the niggli cells to
120 the so called "standard" orientation with one lattice vector along the
121 x-axis and a second vector in the xy plane.
122 """
123 niggli_reduce(atoms)
124 self._standarize_cell(atoms)
126 def _standarize_cell(self, atoms):
127 """Rotate the first vector such that it points along the x-axis.
128 Then rotate around the first vector so the second vector is in the
129 xy plane.
130 """
131 # Rotate first vector to x axis
132 cell = atoms.get_cell().T
133 total_rot_mat = np.eye(3)
134 v1 = cell[:, 0]
135 l1 = np.sqrt(v1[0]**2 + v1[2]**2)
136 angle = np.abs(np.arcsin(v1[2] / l1))
137 if (v1[0] < 0.0 and v1[2] > 0.0):
138 angle = np.pi - angle
139 elif (v1[0] < 0.0 and v1[2] < 0.0):
140 angle = np.pi + angle
141 elif (v1[0] > 0.0 and v1[2] < 0.0):
142 angle = -angle
143 ca = np.cos(angle)
144 sa = np.sin(angle)
145 rotmat = np.array([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]])
146 total_rot_mat = rotmat.dot(total_rot_mat)
147 cell = rotmat.dot(cell)
149 v1 = cell[:, 0]
150 l1 = np.sqrt(v1[0]**2 + v1[1]**2)
151 angle = np.abs(np.arcsin(v1[1] / l1))
152 if (v1[0] < 0.0 and v1[1] > 0.0):
153 angle = np.pi - angle
154 elif (v1[0] < 0.0 and v1[1] < 0.0):
155 angle = np.pi + angle
156 elif (v1[0] > 0.0 and v1[1] < 0.0):
157 angle = -angle
158 ca = np.cos(angle)
159 sa = np.sin(angle)
160 rotmat = np.array([[ca, sa, 0.0], [-sa, ca, 0.0], [0.0, 0.0, 1.0]])
161 total_rot_mat = rotmat.dot(total_rot_mat)
162 cell = rotmat.dot(cell)
164 # Rotate around x axis such that the second vector is in the xy plane
165 v2 = cell[:, 1]
166 l2 = np.sqrt(v2[1]**2 + v2[2]**2)
167 angle = np.abs(np.arcsin(v2[2] / l2))
168 if (v2[1] < 0.0 and v2[2] > 0.0):
169 angle = np.pi - angle
170 elif (v2[1] < 0.0 and v2[2] < 0.0):
171 angle = np.pi + angle
172 elif (v2[1] > 0.0 and v2[2] < 0.0):
173 angle = -angle
174 ca = np.cos(angle)
175 sa = np.sin(angle)
176 rotmat = np.array([[1.0, 0.0, 0.0], [0.0, ca, sa], [0.0, -sa, ca]])
177 total_rot_mat = rotmat.dot(total_rot_mat)
178 cell = rotmat.dot(cell)
180 atoms.set_cell(cell.T)
181 atoms.set_positions(total_rot_mat.dot(atoms.get_positions().T).T)
182 atoms.wrap(pbc=[1, 1, 1])
183 return atoms
185 def _get_element_count(self, struct):
186 """Count the number of elements in each of the structures."""
187 return Counter(struct.numbers)
189 def _get_angles(self, cell):
190 """Get the internal angles of the unit cell."""
191 cell = cell.copy()
193 normalize(cell)
195 dot = cell.dot(cell.T)
197 # Extract only the relevant dot products
198 dot = [dot[0, 1], dot[0, 2], dot[1, 2]]
200 # Return angles
201 return np.arccos(dot)
203 def _has_same_elements(self):
204 """Check if two structures have same elements."""
205 elem1 = self._get_element_count(self.s1)
206 return elem1 == self._get_element_count(self.s2)
208 def _has_same_angles(self):
209 """Check that the Niggli unit vectors has the same internal angles."""
210 ang1 = np.sort(self._get_angles(self.s1.get_cell()))
211 ang2 = np.sort(self._get_angles(self.s2.get_cell()))
213 return np.allclose(ang1, ang2, rtol=0, atol=self.angle_tol)
215 def _has_same_volume(self):
216 vol1 = self.s1.get_volume()
217 vol2 = self.s2.get_volume()
218 return np.abs(vol1 - vol2) < self.vol_tol
220 def _scale_volumes(self):
221 """Scale the cell of s2 to have the same volume as s1."""
222 cell2 = self.s2.get_cell()
223 # Get the volumes
224 v2 = np.linalg.det(cell2)
225 v1 = np.linalg.det(self.s1.get_cell())
227 # Scale the cells
228 coordinate_scaling = (v1 / v2)**(1.0 / 3.0)
229 cell2 *= coordinate_scaling
230 self.s2.set_cell(cell2, scale_atoms=True)
232 def compare(self, s1, s2):
233 """Compare the two structures.
235 Return *True* if the two structures are equivalent, *False* otherwise.
237 Parameters:
239 s1: Atoms object.
240 Transformation matrices are calculated based on this structure.
242 s2: Atoms or list
243 s1 can be compared to one structure or many structures supplied in
244 a list. If s2 is a list it returns True if any structure in s2
245 matches s1, False otherwise.
246 """
247 if self.to_primitive:
248 s1 = self._reduce_to_primitive(s1)
249 self._set_least_frequent_element(s1)
250 self._least_frequent_element_to_origin(s1)
251 self.s1 = s1.copy()
252 vol = self.s1.get_volume()
253 self.expanded_s1 = None
254 s1_niggli_reduced = False
256 if isinstance(s2, Atoms):
257 # Just make it a list of length 1
258 s2 = [s2]
260 matrices = None
261 translations = None
262 transposed_matrices = None
263 for struct in s2:
264 self.s2 = struct.copy()
265 self.expanded_s2 = None
267 if self.to_primitive:
268 self.s2 = self._reduce_to_primitive(self.s2)
270 # Compare number of elements in structures
271 if len(self.s1) != len(self.s2):
272 continue
274 # Compare chemical formulae
275 if not self._has_same_elements():
276 continue
278 # Compare angles
279 if not s1_niggli_reduced:
280 self._niggli_reduce(self.s1)
281 self._niggli_reduce(self.s2)
282 if not self._has_same_angles():
283 continue
285 # Compare volumes
286 if self.scale_volume:
287 self._scale_volumes()
288 if not self._has_same_volume():
289 continue
291 if matrices is None:
292 matrices = self._get_rotation_reflection_matrices()
293 if matrices is None:
294 continue
296 if translations is None:
297 translations = self._get_least_frequent_positions(self.s1)
299 # After the candidate translation based on s1 has been computed
300 # we need potentially to swap s1 and s2 for robust comparison
301 self._least_frequent_element_to_origin(self.s2)
302 switch = self._switch_reference_struct()
303 if switch:
304 # Remember the matrices and translations used before
305 old_matrices = matrices
306 old_translations = translations
308 # If a s1 and s2 has been switched we need to use the
309 # transposed version of the matrices to map atoms the
310 # other way
311 if transposed_matrices is None:
312 transposed_matrices = np.transpose(matrices,
313 axes=[0, 2, 1])
314 matrices = transposed_matrices
315 translations = self._get_least_frequent_positions(self.s1)
317 # Calculate tolerance on positions
318 self.position_tolerance = \
319 self.stol * (vol / len(self.s2))**(1.0 / 3.0)
321 if self._positions_match(matrices, translations):
322 return True
324 # Set the reference structure back to its original
325 self.s1 = s1.copy()
326 if switch:
327 self.expanded_s1 = self.expanded_s2
328 matrices = old_matrices
329 translations = old_translations
330 return False
332 def _set_least_frequent_element(self, atoms):
333 """Save the atomic number of the least frequent element."""
334 elem1 = self._get_element_count(atoms)
335 self.least_freq_element = elem1.most_common()[-1][0]
337 def _get_least_frequent_positions(self, atoms):
338 """Get the positions of the least frequent element in atoms."""
339 pos = atoms.get_positions(wrap=True)
340 return pos[atoms.numbers == self.least_freq_element]
342 def _get_only_least_frequent_of(self, struct):
343 """Get the atoms object with all other elements than the least frequent
344 one removed. Wrap the positions to get everything in the cell."""
345 pos = struct.get_positions(wrap=True)
347 indices = struct.numbers == self.least_freq_element
348 least_freq_struct = struct[indices]
349 least_freq_struct.set_positions(pos[indices])
351 return least_freq_struct
353 def _switch_reference_struct(self):
354 """There is an intrinsic assymetry in the system because
355 one of the atoms are being expanded, while the other is not.
356 This can cause the algorithm to return different result
357 depending on which structure is passed first.
358 We adopt the convention of using the atoms object
359 having the fewest atoms in its expanded cell as the
360 reference object.
361 We return True if a switch of structures has been performed."""
363 # First expand the cells
364 if self.expanded_s1 is None:
365 self.expanded_s1 = self._expand(self.s1)
366 if self.expanded_s2 is None:
367 self.expanded_s2 = self._expand(self.s2)
369 exp1 = self.expanded_s1
370 exp2 = self.expanded_s2
371 if len(exp1) < len(exp2):
372 # s1 should be the reference structure
373 # We have to swap s1 and s2
374 s1_temp = self.s1.copy()
375 self.s1 = self.s2
376 self.s2 = s1_temp
377 exp1_temp = self.expanded_s1.copy()
378 self.expanded_s1 = self.expanded_s2
379 self.expanded_s2 = exp1_temp
380 return True
381 return False
383 def _positions_match(self, rotation_reflection_matrices, translations):
384 """Check if the position and elements match.
386 Note that this function changes self.s1 and self.s2 to the rotation and
387 translation that matches best. Hence, it is crucial that this function
388 calls the element comparison, not the other way around.
389 """
390 pos1_ref = self.s1.get_positions(wrap=True)
392 # Get the expanded reference object
393 exp2 = self.expanded_s2
394 # Build a KD tree to enable fast look-up of nearest neighbours
395 tree = KDTree(exp2.get_positions())
396 for i in range(translations.shape[0]):
397 # Translate
398 pos1_trans = pos1_ref - translations[i]
399 for matrix in rotation_reflection_matrices:
400 # Rotate
401 pos1 = matrix.dot(pos1_trans.T).T
403 # Update the atoms positions
404 self.s1.set_positions(pos1)
405 self.s1.wrap(pbc=[1, 1, 1])
406 if self._elements_match(self.s1, exp2, tree):
407 return True
408 return False
410 def _expand(self, ref_atoms, tol=0.0001):
411 """If an atom is closer to a boundary than tol it is repeated at the
412 opposite boundaries.
414 This ensures that atoms having crossed the cell boundaries due to
415 numerical noise are properly detected.
417 The distance between a position and cell boundary is calculated as:
418 dot(position, (b_vec x c_vec) / (|b_vec| |c_vec|) ), where x is the
419 cross product.
420 """
421 syms = ref_atoms.get_chemical_symbols()
422 cell = ref_atoms.get_cell()
423 positions = ref_atoms.get_positions(wrap=True)
424 expanded_atoms = ref_atoms.copy()
426 # Calculate normal vectors to the unit cell faces
427 normal_vectors = np.array([np.cross(cell[1, :], cell[2, :]),
428 np.cross(cell[0, :], cell[2, :]),
429 np.cross(cell[0, :], cell[1, :])])
430 normalize(normal_vectors)
432 # Get the distance to the unit cell faces from each atomic position
433 pos2faces = np.abs(positions.dot(normal_vectors.T))
435 # And the opposite faces
436 pos2oppofaces = np.abs(np.dot(positions - np.sum(cell, axis=0),
437 normal_vectors.T))
439 for i, i2face in enumerate(pos2faces):
440 # Append indices for positions close to the other faces
441 # and convert to boolean array signifying if the position at
442 # index i is close to the faces bordering origo (0, 1, 2) or
443 # the opposite faces (3, 4, 5)
444 i_close2face = np.append(i2face, pos2oppofaces[i]) < tol
445 # For each position i.e. row it holds that
446 # 1 x True -> close to face -> 1 extra atom at opposite face
447 # 2 x True -> close to edge -> 3 extra atoms at opposite edges
448 # 3 x True -> close to corner -> 7 extra atoms opposite corners
449 # E.g. to add atoms at all corners we need to use the cell
450 # vectors: (a, b, c, a + b, a + c, b + c, a + b + c), we use
451 # itertools.combinations to get them all
452 for j in range(sum(i_close2face)):
453 for c in combinations(np.nonzero(i_close2face)[0], j + 1):
454 # Get the displacement vectors by adding the corresponding
455 # cell vectors, if the atom is close to an opposite face
456 # i.e. k > 2 subtract the cell vector
457 disp_vec = np.zeros(3)
458 for k in c:
459 disp_vec += cell[k % 3] * (int(k < 3) * 2 - 1)
460 pos = positions[i] + disp_vec
461 expanded_atoms.append(Atom(syms[i], position=pos))
462 return expanded_atoms
464 def _equal_elements_in_array(self, arr):
465 s = np.sort(arr)
466 return np.any(s[1:] == s[:-1])
468 def _elements_match(self, s1, s2, kdtree):
469 """Check if all the elements in s1 match corresponding position in s2
471 NOTE: The unit cells may be in different octants
472 Hence, try all cyclic permutations of x,y and z
473 """
474 pos1 = s1.get_positions()
475 for order in range(1): # Is the order still needed?
476 pos_order = [order, (order + 1) % 3, (order + 2) % 3]
477 pos = pos1[:, np.argsort(pos_order)]
478 dists, closest_in_s2 = kdtree.query(pos)
480 # Check if the elements are the same
481 if not np.all(s2.numbers[closest_in_s2] == s1.numbers):
482 return False
484 # Check if any distance is too large
485 if np.any(dists > self.position_tolerance):
486 return False
488 # Check for duplicates in what atom is closest
489 if self._equal_elements_in_array(closest_in_s2):
490 return False
492 return True
494 def _least_frequent_element_to_origin(self, atoms):
495 """Put one of the least frequent elements at the origin."""
496 least_freq_pos = self._get_least_frequent_positions(atoms)
497 cell_diag = np.sum(atoms.get_cell(), axis=0)
498 d = least_freq_pos[0] - 1e-6 * cell_diag
499 atoms.positions -= d
500 atoms.wrap(pbc=[1, 1, 1])
502 def _get_rotation_reflection_matrices(self):
503 """Compute candidates for the transformation matrix."""
504 atoms1_ref = self._get_only_least_frequent_of(self.s1)
505 cell = self.s1.get_cell().T
506 cell_diag = np.sum(cell, axis=1)
507 angle_tol = self.angle_tol
509 # Additional vector that is added to make sure that
510 # there always is an atom at the origin
511 delta_vec = 1E-6 * cell_diag
513 # Store three reference vectors and their lengths
514 ref_vec = self.s2.get_cell()
515 ref_vec_lengths = np.linalg.norm(ref_vec, axis=1)
517 # Compute ref vec angles
518 # ref_angles are arranged as [angle12, angle13, angle23]
519 ref_angles = np.array(self._get_angles(ref_vec))
520 large_angles = ref_angles > np.pi / 2.0
521 ref_angles[large_angles] = np.pi - ref_angles[large_angles]
523 # Translate by one cell diagonal so that a central cell is
524 # surrounded by cells in all directions
525 sc_atom_search = atoms1_ref * (3, 3, 3)
526 new_sc_pos = sc_atom_search.get_positions()
527 new_sc_pos -= new_sc_pos[0] + cell_diag - delta_vec
529 lengths = np.linalg.norm(new_sc_pos, axis=1)
531 candidate_indices = []
532 rtol = self.ltol / len(self.s1)
533 for k in range(3):
534 correct_lengths_mask = np.isclose(lengths,
535 ref_vec_lengths[k],
536 rtol=rtol, atol=0)
537 # The first vector is not interesting
538 correct_lengths_mask[0] = False
540 # If no trial vectors can be found (for any direction)
541 # then the candidates are different and we return None
542 if not np.any(correct_lengths_mask):
543 return None
545 candidate_indices.append(np.nonzero(correct_lengths_mask)[0])
547 # Now we calculate all relevant angles in one step. The relevant angles
548 # are the ones made by the current candidates. We will have to keep
549 # track of the indices in the angles matrix and the indices in the
550 # position and length arrays.
552 # Get all candidate indices (aci), only unique values
553 aci = np.sort(list(set().union(*candidate_indices)))
555 # Make a dictionary from original positions and lengths index to
556 # index in angle matrix
557 i2ang = dict(zip(aci, range(len(aci))))
559 # Calculate the dot product divided by the lengths:
560 # cos(angle) = dot(vec1, vec2) / |vec1| |vec2|
561 cosa = np.inner(new_sc_pos[aci],
562 new_sc_pos[aci]) / np.outer(lengths[aci],
563 lengths[aci])
564 # Make sure the inverse cosine will work
565 cosa[cosa > 1] = 1
566 cosa[cosa < -1] = -1
567 angles = np.arccos(cosa)
568 # Do trick for enantiomorphic structures
569 angles[angles > np.pi / 2] = np.pi - angles[angles > np.pi / 2]
571 # Check which angles match the reference angles
572 # Test for all combinations on candidates. filterfalse makes sure
573 # that there are no duplicate candidates. product is the same as
574 # nested for loops.
575 refined_candidate_list = []
576 for p in filterfalse(self._equal_elements_in_array,
577 product(*candidate_indices)):
578 a = np.array([angles[i2ang[p[0]], i2ang[p[1]]],
579 angles[i2ang[p[0]], i2ang[p[2]]],
580 angles[i2ang[p[1]], i2ang[p[2]]]])
582 if np.allclose(a, ref_angles, atol=angle_tol, rtol=0):
583 refined_candidate_list.append(new_sc_pos[np.array(p)].T)
585 # Get the rotation/reflection matrix [R] by:
586 # [R] = [V][T]^-1, where [V] is the reference vectors and
587 # [T] is the trial vectors
588 # XXX What do we know about the length/shape of refined_candidate_list?
589 if len(refined_candidate_list) == 0:
590 return None
591 else:
592 inverted_trial = np.linalg.inv(refined_candidate_list)
594 # Equivalent to np.matmul(ref_vec.T, inverted_trial)
595 candidate_trans_mat = np.dot(ref_vec.T, inverted_trial.T).T
596 return candidate_trans_mat
598 def _reduce_to_primitive(self, structure):
599 """Reduce the two structure to their primitive type"""
600 try:
601 import spglib
602 except ImportError:
603 raise SpgLibNotFoundError(
604 "SpgLib is required if to_primitive=True")
605 cell = (structure.get_cell()).tolist()
606 pos = structure.get_scaled_positions().tolist()
607 numbers = structure.get_atomic_numbers()
609 cell, scaled_pos, numbers = spglib.standardize_cell(
610 (cell, pos, numbers), to_primitive=True)
612 atoms = Atoms(
613 scaled_positions=scaled_pos,
614 numbers=numbers,
615 cell=cell,
616 pbc=True)
617 return atoms