Skip to content

Instantly share code, notes, and snippets.

@chriselion
Created November 7, 2019 22:57
Show Gist options
  • Save chriselion/3714d05255eea2f9132b96a182fbdcaa to your computer and use it in GitHub Desktop.
Save chriselion/3714d05255eea2f9132b96a182fbdcaa to your computer and use it in GitHub Desktop.
One-off script used to convert demo protos.
"""
One-off script used to convert demo protos.
This takes the old stacked observation data and moves it to an Observation proto
"""
import pathlib
import logging
import os
from typing import List, Tuple
from mlagents.trainers.buffer import Buffer
from mlagents.envs.brain import BrainParameters, BrainInfo
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents.envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto, CompressionTypeProto
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
from google.protobuf.internal.encoder import _EncodeVarint
from mlagents.trainers.demo_loader import load_demonstration
# First 32 bytes of file dedicated to meta-data.
INITIAL_POS = 33
def load_convert_demonstration(file_path: str):
"""
Loads and parses a demonstration file.
:param file_path: Location of demonstration file (.demo).
:return: BrainParameter and list of BrainInfos containing demonstration data.
"""
file_paths = []
if os.path.isdir(file_path):
all_files = os.listdir(file_path)
for _file in all_files:
if _file.endswith(".demo"):
file_paths.append(os.path.join(file_path, _file))
if not all_files:
raise ValueError("There are no '.demo' files in the provided directory.")
elif os.path.isfile(file_path):
file_paths.append(file_path)
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":
raise ValueError(
"The file is not a '.demo' file. Please provide a file with the "
"correct extension."
)
else:
raise FileNotFoundError(
"The demonstration file or directory {} does not exist.".format(file_path)
)
brain_params = None
brain_param_proto = None
meta_data_proto = None
brain_infos = []
agent_info_protos = []
total_expected = 0
for _file_path in file_paths:
data = open(_file_path, "rb").read()
next_pos, pos, obs_decoded = 0, 0, 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
pos += next_pos
if obs_decoded > 1:
agent_info = AgentInfoProto()
agent_info.ParseFromString(data[pos : pos + next_pos])
# Convert
obs = ObservationProto(
float_data=ObservationProto.FloatData(data=agent_info.stacked_vector_observation),
shape=[len(agent_info.stacked_vector_observation)],
compression_type=CompressionTypeProto.NONE
)
agent_info.observations.append(obs)
agent_info_protos.append(agent_info)
if brain_params is None:
brain_params = BrainParameters.from_proto(
brain_param_proto, agent_info
)
brain_info = BrainInfo.from_agent_proto(0, [agent_info], brain_params)
brain_infos.append(brain_info)
if len(brain_infos) == total_expected:
break
pos += next_pos
obs_decoded += 1
return meta_data_proto, brain_param_proto, agent_info_protos, brain_params, brain_infos
def write_delimited(f, message):
msg_string = message.SerializeToString()
msg_size = len(msg_string)
_EncodeVarint(f.write, msg_size)
f.write(msg_string)
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos):
with open (demo_path, "wb") as f:
# write metadata
write_delimited(f, meta_data_proto)
f.seek(INITIAL_POS)
write_delimited(f, brain_param_proto)
for agent in agent_info_protos:
write_delimited(f, agent)
def convert(demo_file):
meta_data_proto, brain_param_proto, agent_info_protos, brain_params, brain_infos = load_convert_demonstration(demo_file)
demo_out = demo_file.replace(".demo", ".converted.demo")
write_demo(demo_out, meta_data_proto, brain_param_proto, agent_info_protos)
brain_params_conv, brain_infos_conv, _ = load_demonstration(demo_out)
assert brain_params_conv
assert brain_infos_conv
assert str(brain_params) == str(brain_params_conv)
assert len(brain_infos) == len(brain_infos_conv)
for a, b in zip(brain_infos, brain_infos_conv):
if a.visual_observations or b.visual_observations:
assert len(a.visual_observations) == len(b.visual_observations)
for va, vb in zip(a.visual_observations, b.visual_observations):
assert len(va) == len(vb)
for vaa, vbb in zip(va, vb):
assert (vaa == vbb).all()
assert (a.vector_observations == b.vector_observations).all()
assert a.rewards == b.rewards
assert a.local_done == b.local_done
assert a.max_reached == b.max_reached
assert a.agents == b.agents
assert (a.previous_vector_actions == b.previous_vector_actions).all()
assert (a.action_masks == b.action_masks).all()
def rename(demo_file):
old_path = demo_file
new_path = demo_file.replace(".demo", ".converted.demo")
os.replace(new_path, old_path)
def main():
files = [
"demos/Expert3DBall.demo",
"demos/Expert3DBallHard.demo",
"demos/ExpertBasic.demo",
"demos/ExpertBouncer.demo",
"demos/ExpertCrawlerDyn.demo",
"demos/ExpertCrawlerSta.demo",
"demos/ExpertFood.demo",
"demos/ExpertGrid.demo",
"demos/ExpertHallway.demo",
"demos/ExpertPush.demo",
"demos/ExpertPyramid.demo",
"demos/ExpertReacher.demo",
"demos/ExpertTennis.demo",
"demos/ExpertWalker.demo",
"ml-agents/mlagents/trainers/tests/test.demo",
"ml-agents/mlagents/trainers/tests/test_demo_dir/test.demo",
"ml-agents/mlagents/trainers/tests/test_demo_dir/test2.demo",
"ml-agents/mlagents/trainers/tests/test_demo_dir/test3.demo",
"ml-agents/mlagents/trainers/tests/testdcvis.demo",
]
for file in files:
print(f"Converting {file}...")
convert(file)
rename(file)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment