From 49029e0347a550ca15df06558fe5aad9ca80479a Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Mon, 3 Sep 2018 07:32:43 -0700 Subject: [PATCH] Don't steal reference from existing Python object (#553) * Don't steal reference from existing Python object closes #551 * Refactoring: add pyreturn function * Fix pystealref! in pyjlwrap_getattr * Fix pystealref! in pyjlwrap_iternext --- src/PyCall.jl | 10 +++++++++ src/callback.jl | 6 +++--- src/pyiterator.jl | 4 ++-- src/pytype.jl | 2 +- test/runtests.jl | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 6 deletions(-) diff --git a/src/PyCall.jl b/src/PyCall.jl index 8c137294..71bb65e4 100644 --- a/src/PyCall.jl +++ b/src/PyCall.jl @@ -133,6 +133,16 @@ function pystealref!(o::PyObject) return optr end +""" + pyreturn(x) :: PyPtr + +Prepare `PyPtr` from `x` for passing it to Python. If `x` is already +a `PyObject`, the refcount is incremented. Otherwise a `PyObject` +wrapping/converted from `x` is created. +""" +pyreturn(x::Any) = pystealref!(PyObject(x)) +pyreturn(x::PyObject) = pyincref_(x.o) + function Base.copy!(dest::PyObject, src::PyObject) pydecref(dest) dest.o = src.o diff --git a/src/callback.jl b/src/callback.jl index 0cd7152c..55d84936 100644 --- a/src/callback.jl +++ b/src/callback.jl @@ -25,7 +25,7 @@ function _pyjlwrap_call(f, args_::PyPtr, kw_::PyPtr) # we need to use invokelatest to get execution in newest world if kw_ == C_NULL - ret = PyObject(Base.invokelatest(f, jlargs...)) + ret = Base.invokelatest(f, jlargs...) else kw = PyDict{Symbol,PyObject}(pyincref(kw_)) kwargs = [ (k,julia_kwarg(f,k,v)) for (k,v) in kw ] @@ -34,10 +34,10 @@ function _pyjlwrap_call(f, args_::PyPtr, kw_::PyPtr) # use a closure over kwargs. see: # https://github.com/JuliaLang/julia/pull/22646 f_kw_closure() = f(jlargs...; kwargs...) - ret = PyObject(Core._apply_latest(f_kw_closure)) + ret = Core._apply_latest(f_kw_closure) end - return pystealref!(ret) + return pyreturn(ret) catch e pyraise(e) finally diff --git a/src/pyiterator.jl b/src/pyiterator.jl index 0e03e4b1..ccce277e 100644 --- a/src/pyiterator.jl +++ b/src/pyiterator.jl @@ -65,7 +65,7 @@ const jlWrapIteratorType = PyTypeObject() if !done(iter, state) item, state′ = next(iter, state) stateref[] = state′ # stores new state in the iterator object - return pystealref!(PyObject(item)) + return pyreturn(item) end catch e pyraise(e) @@ -80,7 +80,7 @@ else if iter_result !== nothing item, state = iter_result iter_result_ref[] = iterate(iter, state) - return pystealref!(PyObject(item)) + return pyreturn(item) end catch e pyraise(e) diff --git a/src/pytype.jl b/src/pytype.jl index cd1adc2c..a05d5f8c 100644 --- a/src/pytype.jl +++ b/src/pytype.jl @@ -377,7 +377,7 @@ function pyjlwrap_getattr(self_::PyPtr, attr__::PyPtr) else fidx = Base.fieldindex(typeof(f), Symbol(attr), false) if fidx != 0 - return pystealref!(PyObject(getfield(f, fidx))) + return pyreturn(getfield(f, fidx)) else return ccall(@pysym(:PyObject_GenericGetAttr), PyPtr, (PyPtr,PyPtr), self_, attr__) end diff --git a/test/runtests.jl b/test/runtests.jl index 4da2865f..6b3099d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -573,4 +573,56 @@ end @test (@pywith IgnoreError(true) error(); true) end +@testset "callback" begin + # Returning existing PyObject in Julia should not invalidate it. + # https://github.com/JuliaPy/PyCall.jl/pull/552 + anonymous = Module() + Base.eval( + anonymous, quote + using PyCall + obj = pyimport("sys") # get some PyObject + end) + py""" + ns = {} + def set(name): + ns[name] = $include_string($anonymous, name) + """ + py"set"("obj") + @test anonymous.obj != PyNULL() + + # Test above for pyjlwrap_getattr too: + anonymous = Module() + Base.eval( + anonymous, quote + using PyCall + struct S + x + end + obj = S(pyimport("sys")) + end) + py""" + ns = {} + def set(name): + ns[name] = $include_string($anonymous, name).x + """ + py"set"("obj") + @test anonymous.obj.x != PyNULL() + + # Test above for pyjlwrap_iternext too: + anonymous = Module() + Base.eval( + anonymous, quote + using PyCall + sys = pyimport("sys") + obj = (sys for _ in 1:1) + end) + py""" + ns = {} + def set(name): + ns[name] = list(iter($include_string($anonymous, name))) + """ + py"set"("obj") + @test anonymous.sys != PyNULL() +end + include("test_pyfncall.jl")