Coverage for /builds/ericyuan00000/ase/ase/utils/filecache.py: 95.92%

196 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-06-18 01:20 +0000

1# fmt: off 

2 

3import json 

4from collections.abc import Mapping, MutableMapping 

5from contextlib import contextmanager 

6from pathlib import Path 

7 

8from ase.io.jsonio import encode as encode_json 

9from ase.io.jsonio import read_json, write_json 

10from ase.io.ulm import InvalidULMFileError, NDArrayReader, Writer, ulmopen 

11from ase.parallel import world 

12from ase.utils import opencew 

13 

14 

15def missing(key): 

16 raise KeyError(key) 

17 

18 

19class Locked(Exception): 

20 pass 

21 

22 

23# Note: 

24# 

25# The communicator handling is a complete hack. 

26# We should entirely remove communicators from these objects. 

27# (Actually: opencew() should not know about communicators.) 

28# Then the caller is responsible for handling parallelism, 

29# which makes life simpler for both the caller and us! 

30# 

31# Also, things like clean()/__del__ are not correctly implemented 

32# in parallel. The reason why it currently "works" is that 

33# we don't call those functions from Vibrations etc., or they do so 

34# only for rank==0. 

35 

36 

37class JSONBackend: 

38 extension = '.json' 

39 DecodeError = json.decoder.JSONDecodeError 

40 

41 @staticmethod 

42 def open_for_writing(path, comm): 

43 return opencew(path, world=comm) 

44 

45 @staticmethod 

46 def read(fname): 

47 return read_json(fname, always_array=False) 

48 

49 @staticmethod 

50 def open_and_write(target, data, comm): 

51 if comm.rank == 0: 

52 write_json(target, data) 

53 

54 @staticmethod 

55 def write(fd, value): 

56 fd.write(encode_json(value).encode('utf-8')) 

57 

58 @classmethod 

59 def dump_cache(cls, path, dct, comm): 

60 return CombinedJSONCache.dump_cache(path, dct, comm) 

61 

62 @classmethod 

63 def create_multifile_cache(cls, directory, comm): 

64 return MultiFileJSONCache(directory, comm=comm) 

65 

66 

67class ULMBackend: 

68 extension = '.ulm' 

69 DecodeError = InvalidULMFileError 

70 

71 @staticmethod 

72 def open_for_writing(path, comm): 

73 fd = opencew(path, world=comm) 

74 if fd is not None: 

75 return Writer(fd, 'w', '') 

76 

77 @staticmethod 

78 def read(fname): 

79 with ulmopen(fname, 'r') as r: 

80 data = r._data['cache'] 

81 if isinstance(data, NDArrayReader): 

82 return data.read() 

83 return data 

84 

85 @staticmethod 

86 def open_and_write(target, data, comm): 

87 if comm.rank == 0: 

88 with ulmopen(target, 'w') as w: 

89 w.write('cache', data) 

90 

91 @staticmethod 

92 def write(fd, value): 

93 fd.write('cache', value) 

94 

95 @classmethod 

96 def dump_cache(cls, path, dct, comm): 

97 return CombinedULMCache.dump_cache(path, dct, comm) 

98 

99 @classmethod 

100 def create_multifile_cache(cls, directory, comm): 

101 return MultiFileULMCache(directory, comm=comm) 

102 

103 

104class CacheLock: 

105 def __init__(self, fd, key, backend): 

106 self.fd = fd 

107 self.key = key 

108 self.backend = backend 

109 

110 def save(self, value): 

111 try: 

112 self.backend.write(self.fd, value) 

113 except Exception as ex: 

114 raise RuntimeError(f'Failed to save {value} to cache') from ex 

115 finally: 

116 self.fd.close() 

117 

118 

119class _MultiFileCacheTemplate(MutableMapping): 

120 writable = True 

121 

122 def __init__(self, directory, comm=world): 

123 self.directory = Path(directory) 

124 self.comm = comm 

125 

126 def _filename(self, key): 

127 return self.directory / (f'cache.{key}' + self.backend.extension) 

128 

129 def _glob(self): 

130 return self.directory.glob('cache.*' + self.backend.extension) 

131 

132 def __iter__(self): 

133 for path in self._glob(): 

134 cache, key = path.stem.split('.', 1) 

135 if cache != 'cache': 

136 continue 

137 yield key 

138 

139 def __len__(self): 

140 # Very inefficient this, but not a big usecase. 

141 return len(list(self._glob())) 

142 

143 @contextmanager 

144 def lock(self, key): 

145 if self.comm.rank == 0: 

146 self.directory.mkdir(exist_ok=True, parents=True) 

147 path = self._filename(key) 

148 fd = self.backend.open_for_writing(path, self.comm) 

149 try: 

150 if fd is None: 

151 yield None 

152 else: 

153 yield CacheLock(fd, key, self.backend) 

154 finally: 

155 if fd is not None: 

156 fd.close() 

157 

158 def __setitem__(self, key, value): 

159 with self.lock(key) as handle: 

160 if handle is None: 

161 raise Locked(key) 

162 handle.save(value) 

163 

164 def __getitem__(self, key): 

165 path = self._filename(key) 

166 try: 

167 return self.backend.read(path) 

168 except FileNotFoundError: 

169 missing(key) 

170 except self.backend.DecodeError: 

171 # May be partially written, which typically means empty 

172 # because the file was locked with exclusive-write-open. 

173 # 

174 # Since we decide what keys we have based on which files exist, 

175 # we are obligated to return a value for this case too. 

176 # So we return None. 

177 return None 

178 

179 def __delitem__(self, key): 

180 try: 

181 self._filename(key).unlink() 

182 except FileNotFoundError: 

183 missing(key) 

184 

185 def combine(self): 

186 cache = self.backend.dump_cache(self.directory, dict(self), 

187 comm=self.comm) 

188 assert set(cache) == set(self) 

189 self.clear() 

190 assert len(self) == 0 

191 return cache 

192 

193 def split(self): 

194 return self 

195 

196 def filecount(self): 

197 return len(self) 

198 

199 def strip_empties(self): 

200 empties = [key for key, value in self.items() if value is None] 

201 for key in empties: 

202 del self[key] 

203 return len(empties) 

204 

205 

206class _CombinedCacheTemplate(Mapping): 

207 writable = False 

208 

209 def __init__(self, directory, dct, comm=world): 

210 self.directory = Path(directory) 

211 self._dct = dict(dct) 

212 self.comm = comm 

213 

214 def filecount(self): 

215 return int(self._filename.is_file()) 

216 

217 @property 

218 def _filename(self): 

219 return self.directory / ('combined' + self.backend.extension) 

220 

221 def __len__(self): 

222 return len(self._dct) 

223 

224 def __iter__(self): 

225 return iter(self._dct) 

226 

227 def __getitem__(self, index): 

228 return self._dct[index] 

229 

230 def _dump(self): 

231 target = self._filename 

232 if target.exists(): 

233 raise RuntimeError(f'Already exists: {target}') 

234 self.directory.mkdir(exist_ok=True, parents=True) 

235 self.backend.open_and_write(target, self._dct, comm=self.comm) 

236 

237 @classmethod 

238 def dump_cache(cls, path, dct, comm=world): 

239 cache = cls(path, dct, comm=comm) 

240 cache._dump() 

241 return cache 

242 

243 @classmethod 

244 def load(cls, path, comm): 

245 # XXX Very hacky this one 

246 cache = cls(path, {}, comm=comm) 

247 dct = cls.backend.read(cache._filename) 

248 cache._dct.update(dct) 

249 return cache 

250 

251 def clear(self): 

252 self._filename.unlink() 

253 self._dct.clear() 

254 

255 def combine(self): 

256 return self 

257 

258 def split(self): 

259 cache = self.backend.create_multifile_cache(self.directory, 

260 comm=self.comm) 

261 assert len(cache) == 0 

262 cache.update(self) 

263 assert set(cache) == set(self) 

264 self.clear() 

265 return cache 

266 

267 

268class MultiFileJSONCache(_MultiFileCacheTemplate): 

269 backend = JSONBackend() 

270 

271 

272class MultiFileULMCache(_MultiFileCacheTemplate): 

273 backend = ULMBackend() 

274 

275 

276class CombinedJSONCache(_CombinedCacheTemplate): 

277 backend = JSONBackend() 

278 

279 

280class CombinedULMCache(_CombinedCacheTemplate): 

281 backend = ULMBackend() 

282 

283 

284def get_json_cache(directory, comm=world): 

285 try: 

286 return CombinedJSONCache.load(directory, comm=comm) 

287 except FileNotFoundError: 

288 return MultiFileJSONCache(directory, comm=comm)