Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Adding list indexing with list of booleans to pure Python
diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py
index d3da05ba84..6e640decc1 100644
--- a/Lib/test/test_list.py
+++ b/Lib/test/test_list.py
@@ -256,6 +256,39 @@ def __eq__(self, other):
lst = [X(), X()]
X() in lst
+ def test_indexing_with_boolean_list_ok(self):
+ a=[1,2,3,4]
+ self.assertEqual(a[[True,False,True,False]],[1,3] )
+ a[[True,False,True,False]]=[-1,-2]
+ self.assertEqual(a,[-1,2,-2,4] )
+
+
+ def test_indexing_with_boolean_list_wrong_indices_length(self):
+ a=[1,2,3,4]
+ with self.assertRaises(IndexError):
+ a[[True, True]]
+ with self.assertRaises(IndexError):
+ a[[True, True, True,True, True]]
+ with self.assertRaises(IndexError):
+ a[[True, True]]=[1,2]
+ with self.assertRaises(IndexError):
+ a[[True, True, True,True, True]]=[1,2,3,4,5]
+
+ def test_indexing_with_boolean_list_assigned_value_not_a_list(self):
+ a=[1,2,3,4]
+ with self.assertRaises(TypeError):
+ a[[True, True,True,True]]="not a list"
+
+ def test_indexing_with_boolean_list_some_indices_not_boolean(self):
+ a=[1,2,3,4]
+ with self.assertRaises(TypeError):
+ a[[1, "str",True,0]]=[1,2,3]
+
+ def test_indexing_with_boolean_list_value_vs_indices_length_mismatch(self):
+ a=[1,2,3,4]
+ with self.assertRaises(IndexError):
+ a[[True, False,True,False]]=[1]
+
if __name__ == "__main__":
unittest.main()
diff --git a/Objects/listobject.c b/Objects/listobject.c
index ccb9b91ba9..3c25e91f9d 100644
--- a/Objects/listobject.c
+++ b/Objects/listobject.c
@@ -2954,9 +2954,61 @@ list_subscript(PyListObject* self, PyObject* item)
return result;
}
}
+ else if (PyList_Check(item)) {
+ if (PyList_GET_SIZE(item) != PyList_GET_SIZE(self)) {
+ PyErr_Format(PyExc_IndexError,
+ "indices length (got %lld) must be equal to target "
+ "list length (got %lld)",
+ PyList_GET_SIZE(item), PyList_GET_SIZE(self));
+ return NULL;
+ }
+
+ Py_ssize_t true_count = 0;
+ Py_ssize_t i;
+
+ PyObject **item_items, **self_items;
+
+ item_items = ((PyListObject *)item)->ob_item;
+ self_items = self->ob_item;
+
+ for (i = 0; i < PyList_GET_SIZE(item); i++) {
+ if (!PyBool_Check(item_items[i])) {
+ PyErr_Format(PyExc_TypeError,
+ "when indexing with list all indices must be"
+ " boolean, element at index %llu is %.200s",
+ i, Py_TYPE(item_items[i])->tp_name);
+ return NULL;
+ }
+ else if (item_items[i] == Py_True) {
+ true_count++;
+ }
+ }
+
+ PyObject *result, *it;
+ PyObject **dest;
+ Py_ssize_t dest_index = 0;
+
+ result = list_new_prealloc(true_count);
+ if (!result)
+ return NULL;
+
+ dest = ((PyListObject *)result)->ob_item;
+
+ for (i = 0, dest_index = 0; i < PyList_GET_SIZE(self); i++) {
+ if (item_items[i] == Py_True) {
+ it = self_items[i];
+ Py_INCREF(it);
+ dest[dest_index] = it;
+ dest_index++;
+ }
+ }
+ Py_SET_SIZE(result, true_count);
+ return result;
+ }
else {
PyErr_Format(PyExc_TypeError,
- "list indices must be integers or slices, not %.200s",
+ "list indices must be integers or slices or list of "
+ "booleans, not %.200s",
Py_TYPE(item)->tp_name);
return NULL;
}
@@ -3117,9 +3169,85 @@ list_ass_subscript(PyListObject* self, PyObject* item, PyObject* value)
return 0;
}
}
+ else if (PyList_Check(item)) {
+ // assign list
+ if (!PyList_Check(value)) {
+ PyErr_Format(PyExc_TypeError,
+ "when assigning with boolean list index,"
+ "assigned value must be a list, not %.200s",
+ Py_TYPE(value)->tp_name);
+ return -1;
+ }
+
+ if (PyList_GET_SIZE(item) != PyList_GET_SIZE(self)) {
+ PyErr_Format(PyExc_IndexError,
+ "indices length (got %ld) must be equal to target"
+ " list length (got %ld)",
+ PyList_GET_SIZE(item), PyList_GET_SIZE(self));
+ return -1;
+ }
+
+ Py_ssize_t true_count = 0;
+ Py_ssize_t i = 0;
+ PyObject **self_items, **item_items, **value_items;
+
+ self_items = self->ob_item;
+ item_items = ((PyListObject *)item)->ob_item;
+ value_items = ((PyListObject *)value)->ob_item;
+
+ for (i = 0; i < PyList_GET_SIZE(item); i++) {
+ if (!PyBool_Check(item_items[i])) {
+ PyErr_Format(
+ PyExc_TypeError,
+ "when assigning with list all indices"
+ "must be boolean, element at index %llu is %.200s",
+ i, Py_TYPE(item_items[i])->tp_name);
+ return -1;
+ }
+ else if (item_items[i] == Py_True) {
+ true_count++;
+ }
+ }
+ if (true_count != PyList_GET_SIZE(value)) {
+ PyErr_Format(PyExc_IndexError,
+ "number of boolean indices with True value "
+ "must be equal to length of assigned list,"
+ " got %ld True indices and %ld assigned values",
+ true_count, PyList_GET_SIZE(value));
+ return -1;
+ }
+
+ PyObject **garbage;
+ PyObject *ins;
+ Py_ssize_t value_index;
+
+ garbage = (PyObject **)PyMem_Malloc(true_count * sizeof(PyObject *));
+ if (!garbage) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ for (i = 0, value_index = 0; i < PyList_GET_SIZE(item); i++) {
+ if (item_items[i] == Py_True) {
+ garbage[value_index] = self_items[i];
+ ins = value_items[value_index];
+ Py_INCREF(ins);
+ self_items[i] = ins;
+ value_index++;
+ }
+ }
+
+ for (i = 0; i < true_count; i++) {
+ Py_DECREF(garbage[i]);
+ }
+ PyMem_FREE(garbage);
+
+ return 0;
+ }
else {
PyErr_Format(PyExc_TypeError,
- "list indices must be integers or slices, not %.200s",
+ "list indices must be integers or slices "
+ "or list of booleans, not %.200s",
Py_TYPE(item)->tp_name);
return -1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment