31 #include <catch2/catch.hpp> 32 #include <torch/torch.h> 41 struct Net : torch::nn::Module
43 Net(int64_t numInputs, int64_t numNeurons)
45 W = register_parameter(
"W", torch::randn({numInputs, numNeurons}));
46 b = register_parameter(
"b", torch::randn(numNeurons));
49 torch::Tensor forward(torch::Tensor input)
51 return torch::addmm(b, input, W);
58 Net net(numInputs, numNeurons);
61 torch::Tensor inputs = torch::rand({numBatch, numInputs});
62 torch::Tensor outputs = net.forward(inputs);
67 TEST_CASE(
"Executing a very basic one layer FNN forward pass",
"[PyTorch C++ API]")
TEST_CASE("Executing a very basic one layer FNN forward pass", "[PyTorch C++ API]")
int createNetAndForward(int64_t numInputs, int64_t numNeurons, int64_t numBatch)