pytorch-inference
layers.hpp File Reference

Equivalent to Conv2d in pytorch. More...

Go to the source code of this file.

Namespaces

 pytorch
 

Enumerations

enum  pytorch::dims { pytorch::n = 3, pytorch::k = 2, pytorch::h = 0, pytorch::w = 1 }
 Convenience enum to use whenever you need to specify a dimension (like in Concat) More...
 

Detailed Description

Equivalent to Conv2d in pytorch.

Equivalent to AvgPool2d in pytorch.

Equivalent to MaxUnool2d in pytorch.

Equivalent to MaxPool2d in pytorch (with a caveat)

Equivalent to BatchNorm2d in pytorch.

Equivalent to Linear in pytorch.

Performs a no-op - makes it so that we can create branching networks.

Abstract base class for all layers.

Implements the forward pass for pytorch's nn.Conv2d module. Note that clearly something needs to happen to get the tensors from python to C++, but I've tried to take care of this through the import constructor.

We store pointers to this class in the inference engine which allows us to use a std::vector to store them but still have multiple classes represented. The only requirement is that they implement the forward and operator() methods.

Implements the forward pass for pytorch's nn.Linear module. Note that clearly something needs to happen to get the tensors from python to C++, but I've tried to take care of this through the import constructor.

Implements the forward pass for pytorch's nn.BatchNorm2d module. Note that you do need to extract the proper running mean, running variance, gamma and beta tensors from pytorch.

This layer implements the forward pass of pytorch's nn.MaxPool2d module. This holds the pooling indices inside itself - unpooling will be another challenge if it's even possible with this framework. For now, we just store the indices in arrayfire format for future use.

This layer implements the forward pass of pytorch's nn.MaxUnool2d module. This holds a pointer to a maxpool layer from which to get the indices for the unpooling process.

This layer implements the forward pass of pytorch's nn.AvgPool2d module. It doesn't do anything fancy, just takes the mean of the various windows.