33 #include <torch/all.h> 40 torch::set_default_dtype(torch::scalarTypeToTypeMeta(torch::kDouble));
43 module = torch::jit::load(filename);
57 auto resetMethod =
module.find_method(
"reset");
58 if (resetMethod.has_value()) {
59 torch::jit::Stack stack;
60 resetMethod->run(stack);
67 torch::Tensor obs_to = torch::from_blob(observation->ele, {observation->m}, torch::dtype(torch::kDouble));
70 torch::Tensor act_to =
module.forward({obs_to}).toTensor();
73 torch::Tensor act_out = torch::from_blob(action->ele, {action->m}, torch::dtype(torch::kDouble));
74 act_out.copy_(act_to);
static ControlPolicyRegistration< TorchPolicy > RegTorchPolicy("torch")
torch::jit::script::Module module
TorchPolicy(const char *filename)
virtual void computeAction(MatNd *action, const MatNd *observation)