Skip to content

Instantly share code, notes, and snippets.

@HiroIshida
Created October 28, 2024 15:24
Show Gist options
  • Save HiroIshida/5d0abbec0fdd0d140c5262eb1301c98e to your computer and use it in GitHub Desktop.
Save HiroIshida/5d0abbec0fdd0d140c5262eb1301c98e to your computer and use it in GitHub Desktop.
lerobot pr
import yaml
import torch
from hashlib import sha256
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy, DiffusionConfig
def run(camera_names, rgb_encoder_per_camera):
torch.manual_seed(0)
resol = 56
input_shapes = {"observation.state": [6]}
for name in camera_names:
input_shapes[f"observation.image.{name}"] = [3, resol, resol]
output_shapes = {"action": [6]}
# determine normalizer
normalization_mode = {"observation.state": "min_max"}
for name in camera_names:
normalization_mode[f"observation.image.{name}"] = "mean_std"
# load stats
with open('stats-diffusion.yaml', 'r') as f:
stats = yaml.load(f, Loader=yaml.FullLoader)
effective_keys = list(output_shapes.keys()) + list(input_shapes.keys()) + ["episode_index", "frame_indx", "index", "next.done", "timestamp"]
effective_key_set = set(effective_keys)
for key, value in stats.items():
if key not in effective_key_set:
continue
inner_dict = {}
for key_inner, value_inner in value.items():
inner_dict[key_inner] = torch.tensor(value_inner)
stats[key] = inner_dict
# create policy
conf = DiffusionConfig(input_shapes=input_shapes, output_shapes=output_shapes, input_normalization_modes=normalization_mode, crop_shape=None, num_inference_steps=5)
if rgb_encoder_per_camera:
conf.use_separate_rgb_encoder_per_camera = True
policy = DiffusionPolicy(conf, dataset_stats=stats)
# inference
observation = {"observation.state": torch.rand(1, 6)}
for name in camera_names:
observation[f"observation.image.{name}"] = torch.rand(1, 3, resol, resol)
action = policy.select_action(observation)
# compute hash of action
hash_value = sha256(action.numpy().tobytes()).hexdigest()
print("hash value of output: ", hash_value)
if __name__ == "__main__":
# case 1
print(f"test with only gripper camera")
camera_names = ["gripper"]
run(camera_names, False)
# case 2
print(f"test with gripper and webcam cameras where rgb_encoder_per_camera is not set")
camera_names = ["gripper", "webcam"]
run(camera_names, False)
# case 3
print(f"test with gripper and webcam cameras where rgb_encoder_per_camera is set")
has_rgb_encoder_per_camera = "use_separate_rgb_encoder_per_camera" in DiffusionConfig.__annotations__
if not has_rgb_encoder_per_camera:
print("DiffusionConfig does not have attribute rgb_encoder_per_camera")
else:
camera_names = ["gripper", "webcam"]
run(camera_names, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment