Skip to content

Instantly share code, notes, and snippets.

@rzilleruelo
Last active November 13, 2021 09:11
Show Gist options
  • Save rzilleruelo/8b2fd82f62202b7ed0c26d79d0194073 to your computer and use it in GitHub Desktop.
Save rzilleruelo/8b2fd82f62202b7ed0c26d79d0194073 to your computer and use it in GitHub Desktop.
/*
Copyright 2021 Ricardo Zilleruelo
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
OPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
/*
The controller inference code is designed to be injected into the SPACEX - ISS Docking Simulator (https://iss-sim.spacex.com/).
Within the simulator webpage open the javascript console.
1. Import dependencies:
```javascript
await import('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.4.0/dist/tf.min.js');
```
2. Paste the code in this gist.
3. Start controller:
```javascript
let plan = initializePlan();
```
3. Stop controller:
```javascript
plan.stop();
```
*/
const pitchSource = $x('/html//div[@id="pitch"]/*[1]')[0];
const yawSource = $x('/html//div[@id="yaw"]/*[1]')[0];
const rollSource = $x('/html//div[@id="roll"]/*[1]')[0];
const xSource = $x('/html//div[@id="x-range"]/div')[0];
const ySource = $x('/html//div[@id="y-range"]/div')[0];
const zSource = $x('/html//div[@id="z-range"]/div')[0];
class ProportionalLayer extends tf.layers.Layer {
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const target_x = inputs[0];
const x = inputs[1];
return target_x.sub(x);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {name: this.name});
return config;
}
static get className() {
return 'ProportionalLayer';
}
}
tf.serialization.registerClass(ProportionalLayer);
class AngularProportionalLayer extends tf.layers.Layer {
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const target_x = inputs[0];
const x = inputs[1];
const errorA = target_x.sub(x);
const errorB = target_x.add(tf.tensor(360.0)).sub(x);
return tf.where(
tf.lessEqual(tf.abs(errorA), tf.abs(errorB)),
errorA,
errorB,
);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {name: this.name});
return config;
}
static get className() {
return 'AngularProportionalLayer';
}
}
tf.serialization.registerClass(AngularProportionalLayer);
class IntegralLayer extends tf.layers.Layer {
constructor(args) {
super(args);
this.units = args.units;
}
build(inputShape) {
self.max_error = this.addWeight(
'max_error',
[this.units],
'float32',
tf.initializers.glorotNormal(this.units),
);
}
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const last_integral = inputs[0];
const proportional = inputs[1];
const dt = inputs[2];
const max_error = self.max_error.read().sum().arraySync();
return (
last_integral
.sign()
.equal(proportional.sign())
.cast('float32')
.mul(last_integral)
.add(proportional.mul(dt))
.clipByValue(-max_error, max_error)
);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {units: this.units, name: this.name});
return config;
}
static get className() {
return 'IntegralLayer';
}
}
tf.serialization.registerClass(IntegralLayer);
class DerivativeLayer extends tf.layers.Layer {
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const proportional = inputs[0];
const last_proportional = inputs[1];
const dt = inputs[2];
return proportional.sub(last_proportional).div(dt);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {name: this.name});
return config;
}
static get className() {
return 'DerivativeLayer';
}
}
tf.serialization.registerClass(DerivativeLayer);
const modelVersion = '20210820T203741';
const modelUri = (
'https://storage.googleapis.com/ec98ba871cca4e92888f62e6e4212756/spacex/models/controller_model/'
+ modelVersion
+ '/js/model.json'
)
const controllerModel = await tf.loadLayersModel(modelUri);
const angularModelVersion = '20210820T041020';
const angularModelUri = (
'https://storage.googleapis.com/ec98ba871cca4e92888f62e6e4212756/spacex/models/angular_controller_model/'
+ angularModelVersion
+ '/js/model.json'
)
const angularControllerModel = await tf.loadLayersModel(angularModelUri);
const getDistanceValue = (sensor) => {
return parseFloat(sensor.textContent.replace(' m', ''));
};
const getDegreesValue = (sensor) => {
return parseFloat(sensor.textContent.replace('°', ''));
};
const mapAngleToFullAngle = (angle) => {
if (0 <= angle && angle <= 180.0) {
return angle;
}
return angle + 360.0;
};
const getDegreesFullAngle = (source) => {
return mapAngleToFullAngle(getDegreesValue(source));
};
const getAttitudeState = () => {
return [
getDegreesFullAngle(rollSource),
getDegreesFullAngle(yawSource),
getDegreesFullAngle(pitchSource),
];
};
const getTranslationState = () => {
return [
getDistanceValue(xSource),
getDistanceValue(ySource),
getDistanceValue(zSource),
];
};
const initializeController = (controllerModel, configs, initialValues, label) => {
const stateSize = configs.length;
const stateArray = (value) => {
if (typeof value == 'function') {
return new Array(stateSize).fill().map((_, i) => value(i));
}
return new Array(stateSize).fill(value);
};
const memorySize = controllerModel.layers.find((layer) => layer.name == 'memory').cell.units;
return {
lastProportional: stateArray((i) => [configs[i].target - initialValues[i]]),
lastIntegral: stateArray([0.0]),
lastMemoryState: stateArray((i) => new Array(memorySize).fill(0.0)),
nudgeCount: stateArray(0),
nudge: function(i, action, derivative) {
if (action == 2 && derivative == 0 && this.lastProportional[i][0] != 0) {
this.nudgeCount[i] += 1;
if (this.nudgeCount[i] >= configs[i].nudgeCount) {
this.nudgeCount[i] = 0;
return this.lastProportional[i] > 0 ? 0 : 1;
}
}
return action;
},
nextAction: function(x, dt) {
const controllerOutput = controllerModel.predict([
tf.tensor(stateArray((i) => configs[i].target)),
tf.tensor(stateArray(dt)),
tf.tensor(x),
tf.tensor(this.lastProportional),
tf.tensor(this.lastIntegral),
tf.tensor(this.lastMemoryState),
]);
const actions_probabilities = controllerOutput[0].arraySync();
const actions = stateArray((i) => {
return actions_probabilities[i].reduce((jMax, x, j, arr) => x > arr[jMax] ? j : jMax, 0);
});
this.lastProportional = controllerOutput[1].arraySync();
this.lastIntegral = controllerOutput[2].arraySync();
this.lastMemoryState = controllerOutput[4].arraySync();
const derivatives = controllerOutput[3].arraySync();
return stateArray((i) => this.nudge(i, actions[i], derivatives[i][0]));
}
}
};
const startControl = (config) => {
const attitudeController = initializeController(
config.attitude.controllerModel,
[config.attitude.roll, config.attitude.yaw, config.attitude.pitch],
getAttitudeState(),
'attitude',
);
const translationController = initializeController(
config.translation.controllerModel,
[config.translation.x, config.translation.y, config.translation.z],
getTranslationState(),
'translation',
);
let lastT = Date.now() / 1000.0;
const control = function() {
const t = Date.now() / 1000.0;
const dt = t - lastT;
const [rollAction, yawAction, pitchAction] = attitudeController.nextAction(getAttitudeState(), dt);
const [xAction, yAction, zAction] = translationController.nextAction(getTranslationState(), dt);
lastT = t;
if (rollAction == 0) {
$('#roll-left-button').click();
}
else if (rollAction == 1) {
$('#roll-right-button').click();
}
if (yawAction == 0) {
$('#yaw-left-button').click();
}
else if (yawAction == 1) {
$('#yaw-right-button').click();
}
if (pitchAction == 0) {
$('#pitch-up-button').click();
}
else if (pitchAction == 1) {
$('#pitch-down-button').click();
}
if (xAction == 0) {
$('#translate-backward-button').click();
}
else if (xAction == 1) {
$('#translate-forward-button').click();
}
if (yAction == 0) {
$('#translate-right-button').click();
} else if (yAction == 1) {
$('#translate-left-button').click();
}
if (zAction == 0) {
$('#translate-up-button').click();
} else if (zAction == 1) {
$('#translate-down-button').click();
}
};
const interval = {
run: true,
execute: function() {
if (this.run) {
const t = Date.now();
control();
const dt = config.controlRate - (Date.now() - t);
if (dt > 0) {
setTimeout(() => this.execute.apply(this), dt);
} else {
this.execute();
}
}
},
};
interval.execute();
return {stop: () => { interval.run = false; }};
};
const initializePlan = () => {
const controlRate = 200;
const translatitonNudgeCount = 5;
const angularNudgeCount = 5;
const config = {
controlRate: controlRate,
attitude: {
controllerModel: angularControllerModel,
roll: {target: 0.0, nudgeCount: angularNudgeCount},
yaw: {target: 0.0, nudgeCount: angularNudgeCount},
pitch: {target: 0.0, nudgeCount: angularNudgeCount},
},
translation: {
controllerModel: controllerModel,
x: {target: 0.0, nudgeCount: translatitonNudgeCount},
y: {target: 0.0, nudgeCount: translatitonNudgeCount},
z: {target: 0.0, nudgeCount: translatitonNudgeCount},
},
};
const controller = startControl(config);
return {
stop: () => { controller.stop(); },
config: config,
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment