FLAME  devel
 All Classes Functions Variables Typedefs Enumerations Pages
modstate.cpp
1 
2 #include <string>
3 #include <sstream>
4 
5 #include "flame/base.h"
6 #include "pyflame.h"
7 
8 #define NO_IMPORT_ARRAY
9 #define PY_ARRAY_UNIQUE_SYMBOL FLAME_PyArray_API
10 #include <numpy/ndarrayobject.h>
11 
12 #if SIZE_MAX==NPY_MAX_UINT32
13 #define NPY_SIZE_T NPY_UINT32
14 #elif SIZE_MAX==NPY_MAX_UINT64
15 #define NPY_SIZE_T NPY_UINT64
16 #else
17 #error logic error with SIZE_MAX
18 #endif
19 
20 #define TRY PyState *state = (PyState*)raw; try
21 
22 namespace {
23 
24 struct PyState {
25  PyObject_HEAD
26  PyObject *dict, *weak; // __dict__ and __weakref__
27  PyObject *attrs; // lookup name to attribute index (for StateBase)
28  StateBase *state;
29 };
30 
31 static
32 int PyState_traverse(PyObject *raw, visitproc visit, void *arg)
33 {
34  PyState *state = (PyState*)raw;
35  Py_VISIT(state->attrs);
36  Py_VISIT(state->dict);
37  return 0;
38 }
39 
40 static
41 int PyState_clear(PyObject *raw)
42 {
43  PyState *state = (PyState*)raw;
44  Py_CLEAR(state->dict);
45  Py_CLEAR(state->attrs);
46  return 0;
47 }
48 
49 static
50 void PyState_free(PyObject *raw)
51 {
52  TRY {
53  std::auto_ptr<StateBase> S(state->state);
54  state->state = NULL;
55 
56  if(state->weak)
57  PyObject_ClearWeakRefs(raw);
58 
59  PyState_clear(raw);
60 
61  Py_TYPE(raw)->tp_free(raw);
62  } CATCH2V(std::exception, RuntimeError)
63 }
64 
65 static
66 PyObject *PyState_getattro(PyObject *raw, PyObject *attr)
67 {
68  TRY {
69  PyObject *idx = PyDict_GetItem(state->attrs, attr);
70  if(!idx) {
71  return PyObject_GenericGetAttr(raw, attr);
72  }
73  int i = PyInt_AsLong(idx);
74 
75 
77 
78  if(!state->state->getArray(i, info))
79  return PyErr_Format(PyExc_RuntimeError, "invalid attribute name (sub-class forgot %d)", i);
80 
81  if(info.ndim==0) { // Scalar
82  switch(info.type) {
83  case StateBase::ArrayInfo::Double:
84  return PyFloat_FromDouble(*(double*)info.ptr);
85  case StateBase::ArrayInfo::Sizet:
86  return PyLong_FromSize_t(*(size_t*)info.ptr);
87  }
88  return PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
89  }
90 
91  int pytype;
92  switch(info.type) {
93  case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE; break;
94  case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T; break;
95  default:
96  return PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
97  }
98 
99  npy_intp dims[StateBase::ArrayInfo::maxdims];
100  std::copy(info.dim,
101  info.dim+StateBase::ArrayInfo::maxdims,
102  dims);
103 
104  // Alloc new array and copy in
105 
106  PyRef<PyArrayObject> obj(PyArray_SimpleNew(info.ndim, dims, pytype));
107 
108  // pull parts from PyArray into ArrayInfo so we can use ArrayInfo::raw() to access
109  StateBase::ArrayInfo pyinfo;
110  pyinfo.ptr = PyArray_BYTES(obj.py());
111  pyinfo.ndim= PyArray_NDIM(obj.get());
112  std::copy(PyArray_DIMS(obj.get()),
113  PyArray_DIMS(obj.get())+pyinfo.ndim,
114  pyinfo.dim);
115  std::copy(PyArray_STRIDES(obj.get()),
116  PyArray_STRIDES(obj.get())+pyinfo.ndim,
117  pyinfo.stride);
118 
120 
121  for(; !idxiter.done; idxiter.next()) {
122  void *dest = pyinfo.raw(idxiter.index);
123  const void *src = info .raw(idxiter.index);
124 
125  switch(info.type) {
126  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
127  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
128  }
129  }
130 
131  return obj.releasePy();
132  } CATCH()
133 }
134 
135 static
136 int PyState_setattro(PyObject *raw, PyObject *attr, PyObject *val)
137 {
138  TRY {
139  PyObject *idx = PyDict_GetItem(state->attrs, attr);
140  if(!idx)
141  return PyObject_GenericSetAttr(raw, attr, val);
142  int i = PyInt_AsLong(idx);
143 
145 
146  if(!state->state->getArray(i, info)) {
147  PyErr_Format(PyExc_RuntimeError, "invalid attribute name (sub-class forgot %d)", i);
148  return -1;
149  }
150 
151  if(info.ndim==0) {
152  // Scalar (use python primative types)
153 
154  switch(info.type) {
155  case StateBase::ArrayInfo::Double: {
156  double *dest = (double*)info.ptr;
157  if(PyFloat_Check(val))
158  *dest = PyFloat_AsDouble(val);
159  else if(PyLong_Check(val))
160  *dest = PyLong_AsDouble(val);
161  else if(PyInt_Check(val))
162  *dest = PyInt_AsLong(val);
163  else
164  PyErr_Format(PyExc_ValueError, "Can't assign to double field");
165  }
166  break;
167  case StateBase::ArrayInfo::Sizet: {
168  size_t *dest = (size_t*)info.ptr;
169  if(PyFloat_Check(val))
170  *dest = PyFloat_AsDouble(val);
171  else if(PyLong_Check(val))
172  *dest = PyLong_AsUnsignedLongLong(val);
173  else if(PyInt_Check(val))
174  *dest = PyInt_AsLong(val);
175  else
176  PyErr_Format(PyExc_ValueError, "Can't assign to double field");
177  }
178  break;
179  default:
180  PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
181  }
182 
183  return PyErr_Occurred() ? -1 : 0;
184  }
185  // array (use numpy)
186 
187  int pytype;
188  switch(info.type) {
189  case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE; break;
190  case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T; break;
191  default:
192  PyErr_Format(PyExc_TypeError, "unsupported type code %d", info.type);
193  return -1;
194  }
195 
196  // ValueError: object too deep for desired array
197  // means assignment with wrong cardinality
198  PyRef<PyArrayObject> arr(PyArray_FromObject(val, pytype, info.ndim, info.ndim));
199 
200  if(info.ndim!=(size_t)PyArray_NDIM(arr.py())) {
201  PyErr_Format(PyExc_ValueError, "cardinality don't match");
202  return -1;
203  } else if(!std::equal(info.dim, info.dim+info.ndim,
204  PyArray_DIMS(arr.py()))) {
205  PyErr_Format(PyExc_ValueError, "shape does not match don't match");
206  return -1;
207  }
208 
209  // pull parts from PyArray into ArrayInfo so we can use ArrayInfo::raw() to access
210  StateBase::ArrayInfo pyinfo;
211  pyinfo.ptr = PyArray_BYTES(arr.py());
212  pyinfo.ndim= PyArray_NDIM(arr.get());
213  std::copy(PyArray_DIMS(arr.get()),
214  PyArray_DIMS(arr.get())+pyinfo.ndim,
215  pyinfo.dim);
216  std::copy(PyArray_STRIDES(arr.get()),
217  PyArray_STRIDES(arr.get())+pyinfo.ndim,
218  pyinfo.stride);
219 
221 
222  for(; !idxiter.done; idxiter.next()) {
223  const void *src = pyinfo .raw(idxiter.index);
224  void *dest = info.raw(idxiter.index);
225 
226  switch(info.type) {
227  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
228  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
229  }
230  }
231 
232 
233  if(info.ndim==1) {
234  for(size_t i=0; i<info.dim[0]; i++) {
235  const void *src = PyArray_GETPTR1(arr.py(), i);
236  void *dest = info.raw(&i);
237  switch(info.type) {
238  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
239  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
240  }
241  }
242  } else if(info.ndim==2) {
243  size_t idx[2];
244  for(idx[0]=0; idx[0]<info.dim[0]; idx[0]++) {
245  for(idx[1]=0; idx[1]<info.dim[1]; idx[1]++) {
246  const void *src = PyArray_GETPTR2(arr.py(), idx[0], idx[1]);
247  void *dest = info.raw(idx);
248  switch(info.type) {
249  case StateBase::ArrayInfo::Double: *(double*)dest = *(double*)src; break;
250  case StateBase::ArrayInfo::Sizet: *(size_t*)dest = *(size_t*)src; break;
251  }
252  }
253  }
254  }
255 
256  return 0;
257  } CATCH3(std::exception, RuntimeError, -1)
258 }
259 
260 static
261 PyObject* PyState_str(PyObject *raw)
262 {
263  TRY {
264  std::ostringstream strm;
265  state->state->show(strm, 0);
266  return PyString_FromString(strm.str().c_str());
267  } CATCH()
268 }
269 
270 static
271 PyObject* PyState_iter(PyObject *raw)
272 {
273  TRY {
274  return PyObject_GetIter(state->attrs);
275  }CATCH()
276 }
277 
278 static
279 Py_ssize_t PyState_len(PyObject *raw)
280 {
281  TRY{
282  return PyObject_Length(state->attrs);
283  }CATCH1(-1)
284 }
285 
286 static PySequenceMethods PyState_seq = {
287  &PyState_len
288 };
289 
290 static
291 PyObject* PyState_clone(PyObject *raw, PyObject *unused)
292 {
293  TRY {
294  std::auto_ptr<StateBase> newstate(state->state->clone());
295 
296  PyObject *ret = wrapstate(newstate.get());
297  newstate.release();
298  return ret;
299  } CATCH()
300 }
301 
302 static
303 PyObject* PyState_show(PyObject *raw, PyObject *args, PyObject *kws)
304 {
305  TRY {
306  unsigned long level = 1;
307  const char *names[] = {"level", NULL};
308  if(!PyArg_ParseTupleAndKeywords(args, kws, "|k", (char**)names, &level))
309  return NULL;
310 
311  std::ostringstream strm;
312  state->state->show(strm, level);
313  return PyString_FromString(strm.str().c_str());
314  } CATCH()
315 }
316 
317 static PyMethodDef PyState_methods[] = {
318  {"clone", (PyCFunction)&PyState_clone, METH_NOARGS,
319  "clone()\n\n"
320  "Returns a new State instance which is a copy of this one"
321  },
322  {"show", (PyCFunction)&PyState_show, METH_VARARGS|METH_KEYWORDS,
323  "show(level=1)"
324  },
325  {NULL, NULL, 0, NULL}
326 };
327 
328 static PyTypeObject PyStateType = {
329 #if PY_MAJOR_VERSION >= 3
330  PyVarObject_HEAD_INIT(NULL, 0)
331 #else
332  PyObject_HEAD_INIT(NULL)
333  0,
334 #endif
335  "flame._internal.State",
336  sizeof(PyState),
337 };
338 
339 } // namespace
340 
341 PyObject* wrapstate(StateBase* b)
342 {
343  try {
344 
345  PyRef<PyState> state(PyStateType.tp_alloc(&PyStateType, 0));
346 
347  state->state = b;
348  state->attrs = state->weak = state->dict = 0;
349 
350  state->attrs = PyDict_New();
351  if(!state->attrs)
352  return NULL;
353 
354  for(unsigned i=0; true; i++)
355  {
357 
358  if(!b->getArray(i, info))
359  break;
360 
361  bool skip = info.ndim>3;
362  switch(info.type) {
363  case StateBase::ArrayInfo::Double:
364  case StateBase::ArrayInfo::Sizet:
365  break;
366  default:
367  skip = true;
368  }
369 
370  if(skip) continue;
371 
372  PyRef<> name(PyInt_FromLong(i));
373  if(PyDict_SetItemString(state->attrs, info.name, name.py()))
374  throw std::runtime_error("Failed to insert into Dict");
375 
376  }
377 
378  return state.releasePy();
379  } CATCH()
380 }
381 
382 
383 StateBase* unwrapstate(PyObject* raw)
384 {
385  if(!PyObject_TypeCheck(raw, &PyStateType))
386  throw std::invalid_argument("Argument is not a State");
387  PyState *state = (PyState*)raw;
388  return state->state;
389 }
390 
391 static const char pymdoc[] =
392  "The interface to a sub-class of C++ StateBase.\n"
393  "Can't be constructed from python, see Machine.allocState()\n"
394  "\n"
395  "Provides access to some C++ member variables via the Machine::getArray() interface.\n"
396  ;
397 
398 int registerModState(PyObject *mod)
399 {
400  PyStateType.tp_doc = pymdoc;
401 
402  PyStateType.tp_str = &PyState_str;
403  PyStateType.tp_repr = &PyState_str;
404  PyStateType.tp_dealloc = &PyState_free;
405 
406  PyStateType.tp_iter = &PyState_iter;
407  PyStateType.tp_as_sequence = &PyState_seq;
408 
409  PyStateType.tp_weaklistoffset = offsetof(PyState, weak);
410  PyStateType.tp_traverse = &PyState_traverse;
411  PyStateType.tp_clear = &PyState_clear;
412 
413  PyStateType.tp_dictoffset = offsetof(PyState, dict);
414  PyStateType.tp_getattro = &PyState_getattro;
415  PyStateType.tp_setattro = &PyState_setattro;
416 
417  PyStateType.tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC;
418  PyStateType.tp_methods = PyState_methods;
419 
420  if(PyType_Ready(&PyStateType))
421  return -1;
422 
423  Py_INCREF(&PyStateType);
424  if(PyModule_AddObject(mod, "State", (PyObject*)&PyStateType)) {
425  Py_DECREF(&PyStateType);
426  return -1;
427  }
428 
429  return 0;
430 }
The abstract base class for all simulation state objects.
Definition: base.h:28
Used with StateBase::getArray() to describe a single parameter.
Definition: base.h:48
Helper to step through the indicies of an Nd array.
Definition: util.h:111
size_t dim[maxdims]
Array dimensions in elements.
Definition: base.h:66
size_t stride[maxdims]
Array strides in bytes.
Definition: base.h:68
virtual bool getArray(unsigned index, ArrayInfo &Info)
Introspect named parameter of the derived class.
Definition: base.cpp:35
const char * name
The parameter name.
Definition: base.h:52
unsigned ndim
Definition: base.h:64