Created
April 24, 2023 08:10
-
-
Save Eatkin/a0dddcc948ee966b9991e6ed70a22be5 to your computer and use it in GitHub Desktop.
Linked list implementation. Includes insert at position, remove at position, get position and reverse methods. Also includes unit tests.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
# Create the node class | |
class Node(object): | |
def __init__(self, d, n=None): | |
self.data = d | |
self.next_node = n | |
def get_next(self): | |
return self.next_node | |
def set_next(self, n): | |
self.next_node = n | |
def get_data(self): | |
return self.data | |
def set_data(self, d): | |
self.data = d | |
# Create the linked list class | |
class LinkedList(object): | |
def __init__(self, r=None): | |
self.root = r | |
self.size = 0 | |
def __find_node(self, d): | |
this_node = self.root | |
while this_node: | |
if this_node.get_data() == d: | |
return this_node | |
elif this_node.get_next() == None: | |
return None | |
else: | |
this_node = this_node.get_next() | |
def get_size(self): | |
return self.size | |
def add(self, d): | |
new_node = Node(d, self.root) | |
self.root = new_node | |
self.size += 1 | |
def remove(self, d): | |
this_node = self.root | |
prev_node = None | |
while this_node: | |
if this_node.get_data() == d: | |
# If we aren't at the root node, set the previous node's next to the current node's next i.e. skip the current node | |
if prev_node: | |
prev_node.set_next(this_node.get_next()) | |
else: | |
# If we are at the root node, we need to update the root node to the next node in the list | |
self.root = this_node.get_next() | |
self.size -= 1 | |
return True | |
else: | |
# Advance to the next node | |
prev_node = this_node | |
this_node = this_node.get_next() | |
return False | |
def find(self, d): | |
node = self.__find_node(d) | |
if node == None: | |
return False | |
return d | |
def find_position(self, d): | |
this_node = self.root | |
position = 0 | |
while this_node: | |
if this_node.get_data() == d: | |
return position | |
elif this_node.get_next() == None: | |
return -1 | |
else: | |
this_node = this_node.get_next() | |
position += 1 | |
def insert_at_position(self, pos, d): | |
position = 0 | |
this_node = self.root | |
prev_node = None | |
# Append node to the beginning of the list as usual if either position is negative or we're inserting into an empty list | |
if pos < 0 or this_node == None: | |
self.add(d) | |
return True | |
while position < pos: | |
position += 1 | |
prev_node = this_node | |
this_node = this_node.get_next() | |
# Insert the node at the end of the list if pos > the length of the list | |
if this_node == None: | |
break | |
# We've got the prev_node and current node so we'll insert the node here | |
new_node = Node(d, None) | |
prev_node.set_next(new_node) | |
new_node.set_next(this_node) | |
self.size += 1 | |
return True | |
def remove_at_position(self, pos): | |
position = 0 | |
this_node = self.root | |
prev_node = None | |
# Cases where this is not possible | |
if pos < 0 or this_node == None: | |
return False | |
while position < pos: | |
position += 1 | |
prev_node = this_node | |
this_node = this_node.get_next() | |
# Fail condition - index is out of bounds | |
if this_node == None: | |
return False | |
# We've got the prev_node and the node to remove now | |
# So we simply update the previous node's reference to the next node and job done | |
next_node = this_node.get_next() | |
prev_node.set_next(next_node) | |
self.size -= 1 | |
return True | |
def reverse(self): | |
# We need to loop through the list and set each successive node's next_node reference to the previous node | |
# Trivial - list is empty | |
if self.root == None: | |
return self | |
prev_node = None | |
this_node = self.root | |
next_node = this_node.get_next() | |
while next_node: | |
this_node.set_next(prev_node) | |
# This is confusing but it simply shifts all the nodes along one | |
prev_node = this_node | |
this_node = next_node | |
next_node = next_node.get_next() | |
# Final update so we don't get a circular reference and update root | |
this_node.set_next(prev_node) | |
self.root = this_node | |
return self | |
def print_list(self): | |
this_node = self.root | |
while this_node: | |
print(this_node.get_data(), end='->') | |
this_node = this_node.get_next() | |
print('None') | |
# Unit tests | |
class TestLinkedList(unittest.TestCase): | |
def setUp(self): | |
self.ll = LinkedList() | |
def test_add(self): | |
self.ll.add(1) | |
self.assertEqual(self.ll.get_size(), 1) | |
self.ll.add(2) | |
self.assertEqual(self.ll.get_size(), 2) | |
self.ll.add(3) | |
self.assertEqual(self.ll.get_size(), 3) | |
def test_remove(self): | |
self.ll.add(1) | |
self.ll.add(2) | |
self.ll.add(3) | |
self.assertEqual(self.ll.get_size(), 3) | |
self.assertTrue(self.ll.remove(2)) | |
self.assertEqual(self.ll.get_size(), 2) | |
self.assertFalse(self.ll.remove(4)) | |
def test_find(self): | |
self.ll.add(1) | |
self.ll.add(2) | |
self.ll.add(3) | |
self.assertEqual(self.ll.find(2), 2) | |
self.assertFalse(self.ll.find(4)) | |
def test_find_position(self): | |
self.ll.add(1) | |
self.ll.add(2) | |
self.ll.add(3) | |
self.assertEqual(self.ll.find_position(2), 1) | |
self.assertEqual(self.ll.find_position(4), -1) | |
def test_insert_at_position(self): | |
self.ll.insert_at_position(0, 1) | |
self.assertEqual(self.ll.get_size(), 1) | |
self.ll.insert_at_position(1, 2) | |
self.assertEqual(self.ll.get_size(), 2) | |
self.ll.insert_at_position(1, 3) | |
self.assertEqual(self.ll.get_size(), 3) | |
def test_remove_at_position(self): | |
self.ll.add(1) | |
self.ll.add(2) | |
self.ll.add(3) | |
self.assertEqual(self.ll.get_size(), 3) | |
self.assertTrue(self.ll.remove_at_position(1)) | |
self.assertEqual(self.ll.get_size(), 2) | |
self.assertFalse(self.ll.remove_at_position(4)) | |
def test_reverse(self): | |
self.ll.add(1) | |
self.ll.add(2) | |
self.ll.add(3) | |
self.assertEqual(self.ll.get_size(), 3) | |
self.ll.reverse() | |
self.assertEqual(self.ll.get_size(), 3) | |
self.assertEqual(self.ll.find_position(3), 2) | |
self.assertEqual(self.ll.find_position(2), 1) | |
self.assertEqual(self.ll.find_position(1), 0) | |
self.assertEqual(self.ll.find(1), 1) | |
self.assertEqual(self.ll.find(2), 2) | |
self.assertEqual(self.ll.find(3), 3) | |
def test_print_list(self): | |
# Redirect stdout to capture output | |
from io import StringIO | |
import sys | |
captured_output = StringIO() | |
sys.stdout = captured_output | |
self.ll.add(1) | |
self.ll.add(2) | |
self.ll.add(3) | |
self.ll.print_list() | |
# Reset redirect | |
sys.stdout = sys.__stdout__ | |
self.assertEqual(captured_output.getvalue().strip(), '3->2->1->None') | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment