Coverage for /builds/ericyuan00000/ase/ase/utils/filecache.py: 95.92%
196 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
1# fmt: off
3import json
4from collections.abc import Mapping, MutableMapping
5from contextlib import contextmanager
6from pathlib import Path
8from ase.io.jsonio import encode as encode_json
9from ase.io.jsonio import read_json, write_json
10from ase.io.ulm import InvalidULMFileError, NDArrayReader, Writer, ulmopen
11from ase.parallel import world
12from ase.utils import opencew
15def missing(key):
16 raise KeyError(key)
19class Locked(Exception):
20 pass
23# Note:
24#
25# The communicator handling is a complete hack.
26# We should entirely remove communicators from these objects.
27# (Actually: opencew() should not know about communicators.)
28# Then the caller is responsible for handling parallelism,
29# which makes life simpler for both the caller and us!
30#
31# Also, things like clean()/__del__ are not correctly implemented
32# in parallel. The reason why it currently "works" is that
33# we don't call those functions from Vibrations etc., or they do so
34# only for rank==0.
37class JSONBackend:
38 extension = '.json'
39 DecodeError = json.decoder.JSONDecodeError
41 @staticmethod
42 def open_for_writing(path, comm):
43 return opencew(path, world=comm)
45 @staticmethod
46 def read(fname):
47 return read_json(fname, always_array=False)
49 @staticmethod
50 def open_and_write(target, data, comm):
51 if comm.rank == 0:
52 write_json(target, data)
54 @staticmethod
55 def write(fd, value):
56 fd.write(encode_json(value).encode('utf-8'))
58 @classmethod
59 def dump_cache(cls, path, dct, comm):
60 return CombinedJSONCache.dump_cache(path, dct, comm)
62 @classmethod
63 def create_multifile_cache(cls, directory, comm):
64 return MultiFileJSONCache(directory, comm=comm)
67class ULMBackend:
68 extension = '.ulm'
69 DecodeError = InvalidULMFileError
71 @staticmethod
72 def open_for_writing(path, comm):
73 fd = opencew(path, world=comm)
74 if fd is not None:
75 return Writer(fd, 'w', '')
77 @staticmethod
78 def read(fname):
79 with ulmopen(fname, 'r') as r:
80 data = r._data['cache']
81 if isinstance(data, NDArrayReader):
82 return data.read()
83 return data
85 @staticmethod
86 def open_and_write(target, data, comm):
87 if comm.rank == 0:
88 with ulmopen(target, 'w') as w:
89 w.write('cache', data)
91 @staticmethod
92 def write(fd, value):
93 fd.write('cache', value)
95 @classmethod
96 def dump_cache(cls, path, dct, comm):
97 return CombinedULMCache.dump_cache(path, dct, comm)
99 @classmethod
100 def create_multifile_cache(cls, directory, comm):
101 return MultiFileULMCache(directory, comm=comm)
104class CacheLock:
105 def __init__(self, fd, key, backend):
106 self.fd = fd
107 self.key = key
108 self.backend = backend
110 def save(self, value):
111 try:
112 self.backend.write(self.fd, value)
113 except Exception as ex:
114 raise RuntimeError(f'Failed to save {value} to cache') from ex
115 finally:
116 self.fd.close()
119class _MultiFileCacheTemplate(MutableMapping):
120 writable = True
122 def __init__(self, directory, comm=world):
123 self.directory = Path(directory)
124 self.comm = comm
126 def _filename(self, key):
127 return self.directory / (f'cache.{key}' + self.backend.extension)
129 def _glob(self):
130 return self.directory.glob('cache.*' + self.backend.extension)
132 def __iter__(self):
133 for path in self._glob():
134 cache, key = path.stem.split('.', 1)
135 if cache != 'cache':
136 continue
137 yield key
139 def __len__(self):
140 # Very inefficient this, but not a big usecase.
141 return len(list(self._glob()))
143 @contextmanager
144 def lock(self, key):
145 if self.comm.rank == 0:
146 self.directory.mkdir(exist_ok=True, parents=True)
147 path = self._filename(key)
148 fd = self.backend.open_for_writing(path, self.comm)
149 try:
150 if fd is None:
151 yield None
152 else:
153 yield CacheLock(fd, key, self.backend)
154 finally:
155 if fd is not None:
156 fd.close()
158 def __setitem__(self, key, value):
159 with self.lock(key) as handle:
160 if handle is None:
161 raise Locked(key)
162 handle.save(value)
164 def __getitem__(self, key):
165 path = self._filename(key)
166 try:
167 return self.backend.read(path)
168 except FileNotFoundError:
169 missing(key)
170 except self.backend.DecodeError:
171 # May be partially written, which typically means empty
172 # because the file was locked with exclusive-write-open.
173 #
174 # Since we decide what keys we have based on which files exist,
175 # we are obligated to return a value for this case too.
176 # So we return None.
177 return None
179 def __delitem__(self, key):
180 try:
181 self._filename(key).unlink()
182 except FileNotFoundError:
183 missing(key)
185 def combine(self):
186 cache = self.backend.dump_cache(self.directory, dict(self),
187 comm=self.comm)
188 assert set(cache) == set(self)
189 self.clear()
190 assert len(self) == 0
191 return cache
193 def split(self):
194 return self
196 def filecount(self):
197 return len(self)
199 def strip_empties(self):
200 empties = [key for key, value in self.items() if value is None]
201 for key in empties:
202 del self[key]
203 return len(empties)
206class _CombinedCacheTemplate(Mapping):
207 writable = False
209 def __init__(self, directory, dct, comm=world):
210 self.directory = Path(directory)
211 self._dct = dict(dct)
212 self.comm = comm
214 def filecount(self):
215 return int(self._filename.is_file())
217 @property
218 def _filename(self):
219 return self.directory / ('combined' + self.backend.extension)
221 def __len__(self):
222 return len(self._dct)
224 def __iter__(self):
225 return iter(self._dct)
227 def __getitem__(self, index):
228 return self._dct[index]
230 def _dump(self):
231 target = self._filename
232 if target.exists():
233 raise RuntimeError(f'Already exists: {target}')
234 self.directory.mkdir(exist_ok=True, parents=True)
235 self.backend.open_and_write(target, self._dct, comm=self.comm)
237 @classmethod
238 def dump_cache(cls, path, dct, comm=world):
239 cache = cls(path, dct, comm=comm)
240 cache._dump()
241 return cache
243 @classmethod
244 def load(cls, path, comm):
245 # XXX Very hacky this one
246 cache = cls(path, {}, comm=comm)
247 dct = cls.backend.read(cache._filename)
248 cache._dct.update(dct)
249 return cache
251 def clear(self):
252 self._filename.unlink()
253 self._dct.clear()
255 def combine(self):
256 return self
258 def split(self):
259 cache = self.backend.create_multifile_cache(self.directory,
260 comm=self.comm)
261 assert len(cache) == 0
262 cache.update(self)
263 assert set(cache) == set(self)
264 self.clear()
265 return cache
268class MultiFileJSONCache(_MultiFileCacheTemplate):
269 backend = JSONBackend()
272class MultiFileULMCache(_MultiFileCacheTemplate):
273 backend = ULMBackend()
276class CombinedJSONCache(_CombinedCacheTemplate):
277 backend = JSONBackend()
280class CombinedULMCache(_CombinedCacheTemplate):
281 backend = ULMBackend()
284def get_json_cache(directory, comm=world):
285 try:
286 return CombinedJSONCache.load(directory, comm=comm)
287 except FileNotFoundError:
288 return MultiFileJSONCache(directory, comm=comm)