Coverage for /builds/ericyuan00000/ase/ase/optimize/optimize.py: 94.83%
174 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
1# fmt: off
3"""Structure optimization. """
4import time
5import warnings
6from collections.abc import Callable
7from functools import cached_property
8from math import sqrt
9from os.path import isfile
10from pathlib import Path
11from typing import IO, Any, Dict, List, Optional, Tuple, Union
13from ase import Atoms
14from ase.calculators.calculator import PropertyNotImplementedError
15from ase.filters import UnitCellFilter
16from ase.parallel import world
17from ase.utils import IOContext
18from ase.utils.abc import Optimizable
20DEFAULT_MAX_STEPS = 100_000_000
23class RestartError(RuntimeError):
24 pass
27class OptimizableAtoms(Optimizable):
28 def __init__(self, atoms):
29 self.atoms = atoms
31 def get_positions(self):
32 return self.atoms.get_positions()
34 def set_positions(self, positions):
35 self.atoms.set_positions(positions)
37 def get_forces(self):
38 return self.atoms.get_forces()
40 @cached_property
41 def _use_force_consistent_energy(self):
42 # This boolean is in principle invalidated if the
43 # calculator changes. This can lead to weird things
44 # in multi-step optimizations.
45 try:
46 self.atoms.get_potential_energy(force_consistent=True)
47 except PropertyNotImplementedError:
48 # warnings.warn(
49 # 'Could not get force consistent energy (\'free_energy\'). '
50 # 'Please make sure calculator provides \'free_energy\', even '
51 # 'if equal to the ordinary energy. '
52 # 'This will raise an error in future versions of ASE.',
53 # FutureWarning)
54 return False
55 else:
56 return True
58 def get_potential_energy(self):
59 force_consistent = self._use_force_consistent_energy
60 return self.atoms.get_potential_energy(
61 force_consistent=force_consistent)
63 def iterimages(self):
64 # XXX document purpose of iterimages
65 return self.atoms.iterimages()
67 def __len__(self):
68 # TODO: return 3 * len(self.atoms), because we want the length
69 # of this to be the number of DOFs
70 return len(self.atoms)
73class Dynamics(IOContext):
74 """Base-class for all MD and structure optimization classes."""
76 def __init__(
77 self,
78 atoms: Atoms,
79 logfile: Optional[Union[IO, Path, str]] = None,
80 trajectory: Optional[Union[str, Path]] = None,
81 append_trajectory: bool = False,
82 master: Optional[bool] = None,
83 comm=world,
84 *,
85 loginterval: int = 1,
86 ):
87 """Dynamics object.
89 Parameters
90 ----------
91 atoms : Atoms object
92 The Atoms object to operate on.
94 logfile : file object, Path, or str
95 If *logfile* is a string, a file with that name will be opened.
96 Use '-' for stdout.
98 trajectory : Trajectory object, str, or Path
99 Attach a trajectory object. If *trajectory* is a string/Path, a
100 Trajectory will be constructed. Use *None* for no trajectory.
102 append_trajectory : bool
103 Defaults to False, which causes the trajectory file to be
104 overwriten each time the dynamics is restarted from scratch.
105 If True, the new structures are appended to the trajectory
106 file instead.
108 master : bool
109 Defaults to None, which causes only rank 0 to save files. If set to
110 true, this rank will save files.
112 comm : Communicator object
113 Communicator to handle parallel file reading and writing.
115 loginterval : int, default: 1
116 Only write a log line for every *loginterval* time steps.
117 """
118 self.atoms = atoms
119 self.optimizable = atoms.__ase_optimizable__()
120 self.logfile = self.openfile(file=logfile, comm=comm, mode='a')
121 self.observers: List[Tuple[Callable, int, Tuple, Dict[str, Any]]] = []
122 self.nsteps = 0
123 self.max_steps = 0 # to be updated in run or irun
124 self.comm = comm
126 if trajectory is not None:
127 if isinstance(trajectory, str) or isinstance(trajectory, Path):
128 from ase.io.trajectory import Trajectory
129 mode = "a" if append_trajectory else "w"
130 trajectory = self.closelater(Trajectory(
131 trajectory, mode=mode, master=master, comm=comm
132 ))
133 self.attach(
134 trajectory,
135 interval=loginterval,
136 atoms=self.optimizable,
137 )
139 self.trajectory = trajectory
141 def todict(self) -> Dict[str, Any]:
142 raise NotImplementedError
144 def get_number_of_steps(self):
145 return self.nsteps
147 def insert_observer(
148 self, function, position=0, interval=1, *args, **kwargs
149 ):
150 """Insert an observer.
152 This can be used for pre-processing before logging and dumping.
154 Examples
155 --------
156 >>> from ase.build import bulk
157 >>> from ase.calculators.emt import EMT
158 >>> from ase.optimize import BFGS
159 ...
160 ...
161 >>> def update_info(atoms, opt):
162 ... atoms.info["nsteps"] = opt.nsteps
163 ...
164 ...
165 >>> atoms = bulk("Cu", cubic=True) * 2
166 >>> atoms.rattle()
167 >>> atoms.calc = EMT()
168 >>> with BFGS(atoms, logfile=None, trajectory="opt.traj") as opt:
169 ... opt.insert_observer(update_info, atoms=atoms, opt=opt)
170 ... opt.run(fmax=0.05, steps=10)
171 True
172 """
173 if not isinstance(function, Callable):
174 function = function.write
175 self.observers.insert(position, (function, interval, args, kwargs))
177 def attach(self, function, interval=1, *args, **kwargs):
178 """Attach callback function.
180 If *interval > 0*, at every *interval* steps, call *function* with
181 arguments *args* and keyword arguments *kwargs*.
183 If *interval <= 0*, after step *interval*, call *function* with
184 arguments *args* and keyword arguments *kwargs*. This is
185 currently zero indexed."""
187 if hasattr(function, "set_description"):
188 d = self.todict()
189 d.update(interval=interval)
190 function.set_description(d)
191 if not isinstance(function, Callable):
192 function = function.write
193 self.observers.append((function, interval, args, kwargs))
195 def call_observers(self):
196 for function, interval, args, kwargs in self.observers:
197 call = False
198 # Call every interval iterations
199 if interval > 0:
200 if (self.nsteps % interval) == 0:
201 call = True
202 # Call only on iteration interval
203 elif interval <= 0:
204 if self.nsteps == abs(interval):
205 call = True
206 if call:
207 function(*args, **kwargs)
209 def irun(self, steps=DEFAULT_MAX_STEPS):
210 """Run dynamics algorithm as generator.
212 Parameters
213 ----------
214 steps : int, default=DEFAULT_MAX_STEPS
215 Number of dynamics steps to be run.
217 Yields
218 ------
219 converged : bool
220 True if the forces on atoms are converged.
222 Examples
223 --------
224 This method allows, e.g., to run two optimizers or MD thermostats at
225 the same time.
226 >>> opt1 = BFGS(atoms)
227 >>> opt2 = BFGS(StrainFilter(atoms)).irun()
228 >>> for _ in opt2:
229 ... opt1.run()
230 """
232 # update the maximum number of steps
233 self.max_steps = self.nsteps + steps
235 # compute the initial step
236 self.optimizable.get_forces()
238 # log the initial step
239 if self.nsteps == 0:
240 self.log()
242 # we write a trajectory file if it is None
243 if self.trajectory is None:
244 self.call_observers()
245 # We do not write on restart w/ an existing trajectory file
246 # present. This duplicates the same entry twice
247 elif len(self.trajectory) == 0:
248 self.call_observers()
250 # check convergence
251 is_converged = self.converged()
252 yield is_converged
254 # run the algorithm until converged or max_steps reached
255 while not is_converged and self.nsteps < self.max_steps:
256 # compute the next step
257 self.step()
258 self.nsteps += 1
260 # log the step
261 self.log()
262 self.call_observers()
264 # check convergence
265 is_converged = self.converged()
266 yield is_converged
268 def run(self, steps=DEFAULT_MAX_STEPS):
269 """Run dynamics algorithm.
271 This method will return when the forces on all individual
272 atoms are less than *fmax* or when the number of steps exceeds
273 *steps*.
275 Parameters
276 ----------
277 steps : int, default=DEFAULT_MAX_STEPS
278 Number of dynamics steps to be run.
280 Returns
281 -------
282 converged : bool
283 True if the forces on atoms are converged.
284 """
286 for converged in Dynamics.irun(self, steps=steps):
287 pass
288 return converged
290 def converged(self):
291 """" a dummy function as placeholder for a real criterion, e.g. in
292 Optimizer """
293 return False
295 def log(self, *args):
296 """ a dummy function as placeholder for a real logger, e.g. in
297 Optimizer """
298 return True
300 def step(self):
301 """this needs to be implemented by subclasses"""
302 raise RuntimeError("step not implemented.")
305class Optimizer(Dynamics):
306 """Base-class for all structure optimization classes."""
308 # default maxstep for all optimizers
309 defaults = {'maxstep': 0.2}
310 _deprecated = object()
312 def __init__(
313 self,
314 atoms: Atoms,
315 restart: Optional[str] = None,
316 logfile: Optional[Union[IO, str, Path]] = None,
317 trajectory: Optional[Union[str, Path]] = None,
318 append_trajectory: bool = False,
319 **kwargs,
320 ):
321 """
323 Parameters
324 ----------
325 atoms: :class:`~ase.Atoms`
326 The Atoms object to relax.
328 restart: str
329 Filename for restart file. Default value is *None*.
331 logfile: file object, Path, or str
332 If *logfile* is a string, a file with that name will be opened.
333 Use '-' for stdout.
335 trajectory: Trajectory object, Path, or str
336 Attach trajectory object. If *trajectory* is a string a
337 Trajectory will be constructed. Use *None* for no
338 trajectory.
340 append_trajectory: bool
341 Appended to the trajectory file instead of overwriting it.
343 kwargs : dict, optional
344 Extra arguments passed to :class:`~ase.optimize.optimize.Dynamics`.
346 """
347 super().__init__(
348 atoms=atoms,
349 logfile=logfile,
350 trajectory=trajectory,
351 append_trajectory=append_trajectory,
352 **kwargs,
353 )
355 self.restart = restart
357 self.fmax = None
359 if restart is None or not isfile(restart):
360 self.initialize()
361 else:
362 self.read()
363 self.comm.barrier()
365 def read(self):
366 raise NotImplementedError
368 def todict(self):
369 description = {
370 "type": "optimization",
371 "optimizer": self.__class__.__name__,
372 }
373 # add custom attributes from subclasses
374 for attr in ('maxstep', 'alpha', 'max_steps', 'restart',
375 'fmax'):
376 if hasattr(self, attr):
377 description.update({attr: getattr(self, attr)})
378 return description
380 def initialize(self):
381 pass
383 def irun(self, fmax=0.05, steps=DEFAULT_MAX_STEPS):
384 """Run optimizer as generator.
386 Parameters
387 ----------
388 fmax : float
389 Convergence criterion of the forces on atoms.
390 steps : int, default=DEFAULT_MAX_STEPS
391 Number of optimizer steps to be run.
393 Yields
394 ------
395 converged : bool
396 True if the forces on atoms are converged.
397 """
398 self.fmax = fmax
399 return Dynamics.irun(self, steps=steps)
401 def run(self, fmax=0.05, steps=DEFAULT_MAX_STEPS):
402 """Run optimizer.
404 Parameters
405 ----------
406 fmax : float
407 Convergence criterion of the forces on atoms.
408 steps : int, default=DEFAULT_MAX_STEPS
409 Number of optimizer steps to be run.
411 Returns
412 -------
413 converged : bool
414 True if the forces on atoms are converged.
415 """
416 self.fmax = fmax
417 return Dynamics.run(self, steps=steps)
419 def converged(self, forces=None):
420 """Did the optimization converge?"""
421 if forces is None:
422 forces = self.optimizable.get_forces()
423 return self.optimizable.converged(forces, self.fmax)
425 def log(self, forces=None):
426 if forces is None:
427 forces = self.optimizable.get_forces()
428 fmax = sqrt((forces ** 2).sum(axis=1).max())
429 e = self.optimizable.get_potential_energy()
430 T = time.localtime()
431 if self.logfile is not None:
432 name = self.__class__.__name__
433 if self.nsteps == 0:
434 args = (" " * len(name), "Step", "Time", "Energy", "fmax")
435 msg = "%s %4s %8s %15s %12s\n" % args
436 self.logfile.write(msg)
438 args = (name, self.nsteps, T[3], T[4], T[5], e, fmax)
439 msg = "%s: %3d %02d:%02d:%02d %15.6f %15.6f\n" % args
440 self.logfile.write(msg)
441 self.logfile.flush()
443 def dump(self, data):
444 from ase.io.jsonio import write_json
445 if self.comm.rank == 0 and self.restart is not None:
446 with open(self.restart, 'w') as fd:
447 write_json(fd, data)
449 def load(self):
450 from ase.io.jsonio import read_json
451 with open(self.restart) as fd:
452 try:
453 from ase.optimize import BFGS
454 if not isinstance(self, BFGS) and isinstance(
455 self.atoms, UnitCellFilter
456 ):
457 warnings.warn(
458 "WARNING: restart function is untested and may result "
459 "in unintended behavior. Namely orig_cell is not "
460 "loaded in the UnitCellFilter. Please test on your own"
461 " to ensure consistent results."
462 )
463 return read_json(fd, always_array=False)
464 except Exception as ex:
465 msg = ('Could not decode restart file as JSON. '
466 'You may need to delete the restart file '
467 f'{self.restart}')
468 raise RestartError(msg) from ex