Coverage for /builds/ericyuan00000/ase/ase/gui/images.py: 70.82%

281 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-06-18 01:20 +0000

1# fmt: off 

2 

3import warnings 

4from math import sqrt 

5 

6import numpy as np 

7 

8from ase import Atoms 

9from ase.calculators.singlepoint import SinglePointCalculator 

10from ase.constraints import FixAtoms 

11from ase.data import covalent_radii 

12from ase.geometry import find_mic 

13from ase.gui.defaults import read_defaults 

14from ase.gui.i18n import _ 

15from ase.io import read, string2index, write 

16 

17 

18class Images: 

19 def __init__(self, images=None): 

20 self.covalent_radii = covalent_radii.copy() 

21 self.config = read_defaults() 

22 self.atom_scale = self.config['radii_scale'] 

23 if images is None: 

24 images = [Atoms()] 

25 self.initialize(images) 

26 

27 def __len__(self): 

28 return len(self._images) 

29 

30 def __getitem__(self, index): 

31 return self._images[index] 

32 

33 def __iter__(self): 

34 return iter(self._images) 

35 

36 # XXXXXXX hack 

37 # compatibility hacks while allowing variable number of atoms 

38 def get_dynamic(self, atoms: Atoms) -> np.ndarray: 

39 dynamic = np.ones(len(atoms), bool) 

40 for constraint in atoms.constraints: 

41 if isinstance(constraint, FixAtoms): 

42 dynamic[constraint.index] = False 

43 return dynamic 

44 

45 def set_dynamic(self, mask, value): 

46 # Does not make much sense if different images have different 

47 # atom counts. Attempts to apply mask to all images, 

48 # to the extent possible. 

49 for atoms in self: 

50 dynamic = self.get_dynamic(atoms) 

51 dynamic[mask[:len(atoms)]] = value 

52 atoms.constraints = [c for c in atoms.constraints 

53 if not isinstance(c, FixAtoms)] 

54 atoms.constraints.append(FixAtoms(mask=~dynamic)) 

55 

56 def scale_radii(self, scaling_factor): 

57 self.covalent_radii *= scaling_factor 

58 

59 def get_energy(self, atoms: Atoms) -> np.float64: 

60 try: 

61 return atoms.get_potential_energy() 

62 except RuntimeError: 

63 return np.nan # type: ignore[return-value] 

64 

65 def get_forces(self, atoms: Atoms): 

66 try: 

67 return atoms.get_forces(apply_constraint=False) 

68 except RuntimeError: 

69 return None 

70 

71 def initialize(self, images, filenames=None): 

72 nimages = len(images) 

73 if filenames is None: 

74 filenames = [None] * nimages 

75 self.filenames = filenames 

76 

77 warning = False 

78 

79 self._images = [] 

80 

81 # Whether length or chemical composition changes: 

82 self.have_varying_species = False 

83 for i, atoms in enumerate(images): 

84 # copy atoms or not? Not copying allows back-editing, 

85 # but copying actually forgets things like the attached 

86 # calculator (might have forces/energies 

87 self._images.append(atoms) 

88 self.have_varying_species |= not np.array_equal(self[0].numbers, 

89 atoms.numbers) 

90 if hasattr(self, 'Q'): 

91 assert False # XXX askhl fix quaternions 

92 self.Q[i] = atoms.get_quaternions() 

93 if (atoms.pbc != self[0].pbc).any(): 

94 warning = True 

95 

96 if warning: 

97 import warnings 

98 warnings.warn('Not all images have the same boundary conditions!') 

99 

100 self.maxnatoms = max(len(atoms) for atoms in self) 

101 self.selected = np.zeros(self.maxnatoms, bool) 

102 self.selected_ordered = [] 

103 self.visible = np.ones(self.maxnatoms, bool) 

104 self.repeat = np.ones(3, int) 

105 

106 def get_radii(self, atoms: Atoms) -> np.ndarray: 

107 radii = np.array([self.covalent_radii[z] for z in atoms.numbers]) 

108 radii *= self.atom_scale 

109 return radii 

110 

111 def read(self, filenames, default_index=':', filetype=None): 

112 if isinstance(default_index, str): 

113 default_index = string2index(default_index) 

114 

115 images = [] 

116 names = [] 

117 for filename in filenames: 

118 from ase.io.formats import parse_filename 

119 

120 if '@' in filename and 'postgres' not in filename or \ 

121 'postgres' in filename and filename.count('@') == 2: 

122 actual_filename, index = parse_filename(filename, None) 

123 else: 

124 actual_filename, index = parse_filename(filename, 

125 default_index) 

126 

127 # Read from stdin: 

128 if filename == '-': 

129 import sys 

130 from io import BytesIO 

131 buf = BytesIO(sys.stdin.buffer.read()) 

132 buf.seek(0) 

133 filename = buf 

134 filetype = 'traj' 

135 

136 imgs = read(filename, index, filetype) 

137 if hasattr(imgs, 'iterimages'): 

138 imgs = list(imgs.iterimages()) 

139 

140 images.extend(imgs) 

141 

142 # Name each file as filename@index: 

143 if isinstance(index, slice): 

144 start = index.start or 0 

145 step = index.step or 1 

146 else: 

147 start = index 

148 step = 1 

149 for i, img in enumerate(imgs): 

150 if isinstance(start, int): 

151 names.append('{}@{}'.format( 

152 actual_filename, start + i * step)) 

153 else: 

154 names.append(f'{actual_filename}@{start}') 

155 

156 self.initialize(images, names) 

157 

158 def repeat_results(self, atoms: Atoms, repeat=None, oldprod=None): 

159 """Return a dictionary which updates the magmoms, energy and forces 

160 to the repeated amount of atoms. 

161 """ 

162 def getresult(name, get_quantity): 

163 # ase/io/trajectory.py line 170 does this by using 

164 # the get_property(prop, atoms, allow_calculation=False) 

165 # so that is an alternative option. 

166 try: 

167 if (not atoms.calc or 

168 atoms.calc.calculation_required(atoms, [name])): 

169 quantity = None 

170 else: 

171 quantity = get_quantity() 

172 except Exception as err: 

173 quantity = None 

174 errmsg = ('An error occurred while retrieving {} ' 

175 'from the calculator: {}'.format(name, err)) 

176 warnings.warn(errmsg) 

177 return quantity 

178 

179 if repeat is None: 

180 repeat = self.repeat.prod() 

181 if oldprod is None: 

182 oldprod = self.repeat.prod() 

183 

184 results = {} 

185 

186 original_length = len(atoms) // oldprod 

187 newprod = repeat.prod() 

188 

189 # Read the old properties 

190 magmoms = getresult('magmoms', atoms.get_magnetic_moments) 

191 magmom = getresult('magmom', atoms.get_magnetic_moment) 

192 energy = getresult('energy', atoms.get_potential_energy) 

193 forces = getresult('forces', atoms.get_forces) 

194 

195 # Update old properties to the repeated image 

196 if magmoms is not None: 

197 magmoms = np.tile(magmoms[:original_length], newprod) 

198 results['magmoms'] = magmoms 

199 

200 if magmom is not None: 

201 magmom = magmom * newprod / oldprod 

202 results['magmom'] = magmom 

203 

204 if forces is not None: 

205 forces = np.tile(forces[:original_length].T, newprod).T 

206 results['forces'] = forces 

207 

208 if energy is not None: 

209 energy = energy * newprod / oldprod 

210 results['energy'] = energy 

211 

212 return results 

213 

214 def repeat_unit_cell(self): 

215 for atoms in self: 

216 # Get quantities taking into account current repeat():' 

217 results = self.repeat_results(atoms, self.repeat.prod(), 

218 oldprod=self.repeat.prod()) 

219 

220 atoms.cell *= self.repeat.reshape((3, 1)) 

221 atoms.calc = SinglePointCalculator(atoms, **results) 

222 self.repeat = np.ones(3, int) 

223 

224 def repeat_images(self, repeat): 

225 from ase.constraints import FixAtoms 

226 repeat = np.array(repeat) 

227 oldprod = self.repeat.prod() 

228 images = [] 

229 constraints_removed = False 

230 

231 for i, atoms in enumerate(self): 

232 refcell = atoms.get_cell() 

233 fa = [] 

234 for c in atoms._constraints: 

235 if isinstance(c, FixAtoms): 

236 fa.append(c) 

237 else: 

238 constraints_removed = True 

239 atoms.set_constraint(fa) 

240 

241 # Update results dictionary to repeated atoms 

242 results = self.repeat_results(atoms, repeat, oldprod) 

243 

244 del atoms[len(atoms) // oldprod:] # Original atoms 

245 

246 atoms *= repeat 

247 atoms.cell = refcell 

248 

249 atoms.calc = SinglePointCalculator(atoms, **results) 

250 

251 images.append(atoms) 

252 

253 if constraints_removed: 

254 from ase.gui.ui import showwarning, tk 

255 

256 # We must be able to show warning before the main GUI 

257 # has been created. So we create a new window, 

258 # then show the warning, then destroy the window. 

259 tmpwindow = tk.Tk() 

260 tmpwindow.withdraw() # Host window will never be shown 

261 showwarning(_('Constraints discarded'), 

262 _('Constraints other than FixAtoms ' 

263 'have been discarded.')) 

264 tmpwindow.destroy() 

265 

266 self.initialize(images, filenames=self.filenames) 

267 self.repeat = repeat 

268 

269 def center(self): 

270 """Center each image in the existing unit cell, keeping the 

271 cell constant.""" 

272 for atoms in self: 

273 atoms.center() 

274 

275 def graph(self, expr: str) -> np.ndarray: 

276 """Routine to create the data in graphs, defined by the 

277 string expr.""" 

278 import ase.units as units 

279 code = compile(expr + ',', '<input>', 'eval') 

280 

281 nimages = len(self) 

282 

283 def d(n1, n2): 

284 return sqrt(((R[n1] - R[n2])**2).sum()) 

285 

286 def a(n1, n2, n3): 

287 v1 = R[n1] - R[n2] 

288 v2 = R[n3] - R[n2] 

289 arg = np.vdot(v1, v2) / (sqrt((v1**2).sum() * (v2**2).sum())) 

290 if arg > 1.0: 

291 arg = 1.0 

292 if arg < -1.0: 

293 arg = -1.0 

294 return 180.0 * np.arccos(arg) / np.pi 

295 

296 def dih(n1, n2, n3, n4): 

297 # vector 0->1, 1->2, 2->3 and their normalized cross products: 

298 a = R[n2] - R[n1] 

299 b = R[n3] - R[n2] 

300 c = R[n4] - R[n3] 

301 bxa = np.cross(b, a) 

302 bxa /= np.sqrt(np.vdot(bxa, bxa)) 

303 cxb = np.cross(c, b) 

304 cxb /= np.sqrt(np.vdot(cxb, cxb)) 

305 angle = np.vdot(bxa, cxb) 

306 # check for numerical trouble due to finite precision: 

307 if angle < -1: 

308 angle = -1 

309 if angle > 1: 

310 angle = 1 

311 angle = np.arccos(angle) 

312 if np.vdot(bxa, c) > 0: 

313 angle = 2 * np.pi - angle 

314 return angle * 180.0 / np.pi 

315 

316 # get number of mobile atoms for temperature calculation 

317 E = np.array([self.get_energy(atoms) for atoms in self]) 

318 

319 s = 0.0 

320 

321 # Namespace for eval: 

322 ns = {'E': E, 

323 'd': d, 'a': a, 'dih': dih} 

324 

325 data = [] 

326 for i in range(nimages): 

327 ns['i'] = i 

328 ns['s'] = s 

329 ns['R'] = R = self[i].get_positions() 

330 ns['V'] = self[i].get_velocities() 

331 F = self.get_forces(self[i]) 

332 if F is not None: 

333 ns['F'] = F 

334 ns['A'] = self[i].get_cell() 

335 ns['M'] = self[i].get_masses() 

336 # XXX askhl verify: 

337 dynamic = self.get_dynamic(self[i]) 

338 if F is not None: 

339 ns['f'] = f = ((F * dynamic[:, None])**2).sum(1)**.5 

340 ns['fmax'] = max(f) 

341 ns['fave'] = f.mean() 

342 ns['epot'] = epot = E[i] 

343 ns['ekin'] = ekin = self[i].get_kinetic_energy() 

344 ns['e'] = epot + ekin 

345 ndynamic = dynamic.sum() 

346 if ndynamic > 0: 

347 ns['T'] = 2.0 * ekin / (3.0 * ndynamic * units.kB) 

348 data = eval(code, ns) 

349 if i == 0: 

350 nvariables = len(data) 

351 xy = np.empty((nvariables, nimages)) 

352 xy[:, i] = data 

353 if i + 1 < nimages and not self.have_varying_species: 

354 dR = find_mic(self[i + 1].positions - R, self[i].get_cell(), 

355 self[i].get_pbc())[0] 

356 s += sqrt((dR**2).sum()) 

357 return xy 

358 

359 def write(self, filename, rotations='', bbox=None, 

360 **kwargs): 

361 # XXX We should show the unit cell whenever there is one 

362 indices = range(len(self)) 

363 p = filename.rfind('@') 

364 if p != -1: 

365 try: 

366 slice = string2index(filename[p + 1:]) 

367 except ValueError: 

368 pass 

369 else: 

370 indices = indices[slice] 

371 filename = filename[:p] 

372 if isinstance(indices, int): 

373 indices = [indices] 

374 

375 images = [self.get_atoms(i) for i in indices] 

376 if len(filename) > 4 and filename[-4:] in ['.eps', '.png', '.pov']: 

377 write(filename, images, 

378 rotation=rotations, 

379 bbox=bbox, **kwargs) 

380 else: 

381 write(filename, images, **kwargs) 

382 

383 def get_atoms(self, frame, remove_hidden=False): 

384 atoms = self[frame] 

385 try: 

386 E = atoms.get_potential_energy() 

387 except RuntimeError: 

388 E = None 

389 try: 

390 F = atoms.get_forces() 

391 except RuntimeError: 

392 F = None 

393 

394 # Remove hidden atoms if applicable 

395 if remove_hidden: 

396 atoms = atoms[self.visible] 

397 if F is not None: 

398 F = F[self.visible] 

399 atoms.calc = SinglePointCalculator(atoms, energy=E, forces=F) 

400 return atoms 

401 

402 def delete(self, i): 

403 self._images.pop(i) 

404 self.filenames.pop(i) 

405 self.initialize(self._images, self.filenames)