-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ddf9b6b
commit ce91bdf
Showing
1 changed file
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Hist Design Prototype" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This is `fill` method in python loop:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import numba as nb\n", | ||
"from hist import Hist\n", | ||
"from hist.axis import Regular\n", | ||
"\n", | ||
"# assets\n", | ||
"array = np.random.randn(\n", | ||
" 10000,\n", | ||
")\n", | ||
"h = Hist.new.Reg(100, -3, 3, name=\"x\", label=\"x-axis\").Double()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# python fill\n", | ||
"# h.fill(array)\n", | ||
"# h" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Numba: Hist" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To extend the Numba, we first need to create a Hist type `HistType` for `Hist`, and then teach Numba about our type inference additions:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from numba import types\n", | ||
"from numba.extending import typeof_impl, as_numba_type, type_callable\n", | ||
"\n", | ||
"# create Numba type\n", | ||
"class RegularType(types.Type):\n", | ||
" def __init__(self):\n", | ||
" super().__init__(name=\"Regular\")\n", | ||
"\n", | ||
"\n", | ||
"regular_type = RegularType()\n", | ||
"\n", | ||
"# infer values\n", | ||
"@typeof_impl.register(Regular)\n", | ||
"def typeof_index(val, c):\n", | ||
" return regular_type\n", | ||
"\n", | ||
"\n", | ||
"# infer annotations\n", | ||
"as_numba_type.register(Regular, regular_type)\n", | ||
"\n", | ||
"# infer operations\n", | ||
"@type_callable(Regular)\n", | ||
"def type_regular(context):\n", | ||
" def typer(bins, lo, hi):\n", | ||
" if (\n", | ||
" isinstance(bins, types.Integer)\n", | ||
" and isinstance(lo, types.Float)\n", | ||
" and isinstance(hi, types.Float)\n", | ||
" ):\n", | ||
" return regular_type\n", | ||
"\n", | ||
" return typer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We also need to teach Numba how to actually generate native representation for the new operations:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from numba.core import cgutils\n", | ||
"from numba.extending import (\n", | ||
" models,\n", | ||
" register_model,\n", | ||
" make_attribute_wrapper,\n", | ||
" overload_attribute,\n", | ||
" lower_builtin,\n", | ||
" box,\n", | ||
" unbox,\n", | ||
" NativeValue,\n", | ||
")\n", | ||
"\n", | ||
"# define data model\n", | ||
"@register_model(RegularType)\n", | ||
"class RegularModel(models.StructModel):\n", | ||
" def __init__(self, dmm, fe_type):\n", | ||
" members = [\n", | ||
" (\"bins\", types.int32),\n", | ||
" (\"lo\", types.float64),\n", | ||
" (\"hi\", types.float64),\n", | ||
" ]\n", | ||
" models.StructModel.__init__(self, dmm, fe_type, members)\n", | ||
"\n", | ||
"\n", | ||
"# expose attributes, porperties and constructors\n", | ||
"make_attribute_wrapper(RegularType, \"bins\", \"bins\")\n", | ||
"make_attribute_wrapper(RegularType, \"lo\", \"lo\")\n", | ||
"make_attribute_wrapper(RegularType, \"hi\", \"hi\")\n", | ||
"\n", | ||
"\n", | ||
"@overload_attribute(RegularType, \"width\")\n", | ||
"def get_width(reg):\n", | ||
" def getter(reg):\n", | ||
" return (reg.hi - reg.lo) / reg.bins\n", | ||
"\n", | ||
" return getter\n", | ||
"\n", | ||
"\n", | ||
"@lower_builtin(Regular, types.Integer, types.Float, types.Float)\n", | ||
"def impl_reg(context, builder, sig, args):\n", | ||
" typ = sig.return_type\n", | ||
" lo, hi, bins = args\n", | ||
" reg = cgutils.create_struct_proxy(typ)(context, builder)\n", | ||
" reg.lo = lo\n", | ||
" reg.hi = hi\n", | ||
" reg.bins = bins\n", | ||
" return reg._getvalue()\n", | ||
"\n", | ||
"\n", | ||
"# unbox and box\n", | ||
"@unbox(RegularType)\n", | ||
"def unbox_reg(typ, obj, c):\n", | ||
" bins_obj = c.pyapi.object_getattr_string(obj, \"bins\")\n", | ||
" lo_obj = c.pyapi.object_getattr_string(obj, \"lo\")\n", | ||
" hi_obj = c.pyapi.object_getattr_string(obj, \"hi\")\n", | ||
" reg = cgutils.create_struct_proxy(typ)(c.context, c.builder)\n", | ||
" reg.bins = c.pyapi.float_as_double(bins_obj)\n", | ||
" reg.lo = c.pyapi.float_as_double(lo_obj)\n", | ||
" reg.hi = c.pyapi.float_as_double(hi_obj)\n", | ||
" c.pyapi.decref(bins_obj)\n", | ||
" c.pyapi.decref(lo_obj)\n", | ||
" c.pyapi.decref(hi_obj)\n", | ||
" is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())\n", | ||
" return NativeValue(reg._getvalue(), is_error=is_error)\n", | ||
"\n", | ||
"\n", | ||
"@box(RegularType)\n", | ||
"def box_reg(typ, val, c):\n", | ||
" reg = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)\n", | ||
" bins_obj = c.pyapi.float_from_double(reg.bins)\n", | ||
" lo_obj = c.pyapi.float_from_double(reg.lo)\n", | ||
" hi_obj = c.pyapi.float_from_double(reg.hi)\n", | ||
" class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Regular))\n", | ||
" res = c.pyapi.call_function_objargs(class_obj, (bins_obj, lo_obj, hi_obj))\n", | ||
" c.pyapi.decref(bins_obj)\n", | ||
" c.pyapi.decref(lo_obj)\n", | ||
" c.pyapi.decref(hi_obj)\n", | ||
" c.pyapi.decref(class_obj)\n", | ||
" return res" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"<ipython-input-5-4658abf95da1>:1: NumbaWarning: \n", | ||
"Compilation is falling back to object mode WITH looplifting enabled because Function nb_create_reg failed at nopython mode lowering due to: Invalid store of i64 to double in <__main__.RegularModel object at 0x183567040> (trying to write member #1)\n", | ||
"\n", | ||
"File \"<ipython-input-5-4658abf95da1>\", line 3:\n", | ||
"def nb_create_reg():\n", | ||
" return Regular(50, -5., 5.,)\n", | ||
" ^\n", | ||
"\n", | ||
"During: lowering \"$10call_function.4 = call $2load_global.0($const4.1, $const6.2, $const8.3, func=$2load_global.0, args=[Var($const4.1, <ipython-input-5-4658abf95da1>:3), Var($const6.2, <ipython-input-5-4658abf95da1>:3), Var($const8.3, <ipython-input-5-4658abf95da1>:3)], kws=(), vararg=None)\" at <ipython-input-5-4658abf95da1> (3)\n", | ||
" @nb.jit\n", | ||
"/Users/ninolau/anaconda3/envs/hist/lib/python3.9/site-packages/numba/core/object_mode_passes.py:151: NumbaWarning: Function \"nb_create_reg\" was compiled in object mode without forceobj=True.\n", | ||
"\n", | ||
"File \"<ipython-input-5-4658abf95da1>\", line 2:\n", | ||
"@nb.jit\n", | ||
"def nb_create_reg():\n", | ||
"^\n", | ||
"\n", | ||
" warnings.warn(errors.NumbaWarning(warn_msg,\n", | ||
"/Users/ninolau/anaconda3/envs/hist/lib/python3.9/site-packages/numba/core/object_mode_passes.py:161: NumbaDeprecationWarning: \n", | ||
"Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.\n", | ||
"\n", | ||
"For more information visit https://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit\n", | ||
"\n", | ||
"File \"<ipython-input-5-4658abf95da1>\", line 2:\n", | ||
"@nb.jit\n", | ||
"def nb_create_reg():\n", | ||
"^\n", | ||
"\n", | ||
" warnings.warn(errors.NumbaDeprecationWarning(msg,\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Regular(50, -5, 5)" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"@nb.jit\n", | ||
"def nb_create_reg():\n", | ||
" return Regular(\n", | ||
" 50,\n", | ||
" -5.0,\n", | ||
" 5.0,\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"nb_create_reg()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# def nb_hist_property(h):\n", | ||
"# print(h.lo)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Numba fill\n", | ||
"# nb_fill(h, array)\n", | ||
"# h" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "hist", | ||
"language": "python", | ||
"name": "hist" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |