-
-
Save haku117/bfb9da469c8fcbc9e19dc2957945905e to your computer and use it in GitHub Desktop.
Project1_solution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TreeNode: | |
def __init__(self, key, val, left=None, right=None, parent=None): | |
self.key = key | |
self.payload = val | |
self.leftChild = left | |
self.rightChild = right | |
self.parent = parent | |
self.balanceFactor = 0 | |
def hasLeftChild(self): | |
return self.leftChild | |
def hasRightChild(self): | |
return self.rightChild | |
def isLeftChild(self): | |
return self.parent and self.parent.leftChild == self | |
def isRightChild(self): | |
return self.parent and self.parent.rightChild == self | |
def isRoot(self): | |
return not self.parent | |
def isLeaf(self): | |
return not (self.rightChild or self.leftChild) | |
def hasAnyChildren(self): | |
return self.rightChild or self.leftChild | |
def hasBothChildren(self): | |
return self.rightChild and self.leftChild | |
def replaceNodeData(self,key,value,lc,rc): | |
self.key = key | |
self.payload = value | |
self.leftChild = lc | |
self.rightChild = rc | |
if self.hasLeftChild(): | |
self.leftChild.parent = self | |
if self.hasRightChild(): | |
self.rightChild.parent = self | |
class AVLTree: | |
def __init__(self): | |
self.root = None | |
self.size = 0 | |
def length(self): | |
return self.size | |
def put(self, key, val): | |
if self.root: | |
self._put(key,val,self.root) | |
else: | |
self.root = TreeNode(key,val) | |
self.size = self.size + 1 | |
def _put(self, key, val, currentNode): | |
if key < currentNode.key: | |
if currentNode.hasLeftChild(): | |
self._put(key, val, currentNode.leftChild) | |
else: | |
currentNode.leftChild = TreeNode(key, val, parent=currentNode) | |
self.updateBalance(currentNode.leftChild) | |
else: | |
if currentNode.hasRightChild(): | |
self._put(key, val, currentNode.rightChild) | |
else: | |
currentNode.rightChild = TreeNode(key, val, parent=currentNode) | |
self.updateBalance(currentNode.rightChild) | |
def updateBalance(self, node): | |
if node.balanceFactor > 1 or node.balanceFactor < -1: | |
self.rebalance(node) | |
return | |
if node.parent != None: | |
if node.isLeftChild(): | |
node.parent.balanceFactor += 1 | |
elif node.isRightChild(): | |
node.parent.balanceFactor -= 1 | |
if node.parent.balanceFactor != 0: | |
self.updateBalance(node.parent) | |
def rebalance(self, node): | |
if node.balanceFactor < 0: | |
if node.rightChild.balanceFactor > 0: | |
self.rotateRight(node.rightChild) | |
self.rotateLeft(node) | |
else: | |
self.rotateLeft(node) | |
elif node.balanceFactor > 0: | |
if node.leftChild.balanceFactor < 0: | |
self.rotateLeft(node.leftChild) | |
self.rotateRight(node) | |
else: | |
self.rotateRight(node) | |
def rotateLeft(self, rotRoot): | |
newRoot = rotRoot.rightChild | |
rotRoot.rightChild = newRoot.leftChild | |
if newRoot.leftChild != None: | |
newRoot.leftChild.parent = rotRoot | |
newRoot.parent = rotRoot.parent | |
if rotRoot.isRoot(): | |
self.root = newRoot | |
else: | |
if rotRoot.isLeftChild(): | |
rotRoot.parent.leftChild = newRoot | |
else: | |
rotRoot.parent.rightChild = newRoot | |
newRoot.leftChild = rotRoot | |
rotRoot.parent = newRoot | |
rotRoot.balanceFactor = rotRoot.balanceFactor + 1 - min(newRoot.balanceFactor, 0) | |
newRoot.balanceFactor = newRoot.balanceFactor + 1 + max(rotRoot.balanceFactor, 0) | |
def rotateRight(self, rotRoot): | |
newRoot = rotRoot.leftChild | |
rotRoot.leftChild = newRoot.rightChild | |
if newRoot.rightChild != None: | |
newRoot.rightChild.parent = rotRoot | |
newRoot.parent = rotRoot.parent | |
if rotRoot.isRoot(): | |
self.root = newRoot | |
else: | |
if rotRoot.isRightChild(): | |
rotRoot.parent.rightChild = newRoot | |
else: | |
rotRoot.parent.leftChild = newRoot | |
newRoot.rightChild = rotRoot | |
rotRoot.parent = newRoot | |
rotRoot.balanceFactor = rotRoot.balanceFactor - 1 - max(newRoot.balanceFactor, 0) | |
newRoot.balanceFactor = newRoot.balanceFactor - 1 + min(rotRoot.balanceFactor, 0) | |
def heightHelper(self, current): | |
if current == None: | |
return 0 | |
return max(self.heightHelper(current.leftChild), self.heightHelper(current.rightChild)) + 1 | |
def getInfo(self): | |
height = self.heightHelper(self.root) | |
print("AVL Tree height is: " + str(height)) | |
print("AVL Tree size is: " + str(self.size)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
UMass ECE 241 - Advanced Programming | |
Project #1 Fall 2021 | |
project1_solution.py - Sorting and Searching | |
""" | |
import random | |
import time | |
import numpy as np | |
import yfinance as yf | |
import matplotlib.pyplot as plt | |
""" | |
Stock class for stock objects | |
""" | |
class Stock: | |
def __init__(self, info: str): | |
tokens = info.split('|') | |
if len(tokens) < 5: | |
print ("incorrect stock information") | |
else: | |
self.sname = tokens[0] | |
self.symbol = tokens[1] | |
self.val = float(tokens[2]) | |
self.prices = list(map(float, tokens[3:])) | |
def __str__(self): | |
return "name: " + self.sname + "; symbol: " + self.symbol + \ | |
"; val: " + str(self.val) + "; price:" + str(self.prices[-1]) | |
class TreeNode: | |
def __init__(self, key, val, left=None, right=None, parent=None): | |
self.key = key | |
self.payload = val | |
self.leftChild = left | |
self.rightChild = right | |
self.parent = parent | |
self.balanceFactor = 0 | |
def hasLeftChild(self): | |
return self.leftChild | |
def hasRightChild(self): | |
return self.rightChild | |
def isLeftChild(self): | |
return self.parent and self.parent.leftChild == self | |
def isRightChild(self): | |
return self.parent and self.parent.rightChild == self | |
def isRoot(self): | |
return not self.parent | |
def isLeaf(self): | |
return not (self.rightChild or self.leftChild) | |
def hasAnyChildren(self): | |
return self.rightChild or self.leftChild | |
def hasBothChildren(self): | |
return self.rightChild and self.leftChild | |
def replaceNodeData(self,key,value,lc,rc): | |
self.key = key | |
self.payload = value | |
self.leftChild = lc | |
self.rightChild = rc | |
if self.hasLeftChild(): | |
self.leftChild.parent = self | |
if self.hasRightChild(): | |
self.rightChild.parent = self | |
class AVLTree: | |
def __init__(self): | |
self.root = None | |
self.size = 0 | |
def length(self): | |
return self.size | |
def put(self, key, val): | |
if self.root: | |
self._put(key,val,self.root) | |
else: | |
self.root = TreeNode(key,val) | |
self.size = self.size + 1 | |
def _put(self, key, val, currentNode): | |
if key < currentNode.key: | |
if currentNode.hasLeftChild(): | |
self._put(key, val, currentNode.leftChild) | |
else: | |
currentNode.leftChild = TreeNode(key, val, parent=currentNode) | |
self.updateBalance(currentNode.leftChild) | |
else: | |
if currentNode.hasRightChild(): | |
self._put(key, val, currentNode.rightChild) | |
else: | |
currentNode.rightChild = TreeNode(key, val, parent=currentNode) | |
self.updateBalance(currentNode.rightChild) | |
def updateBalance(self, node): | |
if node.balanceFactor > 1 or node.balanceFactor < -1: | |
self.rebalance(node) | |
return | |
if node.parent != None: | |
if node.isLeftChild(): | |
node.parent.balanceFactor += 1 | |
elif node.isRightChild(): | |
node.parent.balanceFactor -= 1 | |
if node.parent.balanceFactor != 0: | |
self.updateBalance(node.parent) | |
def rebalance(self, node): | |
if node.balanceFactor < 0: | |
if node.rightChild.balanceFactor > 0: | |
self.rotateRight(node.rightChild) | |
self.rotateLeft(node) | |
else: | |
self.rotateLeft(node) | |
elif node.balanceFactor > 0: | |
if node.leftChild.balanceFactor < 0: | |
self.rotateLeft(node.leftChild) | |
self.rotateRight(node) | |
else: | |
self.rotateRight(node) | |
def rotateLeft(self, rotRoot): | |
newRoot = rotRoot.rightChild | |
rotRoot.rightChild = newRoot.leftChild | |
if newRoot.leftChild != None: | |
newRoot.leftChild.parent = rotRoot | |
newRoot.parent = rotRoot.parent | |
if rotRoot.isRoot(): | |
self.root = newRoot | |
else: | |
if rotRoot.isLeftChild(): | |
rotRoot.parent.leftChild = newRoot | |
else: | |
rotRoot.parent.rightChild = newRoot | |
newRoot.leftChild = rotRoot | |
rotRoot.parent = newRoot | |
rotRoot.balanceFactor = rotRoot.balanceFactor + 1 - min(newRoot.balanceFactor, 0) | |
newRoot.balanceFactor = newRoot.balanceFactor + 1 + max(rotRoot.balanceFactor, 0) | |
def rotateRight(self, rotRoot): | |
newRoot = rotRoot.leftChild | |
rotRoot.leftChild = newRoot.rightChild | |
if newRoot.rightChild != None: | |
newRoot.rightChild.parent = rotRoot | |
newRoot.parent = rotRoot.parent | |
if rotRoot.isRoot(): | |
self.root = newRoot | |
else: | |
if rotRoot.isRightChild(): | |
rotRoot.parent.rightChild = newRoot | |
else: | |
rotRoot.parent.leftChild = newRoot | |
newRoot.rightChild = rotRoot | |
rotRoot.parent = newRoot | |
rotRoot.balanceFactor = rotRoot.balanceFactor - 1 - max(newRoot.balanceFactor, 0) | |
newRoot.balanceFactor = newRoot.balanceFactor - 1 + min(rotRoot.balanceFactor, 0) | |
def heightHelper(self, current): | |
if current == None: | |
return 0 | |
return max(self.heightHelper(current.leftChild), self.heightHelper(current.rightChild)) + 1 | |
def getInfo(self): | |
height = self.heightHelper(self.root) | |
print("AVL Tree height is: " + str(height)) | |
print("AVL Tree size is: " + str(self.size)) | |
""" | |
StockLibrary class to mange stock objects | |
""" | |
class StockLibrary: | |
def __init__(self): | |
self.stockList = list() | |
self.bst = None | |
self.isSorted = False | |
self.size = 0 | |
def loadData(self, filename: str): | |
with open(filename, 'r') as file: | |
for line in file.readlines()[1:]: | |
stock = Stock(line) | |
self.stockList.append(stock) | |
self.size = len(self.stockList) | |
file.close() | |
def linearSearch(self, query: str, attribute: str): | |
for stock in self.stockList: | |
if 'name' in attribute and stock.sname == query: | |
return str(stock) | |
elif 'symbol' in attribute and stock.symbol == query: | |
return str(stock) | |
return 'Stock not found' | |
def buildBST(self): | |
self.avltree = AVLTree() | |
for stock in self.stockList: | |
self.avltree.put(stock.symbol, stock) | |
self.bst = self.avltree.root | |
def searchBST(self, query, current='dnode'): | |
if current == 'dnode': | |
current = self.bst | |
if not current: # current == None:?? | |
return 'Stock not found' | |
if current.key == query: | |
return current.payload | |
elif current.key < query: | |
return self.searchBST(query, current.rightChild) | |
else: | |
return self.searchBST(query, current.leftChild) | |
@staticmethod | |
def heightHelper(current): | |
if current == None: | |
return 0 | |
return max(StockLibrary.heightHelper(current.leftChild), | |
StockLibrary.heightHelper(current.rightChild)) + 1 | |
""" | |
Sort the stockList using QuickSort algorithm based on the stock symbol. | |
The sorted array should be stored in the same stockList. | |
Remember to change the isSorted variable after sorted | |
""" | |
def quickSort(self): | |
self.quickSortHelper(self.stockList, 0, self.size - 1) | |
self.isSorted = True | |
""" | |
quickSort quickSortHelper | |
""" | |
def quickSortHelper(self, alist, first, last): | |
if first < last: | |
splitpoint = self.partition(alist, first, last) | |
self.quickSortHelper(alist, first, splitpoint - 1) | |
self.quickSortHelper(alist, splitpoint + 1, last) | |
def partition(self, alist, first, last): | |
# check based on symbol | |
pivotvalue = alist[first].symbol | |
leftmark = first + 1 | |
rightmark = last | |
done = False | |
while not done: | |
while leftmark <= rightmark and alist[leftmark].symbol <= pivotvalue: | |
leftmark = leftmark + 1 | |
while alist[rightmark].symbol >= pivotvalue and rightmark >= leftmark: | |
rightmark = rightmark - 1 | |
if rightmark < leftmark: | |
done = True | |
else: | |
temp = alist[leftmark] | |
alist[leftmark] = alist[rightmark] | |
alist[rightmark] = temp | |
temp = alist[first] | |
alist[first] = alist[rightmark] | |
alist[rightmark] = temp | |
return rightmark | |
""" | |
Randomly select a number of Stocks | |
""" | |
def generateRandomQuery(self, num: int): | |
queries = list() | |
random.seed(1) | |
for _ in range(num): | |
rr = random.randint(0, self.size) | |
queries.append(self.stockList[rr]) | |
return queries | |
""" | |
validate BST | |
""" | |
def checkBST(self, root): | |
sortedArray = self.inOrder(root) | |
for i in range(len(sortedArray) - 2): | |
if isinstance(sortedArray[i].key, Stock): | |
if sortedArray[i].key.symbol > sortedArray[i + 1].key.symbol: | |
return (len(sortedArray), False) | |
else: | |
if sortedArray[i].key > sortedArray[i + 1].key: | |
return (len(sortedArray), False) | |
return (len(sortedArray), True) | |
""" | |
inOrder traversal for BST | |
Return the inOrder list | |
""" | |
def inOrder(self, root): | |
array = [] | |
if root: | |
array += self.inOrder(root.leftChild) | |
array.append(root) | |
array += self.inOrder(root.rightChild) | |
return array | |
# WRITE YOUR OWN TEST UNDER THIS IF YOU NEED | |
if __name__ == '__main__': | |
stockLib = StockLibrary() | |
stockLib.loadData("stock_database.csv") | |
stockLib.quickSort() | |
for i in range(0, 11, 5): | |
print(i, stockLib.stockList[i]) | |
stockLib.buildBST() | |
stockLib.avltree.getInfo() | |
print(stockLib.checkBST(stockLib.bst)) | |
print(stockLib.searchBST("GE")) | |
print(stockLib.searchBST("BAC")) | |
print(stockLib.searchBST("GOOG")) | |
# stockLib.checkDup() | |
# print(stockLib.linearSearch("Be My Lover", "title")) | |
# print(stockLib.linearSearch("Alice In Chains", "artist")) | |
queries = stockLib.generateRandomQuery(10) | |
t0 = time.time() | |
stockLib.buildBST() | |
t1 = time.time() | |
stockLib.avltree.getInfo() | |
print("build BST time: " + str(t1 - t0)) | |
print("\n-------linear search-------") | |
for query in queries: | |
print(stockLib.linearSearch(query.symbol, "symbol")) | |
t2 = time.time() | |
print("linear search time: " + str(t2 - t1)) | |
# result = stockLib.searchBST("Be My Lover") | |
# if result: | |
# print(result.toString()) | |
print("\n---------BST search---------") | |
for query in queries: | |
print(stockLib.searchBST(query.symbol, stockLib.bst)) | |
t3 = time.time() | |
print("BST search time: " + str(t3 - t2)) | |
print("\n---------Task 11 search---------") | |
max_stock, max_ratio = None, 0 | |
min_stock, min_ratio = None, 0 | |
for s in stockLib.stockList: | |
if s.prices[0] < 0.001 or s.prices[-1] < 0.001: | |
print(s, s.prices) | |
continue | |
ratio = s.prices[-1] / s.prices[0] - 1 | |
if ratio > max_ratio: | |
max_ratio = ratio | |
max_stock = s | |
elif ratio < min_ratio: | |
min_ratio = ratio | |
min_stock = s | |
print('max', max_ratio, max_stock, max_stock.prices) | |
print('min', min_ratio, min_stock, min_stock.prices) | |
print("\n---------Task 10 plot---------") | |
longest = 0 | |
candi = [] | |
for stock in stockLib.stockList: | |
if len(stock.sname) > longest: | |
longest = len(stock.sname) | |
candi = [stock] | |
elif len(stock.sname) == longest: | |
candi.append(stock) | |
print('longest company name', longest, len(candi)) | |
for stock in candi: | |
print(stock) | |
# plt.title(stock.symbol + ', ' + stock.sname) | |
plt.plot(stock.prices) | |
xtick = yf.download("GE", '2021-01-01', '2021-02-01').index | |
plt.xticks(np.arange(len(xtick)), [t.strftime('%m-%d') for t in xtick], rotation=45) | |
plt.xlabel('date') | |
plt.ylabel('price ($)') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment