numpy/sum_products.rs
1use std::borrow::Cow;
2use std::ffi::{CStr, CString};
3use std::ptr::null_mut;
4
5use ndarray::{Dimension, IxDyn};
6use pyo3::types::PyAnyMethods;
7use pyo3::{Borrowed, Bound, FromPyObject, PyResult};
8
9use crate::array::PyArray;
10use crate::dtype::Element;
11use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
12
13/// Return value of a function that can yield either an array or a scalar.
14pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}
15
16impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
17where
18 T: Element,
19 D: Dimension,
20{
21}
22
23impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
24
25/// Return the inner product of two arrays.
26///
27/// [NumPy's documentation][inner] has the details.
28///
29/// # Examples
30///
31/// Note that this function can either return a scalar...
32///
33/// ```
34/// use pyo3::Python;
35/// use numpy::{inner, pyarray, PyArray0};
36///
37/// Python::with_gil(|py| {
38/// let vector = pyarray![py, 1.0, 2.0, 3.0];
39/// let result: f64 = inner(&vector, &vector).unwrap();
40/// assert_eq!(result, 14.0);
41/// });
42/// ```
43///
44/// ...or an array depending on its arguments.
45///
46/// ```
47/// use pyo3::{Python, Bound};
48/// use numpy::prelude::*;
49/// use numpy::{inner, pyarray, PyArray0};
50///
51/// Python::with_gil(|py| {
52/// let vector = pyarray![py, 1, 2, 3];
53/// let result: Bound<'_, PyArray0<_>> = inner(&vector, &vector).unwrap();
54/// assert_eq!(result.item(), 14);
55/// });
56/// ```
57///
58/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
59pub fn inner<'py, T, DIN1, DIN2, OUT>(
60 array1: &Bound<'py, PyArray<T, DIN1>>,
61 array2: &Bound<'py, PyArray<T, DIN2>>,
62) -> PyResult<OUT>
63where
64 T: Element,
65 DIN1: Dimension,
66 DIN2: Dimension,
67 OUT: ArrayOrScalar<'py, T>,
68{
69 let py = array1.py();
70 let obj = unsafe {
71 let result = PY_ARRAY_API.PyArray_InnerProduct(py, array1.as_ptr(), array2.as_ptr());
72 Bound::from_owned_ptr_or_err(py, result)?
73 };
74 obj.extract()
75}
76
77/// Deprecated name for [`inner`].
78#[deprecated(since = "0.23.0", note = "renamed to `inner`")]
79#[inline]
80pub fn inner_bound<'py, T, DIN1, DIN2, OUT>(
81 array1: &Bound<'py, PyArray<T, DIN1>>,
82 array2: &Bound<'py, PyArray<T, DIN2>>,
83) -> PyResult<OUT>
84where
85 T: Element,
86 DIN1: Dimension,
87 DIN2: Dimension,
88 OUT: ArrayOrScalar<'py, T>,
89{
90 inner(array1, array2)
91}
92
93/// Return the dot product of two arrays.
94///
95/// [NumPy's documentation][dot] has the details.
96///
97/// # Examples
98///
99/// Note that this function can either return an array...
100///
101/// ```
102/// use pyo3::{Python, Bound};
103/// use ndarray::array;
104/// use numpy::{dot, pyarray, PyArray2, PyArrayMethods};
105///
106/// Python::with_gil(|py| {
107/// let matrix = pyarray![py, [1, 0], [0, 1]];
108/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
109///
110/// let result: Bound<'_, PyArray2<_>> = dot(&matrix, &another_matrix).unwrap();
111///
112/// assert_eq!(
113/// result.readonly().as_array(),
114/// array![[4, 1], [2, 2]]
115/// );
116/// });
117/// ```
118///
119/// ...or a scalar depending on its arguments.
120///
121/// ```
122/// use pyo3::Python;
123/// use numpy::{dot, pyarray, PyArray0};
124///
125/// Python::with_gil(|py| {
126/// let vector = pyarray![py, 1.0, 2.0, 3.0];
127/// let result: f64 = dot(&vector, &vector).unwrap();
128/// assert_eq!(result, 14.0);
129/// });
130/// ```
131///
132/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
133pub fn dot<'py, T, DIN1, DIN2, OUT>(
134 array1: &Bound<'py, PyArray<T, DIN1>>,
135 array2: &Bound<'py, PyArray<T, DIN2>>,
136) -> PyResult<OUT>
137where
138 T: Element,
139 DIN1: Dimension,
140 DIN2: Dimension,
141 OUT: ArrayOrScalar<'py, T>,
142{
143 let py = array1.py();
144 let obj = unsafe {
145 let result = PY_ARRAY_API.PyArray_MatrixProduct(py, array1.as_ptr(), array2.as_ptr());
146 Bound::from_owned_ptr_or_err(py, result)?
147 };
148 obj.extract()
149}
150
151/// Deprecated name for [`dot`].
152#[deprecated(since = "0.23.0", note = "renamed to `dot`")]
153#[inline]
154pub fn dot_bound<'py, T, DIN1, DIN2, OUT>(
155 array1: &Bound<'py, PyArray<T, DIN1>>,
156 array2: &Bound<'py, PyArray<T, DIN2>>,
157) -> PyResult<OUT>
158where
159 T: Element,
160 DIN1: Dimension,
161 DIN2: Dimension,
162 OUT: ArrayOrScalar<'py, T>,
163{
164 dot(array1, array2)
165}
166
167/// Return the Einstein summation convention of given tensors.
168///
169/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
170pub fn einsum<'py, T, OUT>(
171 subscripts: &str,
172 arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
173) -> PyResult<OUT>
174where
175 T: Element,
176 OUT: ArrayOrScalar<'py, T>,
177{
178 let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
179 Ok(subscripts) => Cow::Borrowed(subscripts),
180 Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
181 };
182
183 let py = arrays[0].py();
184 let obj = unsafe {
185 let result = PY_ARRAY_API.PyArray_EinsteinSum(
186 py,
187 subscripts.as_ptr() as _,
188 arrays.len() as _,
189 arrays.as_ptr() as _,
190 null_mut(),
191 NPY_ORDER::NPY_KEEPORDER,
192 NPY_CASTING::NPY_NO_CASTING,
193 null_mut(),
194 );
195 Bound::from_owned_ptr_or_err(py, result)?
196 };
197 obj.extract()
198}
199
200/// Deprecated name for [`einsum`].
201#[deprecated(since = "0.23.0", note = "renamed to `einsum`")]
202#[inline]
203pub fn einsum_bound<'py, T, OUT>(
204 subscripts: &str,
205 arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
206) -> PyResult<OUT>
207where
208 T: Element,
209 OUT: ArrayOrScalar<'py, T>,
210{
211 einsum(subscripts, arrays)
212}
213
214/// Return the Einstein summation convention of given tensors.
215///
216/// For more about the Einstein summation convention, please refer to
217/// [NumPy's documentation][einsum].
218///
219/// # Example
220///
221/// ```
222/// use pyo3::{Python, Bound};
223/// use ndarray::array;
224/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
225///
226/// Python::with_gil(|py| {
227/// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
228/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
229///
230/// let result: Bound<'_, PyArray2<_>> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
231///
232/// assert_eq!(
233/// result.readonly().as_array(),
234/// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
235/// );
236/// });
237/// ```
238///
239/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
240#[macro_export]
241macro_rules! einsum {
242 ($subscripts:literal $(,$array:ident)+ $(,)*) => {{
243 let arrays = [$($array.to_dyn().as_borrowed(),)+];
244 $crate::einsum(concat!($subscripts, "\0"), &arrays)
245 }};
246}
247
248/// Deprecated name for [`einsum!`].
249#[deprecated(since = "0.23.0", note = "renamed to `einsum!`")]
250#[macro_export]
251macro_rules! einsum_bound {
252 ($subscripts:literal $(,$array:ident)+ $(,)*) => {{
253 let arrays = [$($array.to_dyn().as_borrowed(),)+];
254 $crate::einsum(concat!($subscripts, "\0"), &arrays)
255 }};
256}