C# WPF调用python-tensorflow2深度学习模型

   日期:2020-09-07     浏览:338    评论:0    
核心提示:C# WPF调用python-tensorflow2深度学习模型一 环境介绍二 tensorflow模型的训练和生成1 加载数据训练模型2 h5文件转pb二 C#加载模型并预测1 vs2017环境搭建2 调用模型三 最终效果python在研究深度学习人工智能领域十分强大,但在工业项目开发中仍常常使用C#和C++来做软件,C++有Caffe深度学习框架,但C#尚且没有成熟的深度学习框架(有个tensroflow.net尚在开发中,有兴趣可以去研究研究)。现在实验室项目开发又要用C#,经过实践最终决定在C#端

C# WPF调用python-tensorflow2深度学习模型

    • 一 环境介绍
    • 二 tensorflow模型的训练和生成
      • 1 加载数据训练模型
      • 2 h5文件转pb
    • 二 C#加载模型并预测
      • 1 vs2017环境搭建
      • 2 调用模型
    • 三 最终效果

python在研究深度学习人工智能领域十分强大,但在工业项目开发中仍常常使用C#和C++来做软件,C++有Caffe深度学习框架,但C#尚且没有成熟的深度学习框架(有个tensroflow.net尚在开发中,有兴趣可以去研究研究)。现在实验室项目开发又要用C#,经过实践最终决定在C#端利用OpencvSharp4的DNN模块加载python端tensorflow2训练的模型进行预测,其速度还可以。

一 环境介绍

python:Python3.7 tensorflow2.1

c#: vs2017 .net framework 4.6.1

二 tensorflow模型的训练和生成

1 加载数据训练模型

1.1 数据集采用猫狗二分类数据。
数据集网盘链接:链接:https://pan.baidu.com/s/15LR7-tgvglzwW9n4eFsFgg
提取码:iz64


1.2 创建图片数据输入管道
代码实现:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import glob

gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_virtual_device_configuration( 
    gpu[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
    
image_path = glob.glob('./datasets/dc/train/*.jpg')
image_label = [int(path.split('\\')[1].split('.')[0]=='cat') for path in image_path]

def get_image_data(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (224, 224))
    image = tf.cast(image, tf.float32)/255
    label = tf.reshape(label, [1])
    return image, label
    
dataset = tf.data.Dataset.from_tensor_slices((image_path, image_label))
dataset = dataset.map(get_image_data)
train_count = int(len(image_path)*0.8)
test_count = len(image_path)-train_count
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
train_dataset = train_dataset.shuffle(len(image_path)).repeat().batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

1.3 搭建并训练模型最后保存模型及参数,保存的格式的.h5文件,最终的准确度基本在99%以上。
代码实现:

MobileNet = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
model = tf.keras.Sequential()
model.add(MobileNet)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

model.compile(optimizer='adam',
              loss=tf.keras.losses.binary_crossentropy,
              metrics=['acc'])
model.fit(train_dataset, 
          epochs=10, 
          steps_per_epoch=train_count//BATCH_SIZE, 
          validation_data=test_dataset, 
          validation_steps=test_count//BATCH_SIZE)
model.save('./model_h5/mobilenet.h5')

2 h5文件转pb

Opencv的DNN模块接收tensorflow模型文件为pb文件,先将h5文件转换成pb文件,在tensorflow2.0端完成文件类型的转换。
转换代码:

#参数1为h5文件的路径,参数2为要将pb文件保存到那个文件夹的路径,最后一个参数为pb文件的名称
def convert_h5to_pb(h5_path, pb_path,  pb_name):
    model = tf.keras.models.load_model(h5_path, compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
 
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)
    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir=pb_path,
                      name=pb_name,
                      as_text=False

二 C#加载模型并预测

1 vs2017环境搭建

在项目属性中设置平台目标为x64,

目标框架选择.net framework 4.6.1,没有该框架的可去官网下载安装。

进入NuGet程序包管理界面,搜索并下载如下三个包,有可能由于网络问题无法下载,可根据提示网站进入下载。

2 调用模型

1 xml前端界面

<Window
        xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
        xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
        xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
        xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
        xmlns:telerik="http://schemas.telerik.com/2008/xaml/presentation" x:Class="WpfApp1.MainWindow"
        mc:Ignorable="d"
        Title="猫狗分类" Height="300" Width="500" WindowStartupLocation="CenterScreen">
    <Grid>
        <Grid.ColumnDefinitions>
            <ColumnDefinition Width="200"/>
            <ColumnDefinition/>
        </Grid.ColumnDefinitions>
        <Grid Grid.Column="0">
            <Grid.RowDefinitions>
                <RowDefinition Height="1*"/>
                <RowDefinition Height="3*"/>
            </Grid.RowDefinitions>
            <telerik:RadButton x:Name="read_image"  Content="读取图片" Click="Read_image_Click" Margin="70,15,50,15"/>
            <Grid Grid.Row="1">
                <Grid.ColumnDefinitions>
                    <ColumnDefinition Width="70"/>
                    <ColumnDefinition/>
                </Grid.ColumnDefinitions>
                <Grid.RowDefinitions>
                    <RowDefinition/>
                    <RowDefinition/>
                    <RowDefinition/>
                    <RowDefinition Height="20"/>
                </Grid.RowDefinitions>
                <Label Content="得分:" HorizontalAlignment="Center"  VerticalAlignment="Center"/>
                <TextBox x:Name="score" HorizontalAlignment="Left"   VerticalAlignment="Center"  Grid.Row="0" Grid.Column="1" Width="120"/>
                <Label Content="类别:" HorizontalAlignment="Center"  VerticalAlignment="Center" Grid.Row="1"/>
                <TextBox x:Name="classes" HorizontalAlignment="Left"   VerticalAlignment="Center"  Grid.Row="1" Grid.Column="1" Width="120"/>
                <Label Content="时间:" HorizontalAlignment="Center"  VerticalAlignment="Center" Grid.Row="2"/>
                <TextBox x:Name="time" HorizontalAlignment="Left"   VerticalAlignment="Center"  Grid.Row="2" Grid.Column="1" Width="120"/>
            </Grid>
        </Grid>

        <Border BorderBrush="Black" BorderThickness="1" Grid.Column="1" HorizontalAlignment="Center" Height="214"  VerticalAlignment="Center" Width="265">
            <Image x:Name="img"/>
        </Border>
    </Grid>
</Window>

2 C#后端实现预测

//引入OpencvSharp和Dnn模块
using System;
using System.Windows;
using System.Windows.Media.Imaging;
using OpenCvSharp.Dnn;
using OpenCvSharp;
using Microsoft.Win32;

namespace WpfApp1
{
    /// <summary>
    /// MainWindow.xaml 的交互逻辑
    /// </summary>
    public partial class MainWindow : System.Windows.Window
    {
        public MainWindow()
        {
            InitializeComponent();
        }

        public void Dnn_Classification(Mat image)
        {
            String model_path = ".//mobilenet.pb";//模型路径
            Net net = CvDnn.ReadNetFromTensorflow(model_path);//加载模型
            if (net.Empty())
            {
                MessageBox.Show("pd文件错误");
                return;
            }
            Mat input_image = CvDnn.BlobFromImage(image, 1 / 255.0, new OpenCvSharp.Size(224, 224)); //图片归一化和resize
            net.SetInput(input_image);
            Mat result = net.Forward();//载入图片并前向计算
            float result_score = result.Get<float>(0, 0);//获得计算结果
            score.Text = result_score.ToString();
            if (result_score >= 0.5)
            {
                classes.Text = "Cat";
            }
            else
            {
                classes.Text = "Dog";
            }
        }

        private void Read_image_Click(object sender, RoutedEventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.InitialDirectory = @"C:\Users\LemonQiu\Desktop";
            ofd.Filter = "JPG图片|*.jpg|PNG图片|*.png";
            if (ofd.ShowDialog() == true)
            {
                img.Source = new BitmapImage(new Uri(ofd.FileName));
                Mat image = Cv2.ImRead(ofd.FileName);

                System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch();
                watch.Start();
                Dnn_Classification(image);
                watch.Stop();
                TimeSpan timespan = watch.Elapsed;
                time.Text = (timespan.TotalMilliseconds).ToString() + "ms";
            }
            else
            {
                MessageBox.Show("没有选择图片");
            }
        }
    }
}

三 最终效果

检测时间基本稳定在100ms每张。

 
打赏
 本文转载自:网络 
所有权利归属于原作者,如文章来源标示错误或侵犯了您的权利请联系微信13520258486
更多>最近资讯中心
更多>最新资讯中心
0相关评论

推荐图文
推荐资讯中心
点击排行
最新信息
新手指南
采购商服务
供应商服务
交易安全
关注我们
手机网站:
新浪微博:
微信关注:

13520258486

周一至周五 9:00-18:00
(其他时间联系在线客服)

24小时在线客服