1#![allow(missing_docs)]
2use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::mem::ManuallyDrop;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::impl_::pyclass::{
10 PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef,
11};
12use crate::internal::get_slot::TP_FREE;
13use crate::type_object::{PyLayout, PySizedLayout};
14use crate::types::{PyType, PyTypeMethods};
15use crate::{ffi, PyClass, PyTypeInfo, Python};
16
17use super::{PyBorrowError, PyBorrowMutError};
18
19pub trait PyClassMutability {
20 type Storage: PyClassBorrowChecker;
23 type Checker: PyClassBorrowChecker;
27 type ImmutableChild: PyClassMutability;
28 type MutableChild: PyClassMutability;
29}
30
31pub struct ImmutableClass(());
32pub struct MutableClass(());
33pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>);
34
35impl PyClassMutability for ImmutableClass {
36 type Storage = EmptySlot;
37 type Checker = EmptySlot;
38 type ImmutableChild = ImmutableClass;
39 type MutableChild = MutableClass;
40}
41
42impl PyClassMutability for MutableClass {
43 type Storage = BorrowChecker;
44 type Checker = BorrowChecker;
45 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
46 type MutableChild = ExtendsMutableAncestor<MutableClass>;
47}
48
49impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
50 type Storage = EmptySlot;
51 type Checker = BorrowChecker;
52 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
53 type MutableChild = ExtendsMutableAncestor<MutableClass>;
54}
55
56#[derive(Debug)]
57struct BorrowFlag(AtomicUsize);
58
59impl BorrowFlag {
60 pub(crate) const UNUSED: usize = 0;
61 const HAS_MUTABLE_BORROW: usize = usize::MAX;
62 fn increment(&self) -> Result<(), PyBorrowError> {
63 let mut value = self.0.load(Ordering::Relaxed);
65 loop {
66 if value == BorrowFlag::HAS_MUTABLE_BORROW {
67 return Err(PyBorrowError { _private: () });
68 }
69 match self.0.compare_exchange(
70 value,
73 value + 1,
74 Ordering::AcqRel,
77 Ordering::Relaxed,
79 ) {
80 Ok(..) => {
81 break Ok(());
82 }
83 Err(changed_value) => {
84 value = changed_value;
86 }
87 }
88 }
89 }
90 fn decrement(&self) {
91 self.0.fetch_sub(1, Ordering::Release);
93 }
94}
95
96pub struct EmptySlot(());
97pub struct BorrowChecker(BorrowFlag);
98
99pub trait PyClassBorrowChecker {
100 fn new() -> Self
102 where
103 Self: Sized;
104
105 fn try_borrow(&self) -> Result<(), PyBorrowError>;
107
108 fn release_borrow(&self);
110 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>;
112 fn release_borrow_mut(&self);
114}
115
116impl PyClassBorrowChecker for EmptySlot {
117 #[inline]
118 fn new() -> Self {
119 EmptySlot(())
120 }
121
122 #[inline]
123 fn try_borrow(&self) -> Result<(), PyBorrowError> {
124 Ok(())
125 }
126
127 #[inline]
128 fn release_borrow(&self) {}
129
130 #[inline]
131 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
132 unreachable!()
133 }
134
135 #[inline]
136 fn release_borrow_mut(&self) {
137 unreachable!()
138 }
139}
140
141impl PyClassBorrowChecker for BorrowChecker {
142 #[inline]
143 fn new() -> Self {
144 Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
145 }
146
147 fn try_borrow(&self) -> Result<(), PyBorrowError> {
148 self.0.increment()
149 }
150
151 fn release_borrow(&self) {
152 self.0.decrement();
153 }
154
155 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
156 let flag = &self.0;
157 match flag.0.compare_exchange(
158 BorrowFlag::UNUSED,
161 BorrowFlag::HAS_MUTABLE_BORROW,
162 Ordering::AcqRel,
165 Ordering::Relaxed,
168 ) {
169 Ok(..) => Ok(()),
170 Err(..) => Err(PyBorrowMutError { _private: () }),
171 }
172 }
173
174 fn release_borrow_mut(&self) {
175 self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
176 }
177}
178
179pub trait GetBorrowChecker<T: PyClassImpl> {
180 fn borrow_checker(
181 class_object: &PyClassObject<T>,
182 ) -> &<T::PyClassMutability as PyClassMutability>::Checker;
183}
184
185impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
186 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
187 &class_object.contents.borrow_checker
188 }
189}
190
191impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
192 fn borrow_checker(class_object: &PyClassObject<T>) -> &EmptySlot {
193 &class_object.contents.borrow_checker
194 }
195}
196
197impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T>
198 for ExtendsMutableAncestor<M>
199where
200 T::BaseType: PyClassImpl + PyClassBaseType<LayoutAsBase = PyClassObject<T::BaseType>>,
201 <T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>,
202{
203 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
204 <<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(&class_object.ob_base)
205 }
206}
207
208#[doc(hidden)]
210#[repr(C)]
211pub struct PyClassObjectBase<T> {
212 ob_base: T,
213}
214
215unsafe impl<T, U> PyLayout<T> for PyClassObjectBase<U> where U: PySizedLayout<T> {}
216
217#[doc(hidden)]
218pub trait PyClassObjectLayout<T>: PyLayout<T> {
219 fn ensure_threadsafe(&self);
220 fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
221 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject);
226}
227
228impl<T, U> PyClassObjectLayout<T> for PyClassObjectBase<U>
229where
230 U: PySizedLayout<T>,
231 T: PyTypeInfo,
232{
233 fn ensure_threadsafe(&self) {}
234 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
235 Ok(())
236 }
237 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
238 unsafe {
239 let type_obj = T::type_object(py);
242 let type_ptr = type_obj.as_type_ptr();
243 let actual_type = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(slf));
244
245 if std::ptr::eq(type_ptr, std::ptr::addr_of!(ffi::PyBaseObject_Type)) {
247 let tp_free = actual_type
248 .get_slot(TP_FREE)
249 .expect("PyBaseObject_Type should have tp_free");
250 return tp_free(slf.cast());
251 }
252
253 #[cfg(not(Py_LIMITED_API))]
255 {
256 if let Some(dealloc) = (*type_ptr).tp_dealloc {
258 #[cfg(not(any(Py_3_11, PyPy)))]
262 if ffi::PyType_FastSubclass(type_ptr, ffi::Py_TPFLAGS_BASE_EXC_SUBCLASS) == 1 {
263 ffi::PyObject_GC_Track(slf.cast());
264 }
265 dealloc(slf);
266 } else {
267 (*actual_type.as_type_ptr())
268 .tp_free
269 .expect("type missing tp_free")(slf.cast());
270 }
271 }
272
273 #[cfg(Py_LIMITED_API)]
274 unreachable!("subclassing native types is not possible with the `abi3` feature");
275 }
276 }
277}
278
279#[repr(C)]
281pub struct PyClassObject<T: PyClassImpl> {
282 pub(crate) ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
283 pub(crate) contents: PyClassObjectContents<T>,
284}
285
286#[repr(C)]
287pub(crate) struct PyClassObjectContents<T: PyClassImpl> {
288 pub(crate) value: ManuallyDrop<UnsafeCell<T>>,
289 pub(crate) borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage,
290 pub(crate) thread_checker: T::ThreadChecker,
291 pub(crate) dict: T::Dict,
292 pub(crate) weakref: T::WeakRef,
293}
294
295impl<T: PyClassImpl> PyClassObject<T> {
296 pub(crate) fn get_ptr(&self) -> *mut T {
297 self.contents.value.get()
298 }
299
300 pub(crate) fn dict_offset() -> ffi::Py_ssize_t {
302 use memoffset::offset_of;
303
304 let offset =
305 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, dict);
306
307 #[allow(clippy::useless_conversion)]
309 offset.try_into().expect("offset should fit in Py_ssize_t")
310 }
311
312 pub(crate) fn weaklist_offset() -> ffi::Py_ssize_t {
314 use memoffset::offset_of;
315
316 let offset =
317 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, weakref);
318
319 #[allow(clippy::useless_conversion)]
321 offset.try_into().expect("offset should fit in Py_ssize_t")
322 }
323}
324
325impl<T: PyClassImpl> PyClassObject<T> {
326 pub(crate) fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker {
327 T::PyClassMutability::borrow_checker(self)
328 }
329}
330
331unsafe impl<T: PyClassImpl> PyLayout<T> for PyClassObject<T> {}
332impl<T: PyClass> PySizedLayout<T> for PyClassObject<T> {}
333
334impl<T: PyClassImpl> PyClassObjectLayout<T> for PyClassObject<T>
335where
336 <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>,
337{
338 fn ensure_threadsafe(&self) {
339 self.contents.thread_checker.ensure();
340 self.ob_base.ensure_threadsafe();
341 }
342 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
343 if !self.contents.thread_checker.check() {
344 return Err(PyBorrowError { _private: () });
345 }
346 self.ob_base.check_threadsafe()
347 }
348 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
349 let class_object = unsafe { &mut *(slf.cast::<PyClassObject<T>>()) };
351 if class_object.contents.thread_checker.can_drop(py) {
352 unsafe { ManuallyDrop::drop(&mut class_object.contents.value) };
353 }
354 class_object.contents.dict.clear_dict(py);
355 unsafe {
356 class_object.contents.weakref.clear_weakrefs(slf, py);
357 <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf)
358 }
359 }
360}
361
362#[cfg(test)]
363#[cfg(feature = "macros")]
364mod tests {
365 use super::*;
366
367 use crate::prelude::*;
368 use crate::pyclass::boolean_struct::{False, True};
369
370 #[pyclass(crate = "crate", subclass)]
371 struct MutableBase;
372
373 #[pyclass(crate = "crate", extends = MutableBase, subclass)]
374 struct MutableChildOfMutableBase;
375
376 #[pyclass(crate = "crate", extends = MutableBase, frozen, subclass)]
377 struct ImmutableChildOfMutableBase;
378
379 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase)]
380 struct MutableChildOfMutableChildOfMutableBase;
381
382 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase)]
383 struct MutableChildOfImmutableChildOfMutableBase;
384
385 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase, frozen)]
386 struct ImmutableChildOfMutableChildOfMutableBase;
387
388 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase, frozen)]
389 struct ImmutableChildOfImmutableChildOfMutableBase;
390
391 #[pyclass(crate = "crate", frozen, subclass)]
392 struct ImmutableBase;
393
394 #[pyclass(crate = "crate", extends = ImmutableBase, subclass)]
395 struct MutableChildOfImmutableBase;
396
397 #[pyclass(crate = "crate", extends = ImmutableBase, frozen, subclass)]
398 struct ImmutableChildOfImmutableBase;
399
400 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase)]
401 struct MutableChildOfMutableChildOfImmutableBase;
402
403 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase)]
404 struct MutableChildOfImmutableChildOfImmutableBase;
405
406 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase, frozen)]
407 struct ImmutableChildOfMutableChildOfImmutableBase;
408
409 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase, frozen)]
410 struct ImmutableChildOfImmutableChildOfImmutableBase;
411
412 fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {}
413 fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {}
414 fn assert_mutable_with_mutable_ancestor<
415 T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>,
416 >() {
417 }
418 fn assert_immutable_with_mutable_ancestor<
419 T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>,
420 >() {
421 }
422
423 #[test]
424 fn test_inherited_mutability() {
425 assert_mutable::<MutableBase>();
427
428 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>();
430 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>();
431
432 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>();
434 assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>();
435 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>();
436 assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>();
437
438 assert_immutable::<ImmutableBase>();
440 assert_immutable::<ImmutableChildOfImmutableBase>();
441 assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>();
442
443 assert_mutable::<MutableChildOfImmutableBase>();
445 assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>();
446
447 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>();
449 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>();
450 }
451
452 #[test]
453 fn test_mutable_borrow_prevents_further_borrows() {
454 Python::attach(|py| {
455 let mmm = Py::new(
456 py,
457 PyClassInitializer::from(MutableBase)
458 .add_subclass(MutableChildOfMutableBase)
459 .add_subclass(MutableChildOfMutableChildOfMutableBase),
460 )
461 .unwrap();
462
463 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
464
465 let mmm_refmut = mmm_bound.borrow_mut();
466
467 assert!(mmm_bound
469 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
470 .is_err());
471 assert!(mmm_bound
472 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
473 .is_err());
474 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_err());
475 assert!(mmm_bound
476 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
477 .is_err());
478 assert!(mmm_bound
479 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
480 .is_err());
481 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
482
483 drop(mmm_refmut);
485
486 assert!(mmm_bound
487 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
488 .is_ok());
489 assert!(mmm_bound
490 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
491 .is_ok());
492 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
493 assert!(mmm_bound
494 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
495 .is_ok());
496 assert!(mmm_bound
497 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
498 .is_ok());
499 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
500 })
501 }
502
503 #[test]
504 fn test_immutable_borrows_prevent_mutable_borrows() {
505 Python::attach(|py| {
506 let mmm = Py::new(
507 py,
508 PyClassInitializer::from(MutableBase)
509 .add_subclass(MutableChildOfMutableBase)
510 .add_subclass(MutableChildOfMutableChildOfMutableBase),
511 )
512 .unwrap();
513
514 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
515
516 let mmm_refmut = mmm_bound.borrow();
517
518 assert!(mmm_bound
520 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
521 .is_ok());
522 assert!(mmm_bound
523 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
524 .is_ok());
525 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
526
527 assert!(mmm_bound
529 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
530 .is_err());
531 assert!(mmm_bound
532 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
533 .is_err());
534 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
535
536 drop(mmm_refmut);
538
539 assert!(mmm_bound
540 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
541 .is_ok());
542 assert!(mmm_bound
543 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
544 .is_ok());
545 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
546 })
547 }
548
549 #[test]
550 #[cfg(not(target_arch = "wasm32"))]
551 fn test_thread_safety() {
552 #[crate::pyclass(crate = "crate")]
553 struct MyClass {
554 x: u64,
555 }
556
557 Python::attach(|py| {
558 let inst = Py::new(py, MyClass { x: 0 }).unwrap();
559
560 let total_modifications = py.detach(|| {
561 std::thread::scope(|s| {
562 let threads = (0..10)
565 .map(|_| {
566 s.spawn(|| {
567 Python::attach(|py| {
568 let mut local_modifications = 0;
570 for _ in 0..100 {
571 if let Ok(mut i) = inst.try_borrow_mut(py) {
572 i.x += 1;
573 local_modifications += 1;
574 }
575 }
576 local_modifications
577 })
578 })
579 })
580 .collect::<Vec<_>>();
581
582 threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>()
584 })
585 });
586
587 assert_eq!(total_modifications, inst.borrow(py).x);
590 });
591 }
592
593 #[test]
594 #[cfg(not(target_arch = "wasm32"))]
595 fn test_thread_safety_2() {
596 struct SyncUnsafeCell<T>(UnsafeCell<T>);
597 unsafe impl<T> Sync for SyncUnsafeCell<T> {}
598
599 impl<T> SyncUnsafeCell<T> {
600 fn get(&self) -> *mut T {
601 self.0.get()
602 }
603 }
604
605 let data = SyncUnsafeCell(UnsafeCell::new(0));
606 let data2 = SyncUnsafeCell(UnsafeCell::new(0));
607 let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)));
608
609 std::thread::scope(|s| {
610 s.spawn(|| {
611 for _ in 0..1_000_000 {
612 if borrow_checker.try_borrow_mut().is_ok() {
613 unsafe { *data.get() += 1 };
615 unsafe { *data2.get() += 1 };
616 borrow_checker.release_borrow_mut();
617 }
618 }
619 });
620
621 s.spawn(|| {
622 for _ in 0..1_000_000 {
623 if borrow_checker.try_borrow().is_ok() {
624 assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() });
627 borrow_checker.release_borrow();
628 }
629 }
630 });
631 });
632 }
633}