1 Star 0 Fork 3

立冬 / deep-text

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

Deep Text

简介

Deep Text是一个基于Tensorflow的NLP算法深度学习模型集成库,包含文本分类,序列标注,文本匹配,文本向量化,文本生成,OCR等多种算法实现,目前实现了部分基本深度学习NLP算法。后续会增加更多算法模型的实现,并准备提供Java版的模型预测调用接口。

版本

当前版本 1.3.7 tensorflow 1.8+

已实现模型

  • TextCNN(文本分类,支持多标签)
  • TextRNN(文本分类,支持多标签)
  • BiGRU+CRF(文本序列标注)
  • Skip-thoughts(生成句子向量)
  • CDSSM(文本匹配模型)
  • CNN+BLSTM+CTC(OCR文字识别)

工程结构

.
├── common     基础类&工具类
├── deepcls    分类算法模型
├── deepcrf    序列标注
├── deepemb    句子向量
├── deeplm     语言模型
├── deepmatch  文本匹配模型
└── deepocr    OCR

Deep Text项目里将算法模型分为配置生成、数据读取转换(transform)、训练(train)、模型(model)和预测(graph model)几个部分,每个算法模型都如此实现,数据读取部分提供默认实现,支持实现自定义数据格式的读取。

安装

源码安装:

git clone https://gitee.com/wangsihong/deep-text.git
python setup.py install

pip安装:

pip install deep-text

生成配置文件

model name :

  • "deepcrf" or "deep-crf"
  • "textcnn" or "clscnn"
  • "textrnn" or "clsrnn"
  • "lm" or "deeplm"
  • "cdssm" or "cdssm"
  • "skip-thought" or "stemb"
  • "deepocr" or "lstm-ctc"
deeptext_gen_config -m <model name> -o <output file>

模型训练

命令基本格式如下:

-t    --train 训练数据文件
-e    --eval  测试数据文件
-c    --config 配置文件(不同模型的配置文件不同)
-m    --model 模型保存文件名
-f            执行数据读取&转换
-w            预训练词向量

[model option] 用来区分分类具体要使用的模型,目前只有在文本分类模型中使用。

cnn    TextCNN
rnn    TextRNN
cnn_m  TextCNN(多标签分类)
rnn_m  TextRNN(多标签分类)
deepxxxx_learn [model option] -t <trainfile> -e <testfile> -c <configfile> -m <model save path> [-f <do transform> -w <build word2vec>]

模型训练完会保存成模型文件,也可以通过tensorflow的checkpoint保存模型

deepxxxx_save -p <checkpoint path> -c <config path> -m <model file path>

模型调用

模型预测代码实现在各模块的model.py中,预测类命名方式为GraphXXXXXModel(model_file, config_file) 以下是几个使用python调用模型预测的例子:

TextCNN:

from deepcls import GraphTextCNNModel
from common import config_ops

#load model
config = config_ops.load_config(config_file_path)
model = GraphTextCNNModel(model_path, config)

## predict
ret = model.predict([text], 10)

TextRNN:

from deepcls import GraphTextRNNModel
from common import config_ops

#load model
config = config_ops.load_config(config_file_path)
model = GraphTextRNNModel(model_path, config)

#predict
rets = model.predict([text], 10)

DeepCRF:

from common.config_ops import load_config
from deepcrf.model import GraphDeepCRFModel

#load model
config = load_config(config_path)
model = GraphDeepCRFModel(model_path, config, None)

#predict
rets = model.predict([texts])

DeepOCR:

from common.config_ops import load_config
from deepocr.model import GraphDeepOCRModel
from PIL import Image
import PIL.ImageOps

#load model
config = load_config("config.json")
model = GraphDeepOCRModel(model_path, config)

#predict
image = Image.open("image_file")
rets = model.predict([image])

空文件

简介

暂无描述 展开 收起
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
1
https://gitee.com/wangsihong/deep-text.git
git@gitee.com:wangsihong/deep-text.git
wangsihong
deep-text
deep-text
master

搜索帮助