Skip to content

Commit c64523c

Browse files
committed
Mostly docstrings of IO
1 parent 7b630b4 commit c64523c

File tree

5 files changed

+90
-56
lines changed

5 files changed

+90
-56
lines changed

mpi4py_fft/distarray.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ def write(self, filename, name='darray', step=0, global_slice=None,
397397
f.write(step, {name: field}, as_scalar=as_scalar)
398398

399399
def read(self, filename, name='darray', step=0):
400-
"""Read from file ``filename`` into array ``self``
400+
"""Read data ``name`` at index ``step``from file ``filename`` into
401+
``self``
401402
402403
Note
403404
----
@@ -425,7 +426,7 @@ def read(self, filename, name='darray', step=0):
425426
"""
426427
if isinstance(filename, str):
427428
writer = HDF5File if filename.endswith('.h5') else NCFile
428-
f = writer(filename, u=self, mode='r')
429+
f = writer(filename, mode='r')
429430
elif isinstance(filename, FileBase):
430431
f = filename
431432
f.read(self, name, step=step)
@@ -438,8 +439,8 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
438439
----------
439440
pfft : :class:`.PFFT` object
440441
forward_output: boolean, optional
441-
If False then create newDistArray of shape/type for input to
442-
forward transform, otherwise create newDistArray of shape/type for
442+
If False then create DistArray of shape/type for input to
443+
forward transform, otherwise create DistArray of shape/type for
443444
output from forward transform.
444445
val : int or float, optional
445446
Value used to initialize array.

mpi4py_fft/io/file_base.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ class FileBase(object):
1010
1111
Parameters
1212
----------
13+
filename : str, optional
14+
Name of backend file used to store data
1315
domain : sequence, optional
1416
An optional spatial mesh or domain to go with the data.
1517
Sequence of either
@@ -18,14 +20,15 @@ class FileBase(object):
1820
of each dimension, e.g., (0, 2*pi).
1921
- Arrays of coordinates, e.g., np.linspace(0, 2*pi, N). One
2022
array per dimension.
23+
2124
"""
22-
def __init__(self, domain=None, **kw):
25+
def __init__(self, filename=None, domain=None):
2326
self.f = None
24-
self.filename = None
27+
self.filename = filename
2528
self.domain = domain
2629

2730
def _check_domain(self, group, field):
28-
"""Write domain to file"""
31+
"""Check dimensions and store (if missing) self.domain"""
2932
raise NotImplementedError
3033

3134
def write(self, step, fields, **kw):
@@ -75,25 +78,35 @@ def _write(group, u, sl, step, kw, k=None):
7578
_write(g, u[k, l], sl, step, kw)
7679

7780
def read(self, u, name, **kw):
78-
"""Read into array ``u``
81+
"""Read field ``name`` into distributed array ``u``
7982
8083
Parameters
8184
----------
8285
u : array
83-
The array to read into.
86+
The :class:`.DistArray` to read into.
8487
name : str
85-
Name of array to be read.
88+
Name of field to be read.
89+
step : int, optional
90+
Index of field to be read. Default is 0.
8691
"""
8792
raise NotImplementedError
8893

8994
def close(self):
9095
self.f.close()
9196

92-
def open(self):
97+
def open(self, mode='r+'):
98+
"""Open the self.filename file for reading or writing
99+
100+
Parameters
101+
----------
102+
mode : str
103+
Open file in this mode. Default is 'r+'.
104+
"""
93105
raise NotImplementedError
94106

95107
@staticmethod
96108
def backend():
109+
"""Return which backend is used to store data"""
97110
raise NotImplementedError
98111

99112
def _write_slice_step(self, name, step, slices, field, **kwargs):

mpi4py_fft/io/h5py_file.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ class HDF5File(FileBase):
2323
array per dimension.
2424
mode : str, optional
2525
``r``, ``w`` or ``a`` for read, write or append. Default is ``a``.
26+
kw : dict, optional
27+
Optional additional keyword arguments used when creating the file
28+
used to store data.
2629
"""
2730
def __init__(self, h5name, domain=None, mode='a', **kw):
28-
FileBase.__init__(self, domain=domain, **kw)
31+
FileBase.__init__(self, h5name, domain=domain)
2932
import h5py
30-
self.filename = h5name
31-
self.f = h5py.File(h5name, mode, driver="mpio", comm=comm)
33+
self.f = h5py.File(h5name, mode, driver="mpio", comm=comm, **kw)
3234
self.close()
3335

3436
def _check_domain(self, group, field):
35-
"""Check dimensions of domain and write to file"""
3637
if self.domain is None:
3738
self.domain = ((0, 2*np.pi),)*field.dimensions
3839
assert len(self.domain) == field.dimensions
@@ -66,9 +67,9 @@ def _check_domain(self, group, field):
6667
def backend():
6768
return 'hdf5'
6869

69-
def open(self):
70+
def open(self, mode='r+'):
7071
import h5py
71-
self.f = h5py.File(self.filename, 'r+', driver="mpio", comm=comm)
72+
self.f = h5py.File(self.filename, mode, driver="mpio", comm=comm)
7273

7374
def write(self, step, fields, **kw):
7475
"""Write snapshot ``step`` of ``fields`` to HDF5 file
@@ -82,6 +83,8 @@ def write(self, step, fields, **kw):
8283
and either arrays or 2-tuples, respectively. The arrays are complete
8384
arrays to be stored, whereas 2-tuples are arrays with associated
8485
*global* slices.
86+
as_scalar : boolean, optional
87+
Whether to store rank > 0 arrays as scalars. Default is False.
8588
8689
Example
8790
-------
@@ -116,17 +119,6 @@ def write(self, step, fields, **kw):
116119
self.close()
117120

118121
def read(self, u, name, **kw):
119-
"""Read from file ``self`` into array ``u``
120-
121-
Parameters
122-
----------
123-
u : array
124-
The array to read into.
125-
name : str
126-
Name of array to be read.
127-
step : int, optional
128-
Index of field to be read. Default is 0.
129-
"""
130122
step = kw.get('step', 0)
131123
self.open()
132124
s = u.local_slice()

mpi4py_fft/io/nc_file.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,24 @@ class NCFile(FileBase):
2929
mode : str
3030
``r``, ``w`` or ``a`` for read, write or append. Default is ``a``.
3131
clobber : bool, optional
32+
If True (default), opening a file with mode='w' will clobber an
33+
existing file with the same name. If False, an exception will be
34+
raised if a file with the same name already exists.
35+
kw : dict, optional
36+
Optional additional keyword arguments used when creating the backend
37+
file.
3238
3339
Note
3440
----
3541
Each class instance creates one unique NetCDF4-file, with one step-counter.
3642
It is possible to store multiple fields in each file, but all snapshots of
3743
the fields must be taken at the same time. If you want one field stored
3844
every 10th timestep and another every 20th timestep, then use two different
39-
class instances and as such two NetCDF4-files.
45+
class instances with two different filenames ``ncname``.
4046
"""
4147
def __init__(self, ncname, domain=None, mode='a', clobber=True, **kw):
42-
FileBase.__init__(self, domain=domain, **kw)
48+
FileBase.__init__(self, ncname, domain=domain)
4349
from netCDF4 import Dataset
44-
self.filename = ncname
4550
# netCDF4 does not seem to handle 'a' if the file does not already exist
4651
if mode == 'a' and not os.path.exists(ncname):
4752
mode = 'w'
@@ -51,11 +56,9 @@ def __init__(self, ncname, domain=None, mode='a', clobber=True, **kw):
5156
if not 'time' in self.f.variables:
5257
self.f.createDimension('time', None)
5358
self.f.createVariable('time', np.float, ('time'))
54-
5559
self.close()
5660

57-
def _check_domain(self, write_domain, field):
58-
"""Check dimensions of domain and write to file if missing"""
61+
def _check_domain(self, group, field):
5962
N = field.global_shape[field.rank:]
6063
if self.domain is None:
6164
self.domain = []
@@ -92,22 +95,59 @@ def _check_domain(self, write_domain, field):
9295
def backend():
9396
return 'netcdf4'
9497

95-
def open(self):
98+
def open(self, mode='r+'):
9699
from netCDF4 import Dataset
97-
self.f = Dataset(self.filename, mode='r+', parallel=True, comm=comm)
100+
self.f = Dataset(self.filename, mode=mode, parallel=True, comm=comm)
98101

99102
def write(self, step, fields, **kw):
100-
"""Write snapshot step of ``fields`` to NetCDF4 file
103+
"""Write snapshot ``step`` of ``fields`` to NetCDF4 file
101104
102105
Parameters
103106
----------
104107
step : int
105-
Index of snapshot
108+
Index of snapshot.
106109
fields : dict
107110
The fields to be dumped to file. (key, value) pairs are group name
108111
and either arrays or 2-tuples, respectively. The arrays are complete
109112
arrays to be stored, whereas 2-tuples are arrays with associated
110113
*global* slices.
114+
as_scalar : boolean, optional
115+
Whether to store rank > 0 arrays as scalars. Default is False.
116+
117+
Example
118+
-------
119+
>>> from mpi4py import MPI
120+
>>> from mpi4py_fft import PFFT, NCFile, newDistArray
121+
>>> comm = MPI.COMM_WORLD
122+
>>> T = PFFT(comm, (15, 16, 17))
123+
>>> u = newDistArray(T, forward_output=False, val=1)
124+
>>> v = newDistArray(T, forward_output=False, val=2)
125+
>>> f = NCFile('ncfilename.nc', mode='w')
126+
>>> f.write(0, {'u': [u, (u, [slice(None), 4, slice(None)])],
127+
... 'v': [v, (v, [slice(None), 5, 5])]})
128+
>>> f.write(1, {'u': [u, (u, [slice(None), 4, slice(None)])],
129+
... 'v': [v, (v, [slice(None), 5, 5])]})
130+
>>> f.close()
131+
132+
This stores the following datasets to the file ``ncfilename.nc``.
133+
Using in a terminal 'ncdump -h ncfilename.nc', one gets::
134+
135+
netcdf ncfilename {
136+
dimensions:
137+
time = UNLIMITED ; // (2 currently)
138+
x = 15 ;
139+
y = 16 ;
140+
z = 17 ;
141+
variables:
142+
double time(time) ;
143+
double x(x) ;
144+
double y(y) ;
145+
double z(z) ;
146+
double u(time, x, y, z) ;
147+
double u_slice_4_slice(time, x, z) ;
148+
double v(time, x, y, z) ;
149+
double v_slice_5_5(time, x) ;
150+
}
111151
112152
"""
113153
self.open()
@@ -122,17 +162,6 @@ def write(self, step, fields, **kw):
122162
self.close()
123163

124164
def read(self, u, name, **kw):
125-
"""Read into array ``u``
126-
127-
Parameters
128-
----------
129-
u : array
130-
The array to read into.
131-
name : str
132-
Name of array to be read.
133-
step : int, optional
134-
Index of field to be read. Default is 0.
135-
"""
136165
step = kw.get('step', 0)
137166
self.open()
138167
s = u.local_slice()
@@ -141,10 +170,9 @@ def read(self, u, name, **kw):
141170
self.close()
142171

143172
def _write_slice_step(self, name, step, slices, field, **kw):
144-
assert name not in self.dims
173+
assert name not in self.dims # Crashes if user tries to name fields x, y, z, .
145174
rank = field.rank
146-
slices = (slice(None),)*rank + tuple(slices)
147-
slices = list(slices)
175+
slices = list((slice(None),)*rank + tuple(slices))
148176
slname = self._get_slice_name(slices[rank:])
149177
s = field.local_slice()
150178
slices, inside = self._get_local_slices(slices, s)
@@ -168,7 +196,7 @@ def _write_slice_step(self, name, step, slices, field, **kw):
168196
self.f.sync()
169197

170198
def _write_group(self, name, u, step, **kw):
171-
assert name not in self.dims
199+
assert name not in self.dims # Crashes if user tries to name fields x, y, z, .
172200
s = u.local_slice()
173201
if name not in self.f.variables:
174202
h = self.f.createVariable(name, u.dtype, self.dims)

tests/test_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_2D(backend, forward_output):
5151
generate_xdmf(filename, order='visit')
5252

5353
u0 = newDistArray(T, forward_output=forward_output, rank=rank)
54-
read = reader[backend](filename, u=u0)
54+
read = reader[backend](filename)
5555
read.read(u0, 'u', step=0)
5656
u0.read(filename, 'u', 2)
5757
u0.read(read, 'u', 2)
@@ -117,7 +117,7 @@ def test_3D(backend, forward_output):
117117
generate_xdmf('v'+filename, order='visit')
118118

119119
u0 = newDistArray(T, forward_output=forward_output, rank=rank)
120-
read = reader[backend]('uv'+filename, u=u0)
120+
read = reader[backend]('uv'+filename)
121121
read.read(u0, 'u', step=0)
122122
assert np.allclose(u0, u)
123123
read.read(u0, 'v', step=0)
@@ -153,7 +153,7 @@ def test_4D(backend, forward_output):
153153
generate_xdmf('uv'+filename)
154154

155155
u0 = newDistArray(T, forward_output=forward_output, rank=rank)
156-
read = reader[backend]('uv'+filename, u=u0)
156+
read = reader[backend]('uv'+filename)
157157
read.read(u0, 'u', step=0)
158158
assert np.allclose(u0, u)
159159
read.read(u0, 'v', step=0)

0 commit comments

Comments
 (0)