1 Star 0 Fork 0

陈狗翔 / AlphaZero_Gomoku

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
policy_value_net_keras.py 4.78 KB
一键复制 编辑 原始数据 按行查看 历史
junxiaosong 提交于 2018-03-23 15:38 . minor fix in keras support
# -*- coding: utf-8 -*-
"""
An implementation of the policyValueNet with Keras
Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend
@author: Mingxu Zhang
"""
from __future__ import print_function
from keras.engine.topology import Input
from keras.engine.training import Model
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.merge import Add
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras.optimizers import Adam
import keras.backend as K
from keras.utils import np_utils
import numpy as np
import pickle
class PolicyValueNet():
"""policy-value network """
def __init__(self, board_width, board_height, model_file=None):
self.board_width = board_width
self.board_height = board_height
self.l2_const = 1e-4 # coef of l2 penalty
self.create_policy_value_net()
self._loss_train_op()
if model_file:
net_params = pickle.load(open(model_file, 'rb'))
self.model.set_weights(net_params)
def create_policy_value_net(self):
"""create the policy value network """
in_x = network = Input((4, self.board_width, self.board_height))
# conv layers
network = Conv2D(filters=32, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
network = Conv2D(filters=64, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
network = Conv2D(filters=128, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
# action policy layers
policy_net = Conv2D(filters=4, kernel_size=(1, 1), data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
policy_net = Flatten()(policy_net)
self.policy_net = Dense(self.board_width*self.board_height, activation="softmax", kernel_regularizer=l2(self.l2_const))(policy_net)
# state value layers
value_net = Conv2D(filters=2, kernel_size=(1, 1), data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
value_net = Flatten()(value_net)
value_net = Dense(64, kernel_regularizer=l2(self.l2_const))(value_net)
self.value_net = Dense(1, activation="tanh", kernel_regularizer=l2(self.l2_const))(value_net)
self.model = Model(in_x, [self.policy_net, self.value_net])
def policy_value(state_input):
state_input_union = np.array(state_input)
results = self.model.predict_on_batch(state_input_union)
return results
self.policy_value = policy_value
def policy_value_fn(self, board):
"""
input: board
output: a list of (action, probability) tuples for each available action and the score of the board state
"""
legal_positions = board.availables
current_state = board.current_state()
act_probs, value = self.policy_value(current_state.reshape(-1, 4, self.board_width, self.board_height))
act_probs = zip(legal_positions, act_probs.flatten()[legal_positions])
return act_probs, value[0][0]
def _loss_train_op(self):
"""
Three loss terms:
loss = (z - v)^2 + pi^T * log(p) + c||theta||^2
"""
# get the train op
opt = Adam()
losses = ['categorical_crossentropy', 'mean_squared_error']
self.model.compile(optimizer=opt, loss=losses)
def self_entropy(probs):
return -np.mean(np.sum(probs * np.log(probs + 1e-10), axis=1))
def train_step(state_input, mcts_probs, winner, learning_rate):
state_input_union = np.array(state_input)
mcts_probs_union = np.array(mcts_probs)
winner_union = np.array(winner)
loss = self.model.evaluate(state_input_union, [mcts_probs_union, winner_union], batch_size=len(state_input), verbose=0)
action_probs, _ = self.model.predict_on_batch(state_input_union)
entropy = self_entropy(action_probs)
K.set_value(self.model.optimizer.lr, learning_rate)
self.model.fit(state_input_union, [mcts_probs_union, winner_union], batch_size=len(state_input), verbose=0)
return loss[0], entropy
self.train_step = train_step
def get_policy_param(self):
net_params = self.model.get_weights()
return net_params
def save_model(self, model_file):
""" save model params to file """
net_params = self.get_policy_param()
pickle.dump(net_params, open(model_file, 'wb'), protocol=2)
1
https://gitee.com/ChenGouXiang/AlphaZero_Gomoku.git
git@gitee.com:ChenGouXiang/AlphaZero_Gomoku.git
ChenGouXiang
AlphaZero_Gomoku
AlphaZero_Gomoku
master

搜索帮助