Skip to content

Instantly share code, notes, and snippets.

@guru-florida
Created March 7, 2021 05:43
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save guru-florida/372d43ae336d34eb601bbfee68678e8d to your computer and use it in GitHub Desktop.
Save guru-florida/372d43ae336d34eb601bbfee68678e8d to your computer and use it in GitHub Desktop.
Model Training URDF parameters in Ros2 with Gazebo
#!/usr/bin/env python3
#
# model-training.py
# Use HyperOpt and Gazebo to train parameters of your robot. Perfoms a number
# of episodic simulations and optimizing the inputs to a xacro-enabled urdf file.
#
# Copyright 2020 FlyingEinstein.com
# Author: Colin F. MacKenzie
#
# Features:
# * will call your launch file via the Ros2 launch API.
# * will attempt to restart simulation by deleting and respawning robot without
# restarting Gazebo or launch file.
# * Sometimes ROS nodes crash or Gazebo freezes on sim restart. In this case,
# the whole launch file setup will be killed and restarted.
# * Can write episode results to log file training.csv
# * Can write current episode values to training.ods. If you record the training
# using OBS screen recorder then you can add the text overlay with this file
# and OBS will update the on-screen display as episodes play out.
#
# Requirements:
# * This was used for my project and I haven't gotten around to generalizing
# this code yet. So expect you may need to get intimate with how this code
# works.
# * review imu_callback and odom_callback which establishes a health value,
# or write a node to emit a "health value" and just subscribe to that.
# * Review episode_async where it sets up the episode and determines when it's
# finished.
# * determine what variables in your URDF you will optimize/learn. This can be
# any value in a URDF including Gazebo parameters.
# * Convert your urdf file to xacro if you haven't already. Replace the values
# you want to optimize with xacro variables.
# * Setup the config variable in the run() method. You will want to look at the
# Hyperopt library to see how these configs work. Each episode hyperopt will
# choose new values for your variables based on previous health scores and
# the episode_async will reparse your xacro urdf with the new set of variables
# and respawn the sim.
#
# Example arguments for training URDF parameters for the LSS humanoid model:
# -package lss_humanoid -xacro urdf/lss_humanoid.xacro.urdf -entity humanoid -z 0.3 -episodes 100
#
# Portions of this file were based on spawn_entity.py from
# Open Source Robotics Foundation and John Hsu, Dave Coleman
import argparse
import math
import os
import sys
import time
import asyncio
import threading
import psutil
from typing import List
from typing import Text
from typing import Tuple
from collections import OrderedDict
# Ros2 node imports
import rclpy
from launch.event_handlers import OnProcessIO
from rclpy.node import Node
from rclpy.qos import QoSDurabilityPolicy
from sensor_msgs.msg import Imu
from geometry_msgs.msg import Vector3
from geometry_msgs.msg import Pose
from nav_msgs.msg import Odometry
from gazebo_msgs.srv import SpawnEntity, DeleteEntity
# Ros2 launch API
import launch
import xacro
from ros2launch.api import get_share_file_path_from_package
from ament_index_python.packages import PackageNotFoundError, get_package_share_directory
# Hyperopt Optimization API
from hyperopt import hp, fmin, tpe, STATUS_OK, STATUS_FAIL, Trials
from hyperopt.mongoexp import MongoTrials
class EntityException(RuntimeError):
"""Raised when an entity service request error has occured"""
pass
class EntityOperationFailed(EntityException):
"""Raised when an entity service request has failed"""
pass
class EntityTimeout(EntityException):
"""Raised when an entity service request has timed out"""
pass
class TrainModelNode(Node):
simulation_task = None
event_loop = None
launch_service = None
optimize_thread = None
acc = None
acc_mix = 0.01
attempts = 0
active = False
fallen = False
distance = 0
direction = 0
avgAngularVelocity = None
startTs = None
currentTs = None
package_dir = None
xacro_urdf = None
tasks = {}
def __init__(self, args):
super().__init__('train_model')
parser = argparse.ArgumentParser(
description='Spawn an entity in gazebo. Gazebo must be started with gazebo_ros_init,\
gazebo_ros_factory and gazebo_ros_state for all functionalities to work')
parser.add_argument('-package', required=True, type=str, metavar='PKG_NAME',
help='The package containing the model we will train')
parser.add_argument('-xacro', required=True, type=str, metavar='FILE_NAME',
help='The xacro file to substitute with parameters')
parser.add_argument('-entity', required=True, type=str, metavar='ENTITY_NAME',
help='Name of entity to spawn')
parser.add_argument('-reference_frame', type=str, default='',
help='Name of the model/body where initial pose is defined.\
If left empty or specified as "world", gazebo world frame is used')
parser.add_argument('-gazebo_namespace', type=str, default='',
help='ROS namespace of gazebo offered ROS interfaces. \
Default is without any namespace')
parser.add_argument('-robot_namespace', type=str, default='',
help='change ROS namespace of gazebo-plugins')
parser.add_argument('-timeout', type=float, default=30.0,
help='Number of seconds to wait for the spawn and delete services to \
become available')
parser.add_argument('-wait', type=str, metavar='ENTITY_NAME',
help='Wait for entity to exist')
parser.add_argument('-spawn_service_timeout', type=float, metavar='TIMEOUT',
default=15.0, help='Spawn service wait timeout in seconds')
parser.add_argument('-episodes', type=int, default=100,
help='max number of training iterations')
parser.add_argument('-mongodb', type=str,
help='store trials in a Mongo database')
parser.add_argument('-expid', type=int,
help='optional expirement ID (use with mongo and hyperopt)')
parser.add_argument('-x', type=float, default=0,
help='x component of initial position, meters')
parser.add_argument('-y', type=float, default=0,
help='y component of initial position, meters')
parser.add_argument('-z', type=float, default=0,
help='z component of initial position, meters')
parser.add_argument('-R', type=float, default=0,
help='roll angle of initial orientation, radians')
parser.add_argument('-P', type=float, default=0,
help='pitch angle of initial orientation, radians')
parser.add_argument('-Y', type=float, default=0,
help='yaw angle of initial orientation, radians')
self.args = parser.parse_args(args[1:])
self.acc = Vector3()
# get the share location for the package containing the model
# we will be training
try:
self.package_dir = get_package_share_directory(self.args.package)
xacro_urdf_file = os.path.join(
self.package_dir,
self.args.xacro
)
self.get_logger().info(f'using xacro urdf file at {xacro_urdf_file}')
except PackageNotFoundError as e:
self.get_logger().error(f'cannot find share folder for package {self.args.package}')
exit(-2)
# URDF file in the form of a xacro file
# xacro is required since we are training parameters that are
# replaced during xacro parsing.
try:
self.xacro_urdf = open(xacro_urdf_file)
except OSError as e:
self.get_logger().error(f'cannot open xacro file: {xacro_urdf_file}')
exit(-2)
# subscribe to imu data so we know when the robot has fallen
sensor_qos = rclpy.qos.QoSPresetProfiles.get_from_short_key('sensor_data')
self.imu_data = self.create_subscription(
Imu,
'imu/data',
self.imu_callback,
sensor_qos)
self.imu_data = self.create_subscription(
Odometry,
'odom',
self.odom_callback,
sensor_qos)
def imu_callback(self, msg):
self.currentTs = msg.header.stamp
if self.startTs is None:
self.startTs = self.currentTs
prev_mix = 1 - self.acc_mix
self.acc.x = self.acc.x * prev_mix + msg.linear_acceleration.x * self.acc_mix
self.acc.y = self.acc.y * prev_mix + msg.linear_acceleration.y * self.acc_mix
self.acc.z = self.acc.z * prev_mix + msg.linear_acceleration.z * self.acc_mix
if 9 < abs(self.acc.z) < 9.85:
self.fallen = True
else:
self.fallen = False
angVel = math.sqrt(
msg.angular_velocity.x * msg.angular_velocity.x +
msg.angular_velocity.y * msg.angular_velocity.y +
msg.angular_velocity.z * msg.angular_velocity.z
)
self.avgAngularVelocity = self.avgAngularVelocity * prev_mix + angVel * self.acc_mix \
if self.avgAngularVelocity \
else angVel
# self.get_logger().info(' %s A:%2.4f,%2.4f,%2.4f' %
# ( 'fallen' if self.fallen else 'standing', self.acc.x, self.acc.y, self.acc.z))
def odom_callback(self, msg):
x = msg.pose.pose.position.x
y = msg.pose.pose.position.y
if self.currentTs.sec - self.startTs.sec > 2:
self.distance = round(math.sqrt(x * x + y * y), 2)
self.direction = round(math.atan2(y, x), 2)
else:
self.distance = 0
self.direction = 0
# self.get_logger().info(' odom: %2.4f @ %2.4f' % (self.distance, self.direction))
async def spawn_entity(self, entity_xml, initial_pose, timeout=10.0):
# originally from gazebo_ros_pkgs/gazebo_ros/scripts/spawn_entity.py
# but modified for asyncio operation with timeouts
if timeout < 0:
self.get_logger().error('spawn_entity timeout must be greater than zero')
return False
self.get_logger().debug(
'Waiting for service %s/spawn_entity, timeout = %.f' % (
self.args.gazebo_namespace, timeout))
self.get_logger().debug('Waiting for service %s/spawn_entity' % self.args.gazebo_namespace)
client = self.create_client(SpawnEntity, '%s/spawn_entity' % self.args.gazebo_namespace)
if client.wait_for_service(timeout_sec=timeout):
req = SpawnEntity.Request()
req.name = self.args.entity
req.xml = str(entity_xml, 'utf-8')
req.robot_namespace = self.args.robot_namespace
req.initial_pose = initial_pose
req.reference_frame = self.args.reference_frame
self.get_logger().debug('Calling service %s/spawn_entity' % self.args.gazebo_namespace)
try:
srv_call = await asyncio.wait_for(client.call_async(req), timeout=timeout)
if not srv_call.success:
raise EntityOperationFailed('spawn ' + req.name)
except asyncio.TimeoutError:
raise EntityTimeout('spawn ' + req.name + ' timeout')
else:
self.get_logger().error(
'Service %s/spawn_entity unavailable. Was Gazebo started with GazeboRosFactory?'
% self.args.gazebo_namespace)
raise EntityTimeout('spawn_entity service')
async def delete_entity(self, timeout=10.0, ignore_failure=False):
# originally from gazebo_ros_pkgs/gazebo_ros/scripts/spawn_entity.py
# but modified for asyncio operation with timeouts
self.get_logger().debug('Deleting entity [{}]'.format(self.args.entity))
client = self.create_client(
DeleteEntity, '%s/delete_entity' % self.args.gazebo_namespace)
if client.wait_for_service(timeout_sec=timeout):
req = DeleteEntity.Request()
req.name = self.args.entity
self.get_logger().debug(
'Calling service %s/delete_entity' % self.args.gazebo_namespace)
try:
#srv_call = await asyncio.wait_for(client.call_async(req), timeout=timeout)
srv_call_result = await asyncio.wait_for(client.call_async(req), timeout=timeout)
if not srv_call_result.success and not ignore_failure:
raise EntityOperationFailed('delete ' + req.name)
except asyncio.TimeoutError:
raise EntityTimeout('delete ' + req.name + ' timeout')
else:
self.get_logger().error(
'Service %s/delete_entity unavailable. ' +
'Was Gazebo started with GazeboRosFactory?' % self.args.gazebo_namespace)
if not ignore_failure:
raise EntityTimeout('delete_entity service')
# unfortunately the gzserver process often does not terminate when a launch is
# terminated so we will ensure there is no existing gzserver process before
# launching simulation again.
@staticmethod
def kill_gzserver():
for proc in psutil.process_iter():
# check whether the process name matches
if proc.name() == 'gzserver':
proc.kill()
# borrowed from launch_service.py in the Ros2 Launch API
@staticmethod
def parse_launch_arguments(launch_arguments: List[Text]) -> List[Tuple[Text, Text]]:
"""Parse the given launch arguments from the command line, into list of tuples for launch."""
parsed_launch_arguments = OrderedDict() # type: ignore
for argument in launch_arguments:
count = argument.count(':=')
if count == 0 or argument.startswith(':=') or (count == 1 and argument.endswith(':=')):
raise RuntimeError(
"malformed launch argument '{}', expected format '<name>:=<value>'"
.format(argument))
name, value = argument.split(':=', maxsplit=1)
parsed_launch_arguments[name] = value # last one wins is intentional
return parsed_launch_arguments.items()
# borrowed from launch_service.py in the Ros2 Launch API
def launch_a_launch_file(self, *, launch_file_path, launch_file_arguments, debug=False):
# want to set gzserver instance? set environment, then add env=node_env in Node() (wait, where?)
#node_env = os.environ.copy()
#node_env["PYTHONUNBUFFERED"] = "1" # dont buffer output
"""Launch a given launch file (by path) and pass it the given launch file arguments."""
launch_service = launch.LaunchService(argv=launch_file_arguments, debug=debug)
parsed_launch_arguments = self.parse_launch_arguments(launch_file_arguments)
# Include the user provided launch file using IncludeLaunchDescription so that the
# location of the current launch file is set.
launch_description = launch.LaunchDescription([
launch.actions.IncludeLaunchDescription(
launch.launch_description_sources.AnyLaunchDescriptionSource(
launch_file_path
),
launch_arguments=parsed_launch_arguments,
),
launch.actions.RegisterEventHandler(
OnProcessIO(
on_stdout=lambda info: print('>>>'+str(info.text)+'<<<'),
on_stderr=lambda info: print('***'+str(info.text)+'***')
)
)
])
launch_service.include_launch_description(launch_description)
return launch_service
async def launch_simulation(self, timeout=10.0):
path = get_share_file_path_from_package(
package_name=self.args.package,
file_name='simulation2.launch.py')
self.launch_service = self.launch_a_launch_file(
launch_file_path=path,
launch_file_arguments=["__log_level:=error"],
debug=False
)
# ensure gzserver isnt running
self.kill_gzserver()
self.simulation_task = asyncio.create_task(self.launch_service.run_async(
shutdown_when_idle=True
))
return self.simulation_task
async def kill_simulation(self, timeout=10.0, msg:str = None):
if self.simulation_task:
print('cancelling simulation\n')
await self.launch_service.shutdown()
#while(self.simulation_task.done()
self.simulation_task = None
#res = await self.simulation_task.result()
print('canceled simulation\n')
await asyncio.sleep(2.0)
def episode(self, config):
# This is a trampoline method to run the episode on the async event loop
fut = asyncio.run_coroutine_threadsafe(self.episode_async(config), self.event_loop)
# block and wait for future to return a result
return fut.result()
async def episode_async(self, config):
done = False
# respawn entity
self.currentTs = self.startTs = None
self.fallen = False
self.active = True
self.distance = 0
self.direction = 0
self.acc = Vector3()
self.write_current_config(config)
# convert all config mappings into strings
mappings = {k: str(v) for (k, v) in config.items()}
#print('config: ', mappings)
# convert the xacro into final file
self.xacro_urdf.seek(0, 0)
doc = xacro.parse(self.xacro_urdf)
xacro.process_doc(doc, mappings=mappings)
entity_xml = doc.toxml('utf-8')
# Form requested Pose from arguments
initial_pose = Pose()
initial_pose.position.x = float(self.args.x)
initial_pose.position.y = float(self.args.y)
initial_pose.position.z = float(self.args.z)
q = quaternion_from_euler(self.args.R, self.args.P, self.args.Y)
initial_pose.orientation.w = q[0]
initial_pose.orientation.x = q[1]
initial_pose.orientation.y = q[2]
initial_pose.orientation.z = q[3]
if self.simulation_task and self.simulation_task.done():
print("simulation is apparently DONE")
if not self.simulation_task or self.simulation_task.done():
# restart simulation
print("starting simulation")
try:
await self.launch_simulation()
except Exception as e:
print("exception launching simulation: ", e)
return {
'status': STATUS_FAIL
}
await asyncio.sleep(5.0)
try:
await self.spawn_entity(entity_xml, initial_pose, 30)
except EntityException as e:
self.get_logger().error('Spawn service failed: %s' % e)
await self.kill_simulation()
return {
'status': STATUS_FAIL
}
# spin the simulation
while not done:
duration = self.currentTs.sec - self.startTs.sec if self.currentTs else 0
if self.fallen or self.distance > 1 or duration > 30:
if self.currentTs:
score = self.distance * self.avgAngularVelocity
if self.fallen:
score = score * 10
result = {
'startTs': self.startTs.sec,
'currentTs': self.currentTs.sec,
'duration': self.currentTs.sec - self.startTs.sec,
'distance': self.distance,
'direction': self.direction,
'angVelocity': self.avgAngularVelocity,
'fell': self.fallen,
'score': score
}
self.write_episode_result(config, result)
# remove the entity
try:
await self.delete_entity()
time.sleep(1.0)
except EntityException as e:
self.get_logger().error('Delete entity failed: %s' % e)
# return data
return {
'loss': score,
'status': STATUS_OK,
# extra fields
'startTs': self.startTs.sec,
'currentTs': self.currentTs.sec,
'duration': self.currentTs.sec - self.startTs.sec,
'distance': self.distance,
'direction': self.direction,
'angVelocity': self.avgAngularVelocity,
'fell': self.fallen,
}
#rclpy.spin_once(self)
await asyncio.sleep(0.05)
# exited loop without before end condition
return {
'status': STATUS_FAIL
}
@staticmethod
def write_current_config(config):
columns = ['mu1', 'mu2', 'kp', 'kd']
values = "\n".join([(k + ": " + "{:.2f}".format(v)) for (k, v) in config.items() if k in columns])
cfile = open("current.txt", "w")
cfile.write(values)
cfile.close()
@staticmethod
def write_episode_result(config, result):
training_file = "training.csv"
columns = ['startTs', 'duration', 'distance', 'direction', 'angVelocity', 'fell', 'score', 'mu1', 'mu2', 'kp', 'kd']
combined = {**config, **result}
values = [str(combined[k]) for k in columns]
write_header = not os.path.exists(training_file)
# append config values to file
cfile = open(training_file, "a")
if write_header:
cfile.write(", ".join(columns) + "\n")
cfile.write(", ".join(values) + "\n")
cfile.close()
def optimize(self, config):
# perform optimization episodes with the given config
self.get_logger().info(f'Training {self.args.entity} over {self.args.episodes} episodes')
# Specify the search space and maximize score
if self.args.mongodb:
trials = MongoTrials(self.args.mongodb, exp_key=self.args.expid)
else:
trials = Trials()
best = fmin(
self.episode,
space=config,
algo=tpe.suggest,
max_evals=self.args.episodes,
trials=trials
)
print(best)
# shutting down Ros2 will shut down the program
rclpy.shutdown()
def run(self):
# our parameters to train
#config = {
# 'target': 'gazebo',
# 'kp': hp.uniform('kp', 10000, 40000),
# 'kd': hp.uniform('kd', 100, 1000),
# 'mu1': hp.uniform('mu1', 200, 500),
# 'mu2': hp.uniform('mu2', 200, 1000)
#}
config = {
'target': 'gazebo',
'kp': hp.uniform('kp', 1000, 40000),
'kd': hp.uniform('kd', 1, 1000),
'mu1': hp.uniform('mu1', 0.01, 1000),
'mu2': hp.uniform('mu2', 0.01, 2000)
}
self.event_loop = asyncio.get_event_loop()
# Optimizer is not asyncio compliant so we will run in a thread and have
# it send events back to the main event loop
self.optimize_thread = threading.Thread(target=self.optimize, args=(config,))
# this task will run forever to process ROS node events
async def rosloop():
while rclpy.ok():
rclpy.spin_once(self, timeout_sec=0)
await asyncio.sleep(0.01)
# perform optimization startup sequence here but from within the event loop
async def kickstart():
# first ensure the entity doesnt exist already
#await self.delete_entity(timeout=10.0, ignore_failure=True)
# now begin the optimization thread
self.optimize_thread.start()
# perform main event loop processing
try:
asyncio.ensure_future(rosloop())
asyncio.ensure_future(kickstart())
self.event_loop.run_forever()
except KeyboardInterrupt:
pass
finally:
self.optimize_thread.join(timeout=12.0)
self.event_loop.close()
self.event_loop = None
# borrowed from gazebo_ros_pkgs/gazebo_ros/scripts/spawn_entity.py
def quaternion_from_euler(roll, pitch, yaw):
cy = math.cos(yaw * 0.5)
sy = math.sin(yaw * 0.5)
cp = math.cos(pitch * 0.5)
sp = math.sin(pitch * 0.5)
cr = math.cos(roll * 0.5)
sr = math.sin(roll * 0.5)
q = [0] * 4
q[0] = cy * cp * cr + sy * sp * sr
q[1] = cy * cp * sr - sy * sp * cr
q[2] = sy * cp * sr + cy * sp * cr
q[3] = sy * cp * cr - cy * sp * sr
return q
def main(args=sys.argv):
rclpy.init(args=args)
args_without_ros = rclpy.utilities.remove_ros_args(args)
train_model_node = TrainModelNode(args_without_ros)
train_model_node.run()
rclpy.shutdown()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment