C# WPF调用python-tensorflow2深度学习模型
- 一 环境介绍
- 二 tensorflow模型的训练和生成
- 1 加载数据训练模型
- 2 h5文件转pb
- 二 C#加载模型并预测
- 1 vs2017环境搭建
- 2 调用模型
- 三 最终效果
python在研究深度学习人工智能领域十分强大,但在工业项目开发中仍常常使用C#和C++来做软件,C++有Caffe深度学习框架,但C#尚且没有成熟的深度学习框架(有个tensroflow.net尚在开发中,有兴趣可以去研究研究)。现在实验室项目开发又要用C#,经过实践最终决定在C#端利用OpencvSharp4的DNN模块加载python端tensorflow2训练的模型进行预测,其速度还可以。
一 环境介绍
python:Python3.7 tensorflow2.1
c#: vs2017 .net framework 4.6.1
二 tensorflow模型的训练和生成
1 加载数据训练模型
1.1 数据集采用猫狗二分类数据。
数据集网盘链接:链接:https://pan.baidu.com/s/15LR7-tgvglzwW9n4eFsFgg
提取码:iz64
1.2 创建图片数据输入管道
代码实现:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import glob
gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_virtual_device_configuration(
gpu[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
image_path = glob.glob('./datasets/dc/train/*.jpg')
image_label = [int(path.split('\\')[1].split('.')[0]=='cat') for path in image_path]
def get_image_data(path, label):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, (224, 224))
image = tf.cast(image, tf.float32)/255
label = tf.reshape(label, [1])
return image, label
dataset = tf.data.Dataset.from_tensor_slices((image_path, image_label))
dataset = dataset.map(get_image_data)
train_count = int(len(image_path)*0.8)
test_count = len(image_path)-train_count
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
train_dataset = train_dataset.shuffle(len(image_path)).repeat().batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
1.3 搭建并训练模型最后保存模型及参数,保存的格式的.h5文件,最终的准确度基本在99%以上。
代码实现:
MobileNet = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
model = tf.keras.Sequential()
model.add(MobileNet)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam',
loss=tf.keras.losses.binary_crossentropy,
metrics=['acc'])
model.fit(train_dataset,
epochs=10,
steps_per_epoch=train_count//BATCH_SIZE,
validation_data=test_dataset,
validation_steps=test_count//BATCH_SIZE)
model.save('./model_h5/mobilenet.h5')
2 h5文件转pb
Opencv的DNN模块接收tensorflow模型文件为pb文件,先将h5文件转换成pb文件,在tensorflow2.0端完成文件类型的转换。
转换代码:
#参数1为h5文件的路径,参数2为要将pb文件保存到那个文件夹的路径,最后一个参数为pb文件的名称
def convert_h5to_pb(h5_path, pb_path, pb_name):
model = tf.keras.models.load_model(h5_path, compile=False)
model.summary()
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=pb_path,
name=pb_name,
as_text=False
二 C#加载模型并预测
1 vs2017环境搭建
在项目属性中设置平台目标为x64,
目标框架选择.net framework 4.6.1,没有该框架的可去官网下载安装。
进入NuGet程序包管理界面,搜索并下载如下三个包,有可能由于网络问题无法下载,可根据提示网站进入下载。
2 调用模型
1 xml前端界面
<Window
xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:telerik="http://schemas.telerik.com/2008/xaml/presentation" x:Class="WpfApp1.MainWindow"
mc:Ignorable="d"
Title="猫狗分类" Height="300" Width="500" WindowStartupLocation="CenterScreen">
<Grid>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="200"/>
<ColumnDefinition/>
</Grid.ColumnDefinitions>
<Grid Grid.Column="0">
<Grid.RowDefinitions>
<RowDefinition Height="1*"/>
<RowDefinition Height="3*"/>
</Grid.RowDefinitions>
<telerik:RadButton x:Name="read_image" Content="读取图片" Click="Read_image_Click" Margin="70,15,50,15"/>
<Grid Grid.Row="1">
<Grid.ColumnDefinitions>
<ColumnDefinition Width="70"/>
<ColumnDefinition/>
</Grid.ColumnDefinitions>
<Grid.RowDefinitions>
<RowDefinition/>
<RowDefinition/>
<RowDefinition/>
<RowDefinition Height="20"/>
</Grid.RowDefinitions>
<Label Content="得分:" HorizontalAlignment="Center" VerticalAlignment="Center"/>
<TextBox x:Name="score" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="0" Grid.Column="1" Width="120"/>
<Label Content="类别:" HorizontalAlignment="Center" VerticalAlignment="Center" Grid.Row="1"/>
<TextBox x:Name="classes" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="1" Grid.Column="1" Width="120"/>
<Label Content="时间:" HorizontalAlignment="Center" VerticalAlignment="Center" Grid.Row="2"/>
<TextBox x:Name="time" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="2" Grid.Column="1" Width="120"/>
</Grid>
</Grid>
<Border BorderBrush="Black" BorderThickness="1" Grid.Column="1" HorizontalAlignment="Center" Height="214" VerticalAlignment="Center" Width="265">
<Image x:Name="img"/>
</Border>
</Grid>
</Window>
2 C#后端实现预测
//引入OpencvSharp和Dnn模块
using System;
using System.Windows;
using System.Windows.Media.Imaging;
using OpenCvSharp.Dnn;
using OpenCvSharp;
using Microsoft.Win32;
namespace WpfApp1
{
/// <summary>
/// MainWindow.xaml 的交互逻辑
/// </summary>
public partial class MainWindow : System.Windows.Window
{
public MainWindow()
{
InitializeComponent();
}
public void Dnn_Classification(Mat image)
{
String model_path = ".//mobilenet.pb";//模型路径
Net net = CvDnn.ReadNetFromTensorflow(model_path);//加载模型
if (net.Empty())
{
MessageBox.Show("pd文件错误");
return;
}
Mat input_image = CvDnn.BlobFromImage(image, 1 / 255.0, new OpenCvSharp.Size(224, 224)); //图片归一化和resize
net.SetInput(input_image);
Mat result = net.Forward();//载入图片并前向计算
float result_score = result.Get<float>(0, 0);//获得计算结果
score.Text = result_score.ToString();
if (result_score >= 0.5)
{
classes.Text = "Cat";
}
else
{
classes.Text = "Dog";
}
}
private void Read_image_Click(object sender, RoutedEventArgs e)
{
OpenFileDialog ofd = new OpenFileDialog();
ofd.InitialDirectory = @"C:\Users\LemonQiu\Desktop";
ofd.Filter = "JPG图片|*.jpg|PNG图片|*.png";
if (ofd.ShowDialog() == true)
{
img.Source = new BitmapImage(new Uri(ofd.FileName));
Mat image = Cv2.ImRead(ofd.FileName);
System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch();
watch.Start();
Dnn_Classification(image);
watch.Stop();
TimeSpan timespan = watch.Elapsed;
time.Text = (timespan.TotalMilliseconds).ToString() + "ms";
}
else
{
MessageBox.Show("没有选择图片");
}
}
}
}
三 最终效果
检测时间基本稳定在100ms每张。