Coverage for /builds/ericyuan00000/ase/ase/io/jsonio.py: 89.09%

110 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-06-18 01:20 +0000

1# fmt: off 

2 

3import datetime 

4import json 

5 

6import numpy as np 

7 

8from ase.utils import reader, writer 

9 

10# Note: We are converting JSON classes to the recommended mechanisms 

11# by the json module. That means instead of classes, we will use the 

12# functions default() and object_hook(). 

13# 

14# The encoder classes are to be deprecated (but maybe not removed, if 

15# widely used). 

16 

17 

18def default(obj): 

19 if hasattr(obj, 'todict'): 

20 dct = obj.todict() 

21 

22 if not isinstance(dct, dict): 

23 raise RuntimeError('todict() of {} returned object of type {} ' 

24 'but should have returned dict' 

25 .format(obj, type(dct))) 

26 if hasattr(obj, 'ase_objtype'): 

27 # We modify the dictionary, so it is wise to take a copy. 

28 dct = dct.copy() 

29 dct['__ase_objtype__'] = obj.ase_objtype 

30 

31 return dct 

32 if isinstance(obj, np.ndarray): 

33 flatobj = obj.ravel() 

34 if np.iscomplexobj(obj): 

35 flatobj.dtype = obj.real.dtype 

36 # We use str(obj.dtype) here instead of obj.dtype.name, because 

37 # they are not always the same (e.g. for numpy arrays of strings). 

38 # Using obj.dtype.name can break the ability to recursively decode/ 

39 # encode such arrays. 

40 return {'__ndarray__': (obj.shape, 

41 str(obj.dtype), 

42 flatobj.tolist())} 

43 if isinstance(obj, np.integer): 

44 return int(obj) 

45 if isinstance(obj, np.bool_): 

46 return bool(obj) 

47 if isinstance(obj, datetime.datetime): 

48 return {'__datetime__': obj.isoformat()} 

49 if isinstance(obj, complex): 

50 return {'__complex__': (obj.real, obj.imag)} 

51 

52 raise TypeError(f'Cannot convert object of type {type(obj)} to ' 

53 'dictionary for JSON') 

54 

55 

56class MyEncoder(json.JSONEncoder): 

57 def default(self, obj): 

58 # (Note the name "default" comes from the outer namespace, so 

59 # not actually recursive) 

60 return default(obj) 

61 

62 

63encode = MyEncoder().encode 

64 

65 

66def object_hook(dct): 

67 if '__datetime__' in dct: 

68 return datetime.datetime.strptime(dct['__datetime__'], 

69 '%Y-%m-%dT%H:%M:%S.%f') 

70 

71 if '__complex__' in dct: 

72 return complex(*dct['__complex__']) 

73 

74 if '__ndarray__' in dct: 

75 return create_ndarray(*dct['__ndarray__']) 

76 

77 # No longer used (only here for backwards compatibility): 

78 if '__complex_ndarray__' in dct: 

79 r, i = (np.array(x) for x in dct['__complex_ndarray__']) 

80 return r + i * 1j 

81 

82 if '__ase_objtype__' in dct: 

83 objtype = dct.pop('__ase_objtype__') 

84 dct = numpyfy(dct) 

85 return create_ase_object(objtype, dct) 

86 

87 return dct 

88 

89 

90def create_ndarray(shape, dtype, data): 

91 """Create ndarray from shape, dtype and flattened data.""" 

92 array = np.empty(shape, dtype=dtype) 

93 flatbuf = array.ravel() 

94 if np.iscomplexobj(array): 

95 flatbuf.dtype = array.real.dtype 

96 flatbuf[:] = data 

97 return array 

98 

99 

100def create_ase_object(objtype, dct): 

101 # We just try each object type one after another and instantiate 

102 # them manually, depending on which kind it is. 

103 # We can formalize this later if it ever becomes necessary. 

104 if objtype == 'cell': 

105 from ase.cell import Cell 

106 dct.pop('pbc', None) # compatibility; we once had pbc 

107 obj = Cell(**dct) 

108 elif objtype == 'bandstructure': 

109 from ase.spectrum.band_structure import BandStructure 

110 obj = BandStructure(**dct) 

111 elif objtype == 'bandpath': 

112 from ase.dft.kpoints import BandPath 

113 obj = BandPath(path=dct.pop('labelseq'), **dct) 

114 elif objtype == 'atoms': 

115 from ase import Atoms 

116 obj = Atoms.fromdict(dct) 

117 elif objtype == 'vibrationsdata': 

118 from ase.vibrations import VibrationsData 

119 obj = VibrationsData.fromdict(dct) 

120 else: 

121 raise ValueError('Do not know how to decode object type {} ' 

122 'into an actual object'.format(objtype)) 

123 assert obj.ase_objtype == objtype 

124 return obj 

125 

126 

127mydecode = json.JSONDecoder(object_hook=object_hook).decode 

128 

129 

130def intkey(key): 

131 """Convert str to int if possible.""" 

132 try: 

133 return int(key) 

134 except ValueError: 

135 return key 

136 

137 

138def fix_int_keys_in_dicts(obj): 

139 """Convert "int" keys: "1" -> 1. 

140 

141 The json.dump() function will convert int keys in dicts to str keys. 

142 This function goes the other way. 

143 """ 

144 if isinstance(obj, dict): 

145 return {intkey(key): fix_int_keys_in_dicts(value) 

146 for key, value in obj.items()} 

147 return obj 

148 

149 

150def numpyfy(obj): 

151 if isinstance(obj, dict): 

152 if '__complex_ndarray__' in obj: 

153 r, i = (np.array(x) for x in obj['__complex_ndarray__']) 

154 return r + i * 1j 

155 if isinstance(obj, list) and len(obj) > 0: 

156 try: 

157 a = np.array(obj) 

158 except ValueError: 

159 pass 

160 else: 

161 if a.dtype in [bool, int, float]: 

162 return a 

163 obj = [numpyfy(value) for value in obj] 

164 return obj 

165 

166 

167def decode(txt, always_array=True): 

168 obj = mydecode(txt) 

169 obj = fix_int_keys_in_dicts(obj) 

170 if always_array: 

171 obj = numpyfy(obj) 

172 return obj 

173 

174 

175@reader 

176def read_json(fd, always_array=True): 

177 dct = decode(fd.read(), always_array=always_array) 

178 return dct 

179 

180 

181@writer 

182def write_json(fd, obj): 

183 fd.write(encode(obj))