pytorch-inference
pytorch::Linear Class Reference

#include <Linear.hpp>

Inheritance diagram for pytorch::Linear:
pytorch::Layer

Public Member Functions

 Linear (const tensor &weights, const tensor &bias)
 Constructs a Linear object given weights, and bias tensors. More...
 
 Linear (const std::string &weights_filename="", const std::vector< int > &weights_dims={}, const std::string &bias_filename="", const std::vector< int > &bias_dims={}, const std::string &python_home="../scripts")
 Constructs a Linear object given the filenames and sizes of the requisite tensors. More...
 
 Linear (const Linear &other)
 Copy constructor, constructs another Linear object that is an exact copy of the argument. More...
 
virtual ~Linear ()
 Destructor - for now trivial, may need to take on some functionality. More...
 
void add_weights (const std::string &weights_filename, const std::vector< int > &weights_dims)
 Read in weights from a file given here if it wasn't passed to the constructor. Overwrites current contents of this->weights. More...
 
void add_bias (const std::string &bias_filename, const std::vector< int > &bias_dims)
 Read in bias from a file given here if it wasn't passed to the constructor. Overwrites current contents of this->bias. More...
 
std::vector< tensor > forward (const std::vector< tensor > &input)
 Forward function, takes data and performs the Linear operation using the already-initialized weights and bias tensors. More...
 
std::vector< tensor > operator() (const std::vector< tensor > &input)
 Forward function, takes data and performs the Linear operation using the already-initialized weights and bias tensors. More...
 

Private Attributes

tensor weights
 
tensor bias
 
pycpp::py_object utils
 
bool has_bias = false
 

Constructor & Destructor Documentation

◆ Linear() [1/3]

pytorch::Linear::Linear ( const tensor &  weights,
const tensor &  bias 
)
inline

Constructs a Linear object given weights, and bias tensors.

Parameters
weightsThe trained weight tensors. For those comfortable with Py_Cpp.
biasThe trained bias tensors. For those comfortable with Py_Cpp. Can be initialized to zero.

◆ Linear() [2/3]

pytorch::Linear::Linear ( const std::string &  weights_filename = "",
const std::vector< int > &  weights_dims = {},
const std::string &  bias_filename = "",
const std::vector< int > &  bias_dims = {},
const std::string &  python_home = "../scripts" 
)
inline

Constructs a Linear object given the filenames and sizes of the requisite tensors.

Parameters
weights_filenameThe file where the weights tensor is saved. Will be loaded with numpy.load(filename).
weights_dimsThe dimensions of the weights tensor in pytorch convention - (batch, channels, h, w)
bias_filenameThe file where the bias tensor is saved. Will be loaded with numpy.load(filename).
bias_dimsThe dimensions of the bias tensor in pytorch convention - (batch, channels, h, w)
python_homeWhere the utility scripts are - holds the loading script necessary to load up the tensors.

◆ Linear() [3/3]

pytorch::Linear::Linear ( const Linear other)
inline

Copy constructor, constructs another Linear object that is an exact copy of the argument.

Parameters
otherAnother Linear object to copy.

◆ ~Linear()

virtual pytorch::Linear::~Linear ( )
inlinevirtual

Destructor - for now trivial, may need to take on some functionality.

Member Function Documentation

◆ add_bias()

void pytorch::Linear::add_bias ( const std::string &  bias_filename,
const std::vector< int > &  bias_dims 
)
inline

Read in bias from a file given here if it wasn't passed to the constructor. Overwrites current contents of this->bias.

Parameters
bias_filenameThe file where the bias tensor is saved. Will be loaded with numpy.load(filename).
bias_dimsThe dimensions of the bias tensor in pytorch convention - (batch, channels, h, w)

◆ add_weights()

void pytorch::Linear::add_weights ( const std::string &  weights_filename,
const std::vector< int > &  weights_dims 
)
inline

Read in weights from a file given here if it wasn't passed to the constructor. Overwrites current contents of this->weights.

Parameters
weights_filenameThe file where the weights tensor is saved. Will be loaded with numpy.load(filename).
weights_dimsThe dimensions of the weights tensor in pytorch convention - (batch, channels, h, w)

◆ forward()

std::vector<tensor> pytorch::Linear::forward ( const std::vector< tensor > &  input)
inlinevirtual

Forward function, takes data and performs the Linear operation using the already-initialized weights and bias tensors.

Parameters
inputInput data size (dims_in, 1, 1, batch)
Returns
Transformed data size (dims_out, 1, batch)

Implements pytorch::Layer.

◆ operator()()

std::vector<tensor> pytorch::Linear::operator() ( const std::vector< tensor > &  input)
inlinevirtual

Forward function, takes data and performs the Linear operation using the already-initialized weights and bias tensors.

Parameters
inputInput data size (dims_in, 1, 1, batch)
Returns
Transformed data size (dims_out, 1, batch)

Implements pytorch::Layer.

Member Data Documentation

◆ bias

tensor pytorch::Linear::bias
private

◆ has_bias

bool pytorch::Linear::has_bias = false
private

◆ utils

pycpp::py_object pytorch::Linear::utils
private

◆ weights

tensor pytorch::Linear::weights
private

The documentation for this class was generated from the following file: