pyo3/types/
set.rs

1use crate::types::PyIterator;
2use crate::{
3    err::{self, PyErr, PyResult},
4    ffi_ptr_ext::FfiPtrExt,
5    instance::Bound,
6    py_result_ext::PyResultExt,
7};
8use crate::{ffi, Borrowed, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, Python};
9use std::ptr;
10
11/// Represents a Python `set`.
12///
13/// Values of this type are accessed via PyO3's smart pointers, e.g. as
14/// [`Py<PySet>`][crate::Py] or [`Bound<'py, PySet>`][Bound].
15///
16/// For APIs available on `set` objects, see the [`PySetMethods`] trait which is implemented for
17/// [`Bound<'py, PySet>`][Bound].
18#[repr(transparent)]
19pub struct PySet(PyAny);
20
21#[cfg(not(any(PyPy, GraalPy)))]
22pyobject_subclassable_native_type!(PySet, crate::ffi::PySetObject);
23
24#[cfg(not(any(PyPy, GraalPy)))]
25pyobject_native_type!(
26    PySet,
27    ffi::PySetObject,
28    pyobject_native_static_type_object!(ffi::PySet_Type),
29    #checkfunction=ffi::PySet_Check
30);
31
32#[cfg(any(PyPy, GraalPy))]
33pyobject_native_type_core!(
34    PySet,
35    pyobject_native_static_type_object!(ffi::PySet_Type),
36    #checkfunction=ffi::PySet_Check
37);
38
39impl PySet {
40    /// Creates a new set with elements from the given slice.
41    ///
42    /// Returns an error if some element is not hashable.
43    #[inline]
44    pub fn new<'py, T>(
45        py: Python<'py>,
46        elements: impl IntoIterator<Item = T>,
47    ) -> PyResult<Bound<'py, PySet>>
48    where
49        T: IntoPyObject<'py>,
50    {
51        try_new_from_iter(py, elements)
52    }
53
54    /// Creates a new empty set.
55    pub fn empty(py: Python<'_>) -> PyResult<Bound<'_, PySet>> {
56        unsafe {
57            ffi::PySet_New(ptr::null_mut())
58                .assume_owned_or_err(py)
59                .cast_into_unchecked()
60        }
61    }
62}
63
64/// Implementation of functionality for [`PySet`].
65///
66/// These methods are defined for the `Bound<'py, PySet>` smart pointer, so to use method call
67/// syntax these methods are separated into a trait, because stable Rust does not yet support
68/// `arbitrary_self_types`.
69#[doc(alias = "PySet")]
70pub trait PySetMethods<'py>: crate::sealed::Sealed {
71    /// Removes all elements from the set.
72    fn clear(&self);
73
74    /// Returns the number of items in the set.
75    ///
76    /// This is equivalent to the Python expression `len(self)`.
77    fn len(&self) -> usize;
78
79    /// Checks if set is empty.
80    fn is_empty(&self) -> bool {
81        self.len() == 0
82    }
83
84    /// Determines if the set contains the specified key.
85    ///
86    /// This is equivalent to the Python expression `key in self`.
87    fn contains<K>(&self, key: K) -> PyResult<bool>
88    where
89        K: IntoPyObject<'py>;
90
91    /// Removes the element from the set if it is present.
92    ///
93    /// Returns `true` if the element was present in the set.
94    fn discard<K>(&self, key: K) -> PyResult<bool>
95    where
96        K: IntoPyObject<'py>;
97
98    /// Adds an element to the set.
99    fn add<K>(&self, key: K) -> PyResult<()>
100    where
101        K: IntoPyObject<'py>;
102
103    /// Removes and returns an arbitrary element from the set.
104    fn pop(&self) -> Option<Bound<'py, PyAny>>;
105
106    /// Returns an iterator of values in this set.
107    ///
108    /// # Panics
109    ///
110    /// If PyO3 detects that the set is mutated during iteration, it will panic.
111    fn iter(&self) -> BoundSetIterator<'py>;
112}
113
114impl<'py> PySetMethods<'py> for Bound<'py, PySet> {
115    #[inline]
116    fn clear(&self) {
117        unsafe {
118            ffi::PySet_Clear(self.as_ptr());
119        }
120    }
121
122    #[inline]
123    fn len(&self) -> usize {
124        unsafe { ffi::PySet_Size(self.as_ptr()) as usize }
125    }
126
127    fn contains<K>(&self, key: K) -> PyResult<bool>
128    where
129        K: IntoPyObject<'py>,
130    {
131        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
132            match unsafe { ffi::PySet_Contains(set.as_ptr(), key.as_ptr()) } {
133                1 => Ok(true),
134                0 => Ok(false),
135                _ => Err(PyErr::fetch(set.py())),
136            }
137        }
138
139        let py = self.py();
140        inner(
141            self,
142            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
143        )
144    }
145
146    fn discard<K>(&self, key: K) -> PyResult<bool>
147    where
148        K: IntoPyObject<'py>,
149    {
150        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
151            match unsafe { ffi::PySet_Discard(set.as_ptr(), key.as_ptr()) } {
152                1 => Ok(true),
153                0 => Ok(false),
154                _ => Err(PyErr::fetch(set.py())),
155            }
156        }
157
158        let py = self.py();
159        inner(
160            self,
161            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
162        )
163    }
164
165    fn add<K>(&self, key: K) -> PyResult<()>
166    where
167        K: IntoPyObject<'py>,
168    {
169        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<()> {
170            err::error_on_minusone(set.py(), unsafe {
171                ffi::PySet_Add(set.as_ptr(), key.as_ptr())
172            })
173        }
174
175        let py = self.py();
176        inner(
177            self,
178            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
179        )
180    }
181
182    fn pop(&self) -> Option<Bound<'py, PyAny>> {
183        let element = unsafe { ffi::PySet_Pop(self.as_ptr()).assume_owned_or_err(self.py()) };
184        element.ok()
185    }
186
187    fn iter(&self) -> BoundSetIterator<'py> {
188        BoundSetIterator::new(self.clone())
189    }
190}
191
192impl<'py> IntoIterator for Bound<'py, PySet> {
193    type Item = Bound<'py, PyAny>;
194    type IntoIter = BoundSetIterator<'py>;
195
196    /// Returns an iterator of values in this set.
197    ///
198    /// # Panics
199    ///
200    /// If PyO3 detects that the set is mutated during iteration, it will panic.
201    fn into_iter(self) -> Self::IntoIter {
202        BoundSetIterator::new(self)
203    }
204}
205
206impl<'py> IntoIterator for &Bound<'py, PySet> {
207    type Item = Bound<'py, PyAny>;
208    type IntoIter = BoundSetIterator<'py>;
209
210    /// Returns an iterator of values in this set.
211    ///
212    /// # Panics
213    ///
214    /// If PyO3 detects that the set is mutated during iteration, it will panic.
215    fn into_iter(self) -> Self::IntoIter {
216        self.iter()
217    }
218}
219
220/// PyO3 implementation of an iterator for a Python `set` object.
221pub struct BoundSetIterator<'p> {
222    it: Bound<'p, PyIterator>,
223    // Remaining elements in the set. This is fine to store because
224    // Python will error if the set changes size during iteration.
225    remaining: usize,
226}
227
228impl<'py> BoundSetIterator<'py> {
229    pub(super) fn new(set: Bound<'py, PySet>) -> Self {
230        Self {
231            it: PyIterator::from_object(&set).unwrap(),
232            remaining: set.len(),
233        }
234    }
235}
236
237impl<'py> Iterator for BoundSetIterator<'py> {
238    type Item = Bound<'py, super::PyAny>;
239
240    /// Advances the iterator and returns the next value.
241    fn next(&mut self) -> Option<Self::Item> {
242        self.remaining = self.remaining.saturating_sub(1);
243        self.it.next().map(Result::unwrap)
244    }
245
246    fn size_hint(&self) -> (usize, Option<usize>) {
247        (self.remaining, Some(self.remaining))
248    }
249
250    #[inline]
251    fn count(self) -> usize
252    where
253        Self: Sized,
254    {
255        self.len()
256    }
257}
258
259impl ExactSizeIterator for BoundSetIterator<'_> {
260    fn len(&self) -> usize {
261        self.remaining
262    }
263}
264
265#[inline]
266pub(crate) fn try_new_from_iter<'py, T>(
267    py: Python<'py>,
268    elements: impl IntoIterator<Item = T>,
269) -> PyResult<Bound<'py, PySet>>
270where
271    T: IntoPyObject<'py>,
272{
273    let set = unsafe {
274        // We create the `Bound` pointer because its Drop cleans up the set if
275        // user code errors or panics.
276        ffi::PySet_New(std::ptr::null_mut())
277            .assume_owned_or_err(py)?
278            .cast_into_unchecked()
279    };
280    let ptr = set.as_ptr();
281
282    elements.into_iter().try_for_each(|element| {
283        let obj = element.into_pyobject_or_pyerr(py)?;
284        err::error_on_minusone(py, unsafe { ffi::PySet_Add(ptr, obj.as_ptr()) })
285    })?;
286
287    Ok(set)
288}
289
290#[cfg(test)]
291mod tests {
292    use super::PySet;
293    use crate::{
294        conversion::IntoPyObject,
295        ffi,
296        types::{PyAnyMethods, PySetMethods},
297        Python,
298    };
299    use std::collections::HashSet;
300
301    #[test]
302    fn test_set_new() {
303        Python::attach(|py| {
304            let set = PySet::new(py, [1]).unwrap();
305            assert_eq!(1, set.len());
306
307            let v = vec![1];
308            assert!(PySet::new(py, &[v]).is_err());
309        });
310    }
311
312    #[test]
313    fn test_set_empty() {
314        Python::attach(|py| {
315            let set = PySet::empty(py).unwrap();
316            assert_eq!(0, set.len());
317            assert!(set.is_empty());
318        });
319    }
320
321    #[test]
322    fn test_set_len() {
323        Python::attach(|py| {
324            let mut v = HashSet::<i32>::new();
325            let ob = (&v).into_pyobject(py).unwrap();
326            let set = ob.cast::<PySet>().unwrap();
327            assert_eq!(0, set.len());
328            v.insert(7);
329            let ob = v.into_pyobject(py).unwrap();
330            let set2 = ob.cast::<PySet>().unwrap();
331            assert_eq!(1, set2.len());
332        });
333    }
334
335    #[test]
336    fn test_set_clear() {
337        Python::attach(|py| {
338            let set = PySet::new(py, [1]).unwrap();
339            assert_eq!(1, set.len());
340            set.clear();
341            assert_eq!(0, set.len());
342        });
343    }
344
345    #[test]
346    fn test_set_contains() {
347        Python::attach(|py| {
348            let set = PySet::new(py, [1]).unwrap();
349            assert!(set.contains(1).unwrap());
350        });
351    }
352
353    #[test]
354    fn test_set_discard() {
355        Python::attach(|py| {
356            let set = PySet::new(py, [1]).unwrap();
357            assert!(!set.discard(2).unwrap());
358            assert_eq!(1, set.len());
359
360            assert!(set.discard(1).unwrap());
361            assert_eq!(0, set.len());
362            assert!(!set.discard(1).unwrap());
363
364            assert!(set.discard(vec![1, 2]).is_err());
365        });
366    }
367
368    #[test]
369    fn test_set_add() {
370        Python::attach(|py| {
371            let set = PySet::new(py, [1, 2]).unwrap();
372            set.add(1).unwrap(); // Add a dupliated element
373            assert!(set.contains(1).unwrap());
374        });
375    }
376
377    #[test]
378    fn test_set_pop() {
379        Python::attach(|py| {
380            let set = PySet::new(py, [1]).unwrap();
381            let val = set.pop();
382            assert!(val.is_some());
383            let val2 = set.pop();
384            assert!(val2.is_none());
385            assert!(py
386                .eval(
387                    ffi::c_str!("print('Exception state should not be set.')"),
388                    None,
389                    None
390                )
391                .is_ok());
392        });
393    }
394
395    #[test]
396    fn test_set_iter() {
397        Python::attach(|py| {
398            let set = PySet::new(py, [1]).unwrap();
399
400            for el in set {
401                assert_eq!(1i32, el.extract::<'_, i32>().unwrap());
402            }
403        });
404    }
405
406    #[test]
407    fn test_set_iter_bound() {
408        use crate::types::any::PyAnyMethods;
409
410        Python::attach(|py| {
411            let set = PySet::new(py, [1]).unwrap();
412
413            for el in &set {
414                assert_eq!(1i32, el.extract::<i32>().unwrap());
415            }
416        });
417    }
418
419    #[test]
420    #[should_panic]
421    fn test_set_iter_mutation() {
422        Python::attach(|py| {
423            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
424
425            for _ in &set {
426                let _ = set.add(42);
427            }
428        });
429    }
430
431    #[test]
432    #[should_panic]
433    fn test_set_iter_mutation_same_len() {
434        Python::attach(|py| {
435            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
436
437            for item in &set {
438                let item: i32 = item.extract().unwrap();
439                let _ = set.del_item(item);
440                let _ = set.add(item + 10);
441            }
442        });
443    }
444
445    #[test]
446    fn test_set_iter_size_hint() {
447        Python::attach(|py| {
448            let set = PySet::new(py, [1]).unwrap();
449            let mut iter = set.iter();
450
451            // Exact size
452            assert_eq!(iter.len(), 1);
453            assert_eq!(iter.size_hint(), (1, Some(1)));
454            iter.next();
455            assert_eq!(iter.len(), 0);
456            assert_eq!(iter.size_hint(), (0, Some(0)));
457        });
458    }
459
460    #[test]
461    fn test_iter_count() {
462        Python::attach(|py| {
463            let set = PySet::new(py, vec![1, 2, 3]).unwrap();
464            assert_eq!(set.iter().count(), 3);
465        })
466    }
467}