如果A大于B,请求解DQN

技术标签: Python  神经网络  深度学习  凯拉斯

我最近进入了神经网络。我目前正在尝试DQN。我已经能够让他们与众多的atari教程一起工作,但发现我无法理解整个问题,因此我为DQN编写了一个简单的模式,该模式仅比B大,如果给出1,如果没有给出1给出0。然后,根据是否正确,将神经网的得分为1或0。不幸的是,我无法让它学习这个简单的问题。有人可以帮助我吗?

# -*- coding: utf-8 -*-
import random
import numpy as np
import env
import gym
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

EPISODES = 1000


class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        #self.gamma = 0.95    # discount rate
        self.gamma = 0  # discount rate
        self.epsilon = 0.5  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.985
        self.learning_rate = 0.001
        self.model = self._build_model()

    def _build_model(self):
        # Neural Net for Deep-Q learning Model
        model = Sequential()
        model.add(Dense(2, input_dim=self.state_size, activation='relu'))
        model.add(Dense(2, activation='relu'))
        model.add(Dense(2, activation='relu'))
        model.add(Dense(self.action_size, activation='relu'))
        model.compile(loss='mse',
                      optimizer=Adam(lr=self.learning_rate))
        return model

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
           return random.randrange(self.action_size)
        action = np.argmax(self.model.predict(state)[0])
        return action

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            print("stating")
            target = reward
            if not done:
                target = reward + self.gamma * np.amax(self.model.predict(next_state)[0])
            target_f = self.model.predict(state)
            target_f[0][action] = target
            print("Reward: " + str(reward))
            print("Target: " + str(target))
            print(action)
            print(self.gamma*np.amax(self.model.predict(next_state)[0]))
            print(state)
            print(target_f)
            self.model.fit(state,target_f, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)


if __name__ == "__main__":
    state_size = 2
    action_size = 2
    #timeRange = 1440
    timeRange = 998
    agent = DQNAgent(state_size, action_size)
    # agent.load("./save/cartpole-master.h5")
    done = False
    batch_size = 500
    totalScore=0
    for e in range(EPISODES):
        env_state = env.GameState()
        env_state.reset()
        state = env_state.step(0)[1]
        #state = np.reshape(state, [1, state_size])
        totalreward = 0
        for time in range(timeRange):
            #if time==timeRange-1:
                #done = True
            action = agent.act(state)
            reward, next_state = env_state.step(action) 

            totalreward += reward
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print("episode: {}/{}, score: {}, e: {:.2}"
                      .format(e, EPISODES, env_state.money+env_state.shares*env_state.sharePrice, agent.epsilon))
                totalScore+=env_state.money+env_state.shares*env_state.sharePrice
                done = False
                break
        if len(agent.memory) > batch_size:
            print("replaying")
            print("Total reward: "+str(totalreward))
            totalreward = 0
            #print(agent.memory)
            agent.replay(batch_size)
        # if e % 10 == 0:
        #     agent.save("./save/cartpole.h5")

看答案

这来自您在模型中使用的激活:

def _build_model(self):
    # Neural Net for Deep-Q learning Model
    model = Sequential()
    model.add(Dense(2, input_dim=self.state_size, activation='relu'))
    model.add(Dense(2, activation='relu'))
    model.add(Dense(2, activation='relu'))
    model.add(Dense(self.action_size, activation='relu'))
    model.compile(loss='mse',
                  optimizer=Adam(lr=self.learning_rate))
    return model

使用Relu激活无法学习&GT;或&lt;操作。当您考虑一下时,Relu不能仅通过线性操作来学习一个数字是否大于另一个数字。

但是,当您更改激活时:

def _build_model(self):
    # Neural Net for Deep-Q learning Model
    model = Sequential()
    model.add(Dense(2, input_dim=self.state_size, activation='sigmoid'))
    model.add(Dense(2, activation='sigmoid'))
    model.add(Dense(2, activation='sigmoid'))
    model.add(Dense(self.action_size, activation='sigmoid'))
    model.compile(loss='mse',
                  optimizer=Adam(lr=self.learning_rate))
    return model

现在,您还出现了其他类型的非线性性,使网络可以学习那些显然“简单”的比较任务。

我希望这有帮助 :-)


智能推荐

A页面发起请求跳转到B页面后,如何在B页面获取A页面的数据?

如何获取在B页面获取对应请求A页面的数据? 我们需要运用以下的几点知识就能进行操作: 1、首先需要在A页面运用ajax()的请求来获取自己想要的数据。 2、我们需要运用浏览器的缓存localStorage或者sessionStorage,在浏览器中保存自己想要获取的数据。 3、在跳转到B页面后,就可以在B页面根据localStorage或者sessionStorage缓存来获取你想要拿到的数据。 ...

a = b?

1. 问题描述: 个人觉得代码没有问题,但是结果就是不对 2. 原因 在函数调用的过程中可以改变传输过去的变量,就算这个变量没有返还回来。最好不要改变原变量,新建一个变量。 3.另外要说明的两个list或者array特别容易出现的问题: 在改变b的时候,a也被改变了。...

A/B

1 2 3 4 5 6 7 8 9 10 import java.util.Scanner; import java.math.BigInteger; public class Main{     public static void main(String[] args){        ...

矩阵分析:求解Ax=b

b=0时,为齐次线性方程组。R(A)=n时,即A的行列式D不为0,有唯一零解;R(A)<n时,即D=0,无穷多解。 b不为0,非齐次线性方程组,R(A)=R(B)增广矩阵时,方程有解,否则方程无解。R(A)=R(B)=n,有唯一零解;R(A)=R(B)<n,无穷解。 求解Ax=b:可解性和解的结构 对于求解Ax=b,首先我们要判断:   ① 是否有解?   ② 若有...

矩阵分析:求解Ax=b

b=0时,为齐次线性方程组。R(A)=n时,即A的行列式D不为0,有唯一零解;R(A)<n时,即D=0,无穷多解。 b不为0,非齐次线性方程组,R(A)=R(B)增广矩阵时,方程有解,否则方程无解。R(A)=R(B)=n,有唯一零解;R(A)=R(B)<n,无穷解。 求解Ax=b:可解性和解的结构 对于求解Ax=b,首先我们要判断:  ① 是否有解?  ② 若有解,解是否唯一? 先化为...

猜你喜欢

HTTPS请求和HTTP请求解析

https默认端口号 443 http默认端口号 80 https和http对比: 安全性:https相对于http请求安全性高,因为https采用的是混合加密算法(对称加密和非对称加密混合),http采用的是对称加密算法 效率:http效率相对于https高,因为https请求请求链长 https请求解析: 查看证书:...

PAT (Basic Level)1022 D进制的A+B 2/4测试点不通过求解答

1022 D进制的A+B (20 分) 输入两个非负 10 进制整数 A 和 B (≤230−1),输出 A+B 的 D (1<D≤10)进制数。 输入格式: 输入在一行中依次给出 3 个整数 A、B 和 D。 输出格式: 输出 A+B 的 D 进制数。 输入样例: 123 456 8 输出样例: 1103 请问为什么会有不通过的部分...

自定义类加载器

 自定义类加载器 我们如果需要自定义类加载器,只需要继承ClassLoader类,并覆盖掉findClass方法即可。 自定义文件类加载器     自定义网络类加载器 热部署类加载器 当我们调用loadClass方法加载类时,会采用双亲委派模式,即如果类已经被加载,就从缓存中获取,不会重新加载。如果同一个class被同一个类加载器多次加载,则会报错。因此,我们要实现热...

用户界面和兼容性测试

用户界面测试 1 、导航测试 导航直观 Web系统的主要部分可通过主页存取 Web系统不需要站点地图、搜索引擎或其他的导航帮助 Web应用系统的页面结构、导航、菜单、连接的风格一致 2 、图形测试 图形有明确的用途 所有页面字体的风格一致。 背景颜色与字体颜色和前景颜色相搭配。 图片的大小减小到 30k 以下 文字回绕正确 3 、内容测试 Web应用系统提供的信息是正确的 信息无语法或拼写错误 可...

基于ECS部署LAMP环境搭建Drupal网站,云计算技术与应用报告

实验环境: 建站环境:Windows操作系统,基于ECS部署LAMP环境,阿里云资源, Web服务器:Apache,关联的数据库:MySQ PHP:Drupal 8 要求的PHP版本為7.0.33的版本 实验内容和要求:基于ECS部署LAMP环境搭建Drupal网站,drupal是一个好用且功能强大的内容管理系统(CMS),通常也被称为是内容管理框架(CMF),由来自全世界各地的开发人员共同开发和...

问答精选

How we can create Dataproc cluster through rest API or http request?

I am new in python, Here I want to create dataproc cluster using http request. I am following below dataproc documentation where they mentioned in REST API section. see below https://cloud.google.com/...

AddWithValue method on ASP.NET

I am using AddStringWithValue method in ASP.NET using C# My HTML code is My C# code for the method is: The problem is, it is giving red underline under email and password. Shouldn't we identify them w...

How to apply css using a condition?

I'm trying to apply this css: this works well, the problem is that the web app can set a class on the body called white-content, if the white-content class is setted, then I can't see the text of h2, ...

Tile game collision detection with sprite moving to arbitary (x,y)

So I am struggling with some logic for collision detection in my game. I have a grid of tiles(images), all representative of a value in a 2D array, so the location of tile N would be (column m, row n)...

Kotin sort by descending then ascending

Im trying to order a list on multiple parameters.. for example, one value descending, second value ascending, third value descending. is there a way like this to do it? (i know is incorrect) people = ...

相关问题

相关文章

热门文章

推荐文章

相关标签

推荐问答