Skip to content

Instantly share code, notes, and snippets.

@mkocabas
Created June 18, 2021 16:45
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save mkocabas/238faa8a24c44d60fb30ffce367f9ec7 to your computer and use it in GitHub Desktop.
Save mkocabas/238faa8a24c44d60fb30ffce367f9ec7 to your computer and use it in GitHub Desktop.

How to visualize the body part segmentation?

This script lets you to visualize the body part segmentation labels of SMPL, SMPL-H, and SMPL-X body models. Follow the instructions to be able to run this script.

Instructions

  1. Download the body models you would like to visualize.
  2. Install the requirements
  3. Run python visualize_body_part_segmentation.py <body_model> <body_model_path>

An example command python visualize_body_part_segmentation.py smpl data/body_models/smpl

Requirements

You need python3 to run this demo. Install the requirements below using pip package manager.

pip install matplotlib numpy trimesh smplx
import os
import sys
import json
import trimesh
import subprocess
import numpy as np
from smplx import SMPL, SMPLH, SMPLX
from matplotlib import cm as mpl_cm, colors as mpl_colors
def download_url(url, outdir):
print(f'Downloading files from {url}')
cmd = ['wget', '-c', url, '-P', outdir]
subprocess.call(cmd)
file_path = os.path.join(outdir, url.split('/')[-1])
return file_path
def part_segm_to_vertex_colors(part_segm, n_vertices, alpha=1.0):
vertex_labels = np.zeros(n_vertices)
for part_idx, (k, v) in enumerate(part_segm.items()):
vertex_labels[v] = part_idx
cm = mpl_cm.get_cmap('jet')
norm_gt = mpl_colors.Normalize()
vertex_colors = np.ones((n_vertices, 4))
vertex_colors[:, 3] = alpha
vertex_colors[:, :3] = cm(norm_gt(vertex_labels))[:, :3]
return vertex_colors
def main(body_model='smpl', body_model_path='body_models/smpl/'):
main_url = 'https://raw.githubusercontent.com/Meshcapade/wiki/main/assets/SMPL_body_segmentation/'
if body_model == 'smpl':
part_segm_url = os.path.join(main_url, 'smpl/smpl_vert_segmentation.json')
body_model = SMPL(model_path=body_model_path)
elif body_model == 'smplx':
part_segm_url = os.path.join(main_url, 'smplx/smplx_vert_segmentation.json')
body_model = SMPLX(model_path=body_model_path)
elif body_model == 'smplh':
part_segm_url = os.path.join(main_url, 'smpl/smpl_vert_segmentation.json')
body_model = SMPLH(model_path=body_model_path)
else:
raise ValueError(f'{body_model} is not defined, \"smpl\", \"smplh\" or \"smplx\" are valid body models')
part_segm_filepath = download_url(part_segm_url, '.')
part_segm = json.load(open(part_segm_filepath))
vertices = body_model().vertices[0].detach().numpy()
faces = body_model.faces
vertex_colors = part_segm_to_vertex_colors(part_segm, vertices.shape[0])
mesh = trimesh.Trimesh(vertices, faces, process=False, vertex_colors=vertex_colors)
mesh.show(background=(0,0,0,0))
if __name__ == '__main__':
main(sys.argv[1], sys.argv[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment