go_pybindings/
analysis.rs

1use pyo3::prelude::*;
2use pyo3::exceptions::PyValueError;
3
4use numpy::PyArray1;
5use numpy::PyArrayMethods;
6
7use tof_dataclasses::analysis::{
8    find_peaks,
9    find_peaks_zscore,
10    interpolate_time,
11    cfd_simple,
12    integrate,
13    time2bin,
14    calc_edep_simple
15};
16
17use tof_dataclasses::calibrations::{
18    find_zero_crossings,
19    get_periods,
20    Edge,
21};
22
23
24///helper
25fn convert_pyarray1<'_py>(arr : Bound<'_py, PyArray1<f32>>) -> Vec<f32> {
26  let mut vec = Vec::<f32>::new();
27  unsafe {
28    vec.extend_from_slice(arr.as_slice().unwrap());
29  }
30  return vec;
31}
32
33#[pyfunction]
34#[pyo3(name="get_periods")]
35pub fn py_get_periods<'_py>(trace   : Bound<'_py, PyArray1<f32>>,
36                            dts     : Bound<'_py, PyArray1<f32>>,
37                            nperiod : f32,
38                            nskip   : f32)
39    -> PyResult<(Vec<usize>, Vec<f32>)> {
40  // we fix the edge here
41  let edge = Edge::Rising;
42  let wr_trace : Vec<f32>;
43  let wr_dts   : Vec<f32>;
44  wr_trace = convert_pyarray1(trace);
45  wr_dts   = convert_pyarray1(dts);
46  let result   = get_periods(&wr_trace, &wr_dts, nperiod, nskip, &edge);
47  Ok(result)
48}
49
50
51
52#[pyfunction]
53#[pyo3(name="calc_edep_simple")]
54pub fn py_calc_edep_simple(peak_voltage : f32) -> f32 {
55  calc_edep_simple(peak_voltage)
56}
57
58#[pyfunction]
59#[pyo3(name="find_zero_crossings")]
60/// Get a vector with the indizes where 
61/// the input array crosses zero
62pub fn py_find_zero_crossings<'_py>(trace : Bound<'_py,PyArray1<f32>>) 
63  -> PyResult<Vec<usize>> {
64  let tr  = convert_pyarray1(trace);
65  let zcs = find_zero_crossings(&tr);
66  Ok(zcs)
67}
68
69#[pyfunction]
70#[pyo3(name="cfd_simple")]
71/// Find the peak onset time based on a cfd
72/// "Constant fraction discrimination" algorithm
73///
74/// # Arguments
75///
76/// * start_peak : bin
77/// * end_peak   : bin
78/// * cfd_frac   : 0.2 is the typical default
79pub fn py_cfd_simple<'_py>(voltages    : Bound<'_py,PyArray1<f32>>,
80                           nanoseconds : Bound<'_py,PyArray1<f32>>,
81                           cfd_frac    : f32,
82                           start_peak  : usize,
83                           end_peak    : usize) -> PyResult<f32> {
84  let voltages_vec    = convert_pyarray1(voltages);
85  let nanoseconds_vec = convert_pyarray1(nanoseconds);
86  match cfd_simple(&voltages_vec   ,
87                   &nanoseconds_vec,
88                   cfd_frac       ,
89                   start_peak  ,
90                   end_peak) {
91    Ok(result) => Ok(result),
92    Err(err)   => {
93     return Err(PyValueError::new_err(err.to_string()));
94    } 
95  }
96}
97
98#[pyfunction]
99#[pyo3(name="interpolate_time")]
100pub fn py_interpolate_time<'_py>(voltages    : Bound<'_py,PyArray1<f32>>,
101                                 nanoseconds : Bound<'_py,PyArray1<f32>>,
102                                 threshold   : f32,
103                                 idx         : usize,
104                                 size        : usize) -> PyResult<f32> {
105  let mut voltages_vec    = Vec::<f32>::new();
106  let mut nanoseconds_vec = Vec::<f32>::new(); 
107  unsafe {
108    voltages_vec.extend_from_slice(voltages.as_slice().unwrap());
109    nanoseconds_vec.extend_from_slice(nanoseconds.as_slice().unwrap());
110  }
111  match interpolate_time (&voltages_vec   ,
112                          &nanoseconds_vec, 
113                          threshold      ,
114                          idx            ,
115                          size) {
116   Ok(result) => Ok(result),
117   Err(err)   => {
118    return Err(PyValueError::new_err(err.to_string()));
119   } 
120  }
121}
122
123#[pyfunction]
124#[pyo3(name="time2bin")]
125pub fn py_time2bin<'_py>(nanoseconds : Bound<'_py,PyArray1<f32>>,
126                         t_ns        : f32) -> PyResult<usize> {
127 let mut nanoseconds_vec = Vec::<f32>::new(); 
128 unsafe {
129   nanoseconds_vec.extend_from_slice(nanoseconds.as_slice().unwrap());
130 }
131 match time2bin (&nanoseconds_vec,
132                 t_ns){
133   Ok(result) => Ok(result),
134   Err(err)   => {
135    return Err(PyValueError::new_err(err.to_string()));
136   } 
137 }
138}
139
140#[pyfunction]
141#[pyo3(name="integrate")]
142pub fn py_integrate<'_py>(voltages    : Bound<'_py,PyArray1<f32>>,
143                          nanoseconds : Bound<'_py,PyArray1<f32>>,
144                          lower_bin   : usize,
145                          upper_bin   : usize,
146                          impedance   : f32) -> PyResult<f32>  {
147 let mut voltages_vec    = Vec::<f32>::new();
148 let mut nanoseconds_vec = Vec::<f32>::new(); 
149 unsafe {
150   voltages_vec.extend_from_slice(voltages.as_slice().unwrap());
151   nanoseconds_vec.extend_from_slice(nanoseconds.as_slice().unwrap());
152 }
153 match integrate(&voltages_vec, &nanoseconds_vec, lower_bin, upper_bin, impedance) {
154   Ok(result) => Ok(result),
155   Err(err)   => {
156    return Err(PyValueError::new_err(err.to_string()));
157   }
158 }
159}
160
161#[pyfunction]
162#[pyo3(name = "find_peaks")]
163/// The GAPS peak finding algorithm, based on 
164/// legacy code written by the UCLA TOF team.
165///
166/// This needs to be applied AFTER the peakfinding
167/// and takes a specific peak as input argument
168///
169/// # Arguments
170/// 
171/// * voltages     (np.ndarray) | These both together
172/// * nanosecondes (np.ndarray) | are a calibrated waveform
173/// * start_time   (float)      - begin peak search at this time
174/// * window_size  (float)      - limit peak search to start_time 
175///                               + start_time + window_size (in ns)
176/// * min_peak_width (usize)    - If a peak has a lower width, it 
177///                               will get discarded (in bins)
178/// * threshold      (f32)      - Ingore peaks which fall below this
179///                               voltage (in mV)
180/// * max_peaks      (usize)    - Stop peak search after max_peaks are
181///                              found
182pub fn py_find_peaks<'_py>(voltages       : Bound<'_py, PyArray1<f32>>,
183                           nanoseconds    : Bound<'_py, PyArray1<f32>>,
184                           start_time     : f32,
185                           window_size    : f32,
186                           min_peak_width : usize,
187                           threshold      : f32,
188                           max_peaks      : usize) -> PyResult<Vec<(usize,usize)>> {
189 let mut voltages_vec    = Vec::<f32>::new();
190 let mut nanoseconds_vec = Vec::<f32>::new(); 
191 unsafe {
192   voltages_vec.extend_from_slice(voltages.as_slice().unwrap());
193   nanoseconds_vec.extend_from_slice(nanoseconds.as_slice().unwrap());
194 }
195
196 match find_peaks(&voltages_vec  , 
197                  &nanoseconds_vec   , 
198                  start_time    , 
199                  window_size   , 
200                  min_peak_width, 
201                  threshold     , 
202                  max_peaks     ) {
203   Ok(result) => Ok(result),
204   Err(err)   => {
205    return Err(PyValueError::new_err(err.to_string()));
206   }
207 }
208}
209
210#[pyfunction]
211#[pyo3(name = "find_peaks_zscore")]
212pub fn py_find_peaks_zscore<'_py>(voltages       : Bound<'_py,PyArray1<f32>>,
213                                  nanoseconds    : Bound<'_py,PyArray1<f32>>,
214                                  start_time     : f32,
215                                  window_size    : f32,
216                                  lag            : usize,
217                                  threshold      : f64,
218                                  influence      : f64) -> PyResult<Vec<(usize,usize)>> {
219 let mut voltages_vec    = Vec::<f32>::new();
220 let mut nanoseconds_vec = Vec::<f32>::new(); 
221 unsafe {
222   voltages_vec.extend_from_slice(voltages.as_slice().unwrap());
223   nanoseconds_vec.extend_from_slice(nanoseconds.as_slice().unwrap());
224 }
225
226 match find_peaks_zscore(&nanoseconds_vec, 
227                         &voltages_vec   ,   
228                         start_time      , 
229                         window_size     , 
230                         lag             , 
231                         threshold       , 
232                         influence) {
233   Ok(result) => Ok(result),
234   Err(err)   => {
235     return Err(PyValueError::new_err(err.to_string()));
236   }
237 }
238}
239