1 Star 0 Fork 2

小码编程AI / bert_train

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
demo2.py 1.03 KB
一键复制 编辑 原始数据 按行查看 历史
胡文祥 提交于 2019-02-20 22:17 . 第一次提交
import tensorflow as tf
from bert import modeling
import os
pathname = "chinese_L-12_H-768_A-12/bert_model.ckpt" # 模型地址
bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json")# 配置文件地址。
configsession = tf.ConfigProto()
configsession.gpu_options.allow_growth = True
sess = tf.Session(config=configsession)
input_ids = tf.placeholder(shape=[64, 128], dtype=tf.int32, name="input_ids")
input_mask = tf.placeholder(shape=[64, 128], dtype=tf.int32, name="input_mask")
segment_ids = tf.placeholder(shape=[64, 128], dtype=tf.int32, name="segment_ids")
with sess.as_default():
model = modeling.BertModel(
config=bert_config,
is_training=True,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())# 这里尤其注意,先初始化,在加载参数。这里和demo1是有区别的
saver.restore(sess, pathname)
print(1)
Python
1
https://gitee.com/small_code_programming_ai/bert_train.git
git@gitee.com:small_code_programming_ai/bert_train.git
small_code_programming_ai
bert_train
bert_train
master

搜索帮助