Created
November 7, 2019 22:57
-
-
Save chriselion/3714d05255eea2f9132b96a182fbdcaa to your computer and use it in GitHub Desktop.
One-off script used to convert demo protos.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
One-off script used to convert demo protos. | |
This takes the old stacked observation data and moves it to an Observation proto | |
""" | |
import pathlib | |
import logging | |
import os | |
from typing import List, Tuple | |
from mlagents.trainers.buffer import Buffer | |
from mlagents.envs.brain import BrainParameters, BrainInfo | |
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto | |
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto | |
from mlagents.envs.communicator_objects.demonstration_meta_pb2 import ( | |
DemonstrationMetaProto, | |
) | |
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto, CompressionTypeProto | |
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore | |
from google.protobuf.internal.encoder import _EncodeVarint | |
from mlagents.trainers.demo_loader import load_demonstration | |
# First 32 bytes of file dedicated to meta-data. | |
INITIAL_POS = 33 | |
def load_convert_demonstration(file_path: str): | |
""" | |
Loads and parses a demonstration file. | |
:param file_path: Location of demonstration file (.demo). | |
:return: BrainParameter and list of BrainInfos containing demonstration data. | |
""" | |
file_paths = [] | |
if os.path.isdir(file_path): | |
all_files = os.listdir(file_path) | |
for _file in all_files: | |
if _file.endswith(".demo"): | |
file_paths.append(os.path.join(file_path, _file)) | |
if not all_files: | |
raise ValueError("There are no '.demo' files in the provided directory.") | |
elif os.path.isfile(file_path): | |
file_paths.append(file_path) | |
file_extension = pathlib.Path(file_path).suffix | |
if file_extension != ".demo": | |
raise ValueError( | |
"The file is not a '.demo' file. Please provide a file with the " | |
"correct extension." | |
) | |
else: | |
raise FileNotFoundError( | |
"The demonstration file or directory {} does not exist.".format(file_path) | |
) | |
brain_params = None | |
brain_param_proto = None | |
meta_data_proto = None | |
brain_infos = [] | |
agent_info_protos = [] | |
total_expected = 0 | |
for _file_path in file_paths: | |
data = open(_file_path, "rb").read() | |
next_pos, pos, obs_decoded = 0, 0, 0 | |
while pos < len(data): | |
next_pos, pos = _DecodeVarint32(data, pos) | |
if obs_decoded == 0: | |
meta_data_proto = DemonstrationMetaProto() | |
meta_data_proto.ParseFromString(data[pos : pos + next_pos]) | |
total_expected += meta_data_proto.number_steps | |
pos = INITIAL_POS | |
if obs_decoded == 1: | |
brain_param_proto = BrainParametersProto() | |
brain_param_proto.ParseFromString(data[pos : pos + next_pos]) | |
pos += next_pos | |
if obs_decoded > 1: | |
agent_info = AgentInfoProto() | |
agent_info.ParseFromString(data[pos : pos + next_pos]) | |
# Convert | |
obs = ObservationProto( | |
float_data=ObservationProto.FloatData(data=agent_info.stacked_vector_observation), | |
shape=[len(agent_info.stacked_vector_observation)], | |
compression_type=CompressionTypeProto.NONE | |
) | |
agent_info.observations.append(obs) | |
agent_info_protos.append(agent_info) | |
if brain_params is None: | |
brain_params = BrainParameters.from_proto( | |
brain_param_proto, agent_info | |
) | |
brain_info = BrainInfo.from_agent_proto(0, [agent_info], brain_params) | |
brain_infos.append(brain_info) | |
if len(brain_infos) == total_expected: | |
break | |
pos += next_pos | |
obs_decoded += 1 | |
return meta_data_proto, brain_param_proto, agent_info_protos, brain_params, brain_infos | |
def write_delimited(f, message): | |
msg_string = message.SerializeToString() | |
msg_size = len(msg_string) | |
_EncodeVarint(f.write, msg_size) | |
f.write(msg_string) | |
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos): | |
with open (demo_path, "wb") as f: | |
# write metadata | |
write_delimited(f, meta_data_proto) | |
f.seek(INITIAL_POS) | |
write_delimited(f, brain_param_proto) | |
for agent in agent_info_protos: | |
write_delimited(f, agent) | |
def convert(demo_file): | |
meta_data_proto, brain_param_proto, agent_info_protos, brain_params, brain_infos = load_convert_demonstration(demo_file) | |
demo_out = demo_file.replace(".demo", ".converted.demo") | |
write_demo(demo_out, meta_data_proto, brain_param_proto, agent_info_protos) | |
brain_params_conv, brain_infos_conv, _ = load_demonstration(demo_out) | |
assert brain_params_conv | |
assert brain_infos_conv | |
assert str(brain_params) == str(brain_params_conv) | |
assert len(brain_infos) == len(brain_infos_conv) | |
for a, b in zip(brain_infos, brain_infos_conv): | |
if a.visual_observations or b.visual_observations: | |
assert len(a.visual_observations) == len(b.visual_observations) | |
for va, vb in zip(a.visual_observations, b.visual_observations): | |
assert len(va) == len(vb) | |
for vaa, vbb in zip(va, vb): | |
assert (vaa == vbb).all() | |
assert (a.vector_observations == b.vector_observations).all() | |
assert a.rewards == b.rewards | |
assert a.local_done == b.local_done | |
assert a.max_reached == b.max_reached | |
assert a.agents == b.agents | |
assert (a.previous_vector_actions == b.previous_vector_actions).all() | |
assert (a.action_masks == b.action_masks).all() | |
def rename(demo_file): | |
old_path = demo_file | |
new_path = demo_file.replace(".demo", ".converted.demo") | |
os.replace(new_path, old_path) | |
def main(): | |
files = [ | |
"demos/Expert3DBall.demo", | |
"demos/Expert3DBallHard.demo", | |
"demos/ExpertBasic.demo", | |
"demos/ExpertBouncer.demo", | |
"demos/ExpertCrawlerDyn.demo", | |
"demos/ExpertCrawlerSta.demo", | |
"demos/ExpertFood.demo", | |
"demos/ExpertGrid.demo", | |
"demos/ExpertHallway.demo", | |
"demos/ExpertPush.demo", | |
"demos/ExpertPyramid.demo", | |
"demos/ExpertReacher.demo", | |
"demos/ExpertTennis.demo", | |
"demos/ExpertWalker.demo", | |
"ml-agents/mlagents/trainers/tests/test.demo", | |
"ml-agents/mlagents/trainers/tests/test_demo_dir/test.demo", | |
"ml-agents/mlagents/trainers/tests/test_demo_dir/test2.demo", | |
"ml-agents/mlagents/trainers/tests/test_demo_dir/test3.demo", | |
"ml-agents/mlagents/trainers/tests/testdcvis.demo", | |
] | |
for file in files: | |
print(f"Converting {file}...") | |
convert(file) | |
rename(file) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment