Skip to content

Instantly share code, notes, and snippets.

@jjerphan
Last active September 29, 2023 07:22
Show Gist options
  • Save jjerphan/bb6a9545e4bbbcb4a7d1b15af8e1b5d1 to your computer and use it in GitHub Desktop.
Save jjerphan/bb6a9545e4bbbcb4a7d1b15af8e1b5d1 to your computer and use it in GitHub Desktop.
Cython Extension Types' Methods' dispatch using vtable (extracted from https://github.com/scikit-learn/scikit-learn/pull/20254#discussion_r716904109)

Does self.distance_metric.rdist use a v-table look up? (I am curious. This may not be actionable)

Yes, it does.


V-table implementation details
  • a struct is defined for the base class (here DistanceMetric):
struct __pyx_vtabstruct_7sklearn_7metrics_13_dist_metrics_DistanceMetric {
  __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*dist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_ITYPE_t);
  __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*rdist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_ITYPE_t);
  __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*csr_dist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice);
  __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*csr_rdist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice);
  int (*pdist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __Pyx_memviewslice, __Pyx_memviewslice);
  int (*cdist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice);
  __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*_rdist_to_dist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t);
  __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*_dist_to_rdist)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t);
};
  • a struct is created for each subclass (for instance here EuclideanDistance) and wraps the original struct as __pyx_base.
struct __pyx_vtabstruct_7sklearn_7metrics_13_dist_metrics_EuclideanDistance {
  struct __pyx_vtabstruct_7sklearn_7metrics_13_dist_metrics_DistanceMetric __pyx_base;
};
static struct __pyx_vtabstruct_7sklearn_7metrics_13_dist_metrics_EuclideanDistance *__pyx_vtabptr_7sklearn_7metrics_13_dist_metrics_EuclideanDistance;
  • subclasses' methods are converted to functions (not shown bellow), __pyx_base get set as the base and subclasses' methods get bounds to the table:
  __pyx_vtabptr_7sklearn_7metrics_13_dist_metrics_EuclideanDistance = &__pyx_vtable_7sklearn_7metrics_13_dist_metrics_EuclideanDistance;
  __pyx_vtable_7sklearn_7metrics_13_dist_metrics_EuclideanDistance.__pyx_base = *__pyx_vtabptr_7sklearn_7metrics_13_dist_metrics_DistanceMetric;
  __pyx_vtable_7sklearn_7metrics_13_dist_metrics_EuclideanDistance.__pyx_base.dist = (__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_ITYPE_t))__pyx_f_7sklearn_7metrics_13_dist_metrics_17EuclideanDistance_dist;
  __pyx_vtable_7sklearn_7metrics_13_dist_metrics_EuclideanDistance.__pyx_base.rdist = (__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const *, __pyx_t_7sklearn_5utils_9_typedefs_ITYPE_t))__pyx_f_7sklearn_7metrics_13_dist_metrics_17EuclideanDistance_rdist;
  __pyx_vtable_7sklearn_7metrics_13_dist_metrics_EuclideanDistance.__pyx_base._rdist_to_dist = (__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t))__pyx_f_7sklearn_7metrics_13_dist_metrics_17EuclideanDistance__rdist_to_dist;
  __pyx_vtable_7sklearn_7metrics_13_dist_metrics_EuclideanDistance.__pyx_base._dist_to_rdist = (__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t (*)(struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_DistanceMetric *, __pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t))__pyx_f_7sklearn_7metrics_13_dist_metrics_17EuclideanDistance__dist_to_rdist;
  • when a subclass object is created, its __pyx_base v-tab is set to the one of its class.
static PyObject *__pyx_tp_new_7sklearn_7metrics_13_dist_metrics_EuclideanDistance(PyTypeObject *t, PyObject *a, PyObject *k) {
  struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_EuclideanDistance *p;
  #if CYTHON_COMPILING_IN_LIMITED_API
  newfunc new_func = (newfunc)PyType_GetSlot(__pyx_ptype_7sklearn_7metrics_13_dist_metrics_DistanceMetric, Py_tp_new);
  PyObject *o = new_func(t, a, k);
  #else
  PyObject *o = __pyx_tp_new_7sklearn_7metrics_13_dist_metrics_DistanceMetric(t, a, k);
  #endif
  if (unlikely(!o)) return 0;
  p = ((struct __pyx_obj_7sklearn_7metrics_13_dist_metrics_EuclideanDistance *)o);
  p->__pyx_base.__pyx_vtab = (struct __pyx_vtabstruct_7sklearn_7metrics_13_dist_metrics_DistanceMetric*)__pyx_vtabptr_7sklearn_7metrics_13_dist_metrics_EuclideanDistance;
  return o;
}
  • on call sites, the dispatch is done via look-ups, e.g. on the code you commented:
  /* "sklearn/metrics/_dist_metrics.pyx":1334
 *     @final
 *     cdef DTYPE_t proxy_dist(self, ITYPE_t i, ITYPE_t j) nogil:
 *         return self.distance_metric.rdist(&self.X[i, 0],             # <<<<<<<<<<<<<<
 *                                           &self.Y[j, 0],
 *                                           self.d)
 */
  __pyx_t_5 = ((struct __pyx_vtabstruct_7sklearn_7metrics_13_dist_metrics_DistanceMetric *)__pyx_v_self->__pyx_base.distance_metric->__pyx_vtab)->rdist(__pyx_v_self->__pyx_base.distance_metric, (&(*((__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const  *) ( /* dim=1 */ ((char *) (((__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const  *) ( /* dim=0 */ (__pyx_v_self->X.data + __pyx_t_1 * __pyx_v_self->X.strides[0]) )) + __pyx_t_2)) )))), (&(*((__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const  *) ( /* dim=1 */ ((char *) (((__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t const  *) ( /* dim=0 */ (__pyx_v_self->Y.data + __pyx_t_3 * __pyx_v_self->Y.strides[0]) )) + __pyx_t_4)) )))), __pyx_v_self->d); if (unlikely(__pyx_t_5 == ((__pyx_t_7sklearn_5utils_9_typedefs_DTYPE_t)-1.0))) __PYX_ERR(1, 1334, __pyx_L1_error)
  __pyx_r = __pyx_t_5;
  goto __pyx_L0;

I do not know much about C compilers, but are static-qualified definitions helping with optimizations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment