Skip to content

Instantly share code, notes, and snippets.

@oknoorap
Last active December 6, 2022 14:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oknoorap/5e88555877e9ee2fda6b87f9e9f77534 to your computer and use it in GitHub Desktop.
Save oknoorap/5e88555877e9ee2fda6b87f9e9f77534 to your computer and use it in GitHub Desktop.
DQN Agent Time-Series in Javascript
import ml5 from 'ml5';
import { TimeSeries, DataView } from 'pondjs';
import { DQN } from 'rl-js-dqn';
const data = await fetchTimeSeriesData();
const timeseries = new TimeSeries({
name: 'timeseries',
columns: ['time', 'value'],
points: data.map(d => [d.timestamp, d.value])
});
const view = new DataView();
view.addColumn('value', timeseries.column('value'));
const processedSeries = view.resample({
period: '1h',
aggregation: 'avg'
}).toJSON();
const environment = {
// Define the possible actions the agent can take
actions: ['buy', 'hold', 'sell'],
// Function to compute the reward for a given action
computeReward: action => {
// Calculate the reward based on the current state of the environment
// and the selected action, using a specific formula or logic
const reward = calculateReward(currentState, action);
return reward;
},
// Function to compute the next state of the environment
computeNextState: action => {
// Use the current state and selected action to determine the next state
const nextState = calculateNextState(currentState, action);
return nextState;
},
// Other parameters and settings for the environment
maxSteps: 1000,
initialState: processedSeries[0]
};
const agent = new DQN({
environment: environment,
// Other parameters and settings for the DQN agent
hiddenLayers: [32, 32],
gamma: 0.9,
epsilon: 0.1
});
// Use the trained agent to make predictions on the test data
const predictions = [];
let totalReward = 0;
for (let i = 0; i < processedTestSeries.length; i++) {
const currentState = processedTestSeries[i];
const action = agent.act(currentState);
const reward = environment.computeReward(action);
totalReward += reward;
predictions.push(action);
}
// Compute evaluation metrics
const accuracy = calculateAccuracy(predictions, testData);
const roi = calculateROI(totalReward);
console.log(`Test accuracy: ${accuracy}`);
console.log(`Total ROI: ${roi}`);
const model = {
agent: agent,
environment: environment
};
await ml5.save(model, 'time-series-model.json');
const tf = require("@tensorflow/tfjs-node");
// Define the DQN agent
const agent = {
// Create the model
model: tf.sequential(),
// Set the learning rate
learningRate: 0.01,
// Define the discount factor
discountFactor: 0.95,
// Define the exploration factor
explorationFactor: 0.1,
// Initialize the model
init: function () {
this.model.add(
tf.layers.dense({ units: 32, inputShape: [8], activation: "relu" })
);
this.model.add(tf.layers.dense({ units: 4, activation: "linear" }));
this.model.compile({ loss: "meanSquaredError", optimizer: "adam" });
},
// Select an action for a given state
selectAction: async function (state) {
// Choose a random action with probability equal to the exploration factor
if (Math.random() < this.explorationFactor) {
return Math.floor(Math.random() * 4);
}
// Otherwise, choose the action with the highest predicted Q-value
const qValues = this.model.predict(state);
return tf.argMax(qValues, 1).dataSync()[0];
},
// Update the model using a given batch of experiences
update: async function (batch) {
// Create the input and target tensors
const inputs = [];
const targets = [];
for (let experience of batch) {
const state = experience[0];
const action = experience[1];
const nextState = experience[2];
const reward = experience[3];
const done = experience[4];
// Compute the Q-value for the current state
const qValues = this.model.predict(state);
// If the episode has ended, the Q-value for the next state is 0
let nextQValue = 0;
if (!done) {
// Otherwise, use the predicted Q-values for the next state to compute the target Q-value
const nextQValues = this.model.predict(nextState);
nextQValue = tf.max(nextQValues).dataSync()[0];
}
// Update the target Q-value for the given action
targets.push(qValues.dataSync());
targets[targets.length - 1][action] =
reward + this.discountFactor * nextQValue;
// Use the current state as the input
inputs.push(state.dataSync());
}
// Train the model on the inputs and targets
await this.model.fit(tf.tensor2d(inputs), tf.tensor2d(targets), {
epochs: 1,
});
},
};
// Initialize the DQN agent
agent.init();
// Generate some synthetic data for training
const xs = tf.randomNormal([100, 8]);
const ys = tf.randomNormal([100, 4]);
// Train the agent using the data
for (let i = 0; i < 100; i++) {
// Select an action for the current state
const action = await agent.selectAction(xs[i]);
// Compute the reward and next state based on the action
const reward = computeReward(action);
const nextState = computeNextState(xs[i], action);
// Update the replay buffer
replayBuffer.push([xs[i], action, nextState, reward, false]);
if (replayBuffer.length > replayBufferMaxSize) {
replayBuffer.shift();
}
// Sample a random batch of experiences from the replay buffer
const batch = [];
for (let j = 0; j < 32; j++) {
const index = Math.floor(Math.random() * replayBuffer.length);
batch.push(replayBuffer[index]);
}
// Update the agent using the experience batch
agent.update(batch);
// Decrease the exploration factor over time
agent.explorationFactor *= 0.99;
}
// Function to calculate the reward for a given action
function calculateReward(state, action) {
// Get the current value and the next value from the state
const currentValue = state.get('value');
const nextValue = state.next().get('value');
let reward = 0;
// Calculate the reward based on the action and the change in value
if (action === 'buy') {
if (nextValue > currentValue) {
reward = nextValue - currentValue;
} else {
reward = -1;
}
} else if (action === 'hold') {
if (nextValue > currentValue) {
reward = 0.1;
} else if (nextValue === currentValue) {
reward = 0.01;
} else {
reward = -0.1;
}
} else if (action === 'sell') {
if (nextValue < currentValue) {
reward = currentValue - nextValue;
} else {
reward = -1;
}
}
return reward;
}
// Function to calculate the next state based on the current state and action
function calculateNextState(state, action) {
// Get the next state from the current state
const nextState = state.next();
return nextState;
}
// Function to calculate the accuracy of predictions
function calculateAccuracy(predictions, trueValues) {
let numCorrect = 0;
for (let i = 0; i < predictions.length; i++) {
if (predictions[i] === trueValues[i]) {
numCorrect++;
}
}
return numCorrect / predictions.length;
}
// Function to calculate the return on investment
function calculateROI(totalReward) {
return totalReward / initialInvestment;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment