This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package rl.dqn.reinforcement.dqn.nn | |
import java.util | |
import org.deeplearning4j.rl4j.learning.Learning | |
import org.deeplearning4j.rl4j.learning.sync.Transition | |
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning | |
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning.{QLConfiguration, QLStepReturn} | |
import org.deeplearning4j.rl4j.mdp.MDP | |
import org.deeplearning4j.rl4j.network.dqn.{DQN, IDQN} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
unstacked_output_type operator()( | |
const std::vector<Tensor>& step_inputs, | |
const hidden_type& input_hidden, | |
const cell_params& params, | |
bool pre_compute_input = false) const { | |
std::vector<Tensor> step_outputs; | |
auto hidden = input_hidden; | |
for (const auto& input : step_inputs) { | |
std::cout << "step input " << std::endl; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
output_type operator()( | |
const Tensor& inputs, | |
const hidden_type& input_hidden, | |
const cell_params& params) const override { | |
if (inputs.device().is_cpu()) { | |
const auto inputs_w = params.linear_ih(inputs); | |
auto hidden = cell_(inputs_w, input_hidden, params, true); | |
return {hidden_as_output(hidden), hidden}; | |
} | |
auto unstacked_output = (*this)(inputs.unbind(0), input_hidden, params); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* ex10.cu | |
* | |
* Created on: Aug 23, 2022 | |
* Author: | |
*/ | |
#include <stdlib.h> | |
#include <string.h> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void testLockFloatAdd() { | |
int n = 1; | |
int nBytes = n * sizeof(float); | |
float* ha = (float*)malloc(nBytes); | |
*ha = 0; | |
float *da = nullptr; | |
CHECK(cudaMalloc((float **)&da, nBytes)); | |
CHECK(cudaMemcpy(da, ha, nBytes, cudaMemcpyHostToDevice)); |