1 Star 2 Fork 1

陈狗翔 / dueling-DQN-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ddqn.py 4.90 KB
一键复制 编辑 原始数据 按行查看 历史
chen 提交于 2019-11-07 16:27 . dqn update parameters file deleted
import gym
import torch
import torch.nn as nn
import numpy as np
from collections import deque
import random
from itertools import count
import torch.nn.functional as F
from tensorboardX import SummaryWriter
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class QNetwork(nn.Module):
def __init__(self):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(4, 64)
self.relu = nn.ReLU()
self.fc_value = nn.Linear(64, 256)
self.fc_adv = nn.Linear(64, 256)
self.value = nn.Linear(256, 1)
self.adv = nn.Linear(256, 2)
def forward(self, state):
y = self.relu(self.fc1(state))
value = self.relu(self.fc_value(y))
adv = self.relu(self.fc_adv(y))
value = self.value(value)
adv = self.adv(adv)
advAverage = torch.mean(adv, dim=1, keepdim=True)
Q = value + adv - advAverage
return Q
def select_action(self, state):
with torch.no_grad():
Q = self.forward(state)
action_index = torch.argmax(Q, dim=1)
return action_index.item()
class Memory(object):
def __init__(self, memory_size: int) -> None:
self.memory_size = memory_size
self.buffer = deque(maxlen=self.memory_size)
def add(self, experience) -> None:
self.buffer.append(experience)
def size(self):
return len(self.buffer)
def sample(self, batch_size: int, continuous: bool = True):
if batch_size > len(self.buffer):
batch_size = len(self.buffer)
if continuous:
rand = random.randint(0, len(self.buffer) - batch_size)
return [self.buffer[i] for i in range(rand, rand + batch_size)]
else:
indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False)
return [self.buffer[i] for i in indexes]
def clear(self):
self.buffer.clear()
env = gym.make('CartPole-v0')
n_state = env.observation_space.shape[0]
n_action = env.action_space.n
onlineQNetwork = QNetwork().to(device)
targetQNetwork = QNetwork().to(device)
targetQNetwork.load_state_dict(onlineQNetwork.state_dict())
optimizer = torch.optim.Adam(onlineQNetwork.parameters(), lr=1e-4)
GAMMA = 0.99
EXPLORE = 20000
INITIAL_EPSILON = 0.1
FINAL_EPSILON = 0.0001
REPLAY_MEMORY = 50000
BATCH = 16
UPDATE_STEPS = 4
memory_replay = Memory(REPLAY_MEMORY)
epsilon = INITIAL_EPSILON
learn_steps = 0
writer = SummaryWriter('logs/ddqn')
begin_learn = False
episode_reward = 0
# onlineQNetwork.load_state_dict(torch.load('ddqn-policy.para'))
for epoch in count():
state = env.reset()
episode_reward = 0
for time_steps in range(200):
p = random.random()
if p < epsilon:
action = random.randint(0, 1)
else:
tensor_state = torch.FloatTensor(state).unsqueeze(0).to(device)
action = onlineQNetwork.select_action(tensor_state)
next_state, reward, done, _ = env.step(action)
episode_reward += reward
memory_replay.add((state, next_state, action, reward, done))
if memory_replay.size() > 128:
if begin_learn is False:
print('learn begin!')
begin_learn = True
learn_steps += 1
if learn_steps % UPDATE_STEPS == 0:
targetQNetwork.load_state_dict(onlineQNetwork.state_dict())
batch = memory_replay.sample(BATCH, False)
batch_state, batch_next_state, batch_action, batch_reward, batch_done = zip(*batch)
batch_state = torch.FloatTensor(batch_state).to(device)
batch_next_state = torch.FloatTensor(batch_next_state).to(device)
batch_action = torch.FloatTensor(batch_action).unsqueeze(1).to(device)
batch_reward = torch.FloatTensor(batch_reward).unsqueeze(1).to(device)
batch_done = torch.FloatTensor(batch_done).unsqueeze(1).to(device)
with torch.no_grad():
onlineQ_next = onlineQNetwork(batch_next_state)
targetQ_next = targetQNetwork(batch_next_state)
online_max_action = torch.argmax(onlineQ_next, dim=1, keepdim=True)
y = batch_reward + (1 - batch_done) * GAMMA * targetQ_next.gather(1, online_max_action.long())
loss = F.mse_loss(onlineQNetwork(batch_state).gather(1, batch_action.long()), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), global_step=learn_steps)
if epsilon > FINAL_EPSILON:
epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
if done:
break
state = next_state
writer.add_scalar('episode reward', episode_reward, global_step=epoch)
if epoch % 10 == 0:
torch.save(onlineQNetwork.state_dict(), 'ddqn-policy.para')
print('Ep {}\tMoving average score: {:.2f}\t'.format(epoch, episode_reward))
Python
1
https://gitee.com/ChenGouXiang/dueling-DQN-pytorch.git
git@gitee.com:ChenGouXiang/dueling-DQN-pytorch.git
ChenGouXiang
dueling-DQN-pytorch
dueling-DQN-pytorch
master

搜索帮助