Coverage for /builds/ericyuan00000/ase/ase/io/jsonio.py: 89.09%
110 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-06-18 01:20 +0000
1# fmt: off
3import datetime
4import json
6import numpy as np
8from ase.utils import reader, writer
10# Note: We are converting JSON classes to the recommended mechanisms
11# by the json module. That means instead of classes, we will use the
12# functions default() and object_hook().
13#
14# The encoder classes are to be deprecated (but maybe not removed, if
15# widely used).
18def default(obj):
19 if hasattr(obj, 'todict'):
20 dct = obj.todict()
22 if not isinstance(dct, dict):
23 raise RuntimeError('todict() of {} returned object of type {} '
24 'but should have returned dict'
25 .format(obj, type(dct)))
26 if hasattr(obj, 'ase_objtype'):
27 # We modify the dictionary, so it is wise to take a copy.
28 dct = dct.copy()
29 dct['__ase_objtype__'] = obj.ase_objtype
31 return dct
32 if isinstance(obj, np.ndarray):
33 flatobj = obj.ravel()
34 if np.iscomplexobj(obj):
35 flatobj.dtype = obj.real.dtype
36 # We use str(obj.dtype) here instead of obj.dtype.name, because
37 # they are not always the same (e.g. for numpy arrays of strings).
38 # Using obj.dtype.name can break the ability to recursively decode/
39 # encode such arrays.
40 return {'__ndarray__': (obj.shape,
41 str(obj.dtype),
42 flatobj.tolist())}
43 if isinstance(obj, np.integer):
44 return int(obj)
45 if isinstance(obj, np.bool_):
46 return bool(obj)
47 if isinstance(obj, datetime.datetime):
48 return {'__datetime__': obj.isoformat()}
49 if isinstance(obj, complex):
50 return {'__complex__': (obj.real, obj.imag)}
52 raise TypeError(f'Cannot convert object of type {type(obj)} to '
53 'dictionary for JSON')
56class MyEncoder(json.JSONEncoder):
57 def default(self, obj):
58 # (Note the name "default" comes from the outer namespace, so
59 # not actually recursive)
60 return default(obj)
63encode = MyEncoder().encode
66def object_hook(dct):
67 if '__datetime__' in dct:
68 return datetime.datetime.strptime(dct['__datetime__'],
69 '%Y-%m-%dT%H:%M:%S.%f')
71 if '__complex__' in dct:
72 return complex(*dct['__complex__'])
74 if '__ndarray__' in dct:
75 return create_ndarray(*dct['__ndarray__'])
77 # No longer used (only here for backwards compatibility):
78 if '__complex_ndarray__' in dct:
79 r, i = (np.array(x) for x in dct['__complex_ndarray__'])
80 return r + i * 1j
82 if '__ase_objtype__' in dct:
83 objtype = dct.pop('__ase_objtype__')
84 dct = numpyfy(dct)
85 return create_ase_object(objtype, dct)
87 return dct
90def create_ndarray(shape, dtype, data):
91 """Create ndarray from shape, dtype and flattened data."""
92 array = np.empty(shape, dtype=dtype)
93 flatbuf = array.ravel()
94 if np.iscomplexobj(array):
95 flatbuf.dtype = array.real.dtype
96 flatbuf[:] = data
97 return array
100def create_ase_object(objtype, dct):
101 # We just try each object type one after another and instantiate
102 # them manually, depending on which kind it is.
103 # We can formalize this later if it ever becomes necessary.
104 if objtype == 'cell':
105 from ase.cell import Cell
106 dct.pop('pbc', None) # compatibility; we once had pbc
107 obj = Cell(**dct)
108 elif objtype == 'bandstructure':
109 from ase.spectrum.band_structure import BandStructure
110 obj = BandStructure(**dct)
111 elif objtype == 'bandpath':
112 from ase.dft.kpoints import BandPath
113 obj = BandPath(path=dct.pop('labelseq'), **dct)
114 elif objtype == 'atoms':
115 from ase import Atoms
116 obj = Atoms.fromdict(dct)
117 elif objtype == 'vibrationsdata':
118 from ase.vibrations import VibrationsData
119 obj = VibrationsData.fromdict(dct)
120 else:
121 raise ValueError('Do not know how to decode object type {} '
122 'into an actual object'.format(objtype))
123 assert obj.ase_objtype == objtype
124 return obj
127mydecode = json.JSONDecoder(object_hook=object_hook).decode
130def intkey(key):
131 """Convert str to int if possible."""
132 try:
133 return int(key)
134 except ValueError:
135 return key
138def fix_int_keys_in_dicts(obj):
139 """Convert "int" keys: "1" -> 1.
141 The json.dump() function will convert int keys in dicts to str keys.
142 This function goes the other way.
143 """
144 if isinstance(obj, dict):
145 return {intkey(key): fix_int_keys_in_dicts(value)
146 for key, value in obj.items()}
147 return obj
150def numpyfy(obj):
151 if isinstance(obj, dict):
152 if '__complex_ndarray__' in obj:
153 r, i = (np.array(x) for x in obj['__complex_ndarray__'])
154 return r + i * 1j
155 if isinstance(obj, list) and len(obj) > 0:
156 try:
157 a = np.array(obj)
158 except ValueError:
159 pass
160 else:
161 if a.dtype in [bool, int, float]:
162 return a
163 obj = [numpyfy(value) for value in obj]
164 return obj
167def decode(txt, always_array=True):
168 obj = mydecode(txt)
169 obj = fix_int_keys_in_dicts(obj)
170 if always_array:
171 obj = numpyfy(obj)
172 return obj
175@reader
176def read_json(fd, always_array=True):
177 dct = decode(fd.read(), always_array=always_array)
178 return dct
181@writer
182def write_json(fd, obj):
183 fd.write(encode(obj))