Skip to content

Instantly share code, notes, and snippets.

@normandipalo
Last active May 13, 2024 01:10
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save normandipalo/fbc21f23606fbe3d407e22c363cb134e to your computer and use it in GitHub Desktop.
Save normandipalo/fbc21f23606fbe3d407e22c363cb134e to your computer and use it in GitHub Desktop.
Code snippet for the one-shot imitation learning phase of DINOBot (alignment + replay).
"""
In this script, we demonstrate how to use DINOBot to do one-shot imitation learning.
You first need to install the following repo and its requirements: https://github.com/ShirAmir/dino-vit-features.
You can then run this file inside that repo.
There are a few setup-dependent functions you need to implement, like getting an RGBD observation from the camera
or moving the robot, that you will find on top of this file.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms,utils
from PIL import Image
import torchvision.transforms as T
import warnings
import glob
import time
warnings.filterwarnings("ignore")
#Install this DINO repo to extract correspondences: https://github.com/ShirAmir/dino-vit-features
from correspondences import find_correspondences, draw_correspondences
#Hyperparameters for DINO correspondences extraction
num_pairs = 8
load_size = 224
layer = 9
facet = 'key'
bin=True
thresh=0.05
model_type='dino_vits8'
stride=4
#Deployment hyperparameters
ERR_THRESHOLD = 50 #A generic error between the two sets of points
#Here are the functions you need to create based on your setup.
def camera_get_rgbd():
"""
Outputs a tuple (rgb, depth) taken from a wrist camera.
The two observations should have the same dimension.
"""
raise NotImplementedError
def project_to_3d(points, depth, intrinsics):
"""
Inputs: points: list of [x,y] pixel coordinates,
depth (H,W,1) observations from camera.
intrinsics: intrinsics of the camera, used to
project pixels to 3D space.
Outputs: point_with_depth: list of [x,y,z] coordinates.
Projects the selected pixels to 3D space using intrinsics and
depth value. Based on your setup the implementation may vary,
but here you can find a simple example or the explicit formula:
https://www.open3d.org/docs/0.6.0/python_api/open3d.geometry.create_point_cloud_from_rgbd_image.html.
"""
raise NotImplementedError
def robot_move(t_meters,R):
"""
Inputs: t_meters: (x,y,z) translation in end-effector frame
R: (3x3) array - rotation matrix in end-effector frame
Moves and rotates the robot according to the input translation and rotation.
"""
raise NotImplementedError
def record_demo()
"""
Record a demonstration by moving the end-effector, and stores velocities
that can then be replayed by the "replay_demo" function.
"""
raise NotImplementedError
def replay_demo(demo)
"""
Inputs: demo: list of velocities that can then be executed by the end-effector.
Replays a demonstration by moving the end-effector given recorded velocities.
"""
raise NotImplementedError
def find_transformation(X, Y):
"""
Inputs: X, Y: lists of 3D points
Outputs: R - 3x3 rotation matrix, t - 3-dim translation array.
Find transformation given two sets of correspondences between 3D points.
"""
# Calculate centroids
cX = np.mean(X, axis=0)
cY = np.mean(Y, axis=0)
# Subtract centroids to obtain centered sets of points
Xc = X - cX
Yc = Y - cY
# Calculate covariance matrix
C = np.dot(Xc.T, Yc)
# Compute SVD
U, S, Vt = np.linalg.svd(C)
# Determine rotation matrix
R = np.dot(Vt.T, U.T)
# Determine translation vector
t = cY - np.dot(R, cX)
return R, t
def compute_error(points1, points2):
return np.linalg.norm(np.array(points1) - np.array(points2))
if __name__ == “__main__”:
# RECORD DEMO:
# Move the end-effector to the bottleneck pose and store observation.
#Get rgbd from wrist camera.
rgb_bn, depth_bn = camera_get_rgbd()
#Record demonstration.
demo_vels = record_demo()
# TEST TIME DEPLOYMENT
# Move/change the object and move the end-effector to the home (or a random) pose.
while 1:
error = 100000
while error > ERR_THRESHOLD:
#Collect observations at the current pose.
rgb_live, depth_live = camera_get_rgbd()
#Compute pixel correspondences between new observation and bottleneck observation.
with torch.no_grad():
points1, points2, image1_pil, image2_pil = find_correspondences(rgb_bn, rgb_live, num_pairs, load_size, layer,
facet, bin, thresh, model_type, stride)
#Given the pixel coordinates of the correspondences, and their depth values,
#project the points to 3D space.
points1 = project_to_3d(points1, depth_bn, intrinsics)
points2 = project_to_3d(points2, depth_live, intrinsics)
#Find rigid translation and rotation that aligns the points by minimising error, using SVD.
R, t = find_transformation(points1, points2)
#Move robot
robot.move(t_meters,R)
error = compute_error(points1, points2)
#Once error is small enough, replay demo.
replay_demo(demo_vels)
@Bailey-24
Copy link

Thanks for you great job. I think dino v2 is better than dino, will you implement using dinov2?

An other question, how if we don't have depth_bn(like using DALLE3 to generate goal image), how to find the transformation?

@normandipalo
Copy link
Author

@Bailey-24 thank you for your comment.
Regarding DINOv2, we did not find particular improvements using it, and furthermore its larger patch size can hinder the granularity of pixel matching unless you compute more overlapping patches. We do plan to keep exploring the Vision Foundation Models space (e.g. the use of registers in DINO https://arxiv.org/abs/2309.16588), and will update our code if noteworthy improvements are achieved.

If you do not have depth, you can approximate the needed movement along z by allowing for scaling the 2D points when matching those. You can then compute how much to move closer to the object based on how much you needed to scale the live points to match the bottleneck points. However, this is an approximation and cannot take into account things like different object dimensions.

@qidihan
Copy link

qidihan commented Apr 25, 2024

This is a great project. I recently viewed a demo on your website featuring an experiment where a cup is placed on a tree branch. The multi-stage task involving bottleneck images caught my attention.

I am curious to know if the technique involves sequentially reading a bottleneck image and then identifying and focusing on the corresponding object in the current image for each stage of the task. Could you please confirm if my understanding is correct? Additionally, I would appreciate any further details or documentation on this method

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment