numpy/
sum_products.rs

1use std::borrow::Cow;
2use std::ffi::{CStr, CString};
3use std::ptr::null_mut;
4
5use ndarray::{Dimension, IxDyn};
6use pyo3::types::PyAnyMethods;
7use pyo3::{Borrowed, Bound, FromPyObject, PyResult};
8
9use crate::array::PyArray;
10use crate::dtype::Element;
11use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
12
13/// Return value of a function that can yield either an array or a scalar.
14pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}
15
16impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
17where
18    T: Element,
19    D: Dimension,
20{
21}
22
23impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
24
25/// Return the inner product of two arrays.
26///
27/// [NumPy's documentation][inner] has the details.
28///
29/// # Examples
30///
31/// Note that this function can either return a scalar...
32///
33/// ```
34/// use pyo3::Python;
35/// use numpy::{inner, pyarray, PyArray0};
36///
37/// Python::with_gil(|py| {
38///     let vector = pyarray![py, 1.0, 2.0, 3.0];
39///     let result: f64 = inner(&vector, &vector).unwrap();
40///     assert_eq!(result, 14.0);
41/// });
42/// ```
43///
44/// ...or an array depending on its arguments.
45///
46/// ```
47/// use pyo3::{Python, Bound};
48/// use numpy::prelude::*;
49/// use numpy::{inner, pyarray, PyArray0};
50///
51/// Python::with_gil(|py| {
52///     let vector = pyarray![py, 1, 2, 3];
53///     let result: Bound<'_, PyArray0<_>> = inner(&vector, &vector).unwrap();
54///     assert_eq!(result.item(), 14);
55/// });
56/// ```
57///
58/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
59pub fn inner<'py, T, DIN1, DIN2, OUT>(
60    array1: &Bound<'py, PyArray<T, DIN1>>,
61    array2: &Bound<'py, PyArray<T, DIN2>>,
62) -> PyResult<OUT>
63where
64    T: Element,
65    DIN1: Dimension,
66    DIN2: Dimension,
67    OUT: ArrayOrScalar<'py, T>,
68{
69    let py = array1.py();
70    let obj = unsafe {
71        let result = PY_ARRAY_API.PyArray_InnerProduct(py, array1.as_ptr(), array2.as_ptr());
72        Bound::from_owned_ptr_or_err(py, result)?
73    };
74    obj.extract()
75}
76
77/// Deprecated name for [`inner`].
78#[deprecated(since = "0.23.0", note = "renamed to `inner`")]
79#[inline]
80pub fn inner_bound<'py, T, DIN1, DIN2, OUT>(
81    array1: &Bound<'py, PyArray<T, DIN1>>,
82    array2: &Bound<'py, PyArray<T, DIN2>>,
83) -> PyResult<OUT>
84where
85    T: Element,
86    DIN1: Dimension,
87    DIN2: Dimension,
88    OUT: ArrayOrScalar<'py, T>,
89{
90    inner(array1, array2)
91}
92
93/// Return the dot product of two arrays.
94///
95/// [NumPy's documentation][dot] has the details.
96///
97/// # Examples
98///
99/// Note that this function can either return an array...
100///
101/// ```
102/// use pyo3::{Python, Bound};
103/// use ndarray::array;
104/// use numpy::{dot, pyarray, PyArray2, PyArrayMethods};
105///
106/// Python::with_gil(|py| {
107///     let matrix = pyarray![py, [1, 0], [0, 1]];
108///     let another_matrix = pyarray![py, [4, 1], [2, 2]];
109///
110///     let result: Bound<'_, PyArray2<_>> = dot(&matrix, &another_matrix).unwrap();
111///
112///     assert_eq!(
113///         result.readonly().as_array(),
114///         array![[4, 1], [2, 2]]
115///     );
116/// });
117/// ```
118///
119/// ...or a scalar depending on its arguments.
120///
121/// ```
122/// use pyo3::Python;
123/// use numpy::{dot, pyarray, PyArray0};
124///
125/// Python::with_gil(|py| {
126///     let vector = pyarray![py, 1.0, 2.0, 3.0];
127///     let result: f64 = dot(&vector, &vector).unwrap();
128///     assert_eq!(result, 14.0);
129/// });
130/// ```
131///
132/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
133pub fn dot<'py, T, DIN1, DIN2, OUT>(
134    array1: &Bound<'py, PyArray<T, DIN1>>,
135    array2: &Bound<'py, PyArray<T, DIN2>>,
136) -> PyResult<OUT>
137where
138    T: Element,
139    DIN1: Dimension,
140    DIN2: Dimension,
141    OUT: ArrayOrScalar<'py, T>,
142{
143    let py = array1.py();
144    let obj = unsafe {
145        let result = PY_ARRAY_API.PyArray_MatrixProduct(py, array1.as_ptr(), array2.as_ptr());
146        Bound::from_owned_ptr_or_err(py, result)?
147    };
148    obj.extract()
149}
150
151/// Deprecated name for [`dot`].
152#[deprecated(since = "0.23.0", note = "renamed to `dot`")]
153#[inline]
154pub fn dot_bound<'py, T, DIN1, DIN2, OUT>(
155    array1: &Bound<'py, PyArray<T, DIN1>>,
156    array2: &Bound<'py, PyArray<T, DIN2>>,
157) -> PyResult<OUT>
158where
159    T: Element,
160    DIN1: Dimension,
161    DIN2: Dimension,
162    OUT: ArrayOrScalar<'py, T>,
163{
164    dot(array1, array2)
165}
166
167/// Return the Einstein summation convention of given tensors.
168///
169/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
170pub fn einsum<'py, T, OUT>(
171    subscripts: &str,
172    arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
173) -> PyResult<OUT>
174where
175    T: Element,
176    OUT: ArrayOrScalar<'py, T>,
177{
178    let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
179        Ok(subscripts) => Cow::Borrowed(subscripts),
180        Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
181    };
182
183    let py = arrays[0].py();
184    let obj = unsafe {
185        let result = PY_ARRAY_API.PyArray_EinsteinSum(
186            py,
187            subscripts.as_ptr() as _,
188            arrays.len() as _,
189            arrays.as_ptr() as _,
190            null_mut(),
191            NPY_ORDER::NPY_KEEPORDER,
192            NPY_CASTING::NPY_NO_CASTING,
193            null_mut(),
194        );
195        Bound::from_owned_ptr_or_err(py, result)?
196    };
197    obj.extract()
198}
199
200/// Deprecated name for [`einsum`].
201#[deprecated(since = "0.23.0", note = "renamed to `einsum`")]
202#[inline]
203pub fn einsum_bound<'py, T, OUT>(
204    subscripts: &str,
205    arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
206) -> PyResult<OUT>
207where
208    T: Element,
209    OUT: ArrayOrScalar<'py, T>,
210{
211    einsum(subscripts, arrays)
212}
213
214/// Return the Einstein summation convention of given tensors.
215///
216/// For more about the Einstein summation convention, please refer to
217/// [NumPy's documentation][einsum].
218///
219/// # Example
220///
221/// ```
222/// use pyo3::{Python, Bound};
223/// use ndarray::array;
224/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
225///
226/// Python::with_gil(|py| {
227///     let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
228///     let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
229///
230///     let result: Bound<'_, PyArray2<_>> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
231///
232///     assert_eq!(
233///         result.readonly().as_array(),
234///         array![[640,  760,  880, 1000], [2560, 2710, 2860, 3010]]
235///     );
236/// });
237/// ```
238///
239/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
240#[macro_export]
241macro_rules! einsum {
242    ($subscripts:literal $(,$array:ident)+ $(,)*) => {{
243        let arrays = [$($array.to_dyn().as_borrowed(),)+];
244        $crate::einsum(concat!($subscripts, "\0"), &arrays)
245    }};
246}
247
248/// Deprecated name for [`einsum!`].
249#[deprecated(since = "0.23.0", note = "renamed to `einsum!`")]
250#[macro_export]
251macro_rules! einsum_bound {
252    ($subscripts:literal $(,$array:ident)+ $(,)*) => {{
253        let arrays = [$($array.to_dyn().as_borrowed(),)+];
254        $crate::einsum(concat!($subscripts, "\0"), &arrays)
255    }};
256}