Created
October 28, 2024 15:24
-
-
Save HiroIshida/5d0abbec0fdd0d140c5262eb1301c98e to your computer and use it in GitHub Desktop.
lerobot pr
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import torch | |
from hashlib import sha256 | |
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy, DiffusionConfig | |
def run(camera_names, rgb_encoder_per_camera): | |
torch.manual_seed(0) | |
resol = 56 | |
input_shapes = {"observation.state": [6]} | |
for name in camera_names: | |
input_shapes[f"observation.image.{name}"] = [3, resol, resol] | |
output_shapes = {"action": [6]} | |
# determine normalizer | |
normalization_mode = {"observation.state": "min_max"} | |
for name in camera_names: | |
normalization_mode[f"observation.image.{name}"] = "mean_std" | |
# load stats | |
with open('stats-diffusion.yaml', 'r') as f: | |
stats = yaml.load(f, Loader=yaml.FullLoader) | |
effective_keys = list(output_shapes.keys()) + list(input_shapes.keys()) + ["episode_index", "frame_indx", "index", "next.done", "timestamp"] | |
effective_key_set = set(effective_keys) | |
for key, value in stats.items(): | |
if key not in effective_key_set: | |
continue | |
inner_dict = {} | |
for key_inner, value_inner in value.items(): | |
inner_dict[key_inner] = torch.tensor(value_inner) | |
stats[key] = inner_dict | |
# create policy | |
conf = DiffusionConfig(input_shapes=input_shapes, output_shapes=output_shapes, input_normalization_modes=normalization_mode, crop_shape=None, num_inference_steps=5) | |
if rgb_encoder_per_camera: | |
conf.use_separate_rgb_encoder_per_camera = True | |
policy = DiffusionPolicy(conf, dataset_stats=stats) | |
# inference | |
observation = {"observation.state": torch.rand(1, 6)} | |
for name in camera_names: | |
observation[f"observation.image.{name}"] = torch.rand(1, 3, resol, resol) | |
action = policy.select_action(observation) | |
# compute hash of action | |
hash_value = sha256(action.numpy().tobytes()).hexdigest() | |
print("hash value of output: ", hash_value) | |
if __name__ == "__main__": | |
# case 1 | |
print(f"test with only gripper camera") | |
camera_names = ["gripper"] | |
run(camera_names, False) | |
# case 2 | |
print(f"test with gripper and webcam cameras where rgb_encoder_per_camera is not set") | |
camera_names = ["gripper", "webcam"] | |
run(camera_names, False) | |
# case 3 | |
print(f"test with gripper and webcam cameras where rgb_encoder_per_camera is set") | |
has_rgb_encoder_per_camera = "use_separate_rgb_encoder_per_camera" in DiffusionConfig.__annotations__ | |
if not has_rgb_encoder_per_camera: | |
print("DiffusionConfig does not have attribute rgb_encoder_per_camera") | |
else: | |
camera_names = ["gripper", "webcam"] | |
run(camera_names, True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment