1 Star 0 Fork 2

小码编程AI / bert_train

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
demo3.py 1.93 KB
一键复制 编辑 原始数据 按行查看 历史
lee 提交于 2019-04-21 23:44 . 新增一个demo测试用
import tensorflow as tf
from bert import modeling
import create_input
import tokenization
import os
# 这里是下载下来的bert配置文件
bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json")
# 创建bert的输入
input_ids=tf.placeholder (shape=[64,128],dtype=tf.int32,name="input_ids")
input_mask=tf.placeholder (shape=[64,128],dtype=tf.int32,name="input_mask")
segment_ids=tf.placeholder (shape=[64,128],dtype=tf.int32,name="segment_ids")
# 创建bert模型
model = modeling.BertModel(
config=bert_config,
is_training=True,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False # 这里如果使用TPU 设置为True,速度会快些。使用CPU 或GPU 设置为False ,速度会快些。
)
#bert模型参数初始化的地方
init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
use_tpu = False
# 获取模型中所有的训练参数。
tvars = tf.trainable_variables()
# 加载BERT模型
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
# 打印加载模型的参数
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
vocab_file="chinese_L-12_H-768_A-12/vocab.txt"
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) #
a = create_input.convert_single_example(80 ,tokenizer, '我测试一下')
b = create_input.convert_single_example(80 ,tokenizer, '我打你一下')
print(a)
print(b)
Python
1
https://gitee.com/small_code_programming_ai/bert_train.git
git@gitee.com:small_code_programming_ai/bert_train.git
small_code_programming_ai
bert_train
bert_train
master

搜索帮助