Skip to content

Instantly share code, notes, and snippets.

@dutc
Last active August 29, 2015 14:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dutc/b7e82f587662d9f9a6a1 to your computer and use it in GitHub Desktop.
Save dutc/b7e82f587662d9f9a6a1 to your computer and use it in GitHub Desktop.
Did you mean? in Python (Bonus Round!)
#include <Python.h>
#include "didyoumean-safe.h"
static int safe_merge_list_attr(PyObject* dict, PyObject* obj, const char *attrname);
static int safe_merge_class_dict(PyObject* dict, PyObject* aclass);
static PyObject * safe_PyObject_GetAttr(PyObject *v, PyObject *name);
static PyObject * safe_PyObject_GetAttrString(PyObject *v, const char *name);
static PyObject * safe__generic_dir(PyObject *obj);
static PyObject * safe__specialized_dir_type(PyObject *obj);
static PyObject * safe__specialized_dir_module(PyObject *obj);
static PyObject * safe__dir_object(PyObject *obj);
static int
safe_merge_list_attr(PyObject* dict, PyObject* obj, const char *attrname)
{
PyObject *list;
int result = 0;
assert(PyDict_Check(dict));
assert(obj);
assert(attrname);
list = safe_PyObject_GetAttrString(obj, attrname);
if (list == NULL)
PyErr_Clear();
else if (PyList_Check(list)) {
int i;
for (i = 0; i < PyList_GET_SIZE(list); ++i) {
PyObject *item = PyList_GET_ITEM(list, i);
if (PyString_Check(item)) {
result = PyDict_SetItem(dict, item, Py_None);
if (result < 0)
break;
}
}
if (Py_Py3kWarningFlag &&
(strcmp(attrname, "__members__") == 0 ||
strcmp(attrname, "__methods__") == 0)) {
if (PyErr_WarnEx(PyExc_DeprecationWarning,
"__members__ and __methods__ not "
"supported in 3.x", 1) < 0) {
Py_XDECREF(list);
return -1;
}
}
}
Py_XDECREF(list);
return result;
}
static int
safe_merge_class_dict(PyObject* dict, PyObject* aclass)
{
PyObject *classdict;
PyObject *bases;
assert(PyDict_Check(dict));
assert(aclass);
/* Merge in the type's dict (if any). */
classdict = safe_PyObject_GetAttrString(aclass, "__dict__");
if (classdict == NULL)
PyErr_Clear();
else {
int status = PyDict_Update(dict, classdict);
Py_DECREF(classdict);
if (status < 0)
return -1;
}
/* Recursively merge in the base types' (if any) dicts. */
bases = safe_PyObject_GetAttrString(aclass, "__bases__");
if (bases == NULL)
PyErr_Clear();
else {
/* We have no guarantee that bases is a real tuple */
Py_ssize_t i, n;
n = PySequence_Size(bases); /* This better be right */
if (n < 0)
PyErr_Clear();
else {
for (i = 0; i < n; i++) {
int status;
PyObject *base = PySequence_GetItem(bases, i);
if (base == NULL) {
Py_DECREF(bases);
return -1;
}
status = safe_merge_class_dict(dict, base);
Py_DECREF(base);
if (status < 0) {
Py_DECREF(bases);
return -1;
}
}
}
Py_DECREF(bases);
}
return 0;
}
static PyObject *
safe_PyObject_GetAttr(PyObject *v, PyObject *name)
{
PyTypeObject *tp = Py_TYPE(v);
if (!PyString_Check(name)) {
#ifdef Py_USING_UNICODE
/* The Unicode to string conversion is done here because the
existing tp_getattro slots expect a string object as name
and we wouldn't want to break those. */
if (PyUnicode_Check(name)) {
name = _PyUnicode_AsDefaultEncodedString(name, NULL);
if (name == NULL)
return NULL;
}
else
#endif
{
PyErr_Format(PyExc_TypeError,
"attribute name must be string, not '%.200s'",
Py_TYPE(name)->tp_name);
return NULL;
}
}
if (tp->tp_getattro != NULL)
return (*tp->tp_getattro)(v, name);
if (tp->tp_getattr != NULL)
return (*tp->tp_getattr)(v, PyString_AS_STRING(name));
PyErr_Format(PyExc_AttributeError,
"'%.50s' object has no attribute '%.400s'",
tp->tp_name, PyString_AS_STRING(name));
return NULL;
}
static PyObject *
safe_PyObject_GetAttrString(PyObject *v, const char *name)
{
PyObject *w, *res;
if (Py_TYPE(v)->tp_getattr != NULL)
return (*Py_TYPE(v)->tp_getattr)(v, (char*)name);
w = PyString_InternFromString(name);
if (w == NULL)
return NULL;
res = safe_PyObject_GetAttr(v, w);
Py_XDECREF(w);
return res;
}
static PyObject *
safe__generic_dir(PyObject *obj)
{
PyObject *result = NULL;
PyObject *dict = NULL;
PyObject *itsclass = NULL;
/* Get __dict__ (which may or may not be a real dict...) */
dict = safe_PyObject_GetAttrString(obj, "__dict__");
if (dict == NULL) {
PyErr_Clear();
dict = PyDict_New();
}
else if (!PyDict_Check(dict)) {
Py_DECREF(dict);
dict = PyDict_New();
}
else {
/* Copy __dict__ to avoid mutating it. */
PyObject *temp = PyDict_Copy(dict);
Py_DECREF(dict);
dict = temp;
}
if (dict == NULL)
goto error;
/* Merge in __members__ and __methods__ (if any).
* This is removed in Python 3000. */
if (safe_merge_list_attr(dict, obj, "__members__") < 0)
goto error;
if (safe_merge_list_attr(dict, obj, "__methods__") < 0)
goto error;
/* Merge in attrs reachable from its class. */
itsclass = safe_PyObject_GetAttrString(obj, "__class__");
if (itsclass == NULL)
/* XXX(tomer): Perhaps fall back to obj->ob_type if no
__class__ exists? */
PyErr_Clear();
else {
if (safe_merge_class_dict(dict, itsclass) != 0)
goto error;
}
result = PyDict_Keys(dict);
/* fall through */
error:
Py_XDECREF(itsclass);
Py_XDECREF(dict);
return result;
}
static PyObject *
safe__specialized_dir_type(PyObject *obj)
{
PyObject *result = NULL;
PyObject *dict = PyDict_New();
if (dict != NULL && safe_merge_class_dict(dict, obj) == 0)
result = PyDict_Keys(dict);
Py_XDECREF(dict);
return result;
}
/* Helper for PyObject_Dir of module objects: returns the module's __dict__. */
static PyObject *
safe__specialized_dir_module(PyObject *obj)
{
PyObject *result = NULL;
PyObject *dict = safe_PyObject_GetAttrString(obj, "__dict__");
if (dict != NULL) {
if (PyDict_Check(dict))
result = PyDict_Keys(dict);
else {
char *name = PyModule_GetName(obj);
if (name)
PyErr_Format(PyExc_TypeError,
"%.200s.__dict__ is not a dictionary",
name);
}
}
Py_XDECREF(dict);
return result;
}
static PyObject *
safe__dir_object(PyObject *obj)
{
PyObject *result = NULL;
static PyObject *dir_str = NULL;
PyObject *dirfunc;
assert(obj);
if (PyInstance_Check(obj)) {
dirfunc = safe_PyObject_GetAttrString(obj, "__dir__");
if (dirfunc == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError))
PyErr_Clear();
else
return NULL;
}
}
else {
dirfunc = _PyObject_LookupSpecial(obj, "__dir__", &dir_str);
if (PyErr_Occurred())
return NULL;
}
if (dirfunc == NULL) {
/* use default implementation */
if (PyModule_Check(obj))
result = safe__specialized_dir_module(obj);
else if (PyType_Check(obj) || PyClass_Check(obj))
result = safe__specialized_dir_type(obj);
else
result = safe__generic_dir(obj);
}
else {
/* use __dir__ */
result = PyObject_CallFunctionObjArgs(dirfunc, NULL);
Py_DECREF(dirfunc);
if (result == NULL)
return NULL;
/* result must be a list */
/* XXX(gbrandl): could also check if all items are strings */
if (!PyList_Check(result)) {
PyErr_Format(PyExc_TypeError,
"__dir__() must return a list, not %.200s",
Py_TYPE(result)->tp_name);
Py_DECREF(result);
result = NULL;
}
}
return result;
}
PyObject* safe_PyObject_Dir(PyObject *obj)
{
PyObject * result;
#if 0 // don't need to support
if (obj == NULL)
/* no object -- introspect the locals */
result = safe__dir_locals();
else
#endif
/* object -- introspect the object */
result = safe__dir_object(obj);
assert(result == NULL || PyList_Check(result));
#if 0 // don't need to sort them
if (result != NULL && PyList_Sort(result) != 0) {
/* sorting the list failed */
Py_DECREF(result);
result = NULL;
}
#endif
return result;
}
#ifndef DIDYOUMEAN_SAFE_H
#define DIDYOUMEAN_SAFE_H
PyObject* safe_PyObject_Dir(PyObject *obj);
#endif
#include <Python.h>
#include <stdio.h>
#include <sys/mman.h>
#include <unistd.h>
#include <string.h>
#include "didyoumean-safe.h"
#if !(__x86_64__)
#error "This only works on x86_64"
#endif
extern PyObject* PyErr_Occurred(void);
extern PyObject* PyObject_GetAttr(PyObject *v, PyObject *name);
static int distance(char* a, char* b) {
size_t maxi = strlen(b);
size_t maxj = strlen(a);
unsigned int compare[maxi+1][maxj+1];
compare[0][0] = 0;
for (int i = 1; i <= maxi; i++) compare[i][0] = i;
for (int j = 1; j <= maxj; j++) compare[0][j] = j;
for (int i = 1; i <= maxi; i++) {
for (int j = 1; j <= maxj; j++) {
int left = compare[i-1][j] + 1;
int right = compare[i][j-1] + 1;
int middle = compare[i-1][j-1] + (a[j-1] == b[i-1] ? 0 : 1);
if( left < right && left < middle ) compare[i][j] = left;
else if( right < left && right < middle ) compare[i][j] = right;
else compare[i][j] = middle;
}
}
return compare[maxi][maxj];
}
PyObject* trampoline(PyObject *v, PyObject *name)
{
__asm__("nop");
PyObject* rv = NULL;
PyTypeObject *tp = Py_TYPE(v);
if (!PyString_Check(name)) {
#ifdef Py_USING_UNICODE
/* The Unicode to string conversion is done here because the
existing tp_getattro slots expect a string object as name
and we wouldn't want to break those. */
if (PyUnicode_Check(name)) {
name = _PyUnicode_AsDefaultEncodedString(name, NULL);
if (name == NULL)
return NULL;
}
else
#endif
{
PyErr_Format(PyExc_TypeError,
"attribute name must be string, not '%.200s'",
Py_TYPE(name)->tp_name);
return NULL;
}
}
if (tp->tp_getattro != NULL) {
rv = (*tp->tp_getattro)(v, name);
}
else if (tp->tp_getattr != NULL) {
rv = (*tp->tp_getattr)(v, PyString_AS_STRING(name));
}
else {
PyErr_Format(PyExc_AttributeError,
"'%.50s' object has no attribute '%.400s'",
tp->tp_name, PyString_AS_STRING(name));
}
if(!rv && PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyThreadState *tstate = PyThreadState_GET();
PyObject *oldtype, *oldvalue, *oldtraceback;
oldtype = tstate->curexc_type;
oldvalue = tstate->curexc_value;
oldtraceback = tstate->curexc_traceback;
PyErr_Clear();
PyObject* dir = safe_PyObject_Dir(v);
Py_LeaveRecursiveCall();
PyObject* candidate = NULL;
PyObject* newvalue = oldvalue;
if(dir) {
int candidate_dist = PyString_Size(name);
for(int i = 0; i < PyList_Size(dir); ++i) {
PyObject *item = PyList_GetItem(dir, i);
int dist = distance(PyString_AS_STRING(name), PyString_AS_STRING(item));
if(!candidate || dist < candidate_dist ) {
candidate = item;
candidate_dist = dist;
}
}
if( candidate ) {
newvalue = PyString_FromFormat("%s\n\nMaybe you meant: .%s\n",
PyString_AS_STRING(oldvalue),
PyString_AS_STRING(candidate));
Py_DECREF(oldvalue);
}
}
PyErr_Restore(oldtype, newvalue, oldtraceback);
}
return rv;
}
/* TODO: make less ugly!
* there's got to be a nicer way to do this! */
#pragma pack(push, 1)
struct {
char push_rax;
char mov_rax[2];
char addr[8];
char jmp_rax[2]; }
jump_asm = {
.push_rax = 0x50,
.mov_rax = {0x48, 0xb8},
.jmp_rax = {0xff, 0xe0} };
#pragma pack(pop)
static PyMethodDef module_methods[] = {
{NULL} /* Sentinel */
};
PyDoc_STRVAR(module_doc,
"This module implements a \"did you mean?\" functionality on getattr/LOAD_ATTR.\n"
"(It's not so much what it does but how it does it.)");
PyMODINIT_FUNC
initdidyoumean(void) {
__asm__("");
Py_InitModule3("didyoumean", module_methods, module_doc);
void* target = PyObject_GetAttr;
char* page;
int rc;
int pagesize = sysconf(_SC_PAGE_SIZE);
void* addr = &trampoline;
page = (char *)addr;
page = (char *)((size_t) page & ~(pagesize - 1));
rc = mprotect(page, pagesize, PROT_READ | PROT_WRITE | PROT_EXEC);
if(rc) {
fprintf(stderr, "mprotect() failed.\n");
return;
}
int count;
for(count = 0; count < 255; ++count)
if(((unsigned char*)addr)[count] == 0x90)
break; // found the NOP
for(int i = count; i >= 0; --i)
((unsigned char*)addr)[i] = ((unsigned char*)addr)[i-1];
*((unsigned char *)addr) = 0x58;
page = (char *)target;
page = (char *)((size_t) page & ~(pagesize - 1));
rc = mprotect(page, pagesize, PROT_READ | PROT_WRITE | PROT_EXEC);
if(rc) {
fprintf(stderr, "mprotect() failed.\n");
return;
}
memcpy(jump_asm.addr, &addr, sizeof (void *));
memcpy(target, &jump_asm, sizeof jump_asm);
}
#!/usr/bin/env python
class Foo(object):
def bar(self):
pass
if __name__ == '__main__':
foo = Foo()
print foo.bar
try:
foo.baz
except Exception as e:
print e
import didyoumean
print foo.bar
try:
foo.baz
except Exception as e:
print e
CC=gcc -std=c99 -Wall
didyoumean.so: didyoumean.c didyoumean-safe.c
${CC} `python-config --cflags` `python-config --includes` -Wl,--export-dynamic -fPIC -shared -o $@ $^ -ldl `python-config --libs`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment