pytorch-inference
include/utils.hpp
Go to the documentation of this file.
1 //
2 // Created by Aman LaChapelle on 5/18/17.
3 //
4 // pytorch_inference
5 // Copyright (c) 2017 Aman LaChapelle
6 // Full license at pytorch_inference/LICENSE.txt
7 //
8 
9 #ifndef PYTORCH_INFERENCE_EXTRACT_NUMPY_HPP
10 #define PYTORCH_INFERENCE_EXTRACT_NUMPY_HPP
11 
12 // Python
13 #include <Python.h>
14 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
15 #include <numpy/arrayobject.h>
16 
17 // STL
18 #include <stdexcept>
19 
20 // ArrayFire
21 #include <arrayfire.h>
22 
23 namespace pytorch::internal {
24 
35  inline af::array from_numpy(PyArrayObject *array, int ndim, std::vector<int> dims){
36 
37  array = PyArray_GETCONTIGUOUS(array); // make sure it's contiguous (might already be)
38 
39  int array_ndim = PyArray_NDIM(array);
40  assert(ndim == array_ndim);
41 
42  npy_intp *array_dims = PyArray_SHAPE(array);
43 
44  for (int i = 0; i < ndim; i++){
45  assert(dims[i] == array_dims[i]); // make sure dimensions are right
46  }
47 
48  for (int i = dims.size(); i < 4; i++){
49  dims.push_back(1);
50  }
51 
52  int n, k, h, w;
53  n = dims[0]; k = dims[1]; h = dims[2]; w = dims[3];
54 
55  af::array out (w, h, k, n, reinterpret_cast<float *>(PyArray_DATA(array))); // errors out here
56  out = af::reorder(out, 1, 0, 2, 3); // reorder to arrayfire specs (h, w, k, batch)
57 
58  return out;
59 
60  }
61 
62  inline void check_size(const int &size1, const int &size2, const std::string &func){
63  if (size1 == size2){
64  return;
65  }
66 
67  std::string error = "Incorrect size passed! Sizes: " + std::to_string(size1) + ", " + std::to_string(size2);
68  error += " Function: " + func;
69  throw std::runtime_error(error);
70 
71  }
72 
73  inline void check_num_leq(const int &size1, const int &size2, const std::string &func){
74  if (size1 <= size2){
75  return;
76  }
77 
78  std::string error = "Incorrect size passed! Sizes: " + std::to_string(size1) + ", " + std::to_string(size2);
79  error += " Function: " + func;
80  throw std::runtime_error(error);
81 
82  }
83 
84 } // pytorch::internal
85 
86 #endif //PYTORCH_INFERENCE_EXTRACT_NUMPY_HPP
Definition: layers.hpp:37
Definition: include/utils.hpp:23
af::array from_numpy(PyArrayObject *array, int ndim, std::vector< int > dims)
Converts a numpy array to an ArrayFire array. It is necessary to specify all 4 dimensions if there is...
Definition: include/utils.hpp:35
Definition: layers.hpp:36
Definition: layers.hpp:35
Definition: layers.hpp:34
void check_size(const int &size1, const int &size2, const std::string &func)
Definition: include/utils.hpp:62
dims
Convenience enum to use whenever you need to specify a dimension (like in Concat) ...
Definition: layers.hpp:33
void check_num_leq(const int &size1, const int &size2, const std::string &func)
Definition: include/utils.hpp:73