pyo3/
pybacked.rs

1//! Contains types for working with Python objects that own the underlying data.
2
3use std::{convert::Infallible, ops::Deref, ptr::NonNull, sync::Arc};
4
5use crate::{
6    types::{
7        bytearray::PyByteArrayMethods, bytes::PyBytesMethods, string::PyStringMethods, PyByteArray,
8        PyBytes, PyString,
9    },
10    Bound, DowncastError, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
11};
12
13/// A wrapper around `str` where the storage is owned by a Python `bytes` or `str` object.
14///
15/// This type gives access to the underlying data via a `Deref` implementation.
16#[cfg_attr(feature = "py-clone", derive(Clone))]
17pub struct PyBackedStr {
18    #[allow(dead_code)] // only held so that the storage is not dropped
19    storage: Py<PyAny>,
20    data: NonNull<str>,
21}
22
23impl Deref for PyBackedStr {
24    type Target = str;
25    fn deref(&self) -> &str {
26        // Safety: `data` is known to be immutable and owned by self
27        unsafe { self.data.as_ref() }
28    }
29}
30
31impl AsRef<str> for PyBackedStr {
32    fn as_ref(&self) -> &str {
33        self
34    }
35}
36
37impl AsRef<[u8]> for PyBackedStr {
38    fn as_ref(&self) -> &[u8] {
39        self.as_bytes()
40    }
41}
42
43// Safety: the underlying Python str (or bytes) is immutable and
44// safe to share between threads
45unsafe impl Send for PyBackedStr {}
46unsafe impl Sync for PyBackedStr {}
47
48impl std::fmt::Display for PyBackedStr {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        self.deref().fmt(f)
51    }
52}
53
54impl_traits!(PyBackedStr, str);
55
56impl TryFrom<Bound<'_, PyString>> for PyBackedStr {
57    type Error = PyErr;
58    fn try_from(py_string: Bound<'_, PyString>) -> Result<Self, Self::Error> {
59        #[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
60        {
61            let s = py_string.to_str()?;
62            let data = NonNull::from(s);
63            Ok(Self {
64                storage: py_string.into_any().unbind(),
65                data,
66            })
67        }
68        #[cfg(not(any(Py_3_10, not(Py_LIMITED_API))))]
69        {
70            let bytes = py_string.encode_utf8()?;
71            let s = unsafe { std::str::from_utf8_unchecked(bytes.as_bytes()) };
72            let data = NonNull::from(s);
73            Ok(Self {
74                storage: bytes.into_any().unbind(),
75                data,
76            })
77        }
78    }
79}
80
81impl FromPyObject<'_> for PyBackedStr {
82    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
83        let py_string = obj.cast::<PyString>()?.to_owned();
84        Self::try_from(py_string)
85    }
86}
87
88impl<'py> IntoPyObject<'py> for PyBackedStr {
89    type Target = PyAny;
90    type Output = Bound<'py, Self::Target>;
91    type Error = Infallible;
92
93    #[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
94    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
95        Ok(self.storage.into_bound(py))
96    }
97
98    #[cfg(not(any(Py_3_10, not(Py_LIMITED_API))))]
99    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
100        Ok(PyString::new(py, &self).into_any())
101    }
102}
103
104impl<'py> IntoPyObject<'py> for &PyBackedStr {
105    type Target = PyAny;
106    type Output = Bound<'py, Self::Target>;
107    type Error = Infallible;
108
109    #[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
110    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
111        Ok(self.storage.bind(py).to_owned())
112    }
113
114    #[cfg(not(any(Py_3_10, not(Py_LIMITED_API))))]
115    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
116        Ok(PyString::new(py, self).into_any())
117    }
118}
119
120/// A wrapper around `[u8]` where the storage is either owned by a Python `bytes` object, or a Rust `Box<[u8]>`.
121///
122/// This type gives access to the underlying data via a `Deref` implementation.
123#[cfg_attr(feature = "py-clone", derive(Clone))]
124pub struct PyBackedBytes {
125    #[allow(dead_code)] // only held so that the storage is not dropped
126    storage: PyBackedBytesStorage,
127    data: NonNull<[u8]>,
128}
129
130#[allow(dead_code)]
131#[cfg_attr(feature = "py-clone", derive(Clone))]
132enum PyBackedBytesStorage {
133    Python(Py<PyBytes>),
134    Rust(Arc<[u8]>),
135}
136
137impl Deref for PyBackedBytes {
138    type Target = [u8];
139    fn deref(&self) -> &[u8] {
140        // Safety: `data` is known to be immutable and owned by self
141        unsafe { self.data.as_ref() }
142    }
143}
144
145impl AsRef<[u8]> for PyBackedBytes {
146    fn as_ref(&self) -> &[u8] {
147        self
148    }
149}
150
151// Safety: the underlying Python bytes or Rust bytes is immutable and
152// safe to share between threads
153unsafe impl Send for PyBackedBytes {}
154unsafe impl Sync for PyBackedBytes {}
155
156impl<const N: usize> PartialEq<[u8; N]> for PyBackedBytes {
157    fn eq(&self, other: &[u8; N]) -> bool {
158        self.deref() == other
159    }
160}
161
162impl<const N: usize> PartialEq<PyBackedBytes> for [u8; N] {
163    fn eq(&self, other: &PyBackedBytes) -> bool {
164        self == other.deref()
165    }
166}
167
168impl<const N: usize> PartialEq<&[u8; N]> for PyBackedBytes {
169    fn eq(&self, other: &&[u8; N]) -> bool {
170        self.deref() == *other
171    }
172}
173
174impl<const N: usize> PartialEq<PyBackedBytes> for &[u8; N] {
175    fn eq(&self, other: &PyBackedBytes) -> bool {
176        self == &other.deref()
177    }
178}
179
180impl_traits!(PyBackedBytes, [u8]);
181
182impl From<Bound<'_, PyBytes>> for PyBackedBytes {
183    fn from(py_bytes: Bound<'_, PyBytes>) -> Self {
184        let b = py_bytes.as_bytes();
185        let data = NonNull::from(b);
186        Self {
187            storage: PyBackedBytesStorage::Python(py_bytes.to_owned().unbind()),
188            data,
189        }
190    }
191}
192
193impl From<Bound<'_, PyByteArray>> for PyBackedBytes {
194    fn from(py_bytearray: Bound<'_, PyByteArray>) -> Self {
195        let s = Arc::<[u8]>::from(py_bytearray.to_vec());
196        let data = NonNull::from(s.as_ref());
197        Self {
198            storage: PyBackedBytesStorage::Rust(s),
199            data,
200        }
201    }
202}
203
204impl FromPyObject<'_> for PyBackedBytes {
205    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
206        if let Ok(bytes) = obj.cast::<PyBytes>() {
207            Ok(Self::from(bytes.to_owned()))
208        } else if let Ok(bytearray) = obj.cast::<PyByteArray>() {
209            Ok(Self::from(bytearray.to_owned()))
210        } else {
211            Err(DowncastError::new(obj, "`bytes` or `bytearray`").into())
212        }
213    }
214}
215
216impl<'py> IntoPyObject<'py> for PyBackedBytes {
217    type Target = PyBytes;
218    type Output = Bound<'py, Self::Target>;
219    type Error = Infallible;
220
221    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
222        match self.storage {
223            PyBackedBytesStorage::Python(bytes) => Ok(bytes.into_bound(py)),
224            PyBackedBytesStorage::Rust(bytes) => Ok(PyBytes::new(py, &bytes)),
225        }
226    }
227}
228
229impl<'py> IntoPyObject<'py> for &PyBackedBytes {
230    type Target = PyBytes;
231    type Output = Bound<'py, Self::Target>;
232    type Error = Infallible;
233
234    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
235        match &self.storage {
236            PyBackedBytesStorage::Python(bytes) => Ok(bytes.bind(py).clone()),
237            PyBackedBytesStorage::Rust(bytes) => Ok(PyBytes::new(py, bytes)),
238        }
239    }
240}
241
242macro_rules! impl_traits {
243    ($slf:ty, $equiv:ty) => {
244        impl std::fmt::Debug for $slf {
245            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246                self.deref().fmt(f)
247            }
248        }
249
250        impl PartialEq for $slf {
251            fn eq(&self, other: &Self) -> bool {
252                self.deref() == other.deref()
253            }
254        }
255
256        impl PartialEq<$equiv> for $slf {
257            fn eq(&self, other: &$equiv) -> bool {
258                self.deref() == other
259            }
260        }
261
262        impl PartialEq<&$equiv> for $slf {
263            fn eq(&self, other: &&$equiv) -> bool {
264                self.deref() == *other
265            }
266        }
267
268        impl PartialEq<$slf> for $equiv {
269            fn eq(&self, other: &$slf) -> bool {
270                self == other.deref()
271            }
272        }
273
274        impl PartialEq<$slf> for &$equiv {
275            fn eq(&self, other: &$slf) -> bool {
276                self == &other.deref()
277            }
278        }
279
280        impl Eq for $slf {}
281
282        impl PartialOrd for $slf {
283            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
284                Some(self.cmp(other))
285            }
286        }
287
288        impl PartialOrd<$equiv> for $slf {
289            fn partial_cmp(&self, other: &$equiv) -> Option<std::cmp::Ordering> {
290                self.deref().partial_cmp(other)
291            }
292        }
293
294        impl PartialOrd<$slf> for $equiv {
295            fn partial_cmp(&self, other: &$slf) -> Option<std::cmp::Ordering> {
296                self.partial_cmp(other.deref())
297            }
298        }
299
300        impl Ord for $slf {
301            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302                self.deref().cmp(other.deref())
303            }
304        }
305
306        impl std::hash::Hash for $slf {
307            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
308                self.deref().hash(state)
309            }
310        }
311    };
312}
313use impl_traits;
314
315#[cfg(test)]
316mod test {
317    use super::*;
318    use crate::impl_::pyclass::{value_of, IsSend, IsSync};
319    use crate::types::PyAnyMethods as _;
320    use crate::{IntoPyObject, Python};
321    use std::collections::hash_map::DefaultHasher;
322    use std::hash::{Hash, Hasher};
323
324    #[test]
325    fn py_backed_str_empty() {
326        Python::attach(|py| {
327            let s = PyString::new(py, "");
328            let py_backed_str = s.extract::<PyBackedStr>().unwrap();
329            assert_eq!(&*py_backed_str, "");
330        });
331    }
332
333    #[test]
334    fn py_backed_str() {
335        Python::attach(|py| {
336            let s = PyString::new(py, "hello");
337            let py_backed_str = s.extract::<PyBackedStr>().unwrap();
338            assert_eq!(&*py_backed_str, "hello");
339        });
340    }
341
342    #[test]
343    fn py_backed_str_try_from() {
344        Python::attach(|py| {
345            let s = PyString::new(py, "hello");
346            let py_backed_str = PyBackedStr::try_from(s).unwrap();
347            assert_eq!(&*py_backed_str, "hello");
348        });
349    }
350
351    #[test]
352    fn py_backed_str_into_pyobject() {
353        Python::attach(|py| {
354            let orig_str = PyString::new(py, "hello");
355            let py_backed_str = orig_str.extract::<PyBackedStr>().unwrap();
356            let new_str = py_backed_str.into_pyobject(py).unwrap();
357            assert_eq!(new_str.extract::<PyBackedStr>().unwrap(), "hello");
358            #[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
359            assert!(new_str.is(&orig_str));
360        });
361    }
362
363    #[test]
364    fn py_backed_bytes_empty() {
365        Python::attach(|py| {
366            let b = PyBytes::new(py, b"");
367            let py_backed_bytes = b.extract::<PyBackedBytes>().unwrap();
368            assert_eq!(&*py_backed_bytes, b"");
369        });
370    }
371
372    #[test]
373    fn py_backed_bytes() {
374        Python::attach(|py| {
375            let b = PyBytes::new(py, b"abcde");
376            let py_backed_bytes = b.extract::<PyBackedBytes>().unwrap();
377            assert_eq!(&*py_backed_bytes, b"abcde");
378        });
379    }
380
381    #[test]
382    fn py_backed_bytes_from_bytes() {
383        Python::attach(|py| {
384            let b = PyBytes::new(py, b"abcde");
385            let py_backed_bytes = PyBackedBytes::from(b);
386            assert_eq!(&*py_backed_bytes, b"abcde");
387        });
388    }
389
390    #[test]
391    fn py_backed_bytes_from_bytearray() {
392        Python::attach(|py| {
393            let b = PyByteArray::new(py, b"abcde");
394            let py_backed_bytes = PyBackedBytes::from(b);
395            assert_eq!(&*py_backed_bytes, b"abcde");
396        });
397    }
398
399    #[test]
400    fn py_backed_bytes_into_pyobject() {
401        Python::attach(|py| {
402            let orig_bytes = PyBytes::new(py, b"abcde");
403            let py_backed_bytes = PyBackedBytes::from(orig_bytes.clone());
404            assert!((&py_backed_bytes)
405                .into_pyobject(py)
406                .unwrap()
407                .is(&orig_bytes));
408        });
409    }
410
411    #[test]
412    fn rust_backed_bytes_into_pyobject() {
413        Python::attach(|py| {
414            let orig_bytes = PyByteArray::new(py, b"abcde");
415            let rust_backed_bytes = PyBackedBytes::from(orig_bytes);
416            assert!(matches!(
417                rust_backed_bytes.storage,
418                PyBackedBytesStorage::Rust(_)
419            ));
420            let to_object = (&rust_backed_bytes).into_pyobject(py).unwrap();
421            assert!(&to_object.is_exact_instance_of::<PyBytes>());
422            assert_eq!(&to_object.extract::<PyBackedBytes>().unwrap(), b"abcde");
423        });
424    }
425
426    #[test]
427    fn test_backed_types_send_sync() {
428        assert!(value_of!(IsSend, PyBackedStr));
429        assert!(value_of!(IsSync, PyBackedStr));
430
431        assert!(value_of!(IsSend, PyBackedBytes));
432        assert!(value_of!(IsSync, PyBackedBytes));
433    }
434
435    #[cfg(feature = "py-clone")]
436    #[test]
437    fn test_backed_str_clone() {
438        Python::attach(|py| {
439            let s1: PyBackedStr = PyString::new(py, "hello").try_into().unwrap();
440            let s2 = s1.clone();
441            assert_eq!(s1, s2);
442
443            drop(s1);
444            assert_eq!(s2, "hello");
445        });
446    }
447
448    #[test]
449    fn test_backed_str_eq() {
450        Python::attach(|py| {
451            let s1: PyBackedStr = PyString::new(py, "hello").try_into().unwrap();
452            let s2: PyBackedStr = PyString::new(py, "hello").try_into().unwrap();
453            assert_eq!(s1, "hello");
454            assert_eq!(s1, s2);
455
456            let s3: PyBackedStr = PyString::new(py, "abcde").try_into().unwrap();
457            assert_eq!("abcde", s3);
458            assert_ne!(s1, s3);
459        });
460    }
461
462    #[test]
463    fn test_backed_str_hash() {
464        Python::attach(|py| {
465            let h = {
466                let mut hasher = DefaultHasher::new();
467                "abcde".hash(&mut hasher);
468                hasher.finish()
469            };
470
471            let s1: PyBackedStr = PyString::new(py, "abcde").try_into().unwrap();
472            let h1 = {
473                let mut hasher = DefaultHasher::new();
474                s1.hash(&mut hasher);
475                hasher.finish()
476            };
477
478            assert_eq!(h, h1);
479        });
480    }
481
482    #[test]
483    fn test_backed_str_ord() {
484        Python::attach(|py| {
485            let mut a = vec!["a", "c", "d", "b", "f", "g", "e"];
486            let mut b = a
487                .iter()
488                .map(|s| PyString::new(py, s).try_into().unwrap())
489                .collect::<Vec<PyBackedStr>>();
490
491            a.sort();
492            b.sort();
493
494            assert_eq!(a, b);
495        })
496    }
497
498    #[cfg(feature = "py-clone")]
499    #[test]
500    fn test_backed_bytes_from_bytes_clone() {
501        Python::attach(|py| {
502            let b1: PyBackedBytes = PyBytes::new(py, b"abcde").into();
503            let b2 = b1.clone();
504            assert_eq!(b1, b2);
505
506            drop(b1);
507            assert_eq!(b2, b"abcde");
508        });
509    }
510
511    #[cfg(feature = "py-clone")]
512    #[test]
513    fn test_backed_bytes_from_bytearray_clone() {
514        Python::attach(|py| {
515            let b1: PyBackedBytes = PyByteArray::new(py, b"abcde").into();
516            let b2 = b1.clone();
517            assert_eq!(b1, b2);
518
519            drop(b1);
520            assert_eq!(b2, b"abcde");
521        });
522    }
523
524    #[test]
525    fn test_backed_bytes_eq() {
526        Python::attach(|py| {
527            let b1: PyBackedBytes = PyBytes::new(py, b"abcde").into();
528            let b2: PyBackedBytes = PyByteArray::new(py, b"abcde").into();
529
530            assert_eq!(b1, b"abcde");
531            assert_eq!(b1, b2);
532
533            let b3: PyBackedBytes = PyBytes::new(py, b"hello").into();
534            assert_eq!(b"hello", b3);
535            assert_ne!(b1, b3);
536        });
537    }
538
539    #[test]
540    fn test_backed_bytes_hash() {
541        Python::attach(|py| {
542            let h = {
543                let mut hasher = DefaultHasher::new();
544                b"abcde".hash(&mut hasher);
545                hasher.finish()
546            };
547
548            let b1: PyBackedBytes = PyBytes::new(py, b"abcde").into();
549            let h1 = {
550                let mut hasher = DefaultHasher::new();
551                b1.hash(&mut hasher);
552                hasher.finish()
553            };
554
555            let b2: PyBackedBytes = PyByteArray::new(py, b"abcde").into();
556            let h2 = {
557                let mut hasher = DefaultHasher::new();
558                b2.hash(&mut hasher);
559                hasher.finish()
560            };
561
562            assert_eq!(h, h1);
563            assert_eq!(h, h2);
564        });
565    }
566
567    #[test]
568    fn test_backed_bytes_ord() {
569        Python::attach(|py| {
570            let mut a = vec![b"a", b"c", b"d", b"b", b"f", b"g", b"e"];
571            let mut b = a
572                .iter()
573                .map(|&b| PyBytes::new(py, b).into())
574                .collect::<Vec<PyBackedBytes>>();
575
576            a.sort();
577            b.sort();
578
579            assert_eq!(a, b);
580        })
581    }
582}