Skip to content

Commit

Permalink
fix garbage collection in inheritance cases
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 17, 2024
1 parent a32afdd commit 0d7d013
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 4 deletions.
54 changes: 52 additions & 2 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ impl PyMethodKind {
"__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)),
"__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)),
"__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)),
// Protocols implemented through traits
"__getattribute__" => {
PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GETATTRIBUTE__))
Expand Down Expand Up @@ -146,6 +145,7 @@ impl PyMethodKind {
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Clear),
// Not a proto
_ => PyMethodKind::Fn,
}
Expand All @@ -156,6 +156,7 @@ enum PyMethodProtoKind {
Slot(&'static SlotDef),
Call,
Traverse,
Clear,
SlotFragment(&'static SlotFragmentDef),
}

Expand Down Expand Up @@ -217,6 +218,9 @@ pub fn gen_py_method(
PyMethodProtoKind::Traverse => {
GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec, ctx)?)
}
PyMethodProtoKind::Clear => {
GeneratedPyMethod::Proto(impl_clear_slot(cls, spec, ctx)?)
}
PyMethodProtoKind::SlotFragment(slot_fragment_def) => {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec, ctx)?;
GeneratedPyMethod::SlotTraitImpl(method.method_name, proto)
Expand Down Expand Up @@ -462,7 +466,7 @@ fn impl_traverse_slot(
visit: #pyo3_path::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int {
#pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg)
#pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg, #cls::__pymethod_traverse__)
}
};
let slot_def = quote! {
Expand All @@ -477,6 +481,52 @@ fn impl_traverse_slot(
})
}

fn impl_clear_slot(cls: &syn::Type, spec: &FnSpec<'_>, ctx: &Ctx) -> syn::Result<MethodAndSlotDef> {
let Ctx { pyo3_path, .. } = ctx;
let (py_arg, args) = split_off_python_arg(&spec.signature.arguments);
let self_type = match &spec.tp {
FnType::Fn(self_type) => self_type,
_ => bail_spanned!(spec.name.span() => "expected instance method for `__clear__` function"),
};
let mut holders = Holders::new();
let slf = self_type.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx);

if let [arg, ..] = args {
bail_spanned!(arg.ty().span() => "`__clear__` function expected to have no arguments");
}

let name = &spec.name;
let holders = holders.init_holders(ctx);
let fncall = if py_arg.is_some() {
quote!(#cls::#name(#slf, py))
} else {
quote!(#cls::#name(#slf))
};

let associated_method = quote! {
pub unsafe extern "C" fn __pymethod_clear__(
_slf: *mut #pyo3_path::ffi::PyObject,
) -> ::std::os::raw::c_int {
#pyo3_path::impl_::pymethods::_call_clear(_slf, |py, _slf| {
#holders
let result = #fncall;
let result = #pyo3_path::impl_::wrap::converter(&result).wrap(result)?;
Ok(result)
}, #cls::__pymethod_clear__)
}
};
let slot_def = quote! {
#pyo3_path::ffi::PyType_Slot {
slot: #pyo3_path::ffi::Py_tp_clear,
pfunc: #cls::__pymethod_clear__ as #pyo3_path::ffi::inquiry as _
}
};
Ok(MethodAndSlotDef {
associated_method,
slot_def,
})
}

fn impl_py_class_attribute(
cls: &syn::Type,
spec: &FnSpec<'_>,
Expand Down
119 changes: 119 additions & 0 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;

use super::trampoline;

/// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)]
#[repr(transparent)]
Expand Down Expand Up @@ -275,6 +277,7 @@ pub unsafe fn _call_traverse<T>(
impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>,
visit: ffi::visitproc,
arg: *mut c_void,
current_traverse: ffi::traverseproc,
) -> c_int
where
T: PyClass,
Expand All @@ -289,6 +292,11 @@ where
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
let lock = LockGIL::during_traverse();

let super_retval = call_super_traverse(slf, visit, arg, current_traverse);
if super_retval != 0 {
return super_retval;
}

// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
// traversal is running so no mutations can occur.
let class_object: &PyClassObject<T> = &*slf.cast();
Expand Down Expand Up @@ -328,6 +336,117 @@ where
retval
}

/// Call super-type traverse method, if necessary.
///
/// Adapted from https://github.com/cython/cython/blob/7acfb375fb54a033f021b0982a3cd40c34fb22ac/Cython/Utility/ExtensionTypes.c#L386
///
/// TODO: There are possible optimizations over looking up the base type in this way
/// - if the base type is known in this module, can potentially look it up directly in module state
/// (when we have it)
/// - if the base type is a Python builtin, can jut call the C function directly
/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to
/// tp_alloc where we solve this at compile time
unsafe fn call_super_traverse(
obj: *mut ffi::PyObject,
visit: ffi::visitproc,
arg: *mut c_void,
current_traverse: ffi::traverseproc,
) -> c_int {
let mut ty = ffi::Py_TYPE(obj);
let mut traverse: Option<ffi::traverseproc>;

// First find the current type by the current_traverse function
loop {
traverse = std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::Py_tp_traverse));
if traverse == Some(current_traverse) {
break;
}
ty = ffi::PyType_GetSlot(ty, ffi::Py_tp_base).cast();
if ty.is_null() {
// FIXME: return an error if current type not in the MRO? Should be impossible.
return 0;
}
}

// Get first base which has a different traverse function
while !ty.is_null() && traverse == Some(current_traverse) {
ty = ffi::PyType_GetSlot(ty, ffi::Py_tp_base).cast();
if ty.is_null() {
break;
}
traverse = std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::Py_tp_traverse));
}

// If we found a type with a different traverse function, call it
if let Some(traverse) = traverse {
return traverse(obj, visit, arg);
}

// FIXME same question as cython: what if the current type is not in the MRO?
return 0;
}

/// Calls an implementation of __clear__ for tp_clear
pub unsafe fn _call_clear(
slf: *mut ffi::PyObject,
impl_: for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<()>,
current_clear: ffi::inquiry,
) -> c_int {
let super_retval = call_super_clear(slf, current_clear);
if super_retval != 0 {
return super_retval;
}
trampoline::trampoline(move |py| {
impl_(py, slf)?;
Ok(0)
})
}

/// Call super-type traverse method, if necessary.
///
/// Adapted from https://github.com/cython/cython/blob/7acfb375fb54a033f021b0982a3cd40c34fb22ac/Cython/Utility/ExtensionTypes.c#L386
///
/// TODO: There are possible optimizations over looking up the base type in this way
/// - if the base type is known in this module, can potentially look it up directly in module state
/// (when we have it)
/// - if the base type is a Python builtin, can jut call the C function directly
/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to
/// tp_alloc where we solve this at compile time
unsafe fn call_super_clear(obj: *mut ffi::PyObject, current_clear: ffi::inquiry) -> c_int {
let mut ty = ffi::Py_TYPE(obj);
let mut clear: Option<ffi::inquiry>;

// First find the current type by the current_clear function
loop {
clear = std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::Py_tp_clear));
if clear == Some(current_clear) {
break;
}
ty = ffi::PyType_GetSlot(ty, ffi::Py_tp_base).cast();
if ty.is_null() {
// FIXME: return an error if current type not in the MRO? Should be impossible.
return 0;
}
}

// Get first base which has a different clear function
while !ty.is_null() && clear == Some(current_clear) {
ty = ffi::PyType_GetSlot(ty, ffi::Py_tp_base).cast();
if ty.is_null() {
break;
}
clear = std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::Py_tp_clear));
}

// If we found a type with a different clear function, call it
if let Some(clear) = clear {
return clear(obj);
}

// FIXME same question as cython: what if the current type is not in the MRO?
return 0;
}

// Autoref-based specialization for handling `__next__` returning `Option`

pub struct IterBaseTag;
Expand Down
25 changes: 23 additions & 2 deletions src/pyclass/create_type_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
assign_sequence_item_from_mapping, get_sequence_item_from_mapping, tp_dealloc,
tp_dealloc_with_gc, MaybeRuntimePyMethodDef, PyClassItemsIter,
},
pymethods::{Getter, PyGetterDef, PyMethodDefType, PySetterDef, Setter},
pymethods::{Getter, PyGetterDef, PyMethodDefType, PySetterDef, Setter, _call_clear},
trampoline::trampoline,
},
internal_tricks::ptr_from_ref,
Expand Down Expand Up @@ -432,7 +432,8 @@ impl PyTypeBuilder {
unsafe { self.push_slot(ffi::Py_tp_new, no_constructor_defined as *mut c_void) }
}

let tp_dealloc = if self.has_traverse || unsafe { ffi::PyType_IS_GC(self.tp_base) == 1 } {
let base_is_gc = unsafe { ffi::PyType_IS_GC(self.tp_base) == 1 };
let tp_dealloc = if self.has_traverse || base_is_gc {
self.tp_dealloc_with_gc
} else {
self.tp_dealloc
Expand All @@ -446,6 +447,22 @@ impl PyTypeBuilder {
)));
}

// If this type is a GC type, and the base also is, we may need to add
// `tp_traverse` / `tp_clear` implementations to call the base, if this type didn't
// define `__traverse__` or `__clear__`.
//
// This is because when Py_TPFLAGS_HAVE_GC is set, then `tp_traverse` and
// `tp_clear` are not inherited.
if ((self.class_flags & ffi::Py_TPFLAGS_HAVE_GC) != 0) && base_is_gc {
// If this assertion breaks, need to consider doing the same for __traverse__.
assert!(self.has_traverse); // Py_TPFLAGS_HAVE_GC is set when a `__traverse__` method is found

if !self.has_clear {
// Safety: This is the correct slot type for Py_tp_clear
unsafe { self.push_slot(ffi::Py_tp_clear, call_super_clear as *mut c_void) }
}
}

// For sequences, implement sq_length instead of mp_length
if self.is_sequence {
for slot in &mut self.slots {
Expand Down Expand Up @@ -540,6 +557,10 @@ unsafe extern "C" fn no_constructor_defined(
})
}

unsafe extern "C" fn call_super_clear(slf: *mut ffi::PyObject) -> c_int {
_call_clear(slf, |_, _| Ok(()), call_super_clear)
}

#[derive(Default)]
struct GetSetDefBuilder {
doc: Option<&'static CStr>,
Expand Down
Loading

0 comments on commit 0d7d013

Please sign in to comment.