Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rzilleruelo/fae8b5f4e0948e28d2c1bde23d0830e9 to your computer and use it in GitHub Desktop.
Save rzilleruelo/fae8b5f4e0948e28d2c1bde23d0830e9 to your computer and use it in GitHub Desktop.
/*
Copyright 2021 Ricardo Zilleruelo
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
OPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
/*
The controller inference code is designed to be injected into the SPACEX - ISS Docking Simulator (https://iss-sim.spacex.com/).
Within the simulator webpage open the javascript console.
1. Import dependencies:
```javascript
await import('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.4.0/dist/tf.min.js');
```
2. Paste the code in this gist.
3. Start controller for ride around plan:
```javascript
let plan = initializeRideAroundPlan();
```
or start controller for relocate plan:
```javascript
let plan = initializeRelocatePlan();
```
3. Stop controller:
```javascript
plan.stop();
```
*/
const pitchSource = $x('/html//div[@id="pitch"]/*[1]')[0];
const yawSource = $x('/html//div[@id="yaw"]/*[1]')[0];
const rollSource = $x('/html//div[@id="roll"]/*[1]')[0];
const xSource = $x('/html//div[@id="x-range"]/div')[0];
const ySource = $x('/html//div[@id="y-range"]/div')[0];
const zSource = $x('/html//div[@id="z-range"]/div')[0];
class ProportionalLayer extends tf.layers.Layer {
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const target_x = inputs[0];
const x = inputs[1];
return target_x.sub(x);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {name: this.name});
return config;
}
static get className() {
return 'ProportionalLayer';
}
}
tf.serialization.registerClass(ProportionalLayer);
class AngularProportionalLayer extends tf.layers.Layer {
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const target_x = inputs[0];
const x = inputs[1];
const errorA = target_x.sub(x);
const errorB = target_x.add(tf.tensor(360.0)).sub(x);
return tf.where(
tf.lessEqual(tf.abs(errorA), tf.abs(errorB)),
errorA,
errorB,
);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {name: this.name});
return config;
}
static get className() {
return 'AngularProportionalLayer';
}
}
tf.serialization.registerClass(AngularProportionalLayer);
class IntegralLayer extends tf.layers.Layer {
constructor(args) {
super(args);
this.units = args.units;
}
build(inputShape) {
self.max_error = this.addWeight(
'max_error',
[this.units],
'float32',
tf.initializers.glorotNormal(this.units),
);
}
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const last_integral = inputs[0];
const proportional = inputs[1];
const dt = inputs[2];
const max_error = self.max_error.read().sum().arraySync();
return (
last_integral
.sign()
.equal(proportional.sign())
.cast('float32')
.mul(last_integral)
.add(proportional.mul(dt))
.clipByValue(-max_error, max_error)
);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {units: this.units, name: this.name});
return config;
}
static get className() {
return 'IntegralLayer';
}
}
tf.serialization.registerClass(IntegralLayer);
class DerivativeLayer extends tf.layers.Layer {
computeOutputShape(inputShape) {
return inputShape[0];
}
call(inputs, kwargs) {
return tf.tidy(() => {
this.invokeCallHook(inputs, kwargs);
const proportional = inputs[0];
const last_proportional = inputs[1];
const dt = inputs[2];
return proportional.sub(last_proportional).div(dt);
});
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {name: this.name});
return config;
}
static get className() {
return 'DerivativeLayer';
}
}
tf.serialization.registerClass(DerivativeLayer);
const modelVersion = '20210820T203741';
const modelUri = (
'https://storage.googleapis.com/ec98ba871cca4e92888f62e6e4212756/spacex/models/controller_model/'
+ modelVersion
+ '/js/model.json'
)
const controllerModel = await tf.loadLayersModel(modelUri);
const angularModelVersion = '20210820T041020';
const angularModelUri = (
'https://storage.googleapis.com/ec98ba871cca4e92888f62e6e4212756/spacex/models/angular_controller_model/'
+ angularModelVersion
+ '/js/model.json'
)
const angularControllerModel = await tf.loadLayersModel(angularModelUri);
const getDistanceValue = (sensor) => {
return parseFloat(sensor.textContent.replace(' m', ''));
};
const getDegreesValue = (sensor) => {
return parseFloat(sensor.textContent.replace('°', ''));
};
const mapAngleToFullAngle = (angle) => {
if (0 <= angle && angle <= 180.0) {
return angle;
}
return angle + 360.0;
};
const getDegreesFullAngle = (source) => {
return mapAngleToFullAngle(getDegreesValue(source));
};
const getAttitudeState = () => {
return [
getDegreesFullAngle(rollSource),
getDegreesFullAngle(yawSource),
getDegreesFullAngle(pitchSource),
];
};
const getTranslationState = () => {
return [
getDistanceValue(xSource),
getDistanceValue(ySource),
getDistanceValue(zSource),
];
};
const initializeController = (controllerModel, configs, initialValues, label) => {
const stateSize = configs.length;
const stateArray = (value) => {
if (typeof value == 'function') {
return new Array(stateSize).fill().map((_, i) => value(i));
}
return new Array(stateSize).fill(value);
};
const memorySize = controllerModel.layers.find((layer) => layer.name == 'memory').cell.units;
return {
lastProportional: stateArray((i) => [configs[i].target - initialValues[i]]),
lastIntegral: stateArray([0.0]),
lastMemoryState: stateArray((i) => new Array(memorySize).fill(0.0)),
nudgeCount: stateArray(0),
nudge: function(i, action, derivative) {
if (action == 2 && derivative == 0 && this.lastProportional[i][0] != 0) {
this.nudgeCount[i] += 1;
if (this.nudgeCount[i] >= configs[i].nudgeCount) {
this.nudgeCount[i] = 0;
return this.lastProportional[i] > 0 ? 0 : 1;
}
}
return action;
},
nextAction: function(x, dt) {
const controllerOutput = controllerModel.predict([
tf.tensor(stateArray((i) => configs[i].target)),
tf.tensor(stateArray(dt)),
tf.tensor(x),
tf.tensor(this.lastProportional),
tf.tensor(this.lastIntegral),
tf.tensor(this.lastMemoryState),
]);
const actions_probabilities = controllerOutput[0].arraySync();
const actions = stateArray((i) => {
return actions_probabilities[i].reduce((jMax, x, j, arr) => x > arr[jMax] ? j : jMax, 0);
});
this.lastProportional = controllerOutput[1].arraySync();
this.lastIntegral = controllerOutput[2].arraySync();
this.lastMemoryState = controllerOutput[4].arraySync();
const derivatives = controllerOutput[3].arraySync();
return stateArray((i) => this.nudge(i, actions[i], derivatives[i][0]));
}
}
};
const getRadianValue = (sensor) => {
return getDegreesValue(sensor) * Math.PI / 180.0;
};
const projectErrorToCapsuleReferenceFrame = ([errorX, errorY, errorZ]) => {
errorZ = -errorZ;
let yaw = getRadianValue(yawSource);
yaw = (yaw >= 0.0 ? Math.PI : -Math.PI) - yaw;
const pitch = getRadianValue(pitchSource);
const roll = getRadianValue(rollSource);
const projectedErrors = {0: 0.0, 1: 0.0, 2: 0.0};
const errorXProjectedComponents = [
errorX * Math.cos(yaw) * Math.cos(pitch),
errorX * Math.sin(yaw) * Math.cos(pitch),
-errorX * Math.sin(pitch),
];
const errorXMaxComponent = errorXProjectedComponents.reduce(
(iMax, x, i, arr) => Math.abs(x) > Math.abs(arr[iMax]) ? i : iMax,
0,
);
projectedErrors[errorXMaxComponent] = errorXProjectedComponents[errorXMaxComponent];
const errorYProjectedComponents = [
errorY * (Math.cos(yaw) * Math.sin(pitch) * Math.sin(roll) - Math.sin(yaw) * Math.cos(roll)),
errorY * (Math.sin(yaw) * Math.sin(pitch) * Math.sin(roll) + Math.cos(yaw) * Math.cos(roll)),
errorY * Math.cos(pitch) * Math.sin(roll),
];
const errorYMaxComponent = errorYProjectedComponents.reduce(
(iMax, x, i, arr) => Math.abs(x) > Math.abs(arr[iMax]) ? i : iMax,
0,
);
if (Math.abs(projectedErrors[errorYMaxComponent]) < Math.abs(errorYProjectedComponents[errorYMaxComponent])) {
projectedErrors[errorYMaxComponent] = errorYProjectedComponents[errorYMaxComponent];
}
const errorZProjectedComponents = [
errorZ * (Math.cos(yaw) * Math.sin(pitch) * Math.cos(roll) + Math.sin(yaw) * Math.sin(roll)),
errorZ * (Math.sin(yaw) * Math.sin(pitch) * Math.cos(roll) - Math.cos(yaw) * Math.sin(roll)),
errorZ * Math.cos(pitch) * Math.cos(roll),
];
const errorZMaxComponent = errorZProjectedComponents.reduce(
(iMax, x, i, arr) => Math.abs(x) > Math.abs(arr[iMax]) ? i : iMax,
0,
);
if (Math.abs(projectedErrors[errorZMaxComponent]) < Math.abs(errorZProjectedComponents[errorZMaxComponent])) {
projectedErrors[errorZMaxComponent] = errorZProjectedComponents[errorZMaxComponent];
}
return [projectedErrors[0], projectedErrors[1], projectedErrors[2]];
};
const startControl = (config) => {
const attitudeController = initializeController(
config.attitude.controllerModel,
[config.attitude.roll, config.attitude.yaw, config.attitude.pitch],
getAttitudeState(),
'attitude',
);
const translationConfigs = [config.translation.x, config.translation.y, config.translation.z];
const translationController = initializeController(
config.translation.controllerModel,
translationConfigs,
projectErrorToCapsuleReferenceFrame(
getTranslationState().map((value, i) => translationConfigs[i].target - value)
).map((value, i) => value + translationConfigs[i].target),
'translation',
);
let lastT = Date.now() / 1000.0;
const control = function() {
const t = Date.now() / 1000.0;
const dt = t - lastT;
const [rollAction, yawAction, pitchAction] = attitudeController.nextAction(getAttitudeState(), dt);
const [xAction, yAction, zAction] = translationController.nextAction(
projectErrorToCapsuleReferenceFrame(
getTranslationState().map((value, i) => translationConfigs[i].target - value)
).map((value, i) => value + translationConfigs[i].target),
dt,
);
lastT = t;
if (rollAction == 0) {
$('#roll-left-button').click();
}
else if (rollAction == 1) {
$('#roll-right-button').click();
}
if (yawAction == 0) {
$('#yaw-left-button').click();
}
else if (yawAction == 1) {
$('#yaw-right-button').click();
}
if (pitchAction == 0) {
$('#pitch-up-button').click();
}
else if (pitchAction == 1) {
$('#pitch-down-button').click();
}
if (xAction == 0) {
$('#translate-backward-button').click();
}
else if (xAction == 1) {
$('#translate-forward-button').click();
}
if (yAction == 0) {
$('#translate-right-button').click();
} else if (yAction == 1) {
$('#translate-left-button').click();
}
if (zAction == 0) {
$('#translate-up-button').click();
} else if (zAction == 1) {
$('#translate-down-button').click();
}
};
const interval = {
run: true,
execute: function() {
if (this.run) {
const t = Date.now();
control();
const dt = config.controlRate - (Date.now() - t);
if (dt > 0) {
setTimeout(() => this.execute.apply(this), dt);
} else {
this.execute();
}
}
},
};
interval.execute();
return {stop: () => { interval.run = false; }};
};
const initializeRideAroundPlan = () => {
const controlRate = 250;
const translatitonNudgeCount = 4;
const angularNudgeCount = 4;
const config = {
controlRate: controlRate,
attitude: {
controllerModel: angularControllerModel,
roll: {target: 0.0, nudgeCount: angularNudgeCount},
yaw: {target: 0.0, nudgeCount: angularNudgeCount},
pitch: {target: 0.0, nudgeCount: angularNudgeCount},
},
translation: {
controllerModel: controllerModel,
x: {target: 0.0, nudgeCount: translatitonNudgeCount},
y: {target: 0.0, nudgeCount: translatitonNudgeCount},
z: {target: 0.0, nudgeCount: translatitonNudgeCount},
},
};
const viewTarget = {
x: 0.0,
y: 0.0,
z: 0.0,
}
const targets = [
{x: 158.6, y: 65.7, z: 34.1, threshold: 20.0},
{x: 98.0, y: 98.0, z: 57.4, threshold: 20.0},
{x: 39.8, y: 96.0, z: 69.4, threshold: 20.0},
{x: 0.0, y: 70.7, z: 70.7, threshold: 20.0},
{x: -43.3, y: 75.0, z: 50.0, threshold: 20.0},
{x: -83.7, y: 48.3, z: 25.9, threshold: 20.0},
{x: -100.0, y: 0.0, z: 0.0, threshold: 20.0},
{x: -83.7, y: -48.3, z: -25.9, threshold: 20.0},
{x: -43.3, y: -75.0, z: -50.0, threshold: 20.0},
{x: -0.0, y: -70.7, z: -70.7, threshold: 10.0},
{x: 24.7, y: -59.5, z: -43.1, threshold: 10.0},
{x: 35.9, y: -35.9, z: -21.0, threshold: 5.0},
{x: 29.4, y: -12.2, z: -6.3, threshold: 5.0},
{x: 29.4, y: 0.0, z: 0.0, threshold: 1.0},
{x: 0.0, y: 0.0, z: 0.0, threshold: -1.0},
];
let index = 0;
config.translation.x.target = targets[0].x;
config.translation.y.target = targets[0].y;
config.translation.z.target = targets[0].z;
console.info(
'x:', config.translation.x.target.toFixed(2),
'y:', config.translation.y.target.toFixed(2),
'z:', config.translation.z.target.toFixed(2),
);
const update = () => {
if (index + 1 >= targets.length) {
config.attitude.yaw.target = 0.0;
config.attitude.pitch.target = 0.0;
} else {
const round = (x) => Number(Math.round(x + 'e1') + 'e-1');
const [x, y, z] = getTranslationState();
let dX = viewTarget.x - x;
let dY = viewTarget.y - y;
let dZ = viewTarget.z - z;
const norm = Math.sqrt(dX * dX + dY * dY + dZ * dZ);
dX /= norm;
dY /= norm;
dZ /= norm;
const yaw = round(90.0 + Math.atan2(dX, -dY) * 180.0 / Math.PI);
const pitch = round(Math.atan2(dZ, Math.sqrt(dX * dX + dY * dY)) * 180.0 / Math.PI);
config.attitude.yaw.target = yaw;
config.attitude.pitch.target = pitch;
dX = config.translation.x.target - getDistanceValue(xSource);
dY = config.translation.y.target - getDistanceValue(ySource);
dZ = config.translation.z.target - getDistanceValue(zSource);
const distance = Math.sqrt(dX * dX + dY * dY + dZ * dZ);
if (distance <= targets[index].threshold) {
index += 1;
config.translation.x.target = targets[index].x;
config.translation.y.target = targets[index].y;
config.translation.z.target = targets[index].z;
console.info(
'x:', config.translation.x.target.toFixed(2),
'y:', config.translation.y.target.toFixed(2),
'z:', config.translation.z.target.toFixed(2),
);
}
}
}
const interval = setInterval(update, 100);
const controller = startControl(config);
return {
stop: () => {
controller.stop();
clearInterval(interval);
},
config: config,
}
};
let initializeRelocatePlan = () => {
const controlRate = 250;
const translatitonNudgeCount = 4;
const angularNudgeCount = 4;
const config = {
controlRate: controlRate,
attitude: {
controllerModel: angularControllerModel,
roll: {target: 0.0, nudgeCount: angularNudgeCount},
yaw: {target: 0.0, nudgeCount: angularNudgeCount},
pitch: {target: 0.0, nudgeCount: angularNudgeCount},
},
translation: {
controllerModel: controllerModel,
x: {target: 0.0, nudgeCount: translatitonNudgeCount},
y: {target: 0.0, nudgeCount: translatitonNudgeCount},
z: {target: 0.0, nudgeCount: translatitonNudgeCount},
},
};
const initialViewTarget = {
x: -21.0,
y: 12.0,
z: 0.0,
};
const finalViewTarget = {
x: 0.0,
y: 0.0,
z: 0.0,
}
const targets = [
{x: 0.0, y: 17.0, z: -5.0, threshold: 10.0},
{x: -21.0, y: 17.0, z: -5.0, threshold: 1.0},
{x: -21.0, y: 12.0, z: 0.0, threshold: 0.0},
{x: -21.0, y: 17.0, z: -5.0, threshold: 1.0},
{x: 0.0, y: 10.0, z: -10.0, threshold: 5.0},
{x: 10.0, y: 0.0, z: 0.0, threshold: 0.3},
{x: 0.0, y: 0.0, z: 0.0, threshold: 0.0},
];
let index = 0;
config.translation.x.target = targets[0].x;
config.translation.y.target = targets[0].y;
config.translation.z.target = targets[0].z;
console.info(
'x:', config.translation.x.target.toFixed(2),
'y:', config.translation.y.target.toFixed(2),
'z:', config.translation.z.target.toFixed(2),
);
const getDistance = (reference) => {
const dX = reference.x - getDistanceValue(xSource);
const dY = reference.y - getDistanceValue(ySource);
const dZ = reference.z - getDistanceValue(zSource);
return Math.sqrt(dX * dX + dY * dY + dZ * dZ);
};
const getViewAttitude = (viewTarget) => {
const round = (x) => Number(Math.round(x + 'e1') + 'e-1');
const [x, y, z] = getTranslationState();
let dX = viewTarget.x - x;
let dY = viewTarget.y - y;
let dZ = viewTarget.z - z;
const norm = Math.sqrt(dX * dX + dY * dY + dZ * dZ);
dX /= norm;
dY /= norm;
dZ /= norm;
const yaw = round(90.0 + Math.atan2(dX, -dY) * 180.0 / Math.PI);
const pitch = round(Math.atan2(dZ, Math.sqrt(dX * dX + dY * dY)) * 180.0 / Math.PI);
return [yaw, pitch];
};
const update = () => {
if (index + 1 >= targets.length) {
config.attitude.yaw.target = 0.0;
config.attitude.pitch.target = 0.0;
} else {
const [initialYaw, initialPitch] = getViewAttitude(initialViewTarget);
const [finalYaw, finalPitch] = getViewAttitude(finalViewTarget);
const initialDistance = getDistance(initialViewTarget);
const finalDistance = getDistance(finalViewTarget);
if (initialDistance < 5.0) {
config.attitude.yaw.target = 90.0;
config.attitude.pitch.target = 0.0;
} else if (finalDistance < 5.0) {
config.attitude.yaw.target = 0.0;
config.attitude.pitch.target = 0.0;
} else {
const lambda = initialDistance / (initialDistance + finalDistance);
config.attitude.yaw.target = (1 - lambda) * initialYaw + lambda * finalYaw;
config.attitude.pitch.target = (1 - lambda) * initialPitch + lambda * finalPitch;
}
const dX = config.translation.x.target - getDistanceValue(xSource);
const dY = config.translation.y.target - getDistanceValue(ySource);
const dZ = config.translation.z.target - getDistanceValue(zSource);
const distance = Math.sqrt(dX * dX + dY * dY + dZ * dZ);
if (distance <= targets[index].threshold) {
index += 1;
config.translation.x.target = targets[index].x;
config.translation.y.target = targets[index].y;
config.translation.z.target = targets[index].z;
console.info(
'x:', config.translation.x.target.toFixed(2),
'y:', config.translation.y.target.toFixed(2),
'z:', config.translation.z.target.toFixed(2),
);
}
}
}
const interval = setInterval(update, 100);
const controller = startControl(config);
return {
stop: () => {
controller.stop();
clearInterval(interval);
},
config: config,
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment