当前位置:首页 > 技术知识 > 正文内容

深度|Matlab编程之——卷积神经网络CNN代码解析

maynowei6个月前 (10-19)技术知识87

DeepLearnToolbox-master是一个深度学习matlab包,里面含有很多机器学习算法,如卷积神经网络CNN,深度信念网络DBN,自动编码AutoEncoder(堆栈SAE,卷积CAE)的作者是 RasmusBerg Palm。

今天给介绍deepLearnToolbox-master中的CNN部分。

DeepLearnToolbox-master中CNN内的函数:

调用关系为:

该模型使用了mnist的数字mnist_uint8.mat作为训练样本,作为cnn的一个使用样例,每个样本特征为一个28*28=的向量。

网络结构为:

让我们来分析各个函数:

一、Test_example_CNN

三、cnntrain.m.

四、cnnff.m.

五、cnnbp.m.

五、cnnapplygrads.m.

六、cnntest.m.

一、Test_example_CNN:

1、设置CNN的基本参数规格,如卷积、降采样层的数量,卷积核的大小、降采样的降幅

2、cnnsetup函数 初始化卷积核、偏置等

3、cnntrain函数 训练cnn,把训练数据分成batch,然后调用

3.1 cnnff 完成训练的前向过程,

3.2 cnnbp计算并传递神经网络的error,并计算梯度(权重的修改量)

3.3 cnnapplygrads 把计算出来的梯度加到原始模型上去

4、cnntest函数,测试当前模型的准确率

该模型采用的数据为mnist_uint8.mat,

含有70000个手写数字样本其中60000作为训练样本,10000作为测试样本。

把数据转成相应的格式,并归一化。

二、Cnnsetup.m

该函数你用于初始化CNN的参数。

设置各层的mapsize大小,初始化卷积层的卷积核、bias尾部单层感知机的参数设置bias统一设置为0,权重设置为:-1~1之间的随机数/sqrt(6/(输入神经元数量+输出神经元数量))

对于卷积核权重,输入输出为fan_in, fan_out

fan_out= net.layers{l}.outputmaps * net.layers{l}.kernelsize ^ 2;

%卷积核初始化,1层卷积为1*6个卷积核,2层卷积一共6*12=72个卷积核。对于每个卷积输出featuremap,%fan_in= 表示该层的一个输出map,所对应的所有卷积核,包含的神经元的总数。1*25,6*25

fan_in =numInputmaps * net.layers{l}.kernelsize ^ 2;

fin=1*25 or 6*25

fout=1*6*25 or 6*12*25

net.layers{l}.k{i}{j} =(rand(net.layers{l}.kernelsize) – 0.5) * 2 * sqrt(6 / (fan_in + fan_out));

1、卷积降采样的参数初始化

2、尾部单层感知机的参数(权重和偏量)设置:

三、cnntrain.m

该函数用于训练CNN。

生成随机序列,每次选取一个batch(50)个样本进行训练。

批训练:计算50个随机样本的梯度,求和之后一次性更新到模型权重中。

在批训练过程中调用:

Cnnff.m 完成前向过程

Cnnbp.m 完成误差传导和梯度计算过程

Cnnapplygrads.m把计算出来的梯度加到原始模型上去

四、cnnff.m

3、尾部单层感知机的数据处理,需要把subFeatureMap2连接成为一个(4*4)*12=192的向量,但是由于采用了50样本批训练的方法,subFeatureMap2被拼合成为一个192*50的特征向量fv;

Fv作为单层感知机的输入,全连接的方式得到输出层

五、cnnbp.m

该函数实现2部分功能,计算并传递误差,计算梯度

3、把单层感知机的输入层featureVector的误差矩阵,恢复为subFeatureMap2的4*4二维矩阵形式

插播一张图片:

4、误差在特征提取网络【卷积降采样层】的传播

如果本层是卷积层,它的误差是从后一层(降采样层)传过来,误差传播实际上是用降采样的反向过程,也就是降采样层的误差复制为2*2=4份。卷积层的输入是经过sigmoid处理的,所以,从降采样层扩充来的误差要经过sigmoid求导处理。

如果本层是降采样层,他的误差是从后一层(卷积层)传过来,误差传播实际是用卷积的反向过程,也就是卷积层的误差,反卷积(卷积核转180度)卷积层的误差,原理参看插图。

5、计算特征抽取层和尾部单层感知机的梯度

五、cnnapplygrads.m

该函数完成权重修改,更新模型的功能

1、更新特征抽取层的权重 weight+bias

2、更新末尾单层感知机的权重 weight+bias

六、cnntest.m

验证测试样本的准确率

点击“阅读原文”

相关文章

Android之自定义ListView(一)(android 自定义view绘制流程)

PS:自定义View是Android中高手进阶的路线.因此我也打算一步一步的学习.看了鸿洋和郭霖这两位大牛的博客,决定一步一步的学习,循序渐进.学习内容:1.自定义View实现ListView的Ite...

C# 中的多线程同步机制:lock、Monitor 和 Mutex 用法详解

在多线程编程中,线程同步是确保多个线程安全地访问共享资源的关键技术。C# 提供了几种常用的同步机制,其中 lock、Monitor 和 Mutex 是最常用的同步工具。本文将全面介绍这三种同步机制的用...

如何优雅地使用嵌入式事件标志组?

事件标志组嵌入式事件标志组是一种在嵌入式系统中广泛使用的同步机制,主要用于实现多任务间的同步与通信。事件标志组是一组事件标志位的集合,每个位代表一个事件是否发生。它允许任务等待特定的事件发生,当事件发...

btrace 3.0 重磅新增 iOS 支持!免插桩原理大揭秘!

重磅更新btrace 是由字节跳动抖音基础技术团队自主研发的面向移动端的性能数据采集工具,它能够高效的助力移动端应用采集性能 Trace 数据,深入剖析代码的运行状况,进而辅助优化提升移动端应用的性能...

如何正确理解Java领域中的并发锁,我们应该具体掌握到什么程度?

苍穹之边,浩瀚之挚,眰恦之美; 悟心悟性,善始善终,惟善惟道! —— 朝槿《朝槿兮年说》写在开头对于Java领域中的锁,其实从接触Java至今,我相信每一位Java Developer都会有这样的一个...

Oracle数据库无法连接问题排查(oracle数据库连接不成功)

数据库告警日志 如下图 。发现 问题时间段,没有 数据库服务故障 报错,但是存在较多 TNS-12535 、 12560 、 12170 、 00505 错误:通过检查问题时间段应用日志, 也记录了...