關於深度學習實時檢測的三種方法(二)——pytorch訓練模型後再使用libtorch加載

碼道人 2021-08-15 20:41:57 阅读数:872

本文一共[544]字,预计阅读时长:1分钟~
深度 方法 pytorch 模型 使用

上一篇文章我曾經提到過,libtorch在進行圖像實時檢測中性能並不突出,很多時候無法滿足我們的需求,後來我在想能不能現在pytorch上訓練模型,只將libtorch作為一個加載模型的工具呢?經過嘗試我發現這種方法是可行的,並且無論是在運行時間上還是在預測的准確率上都要優於前者,本文章將介紹如何在pytorch上訓練模型並用libtorch進行加載預測。

pytorch安裝

pytorch官方為我們提供了非常方便的安裝渠道,在官網上選擇適配自己電腦的選項後用官方給出的命令行下載即可,如下圖:

驗證安裝成功

終端鍵入

python3
import torch

若如下圖成功執行,即錶示安裝成功(記得選擇對應的python版本)

pytorch構建網絡

主要用到的就是torch中的nn、autograd、torchvision和tqdm模塊,在nn.Module的基礎上繼承並進行構建,下面給出例程代碼,同樣以LeNet5為例:

class LeNet5(nn.Module):
def __init__(self, num_class):
super(LeNet5, self).__init__()
#卷積層
self.Conv = nn.Sequential(
nn.Conv2d(1, 6, 5, stride=1, padding=2),
nn.BatchNorm2d(6),
nn.ReLU(True),
#池化 
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.BatchNorm2d(16),
nn.ReLU(True),
#池化
nn.MaxPool2d(2, 2)
)
#全連接層
self.FC = nn.Sequential(
nn.Linear(400, 120),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(120, 84),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(84, num_class)
)
def forward(self, x):
#卷積層
out = self.Conv(x)
out = out.view(out.size(0), -1)
#全連接層
out = self.FC(out)
return out

訓練(核心代碼)

# 向前傳播
out = model(img)
# 計算損失函數
loss = criterion(out, label)
running_loss = loss.item() * label.size(0)
_, pred = torch.max(out, dim=1)
num_correct = (pred == label).sum()
running_acc += num_correct.item()
# 手動清空梯度
optimizer.zero_grad()
# 向後傳播
loss.backward()
optimizer.step()

保存模型

# 保存模型
torch.save(model, '../model/Vision_NumDetect.pth')
# pth模型轉成pt模型
with torch.no_grad():
model.eval()
trace_script_modile=torch.jit.trace(model, img)
trace_script_modile.save(r"../model/NumDetect.pt") #壓縮好的模型存出來

以上均是在pytorch下完成的,然後再通過libtorch載入並預測:

libtorch載入模型

torch::jit::load(path);

預測

/**
* @brief 使用torch將圖片傳入模型中進行預測
* @return 返回預測結果
*/
int torchForward(torch::jit::script::Module &module, const Mat &src)
{
std::vector<int64_t> sizes = {1, 1, src.rows, src.cols};
at::TensorOptions options(at::ScalarType::Byte);
//將opencv的圖像數據轉為Tensor張量數據
at::Tensor tensor_image = torch::from_blob(src.data, at::IntList(sizes), options);
//轉為浮點型張量數據
tensor_image = tensor_image.toType(at::kFloat);
// 前饋預測
at::Tensor result = module.forward({tensor_image}).toTensor();
auto max_result = result.max(1, true);
int max_index = std::get<1>(max_result).item<int>();
return max_index;
}

總結

由於libtorch是從pytorch上移植的,所以在構建網絡的邏輯部分並不完善,構建出的網絡性能對於pytorch較差,采用pytorch構建網絡並訓練模型,再用libtorch進行加載,既解决了實時性檢測的性能問題,又解决在C++上的適配問題。如果有什麼錯誤的地方,歡迎留言交流!

版权声明:本文为[碼道人]所创,转载请带上原文链接,感谢。 https://gsmany.com/2021/08/20210815204156300l.html