1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
| #include <time.h> #include <torch/torch.h> #include <iostream>
#define USE_MPS 1
using namespace std;
struct Net : torch::nn::Module { Net() { conv1 = register_module("conv1", torch::nn::Conv2d(3, 64, 3)); conv2 = register_module("conv2", torch::nn::Conv2d(64, 128, 3)); conv3 = register_module("conv3", torch::nn::Conv2d(128, 256, 3)); fc1 = register_module("fc1", torch::nn::Linear(256, 128)); fc2 = register_module("fc2", torch::nn::Linear(128, 56)); fc3 = register_module("fc3", torch::nn::Linear(56, 10)); global_pool = register_module("global_pool", torch::nn::AdaptiveAvgPool2d(1)); }
torch::Tensor forward(torch::Tensor x) { x = torch::relu(conv1->forward(x)); x = torch::max_pool2d(x, {2, 2}); x = torch::relu(conv2->forward(x)); x = torch::max_pool2d(x, {2, 2}); x = torch::relu(conv3->forward(x)); x = torch::max_pool2d(x, {2, 2}); x = global_pool->forward(x); x = torch::relu(fc1->forward(x.reshape({x.size(0), -1}))); x = torch::relu(fc2->forward(x)); x = torch::log_softmax(fc3->forward(x), 1);
return x; }
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; torch::nn::AdaptiveAvgPool2d global_pool{nullptr}; };
int main(int argc, char* argv[]) { auto net = std::make_shared<Net>(); torch::Tensor data = torch::ones({8, 3, 128, 128});
#ifdef USE_MPS net->to(torch::Device(torch::kMPS)); data = data.to("mps"); // torch::Tensor data = torch::ones({8, 3, 128, 128}).to("mps"); #endif
torch::Tensor y; clock_t start, end; start = clock(); for (int i = 0; i < 100; ++i) { y = net->forward(data); } end = clock(); cout << "Time: " << double(end - start) / CLOCKS_PER_SEC << endl;
return 0; }
|