Skip to content

Commit

Permalink
Update HistNumbaFill.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
LovelyBuggies authored and henryiii committed Sep 24, 2021
1 parent ce91bdf commit 00c1d2c
Showing 1 changed file with 63 additions and 136 deletions.
199 changes: 63 additions & 136 deletions docs/examples/HistNumbaFill.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import numba as nb\n",
"from hist import Hist\n",
"from hist.axis import Regular\n",
"from hist import axis\n",
"\n",
"# assets\n",
"array = np.random.randn(\n",
Expand All @@ -34,13 +34,24 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# python fill\n",
"# h.fill(array)\n",
"# h"
"h.fill(array)\n",
"h"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import hist\n",
"\n",
"isinstance(h.axes[0], hist.axis.Regular)"
]
},
{
Expand All @@ -59,42 +70,39 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from numba import types\n",
"from numba.extending import typeof_impl, as_numba_type, type_callable\n",
"\n",
"# create Numba type\n",
"class RegularType(types.Type):\n",
"class HistType(types.Type):\n",
" def __init__(self):\n",
" super().__init__(name=\"Regular\")\n",
" super().__init__(name=\"Hist\")\n",
"\n",
"\n",
"regular_type = RegularType()\n",
"hist_type = HistType()\n",
"\n",
"# infer values\n",
"@typeof_impl.register(Regular)\n",
"@typeof_impl.register(Hist)\n",
"def typeof_index(val, c):\n",
" return regular_type\n",
" return hist_type\n",
"\n",
"\n",
"# infer annotations\n",
"as_numba_type.register(Regular, regular_type)\n",
"as_numba_type.register(Hist, hist_type)\n",
"\n",
"# infer operations\n",
"@type_callable(Regular)\n",
"def type_regular(context):\n",
" def typer(bins, lo, hi):\n",
" if (\n",
" isinstance(bins, types.Integer)\n",
" and isinstance(lo, types.Float)\n",
" and isinstance(hi, types.Float)\n",
" ):\n",
" return regular_type\n",
"\n",
" return typer"
"@type_callable(Hist)\n",
"def type_hist(context):\n",
" def typer(axes):\n",
" for ax in axes:\n",
" # if not (isinstance(ax, types of axis)):\n",
" return typer\n",
"\n",
" return hist_type"
]
},
{
Expand All @@ -106,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -123,151 +131,70 @@
")\n",
"\n",
"# define data model\n",
"@register_model(RegularType)\n",
"class RegularModel(models.StructModel):\n",
"@register_model(HistType)\n",
"class HistModel(models.StructModel):\n",
" def __init__(self, dmm, fe_type):\n",
" members = [\n",
" (\"bins\", types.int32),\n",
" (\"lo\", types.float64),\n",
" (\"hi\", types.float64),\n",
" # (\"axes\", types of list of axis),\n",
" ]\n",
" models.StructModel.__init__(self, dmm, fe_type, members)\n",
"\n",
"\n",
"# expose attributes, porperties and constructors\n",
"make_attribute_wrapper(RegularType, \"bins\", \"bins\")\n",
"make_attribute_wrapper(RegularType, \"lo\", \"lo\")\n",
"make_attribute_wrapper(RegularType, \"hi\", \"hi\")\n",
"\n",
"\n",
"@overload_attribute(RegularType, \"width\")\n",
"def get_width(reg):\n",
" def getter(reg):\n",
" return (reg.hi - reg.lo) / reg.bins\n",
"\n",
" return getter\n",
"make_attribute_wrapper(HistType, \"axes\", \"axes\")\n",
"\n",
"\n",
"@lower_builtin(Regular, types.Integer, types.Float, types.Float)\n",
"def impl_reg(context, builder, sig, args):\n",
"def impl_h(context, builder, sig, args):\n",
" axes = args\n",
" typ = sig.return_type\n",
" lo, hi, bins = args\n",
" reg = cgutils.create_struct_proxy(typ)(context, builder)\n",
" reg.lo = lo\n",
" reg.hi = hi\n",
" reg.bins = bins\n",
" h = cgutils.create_struct_proxy(typ)(context, builder)\n",
" h.axes = axes\n",
" return reg._getvalue()\n",
"\n",
"\n",
"# unbox and box\n",
"@unbox(RegularType)\n",
"def unbox_reg(typ, obj, c):\n",
" bins_obj = c.pyapi.object_getattr_string(obj, \"bins\")\n",
" lo_obj = c.pyapi.object_getattr_string(obj, \"lo\")\n",
" hi_obj = c.pyapi.object_getattr_string(obj, \"hi\")\n",
" reg = cgutils.create_struct_proxy(typ)(c.context, c.builder)\n",
" reg.bins = c.pyapi.float_as_double(bins_obj)\n",
" reg.lo = c.pyapi.float_as_double(lo_obj)\n",
" reg.hi = c.pyapi.float_as_double(hi_obj)\n",
" c.pyapi.decref(bins_obj)\n",
" c.pyapi.decref(lo_obj)\n",
" c.pyapi.decref(hi_obj)\n",
"@unbox(HistType)\n",
"def unbox_h(typ, obj, c):\n",
" axes_obj = c.pyapi.object_getattr_string(obj, \"axes\")\n",
" h = cgutils.create_struct_proxy(typ)(c.context, c.builder)\n",
" # h.axes = c.pyapi.float_as_double(axes_obj)\n",
" c.pyapi.decref(axes_obj)\n",
" is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())\n",
" return NativeValue(reg._getvalue(), is_error=is_error)\n",
" return NativeValue(h._getvalue(), is_error=is_error)\n",
"\n",
"\n",
"@box(RegularType)\n",
"def box_reg(typ, val, c):\n",
" reg = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)\n",
" bins_obj = c.pyapi.float_from_double(reg.bins)\n",
" lo_obj = c.pyapi.float_from_double(reg.lo)\n",
" hi_obj = c.pyapi.float_from_double(reg.hi)\n",
" class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Regular))\n",
" res = c.pyapi.call_function_objargs(class_obj, (bins_obj, lo_obj, hi_obj))\n",
" c.pyapi.decref(bins_obj)\n",
" c.pyapi.decref(lo_obj)\n",
" c.pyapi.decref(hi_obj)\n",
"@box(HistType)\n",
"def box_h(typ, val, c):\n",
" h = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)\n",
" axes_obj = c.pyapi.float_from_double(h.axes)\n",
" class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Hist))\n",
" res = c.pyapi.call_function_objargs(class_obj, (axes_obj))\n",
" c.pyapi.decref(axes_obj)\n",
" c.pyapi.decref(class_obj)\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-5-4658abf95da1>:1: NumbaWarning: \n",
"Compilation is falling back to object mode WITH looplifting enabled because Function nb_create_reg failed at nopython mode lowering due to: Invalid store of i64 to double in <__main__.RegularModel object at 0x183567040> (trying to write member #1)\n",
"\n",
"File \"<ipython-input-5-4658abf95da1>\", line 3:\n",
"def nb_create_reg():\n",
" return Regular(50, -5., 5.,)\n",
" ^\n",
"\n",
"During: lowering \"$10call_function.4 = call $2load_global.0($const4.1, $const6.2, $const8.3, func=$2load_global.0, args=[Var($const4.1, <ipython-input-5-4658abf95da1>:3), Var($const6.2, <ipython-input-5-4658abf95da1>:3), Var($const8.3, <ipython-input-5-4658abf95da1>:3)], kws=(), vararg=None)\" at <ipython-input-5-4658abf95da1> (3)\n",
" @nb.jit\n",
"/Users/ninolau/anaconda3/envs/hist/lib/python3.9/site-packages/numba/core/object_mode_passes.py:151: NumbaWarning: Function \"nb_create_reg\" was compiled in object mode without forceobj=True.\n",
"\n",
"File \"<ipython-input-5-4658abf95da1>\", line 2:\n",
"@nb.jit\n",
"def nb_create_reg():\n",
"^\n",
"\n",
" warnings.warn(errors.NumbaWarning(warn_msg,\n",
"/Users/ninolau/anaconda3/envs/hist/lib/python3.9/site-packages/numba/core/object_mode_passes.py:161: NumbaDeprecationWarning: \n",
"Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.\n",
"\n",
"For more information visit https://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit\n",
"\n",
"File \"<ipython-input-5-4658abf95da1>\", line 2:\n",
"@nb.jit\n",
"def nb_create_reg():\n",
"^\n",
"\n",
" warnings.warn(errors.NumbaDeprecationWarning(msg,\n"
]
},
{
"data": {
"text/plain": [
"Regular(50, -5, 5)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"reg_ax = axis.Regular(10, 0, 1)\n",
"\n",
"\n",
"@nb.jit\n",
"def nb_create_reg():\n",
" return Regular(\n",
" 50,\n",
" -5.0,\n",
" 5.0,\n",
" )\n",
"def nb_create_Hist():\n",
" Hist(reg_ax)\n",
"\n",
"\n",
"nb_create_reg()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# def nb_hist_property(h):\n",
"# print(h.lo)"
"nb_create_Hist()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit 00c1d2c

Please sign in to comment.