1#![allow(missing_docs)]
2use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::mem::ManuallyDrop;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::impl_::pyclass::{
10 PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef,
11};
12use crate::internal::get_slot::TP_FREE;
13use crate::type_object::{PyLayout, PySizedLayout};
14use crate::types::{PyType, PyTypeMethods};
15use crate::{ffi, PyClass, PyTypeInfo, Python};
16
17use super::{PyBorrowError, PyBorrowMutError};
18
19pub trait PyClassMutability {
20 type Storage: PyClassBorrowChecker;
23 type Checker: PyClassBorrowChecker;
27 type ImmutableChild: PyClassMutability;
28 type MutableChild: PyClassMutability;
29}
30
31pub struct ImmutableClass(());
32pub struct MutableClass(());
33pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>);
34
35impl PyClassMutability for ImmutableClass {
36 type Storage = EmptySlot;
37 type Checker = EmptySlot;
38 type ImmutableChild = ImmutableClass;
39 type MutableChild = MutableClass;
40}
41
42impl PyClassMutability for MutableClass {
43 type Storage = BorrowChecker;
44 type Checker = BorrowChecker;
45 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
46 type MutableChild = ExtendsMutableAncestor<MutableClass>;
47}
48
49impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
50 type Storage = EmptySlot;
51 type Checker = BorrowChecker;
52 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
53 type MutableChild = ExtendsMutableAncestor<MutableClass>;
54}
55
56#[derive(Debug)]
57struct BorrowFlag(AtomicUsize);
58
59impl BorrowFlag {
60 pub(crate) const UNUSED: usize = 0;
61 const HAS_MUTABLE_BORROW: usize = usize::MAX;
62 fn increment(&self) -> Result<(), PyBorrowError> {
63 let mut value = self.0.load(Ordering::Relaxed);
64 loop {
65 if value == BorrowFlag::HAS_MUTABLE_BORROW {
66 return Err(PyBorrowError { _private: () });
67 }
68 match self.0.compare_exchange(
69 value,
72 value + 1,
73 Ordering::Relaxed,
74 Ordering::Relaxed,
75 ) {
76 Ok(..) => {
77 std::sync::atomic::fence(Ordering::Acquire);
80 break Ok(());
81 }
82 Err(changed_value) => {
83 value = changed_value;
85 }
86 }
87 }
88 }
89 fn decrement(&self) {
90 self.0.fetch_sub(1, Ordering::Relaxed);
94 }
95}
96
97pub struct EmptySlot(());
98pub struct BorrowChecker(BorrowFlag);
99
100pub trait PyClassBorrowChecker {
101 fn new() -> Self;
103
104 fn try_borrow(&self) -> Result<(), PyBorrowError>;
106
107 fn release_borrow(&self);
109 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>;
111 fn release_borrow_mut(&self);
113}
114
115impl PyClassBorrowChecker for EmptySlot {
116 #[inline]
117 fn new() -> Self {
118 EmptySlot(())
119 }
120
121 #[inline]
122 fn try_borrow(&self) -> Result<(), PyBorrowError> {
123 Ok(())
124 }
125
126 #[inline]
127 fn release_borrow(&self) {}
128
129 #[inline]
130 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
131 unreachable!()
132 }
133
134 #[inline]
135 fn release_borrow_mut(&self) {
136 unreachable!()
137 }
138}
139
140impl PyClassBorrowChecker for BorrowChecker {
141 #[inline]
142 fn new() -> Self {
143 Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
144 }
145
146 fn try_borrow(&self) -> Result<(), PyBorrowError> {
147 self.0.increment()
148 }
149
150 fn release_borrow(&self) {
151 self.0.decrement();
152 }
153
154 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
155 let flag = &self.0;
156 match flag.0.compare_exchange(
157 BorrowFlag::UNUSED,
160 BorrowFlag::HAS_MUTABLE_BORROW,
161 Ordering::AcqRel,
164 Ordering::Relaxed,
167 ) {
168 Ok(..) => Ok(()),
169 Err(..) => Err(PyBorrowMutError { _private: () }),
170 }
171 }
172
173 fn release_borrow_mut(&self) {
174 self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
175 }
176}
177
178pub trait GetBorrowChecker<T: PyClassImpl> {
179 fn borrow_checker(
180 class_object: &PyClassObject<T>,
181 ) -> &<T::PyClassMutability as PyClassMutability>::Checker;
182}
183
184impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
185 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
186 &class_object.contents.borrow_checker
187 }
188}
189
190impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
191 fn borrow_checker(class_object: &PyClassObject<T>) -> &EmptySlot {
192 &class_object.contents.borrow_checker
193 }
194}
195
196impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T>
197 for ExtendsMutableAncestor<M>
198where
199 T::BaseType: PyClassImpl + PyClassBaseType<LayoutAsBase = PyClassObject<T::BaseType>>,
200 <T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>,
201{
202 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
203 <<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(&class_object.ob_base)
204 }
205}
206
207#[doc(hidden)]
209#[repr(C)]
210pub struct PyClassObjectBase<T> {
211 ob_base: T,
212}
213
214unsafe impl<T, U> PyLayout<T> for PyClassObjectBase<U> where U: PySizedLayout<T> {}
215
216#[doc(hidden)]
217pub trait PyClassObjectLayout<T>: PyLayout<T> {
218 fn ensure_threadsafe(&self);
219 fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
220 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject);
225}
226
227impl<T, U> PyClassObjectLayout<T> for PyClassObjectBase<U>
228where
229 U: PySizedLayout<T>,
230 T: PyTypeInfo,
231{
232 fn ensure_threadsafe(&self) {}
233 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
234 Ok(())
235 }
236 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
237 let type_obj = T::type_object(py);
240 let type_ptr = type_obj.as_type_ptr();
241 let actual_type = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(slf));
242
243 if type_ptr == std::ptr::addr_of_mut!(ffi::PyBaseObject_Type) {
245 let tp_free = actual_type
246 .get_slot(TP_FREE)
247 .expect("PyBaseObject_Type should have tp_free");
248 return tp_free(slf.cast());
249 }
250
251 #[cfg(not(Py_LIMITED_API))]
253 {
254 if let Some(dealloc) = (*type_ptr).tp_dealloc {
256 #[cfg(not(any(Py_3_11, PyPy)))]
260 if ffi::PyType_FastSubclass(type_ptr, ffi::Py_TPFLAGS_BASE_EXC_SUBCLASS) == 1 {
261 ffi::PyObject_GC_Track(slf.cast());
262 }
263 dealloc(slf);
264 } else {
265 (*actual_type.as_type_ptr())
266 .tp_free
267 .expect("type missing tp_free")(slf.cast());
268 }
269 }
270
271 #[cfg(Py_LIMITED_API)]
272 unreachable!("subclassing native types is not possible with the `abi3` feature");
273 }
274}
275
276#[repr(C)]
278pub struct PyClassObject<T: PyClassImpl> {
279 pub(crate) ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
280 pub(crate) contents: PyClassObjectContents<T>,
281}
282
283#[repr(C)]
284pub(crate) struct PyClassObjectContents<T: PyClassImpl> {
285 pub(crate) value: ManuallyDrop<UnsafeCell<T>>,
286 pub(crate) borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage,
287 pub(crate) thread_checker: T::ThreadChecker,
288 pub(crate) dict: T::Dict,
289 pub(crate) weakref: T::WeakRef,
290}
291
292impl<T: PyClassImpl> PyClassObject<T> {
293 pub(crate) fn get_ptr(&self) -> *mut T {
294 self.contents.value.get()
295 }
296
297 pub(crate) fn dict_offset() -> ffi::Py_ssize_t {
299 use memoffset::offset_of;
300
301 let offset =
302 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, dict);
303
304 #[allow(clippy::useless_conversion)]
306 offset.try_into().expect("offset should fit in Py_ssize_t")
307 }
308
309 pub(crate) fn weaklist_offset() -> ffi::Py_ssize_t {
311 use memoffset::offset_of;
312
313 let offset =
314 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, weakref);
315
316 #[allow(clippy::useless_conversion)]
318 offset.try_into().expect("offset should fit in Py_ssize_t")
319 }
320}
321
322impl<T: PyClassImpl> PyClassObject<T> {
323 pub(crate) fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker {
324 T::PyClassMutability::borrow_checker(self)
325 }
326}
327
328unsafe impl<T: PyClassImpl> PyLayout<T> for PyClassObject<T> {}
329impl<T: PyClass> PySizedLayout<T> for PyClassObject<T> {}
330
331impl<T: PyClassImpl> PyClassObjectLayout<T> for PyClassObject<T>
332where
333 <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>,
334{
335 fn ensure_threadsafe(&self) {
336 self.contents.thread_checker.ensure();
337 self.ob_base.ensure_threadsafe();
338 }
339 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
340 if !self.contents.thread_checker.check() {
341 return Err(PyBorrowError { _private: () });
342 }
343 self.ob_base.check_threadsafe()
344 }
345 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
346 let class_object = &mut *(slf.cast::<PyClassObject<T>>());
348 if class_object.contents.thread_checker.can_drop(py) {
349 ManuallyDrop::drop(&mut class_object.contents.value);
350 }
351 class_object.contents.dict.clear_dict(py);
352 class_object.contents.weakref.clear_weakrefs(slf, py);
353 <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf)
354 }
355}
356
357#[cfg(test)]
358#[cfg(feature = "macros")]
359mod tests {
360 use super::*;
361
362 use crate::prelude::*;
363 use crate::pyclass::boolean_struct::{False, True};
364
365 #[pyclass(crate = "crate", subclass)]
366 struct MutableBase;
367
368 #[pyclass(crate = "crate", extends = MutableBase, subclass)]
369 struct MutableChildOfMutableBase;
370
371 #[pyclass(crate = "crate", extends = MutableBase, frozen, subclass)]
372 struct ImmutableChildOfMutableBase;
373
374 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase)]
375 struct MutableChildOfMutableChildOfMutableBase;
376
377 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase)]
378 struct MutableChildOfImmutableChildOfMutableBase;
379
380 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase, frozen)]
381 struct ImmutableChildOfMutableChildOfMutableBase;
382
383 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase, frozen)]
384 struct ImmutableChildOfImmutableChildOfMutableBase;
385
386 #[pyclass(crate = "crate", frozen, subclass)]
387 struct ImmutableBase;
388
389 #[pyclass(crate = "crate", extends = ImmutableBase, subclass)]
390 struct MutableChildOfImmutableBase;
391
392 #[pyclass(crate = "crate", extends = ImmutableBase, frozen, subclass)]
393 struct ImmutableChildOfImmutableBase;
394
395 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase)]
396 struct MutableChildOfMutableChildOfImmutableBase;
397
398 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase)]
399 struct MutableChildOfImmutableChildOfImmutableBase;
400
401 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase, frozen)]
402 struct ImmutableChildOfMutableChildOfImmutableBase;
403
404 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase, frozen)]
405 struct ImmutableChildOfImmutableChildOfImmutableBase;
406
407 fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {}
408 fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {}
409 fn assert_mutable_with_mutable_ancestor<
410 T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>,
411 >() {
412 }
413 fn assert_immutable_with_mutable_ancestor<
414 T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>,
415 >() {
416 }
417
418 #[test]
419 fn test_inherited_mutability() {
420 assert_mutable::<MutableBase>();
422
423 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>();
425 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>();
426
427 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>();
429 assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>();
430 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>();
431 assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>();
432
433 assert_immutable::<ImmutableBase>();
435 assert_immutable::<ImmutableChildOfImmutableBase>();
436 assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>();
437
438 assert_mutable::<MutableChildOfImmutableBase>();
440 assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>();
441
442 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>();
444 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>();
445 }
446
447 #[test]
448 fn test_mutable_borrow_prevents_further_borrows() {
449 Python::with_gil(|py| {
450 let mmm = Py::new(
451 py,
452 PyClassInitializer::from(MutableBase)
453 .add_subclass(MutableChildOfMutableBase)
454 .add_subclass(MutableChildOfMutableChildOfMutableBase),
455 )
456 .unwrap();
457
458 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
459
460 let mmm_refmut = mmm_bound.borrow_mut();
461
462 assert!(mmm_bound
464 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
465 .is_err());
466 assert!(mmm_bound
467 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
468 .is_err());
469 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_err());
470 assert!(mmm_bound
471 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
472 .is_err());
473 assert!(mmm_bound
474 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
475 .is_err());
476 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
477
478 drop(mmm_refmut);
480
481 assert!(mmm_bound
482 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
483 .is_ok());
484 assert!(mmm_bound
485 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
486 .is_ok());
487 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
488 assert!(mmm_bound
489 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
490 .is_ok());
491 assert!(mmm_bound
492 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
493 .is_ok());
494 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
495 })
496 }
497
498 #[test]
499 fn test_immutable_borrows_prevent_mutable_borrows() {
500 Python::with_gil(|py| {
501 let mmm = Py::new(
502 py,
503 PyClassInitializer::from(MutableBase)
504 .add_subclass(MutableChildOfMutableBase)
505 .add_subclass(MutableChildOfMutableChildOfMutableBase),
506 )
507 .unwrap();
508
509 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
510
511 let mmm_refmut = mmm_bound.borrow();
512
513 assert!(mmm_bound
515 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
516 .is_ok());
517 assert!(mmm_bound
518 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
519 .is_ok());
520 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
521
522 assert!(mmm_bound
524 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
525 .is_err());
526 assert!(mmm_bound
527 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
528 .is_err());
529 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
530
531 drop(mmm_refmut);
533
534 assert!(mmm_bound
535 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
536 .is_ok());
537 assert!(mmm_bound
538 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
539 .is_ok());
540 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
541 })
542 }
543
544 #[test]
545 #[cfg(not(target_arch = "wasm32"))]
546 fn test_thread_safety() {
547 #[crate::pyclass(crate = "crate")]
548 struct MyClass {
549 x: u64,
550 }
551
552 Python::with_gil(|py| {
553 let inst = Py::new(py, MyClass { x: 0 }).unwrap();
554
555 let total_modifications = py.allow_threads(|| {
556 std::thread::scope(|s| {
557 let threads = (0..10)
560 .map(|_| {
561 s.spawn(|| {
562 Python::with_gil(|py| {
563 let mut local_modifications = 0;
565 for _ in 0..100 {
566 if let Ok(mut i) = inst.try_borrow_mut(py) {
567 i.x += 1;
568 local_modifications += 1;
569 }
570 }
571 local_modifications
572 })
573 })
574 })
575 .collect::<Vec<_>>();
576
577 threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>()
579 })
580 });
581
582 assert_eq!(total_modifications, inst.borrow(py).x);
585 });
586 }
587
588 #[test]
589 #[cfg(not(target_arch = "wasm32"))]
590 fn test_thread_safety_2() {
591 struct SyncUnsafeCell<T>(UnsafeCell<T>);
592 unsafe impl<T> Sync for SyncUnsafeCell<T> {}
593
594 impl<T> SyncUnsafeCell<T> {
595 fn get(&self) -> *mut T {
596 self.0.get()
597 }
598 }
599
600 let data = SyncUnsafeCell(UnsafeCell::new(0));
601 let data2 = SyncUnsafeCell(UnsafeCell::new(0));
602 let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)));
603
604 std::thread::scope(|s| {
605 s.spawn(|| {
606 for _ in 0..1_000_000 {
607 if borrow_checker.try_borrow_mut().is_ok() {
608 unsafe { *data.get() += 1 };
610 unsafe { *data2.get() += 1 };
611 borrow_checker.release_borrow_mut();
612 }
613 }
614 });
615
616 s.spawn(|| {
617 for _ in 0..1_000_000 {
618 if borrow_checker.try_borrow().is_ok() {
619 assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() });
622 borrow_checker.release_borrow();
623 }
624 }
625 });
626 });
627 }
628}