1use crate::dimension::IntoDimension;
2use crate::Dimension;
3use crate::order::Order;
4
5#[derive(Copy, Clone, Debug)]
9pub struct Shape<D> {
10 pub(crate) dim: D,
12 pub(crate) strides: Strides<Contiguous>,
14}
15
16#[derive(Copy, Clone, Debug)]
17pub(crate) enum Contiguous {}
18
19impl<D> Shape<D> {
20 pub(crate) fn is_c(&self) -> bool {
21 matches!(self.strides, Strides::C)
22 }
23}
24
25#[derive(Copy, Clone, Debug)]
27pub struct StrideShape<D> {
28 pub(crate) dim: D,
29 pub(crate) strides: Strides<D>,
30}
31
32impl<D> StrideShape<D>
33where
34 D: Dimension,
35{
36 pub fn raw_dim(&self) -> &D {
38 &self.dim
39 }
40 pub fn size(&self) -> usize {
42 self.dim.size()
43 }
44}
45
46#[derive(Copy, Clone, Debug)]
48pub(crate) enum Strides<D> {
49 C,
51 F,
53 Custom(D),
55}
56
57impl<D> Strides<D> {
58 pub(crate) fn strides_for_dim(self, dim: &D) -> D
60 where
61 D: Dimension,
62 {
63 match self {
64 Strides::C => dim.default_strides(),
65 Strides::F => dim.fortran_strides(),
66 Strides::Custom(c) => {
67 debug_assert_eq!(
68 c.ndim(),
69 dim.ndim(),
70 "Custom strides given with {} dimensions, expected {}",
71 c.ndim(),
72 dim.ndim()
73 );
74 c
75 }
76 }
77 }
78
79 pub(crate) fn is_custom(&self) -> bool {
80 matches!(*self, Strides::Custom(_))
81 }
82}
83
84pub trait ShapeBuilder {
90 type Dim: Dimension;
91 type Strides;
92
93 fn into_shape(self) -> Shape<Self::Dim>;
94 fn f(self) -> Shape<Self::Dim>;
95 fn set_f(self, is_f: bool) -> Shape<Self::Dim>;
96 fn strides(self, strides: Self::Strides) -> StrideShape<Self::Dim>;
97}
98
99impl<D> From<D> for Shape<D>
100where
101 D: Dimension,
102{
103 fn from(dimension: D) -> Shape<D> {
105 dimension.into_shape()
106 }
107}
108
109impl<T, D> From<T> for StrideShape<D>
110where
111 D: Dimension,
112 T: ShapeBuilder<Dim = D>,
113{
114 fn from(value: T) -> Self {
115 let shape = value.into_shape();
116 let st = if shape.is_c() { Strides::C } else { Strides::F };
117 StrideShape {
118 strides: st,
119 dim: shape.dim,
120 }
121 }
122}
123
124impl<T> ShapeBuilder for T
125where
126 T: IntoDimension,
127{
128 type Dim = T::Dim;
129 type Strides = T;
130 fn into_shape(self) -> Shape<Self::Dim> {
131 Shape {
132 dim: self.into_dimension(),
133 strides: Strides::C,
134 }
135 }
136 fn f(self) -> Shape<Self::Dim> {
137 self.set_f(true)
138 }
139 fn set_f(self, is_f: bool) -> Shape<Self::Dim> {
140 self.into_shape().set_f(is_f)
141 }
142 fn strides(self, st: T) -> StrideShape<Self::Dim> {
143 self.into_shape().strides(st.into_dimension())
144 }
145}
146
147impl<D> ShapeBuilder for Shape<D>
148where
149 D: Dimension,
150{
151 type Dim = D;
152 type Strides = D;
153
154 fn into_shape(self) -> Shape<D> {
155 self
156 }
157
158 fn f(self) -> Self {
159 self.set_f(true)
160 }
161
162 fn set_f(mut self, is_f: bool) -> Self {
163 self.strides = if !is_f { Strides::C } else { Strides::F };
164 self
165 }
166
167 fn strides(self, st: D) -> StrideShape<D> {
168 StrideShape {
169 dim: self.dim,
170 strides: Strides::Custom(st),
171 }
172 }
173}
174
175impl<D> Shape<D>
176where
177 D: Dimension,
178{
179 pub fn raw_dim(&self) -> &D {
181 &self.dim
182 }
183 pub fn size(&self) -> usize {
185 self.dim.size()
186 }
187}
188
189
190pub trait ShapeArg {
199 type Dim: Dimension;
200 fn into_shape_and_order(self) -> (Self::Dim, Option<Order>);
201}
202
203impl<T> ShapeArg for T where T: IntoDimension {
204 type Dim = T::Dim;
205
206 fn into_shape_and_order(self) -> (Self::Dim, Option<Order>) {
207 (self.into_dimension(), None)
208 }
209}
210
211impl<T> ShapeArg for (T, Order) where T: IntoDimension {
212 type Dim = T::Dim;
213
214 fn into_shape_and_order(self) -> (Self::Dim, Option<Order>) {
215 (self.0.into_dimension(), Some(self.1))
216 }
217}