5 #include "flame/base.h"
8 #define NO_IMPORT_ARRAY
9 #define PY_ARRAY_UNIQUE_SYMBOL FLAME_PyArray_API
10 #include <numpy/ndarrayobject.h>
12 #if SIZE_MAX==NPY_MAX_UINT32
13 #define NPY_SIZE_T NPY_UINT32
14 #elif SIZE_MAX==NPY_MAX_UINT64
15 #define NPY_SIZE_T NPY_UINT64
17 #error logic error with SIZE_MAX
20 #define TRY PyState *state = (PyState*)raw; try
26 PyObject *dict, *weak;
32 int PyState_traverse(PyObject *raw, visitproc visit,
void *arg)
34 PyState *state = (PyState*)raw;
35 Py_VISIT(state->attrs);
36 Py_VISIT(state->dict);
41 int PyState_clear(PyObject *raw)
43 PyState *state = (PyState*)raw;
44 Py_CLEAR(state->dict);
45 Py_CLEAR(state->attrs);
50 void PyState_free(PyObject *raw)
53 std::auto_ptr<StateBase> S(state->state);
57 PyObject_ClearWeakRefs(raw);
61 Py_TYPE(raw)->tp_free(raw);
62 } CATCH2V(std::exception, RuntimeError)
66 PyObject *PyState_getattro(PyObject *raw, PyObject *attr)
69 PyObject *idx = PyDict_GetItem(state->attrs, attr);
71 return PyObject_GenericGetAttr(raw, attr);
73 int i = PyInt_AsLong(idx);
78 if(!state->state->getArray(i, info))
79 return PyErr_Format(PyExc_RuntimeError,
"invalid attribute name (sub-class forgot %d)", i);
83 case StateBase::ArrayInfo::Double:
84 return PyFloat_FromDouble(*(
double*)info.
ptr);
85 case StateBase::ArrayInfo::Sizet:
86 return PyLong_FromSize_t(*(
size_t*)info.
ptr);
88 return PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
93 case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE;
break;
94 case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T;
break;
96 return PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
99 npy_intp dims[StateBase::ArrayInfo::maxdims];
101 info.
dim+StateBase::ArrayInfo::maxdims,
106 PyRef<PyArrayObject> obj(PyArray_SimpleNew(info.
ndim, dims, pytype));
110 pyinfo.
ptr = PyArray_BYTES(obj.py());
111 pyinfo.
ndim= PyArray_NDIM(obj.get());
112 std::copy(PyArray_DIMS(obj.get()),
113 PyArray_DIMS(obj.get())+pyinfo.
ndim,
115 std::copy(PyArray_STRIDES(obj.get()),
116 PyArray_STRIDES(obj.get())+pyinfo.
ndim,
121 for(; !idxiter.done; idxiter.next()) {
122 void *dest = pyinfo.raw(idxiter.index);
123 const void *src = info .raw(idxiter.index);
126 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
127 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
131 return obj.releasePy();
136 int PyState_setattro(PyObject *raw, PyObject *attr, PyObject *val)
139 PyObject *idx = PyDict_GetItem(state->attrs, attr);
141 return PyObject_GenericSetAttr(raw, attr, val);
142 int i = PyInt_AsLong(idx);
146 if(!state->state->getArray(i, info)) {
147 PyErr_Format(PyExc_RuntimeError,
"invalid attribute name (sub-class forgot %d)", i);
155 case StateBase::ArrayInfo::Double: {
156 double *dest = (
double*)info.
ptr;
157 if(PyFloat_Check(val))
158 *dest = PyFloat_AsDouble(val);
159 else if(PyLong_Check(val))
160 *dest = PyLong_AsDouble(val);
161 else if(PyInt_Check(val))
162 *dest = PyInt_AsLong(val);
164 PyErr_Format(PyExc_ValueError,
"Can't assign to double field");
167 case StateBase::ArrayInfo::Sizet: {
168 size_t *dest = (
size_t*)info.
ptr;
169 if(PyFloat_Check(val))
170 *dest = PyFloat_AsDouble(val);
171 else if(PyLong_Check(val))
172 *dest = PyLong_AsUnsignedLongLong(val);
173 else if(PyInt_Check(val))
174 *dest = PyInt_AsLong(val);
176 PyErr_Format(PyExc_ValueError,
"Can't assign to double field");
180 PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
183 return PyErr_Occurred() ? -1 : 0;
189 case StateBase::ArrayInfo::Double: pytype = NPY_DOUBLE;
break;
190 case StateBase::ArrayInfo::Sizet: pytype = NPY_SIZE_T;
break;
192 PyErr_Format(PyExc_TypeError,
"unsupported type code %d", info.type);
198 PyRef<PyArrayObject> arr(PyArray_FromObject(val, pytype, info.
ndim, info.
ndim));
200 if(info.
ndim!=(
size_t)PyArray_NDIM(arr.py())) {
201 PyErr_Format(PyExc_ValueError,
"cardinality don't match");
203 }
else if(!std::equal(info.
dim, info.
dim+info.
ndim,
204 PyArray_DIMS(arr.py()))) {
205 PyErr_Format(PyExc_ValueError,
"shape does not match don't match");
211 pyinfo.
ptr = PyArray_BYTES(arr.py());
212 pyinfo.
ndim= PyArray_NDIM(arr.get());
213 std::copy(PyArray_DIMS(arr.get()),
214 PyArray_DIMS(arr.get())+pyinfo.
ndim,
216 std::copy(PyArray_STRIDES(arr.get()),
217 PyArray_STRIDES(arr.get())+pyinfo.
ndim,
222 for(; !idxiter.done; idxiter.next()) {
223 const void *src = pyinfo .raw(idxiter.index);
224 void *dest = info.raw(idxiter.index);
227 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
228 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
234 for(
size_t i=0; i<info.
dim[0]; i++) {
235 const void *src = PyArray_GETPTR1(arr.py(), i);
236 void *dest = info.raw(&i);
238 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
239 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
242 }
else if(info.
ndim==2) {
244 for(idx[0]=0; idx[0]<info.
dim[0]; idx[0]++) {
245 for(idx[1]=0; idx[1]<info.
dim[1]; idx[1]++) {
246 const void *src = PyArray_GETPTR2(arr.py(), idx[0], idx[1]);
247 void *dest = info.raw(idx);
249 case StateBase::ArrayInfo::Double: *(
double*)dest = *(
double*)src;
break;
250 case StateBase::ArrayInfo::Sizet: *(
size_t*)dest = *(
size_t*)src;
break;
257 } CATCH3(std::exception, RuntimeError, -1)
261 PyObject* PyState_str(PyObject *raw)
264 std::ostringstream strm;
265 state->state->show(strm, 0);
266 return PyString_FromString(strm.str().c_str());
271 PyObject* PyState_iter(PyObject *raw)
274 return PyObject_GetIter(state->attrs);
279 Py_ssize_t PyState_len(PyObject *raw)
282 return PyObject_Length(state->attrs);
286 static PySequenceMethods PyState_seq = {
291 PyObject* PyState_clone(PyObject *raw, PyObject *unused)
294 std::auto_ptr<StateBase> newstate(state->state->clone());
296 PyObject *ret = wrapstate(newstate.get());
303 PyObject* PyState_show(PyObject *raw, PyObject *args, PyObject *kws)
306 unsigned long level = 1;
307 const char *names[] = {
"level", NULL};
308 if(!PyArg_ParseTupleAndKeywords(args, kws,
"|k", (
char**)names, &level))
311 std::ostringstream strm;
312 state->state->show(strm, level);
313 return PyString_FromString(strm.str().c_str());
317 static PyMethodDef PyState_methods[] = {
318 {
"clone", (PyCFunction)&PyState_clone, METH_NOARGS,
320 "Returns a new State instance which is a copy of this one"
322 {
"show", (PyCFunction)&PyState_show, METH_VARARGS|METH_KEYWORDS,
325 {NULL, NULL, 0, NULL}
328 static PyTypeObject PyStateType = {
329 #if PY_MAJOR_VERSION >= 3
330 PyVarObject_HEAD_INIT(NULL, 0)
332 PyObject_HEAD_INIT(NULL)
335 "flame._internal.State",
345 PyRef<PyState> state(PyStateType.tp_alloc(&PyStateType, 0));
348 state->attrs = state->weak = state->dict = 0;
350 state->attrs = PyDict_New();
354 for(
unsigned i=0;
true; i++)
361 bool skip = info.
ndim>3;
363 case StateBase::ArrayInfo::Double:
364 case StateBase::ArrayInfo::Sizet:
372 PyRef<> name(PyInt_FromLong(i));
373 if(PyDict_SetItemString(state->attrs, info.
name, name.py()))
374 throw std::runtime_error(
"Failed to insert into Dict");
378 return state.releasePy();
385 if(!PyObject_TypeCheck(raw, &PyStateType))
386 throw std::invalid_argument(
"Argument is not a State");
387 PyState *state = (PyState*)raw;
391 static const char pymdoc[] =
392 "The interface to a sub-class of C++ StateBase.\n"
393 "Can't be constructed from python, see Machine.allocState()\n"
395 "Provides access to some C++ member variables via the Machine::getArray() interface.\n"
398 int registerModState(PyObject *mod)
400 PyStateType.tp_doc = pymdoc;
402 PyStateType.tp_str = &PyState_str;
403 PyStateType.tp_repr = &PyState_str;
404 PyStateType.tp_dealloc = &PyState_free;
406 PyStateType.tp_iter = &PyState_iter;
407 PyStateType.tp_as_sequence = &PyState_seq;
409 PyStateType.tp_weaklistoffset = offsetof(PyState, weak);
410 PyStateType.tp_traverse = &PyState_traverse;
411 PyStateType.tp_clear = &PyState_clear;
413 PyStateType.tp_dictoffset = offsetof(PyState, dict);
414 PyStateType.tp_getattro = &PyState_getattro;
415 PyStateType.tp_setattro = &PyState_setattro;
417 PyStateType.tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC;
418 PyStateType.tp_methods = PyState_methods;
420 if(PyType_Ready(&PyStateType))
423 Py_INCREF(&PyStateType);
424 if(PyModule_AddObject(mod,
"State", (PyObject*)&PyStateType)) {
425 Py_DECREF(&PyStateType);
The abstract base class for all simulation state objects.
Used with StateBase::getArray() to describe a single parameter.
Helper to step through the indicies of an Nd array.
size_t dim[maxdims]
Array dimensions in elements.
size_t stride[maxdims]
Array strides in bytes.
virtual bool getArray(unsigned index, ArrayInfo &Info)
Introspect named parameter of the derived class.
const char * name
The parameter name.