Skip to content

Commit

Permalink
feat: Add B-field accessors to Python bindings
Browse files Browse the repository at this point in the history
As acts-project#3479 reveals, we don't currently have any clean, cache-aware ways of
accessing B-fields in Python code. In order to avoid hacks, this commit
adds the necessary bindings to allow us to cleanly access B-fields with
cache objects.
  • Loading branch information
stephenswat committed Aug 26, 2024
1 parent 2cfa5f7 commit 4aaac1f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 2 deletions.
7 changes: 7 additions & 0 deletions Examples/Python/src/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Acts/Geometry/GeometryContext.hpp"
#include "Acts/MagneticField/MagneticFieldContext.hpp"
#include "Acts/Plugins/Python/Utilities.hpp"
#include "Acts/Utilities/Any.hpp"
#include "Acts/Utilities/AxisFwd.hpp"
#include "Acts/Utilities/BinningData.hpp"
#include "Acts/Utilities/CalibrationContext.hpp"
Expand Down Expand Up @@ -42,6 +43,12 @@ void addContext(Context& ctx) {
.def(py::init<>());
}

void addAny(Context& ctx) {
auto& m = ctx.get("main");

py::class_<Acts::AnyBase<512>>(m, "AnyBase512").def(py::init<>());
}

void addUnits(Context& ctx) {
auto& m = ctx.get("main");
auto u = m.def_submodule("UnitConstants");
Expand Down
20 changes: 19 additions & 1 deletion Examples/Python/src/MagneticField.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,30 @@ using namespace pybind11::literals;

namespace Acts::Python {

/// @brief Get the value of a field, throwing an exception if the result is
/// invalid.
Acts::Vector3 getField(Acts::MagneticFieldProvider& self,
const Acts::Vector3& position,
Acts::MagneticFieldProvider::Cache& cache) {
if (Result<Vector3> res = self.getField(position, cache); !res.ok()) {
std::stringstream ss;

ss << "Field lookup failure with error: \"" << res.error() << "\"";

throw std::runtime_error{ss.str()};
} else {
return *res;
}
}

void addMagneticField(Context& ctx) {
auto [m, mex, prop] = ctx.get("main", "examples", "propagation");

py::class_<Acts::MagneticFieldProvider,
std::shared_ptr<Acts::MagneticFieldProvider>>(
m, "MagneticFieldProvider");
m, "MagneticFieldProvider")
.def("getField", &getField)
.def("makeCache", &Acts::MagneticFieldProvider::makeCache);

py::class_<Acts::InterpolatedMagneticField,
std::shared_ptr<Acts::InterpolatedMagneticField>>(
Expand Down
2 changes: 2 additions & 0 deletions Examples/Python/src/ModuleEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ using namespace Acts::Python;

namespace Acts::Python {
void addContext(Context& ctx);
void addAny(Context& ctx);
void addUnits(Context& ctx);
void addFramework(Context& ctx);
void addLogging(Context& ctx);
Expand Down Expand Up @@ -108,6 +109,7 @@ PYBIND11_MODULE(ActsPythonBindings, m) {
}

addContext(ctx);
addAny(ctx);
addUnits(ctx);
addFramework(ctx);
addLogging(ctx);
Expand Down
39 changes: 38 additions & 1 deletion Examples/Python/tests/test_magnetic_field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import random

import acts
import acts.examples
Expand All @@ -7,16 +8,52 @@


def test_null_bfield():
assert acts.NullBField()
nb = acts.NullBField()
assert nb

ct = acts.MagneticFieldContext()
assert ct

fc = nb.makeCache(ct)
assert fc

for i in range(100):
x = random.uniform(-10000.0, 10000.0)
y = random.uniform(-10000.0, 10000.0)
z = random.uniform(-10000.0, 10000.0)

rv = nb.getField(acts.Vector3(x, y, z), fc)

assert rv[0] == pytest.approx(0.0)
assert rv[1] == pytest.approx(0.0)
assert rv[2] == pytest.approx(0.0)


def test_constant_bfield():
with pytest.raises(TypeError):
acts.ConstantBField()

v = acts.Vector3(1, 2, 3)
cb = acts.ConstantBField(v)
assert cb

ct = acts.MagneticFieldContext()
assert ct

fc = cb.makeCache(ct)
assert fc

for i in range(100):
x = random.uniform(-10000.0, 10000.0)
y = random.uniform(-10000.0, 10000.0)
z = random.uniform(-10000.0, 10000.0)

rv = cb.getField(acts.Vector3(x, y, z), fc)

assert rv[0] == pytest.approx(1.0)
assert rv[1] == pytest.approx(2.0)
assert rv[2] == pytest.approx(3.0)


def test_solenoid(conf_const):
solenoid = conf_const(
Expand Down

0 comments on commit 4aaac1f

Please sign in to comment.