博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow C++接口调用图像分类pb模型代码
阅读量:4687 次
发布时间:2019-06-09

本文共 4311 字,大约阅读时间需要 14 分钟。

#include 
#include
#include
#include
#include
#include "tensorflow/cc/ops/const_op.h"#include "tensorflow/cc/ops/image_ops.h"#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph.pb.h"#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/default_device.h"#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h"#include "tensorflow/core/lib/core/stringpiece.h"#include "tensorflow/core/lib/core/threadpool.h"#include "tensorflow/core/lib/io/path.h"#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/session.h"#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/platform/env.h"#include "tensorflow/core/platform/init_main.h"#include "tensorflow/core/platform/logging.h"#include "tensorflow/core/platform/types.h" #include "opencv2/opencv.hpp" using namespace tensorflow::ops;using namespace tensorflow;using namespace std;using namespace cv;using tensorflow::Flag;using tensorflow::Tensor;using tensorflow::Status;using tensorflow::string;using tensorflow::int32 ; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为// Tensor了void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols){ //imshow("input image",img); //图像进行resize处理 resize(img,img,cv::Size(input_cols,input_rows)); //imshow("resized image",img); //归一化 img.convertTo(img,CV_32FC1); img=1-img/255; //创建一个指向tensor的内容的指针 float *p = output_tensor->flat
().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值 cv::Mat tempMat(input_rows, input_cols, CV_32FC1, p); img.convertTo(tempMat,CV_32FC1); // waitKey(0); } int main(int argc, char** argv ){ /*--------------------------------配置关键信息------------------------------*/ string model_path="../inception_v3_2016_08_28_frozen.pb"; string image_path="../test.jpg"; int input_height =299; int input_width=299; string input_tensor_name="input"; string output_tensor_name="InceptionV3/Predictions/Reshape_1"; /*--------------------------------创建session------------------------------*/ Session* session; Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/ GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型; if (!status_load.ok()) { cout << "ERROR: Loading model failed..." << model_path << std::endl; cout << status_load.ToString() << "\n"; return -1; } Status status_create = session->Create(graphdef); //将模型导入会话Session中; if (!status_create.ok()) { cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl; return -1; } cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/ cout<
<<"<------------loading test_image-------------->"<
"<
outputs; string output_node = output_tensor_name; Status status_run = session->Run({ {input_tensor_name, resized_tensor}}, {output_node}, {}, &outputs); if (!status_run.ok()) { cout << "ERROR: RUN failed..." << std::endl; cout << status_run.ToString() << "\n"; return -1; } //把输出值给提取出来 cout << "Output tensor size:" << outputs.size() << std::endl; for (std::size_t i = 0; i < outputs.size(); i++) { cout << outputs[i].DebugString()<
(); // Tensor Shape: [batch_size, target_class_num] int output_dim = t.shape().dim_size(1); // Get the target_class_num from 1st dimension // Argmax: Get Final Prediction Label and Probability int output_class_id = -1; double output_prob = 0.0; for (int j = 0; j < output_dim; j++) { cout << "Class " << j << " prob:" << tmap(0, j) << "," << std::endl; if (tmap(0, j) >= output_prob) { output_class_id = j; output_prob = tmap(0, j); } } // 输出结果 cout << "Final class id: " << output_class_id << std::endl; cout << "Final class prob: " << output_prob << std::endl; return 0;}

 

转载于:https://www.cnblogs.com/cnugis/p/11507872.html

你可能感兴趣的文章
SQL case when else
查看>>
MVc Identity登陆锁定
查看>>
cdn连接失败是什么意思_关于CDN的原理、术语和应用场景那些事
查看>>
ultraedit26 运行的是试用模式_免费试用U盘数据恢复工具 – 轻松找回U盘丢失的各种数据!...
查看>>
python sum函数导入list_python sum函数iterable参数为二维list,start参数为“[]”该如何理解...
查看>>
UVa540 Team Queue
查看>>
android 练习之路 (八)
查看>>
tp5 中 model 的聚合查询
查看>>
android wear开发之:增加可穿戴设备功能到通知中 - Adding Wearable Features to Notifications...
查看>>
压缩文件函数库(转载)
查看>>
【转】ubuntu12.04没有/var/log/messages解决
查看>>
Oracle EBS 初始化用户密码
查看>>
SYS_CONTEXT 详细用法
查看>>
Pycharm配置autopep8让Python代码更符合pep8规范
查看>>
函数的复写
查看>>
17_重入锁ReentrantLock
查看>>
winform窗口关闭提示
查看>>
64款工具,总有合适您的那款
查看>>
我的第一篇博客
查看>>
大数据学习线路整理
查看>>