3 Star 6 Fork 2

Fan WenJie / LeNet5-Vulkan

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

LeNet5-Vulkan

介绍

通过Vulkan的计算着色器编写LeNet5神经网络,并完成训练和推理的过程,根据YANN LECUN的论文《Gradient-based Learning Applied To Document Recognition》设计的LeNet-5神经网络,C语言写成,不依赖任何第三方库。 MNIST手写字符集训练识别率97%

DEMO

main.c文件为MNIST数据集的识别DEMO,直接编译即可运行,训练集60000张,测试集10000张。打开项目直接编译即可

项目环境

  1. 安装Visual Studio 2019
  2. 安装Vulkan SDK

API 说明

初始化设备上下文

ctx: 需要初始化的设备上下文 VkResult CreateDeviceContext(DeviceContext* ctx);

销毁设备上下文

ctx: 需要销毁的设备上下文 void DestroyDeviceContext(DeviceContext* ctx);

初始化训练缓存

ctx: 设备上下文 cache: 训练缓存 batchSize: 批量训练数 VkResult CreateTrainCache(DeviceContext* ctx, TrainCache* cache, const uint32_t batchSize);

销毁训练缓存

ctx: 设备上下文 cache: 训练缓存 void DestroyTrainCache(DeviceContext* ctx, TrainCache* cache);

从主存中加载模型到设备内存中

void LoadModel(DeviceContext* lenet, LeNet5* data);

将设备内存中的模型加载到主存中

void SaveModel(DeviceContext* lenet, LeNet5* data);

预测模型结果

ctx: 设备上下文 feature: 特征数据 uint32_t Predict(DeviceContext* ctx, Feature* feature);

批量训练模型

ctx: 设备上下文 cache: 训练缓存 feature: 参与训练的特征数据 label: 训练的标签 void TrainBatch(DeviceContext* ctx, TrainCache* cache, Feature* feature, uint32_t* label);

特别说明

内推字节跳动抖音/tiktok图形图像方面的人才,主要岗位有图形引擎开发、算法等,有意者请私信gitee.com/fanwenjie

空文件

简介

通过Vulkan的计算着色器编写LeNet5神经网络,并完成训练和推理的过程 展开 收起
C
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
C
1
https://gitee.com/fanwenjie/LeNet5-Vulkan.git
git@gitee.com:fanwenjie/LeNet5-Vulkan.git
fanwenjie
LeNet5-Vulkan
LeNet5-Vulkan
master

搜索帮助