Skip to content

Instantly share code, notes, and snippets.

@treyhunner
Last active February 21, 2024 12:06
Show Gist options
  • Save treyhunner/1644c56401103136520ba1535967a735 to your computer and use it in GitHub Desktop.
Save treyhunner/1644c56401103136520ba1535967a735 to your computer and use it in GitHub Desktop.
Pythonic Code Refactoring Session
def add(matrix1, matrix2):
"""Add corresponding numbers in given 2-D matrices."""
combined = []
for i in range(len(matrix1)):
row = []
for j in range(len(matrix1[i])):
row.append(matrix1[i][j] + matrix2[i][j])
combined.append(row)
return combined
from functools import reduce
def all_same(iterable):
for item in iterable:
first_item = item
break
return reduce(lambda x, y: x and y == first_item, iterable, True)
def all_same(iterable):
first_item = next(iter(iterable), None)
return all(
item == first_item
for item in iterable
)
def count_words(string):
"""Return the number of times each word occurs in the string."""
count = {}
for word in string.split():
word = word.strip(',;.!?"()').lower()
if word in count:
count[word] += 1
else:
count[word] = 1
return count
import sys
old_filename = sys.argv[1]
new_filename = sys.argv[2]
old_file = open(old_filename)
rows = [
line.split('|')
for line in old_file.read().splitlines()
]
new_file = open(new_filename, mode='wt', newline='\r\n')
print("\n".join(
",".join(row)
for row in rows
), file=new_file)
old_file.close()
new_file.close()
from argparse import ArgumentParser, FileType
import csv
parser = ArgumentParser()
parser.add_argument('old_file', type=FileType('rt'))
parser.add_argument('new_file', type=FileType('wt'))
args = parser.parse_args()
reader = csv.reader(args.old_file, delimiter='|')
writer = csv.writer(args.new_file, delimiter=',')
writer.writerows(reader)
def get_earliest(date1, date2):
mdy1 = date1.split('/')
mdy2 = date2.split('/')
if mdy1[2] < mdy2[2]:
return date1
elif mdy1[2] > mdy2[2]:
return date2
elif mdy1[0] < mdy2[0]:
return date1
elif mdy1[0] > mdy2[0]:
return date2
elif mdy1[1] < mdy2[1]:
return date1
elif mdy1[1] > mdy2[1]:
return date2
else:
return date1
from datetime import datetime
def get_earliest(date1, date2):
d1 = datetime.strptime(date1, "%m/%d/%Y")
d2 = datetime.strptime(date2, "%m/%d/%Y")
return date1 if (d1 < d2) else date2
def interleave(*iterables):
"""Return iterable of one item at a time from each given iterable."""
interleaved = []
for i in range(len(iterables[0])):
for iterable in iterables:
interleaved.append(iterable[i])
return interleaved
def interleave(*iterables):
"""Return iterable of one item at a time from each given iterable."""
return (
item
for items in zip(*iterables)
for item in items
)
def get_color_ratios(colors, ratios):
"""Return dictionary of color ratios from color and ratio lists."""
assert len(colors) == len(ratios)
color_ratios = {}
for i in range(len(colors)):
color_ratios[colors[i]] = ratios[i]
return color_ratios
if __name__ == "__main__":
test_colors = ["red", "green", "blue"]
test_ratios = [0.1, 0.6, 0.3]
combined_dict = {'red': 0.1, 'green': 0.6, 'blue': 0.3}
assert get_color_ratios(test_colors, test_ratios) == combined_dict
def get_color_ratios(colors, ratios):
"""Return dictionary of color ratios from color and ratio lists."""
assert len(colors) == len(ratios)
return dict(zip(colors, ratios))
if __name__ == "__main__":
test_colors = ["red", "green", "blue"]
test_ratios = [0.1, 0.6, 0.3]
combined_dict = {'red': 0.1, 'green': 0.6, 'blue': 0.3}
assert get_color_ratios(test_colors, test_ratios) == combined_dict
print("Tests passed")
import csv
import sys
filename = sys.argv[1]
column_numbers = sys.argv[2:]
numbers = []
for n in column_numbers:
numbers.append(int(n))
with open(filename) as csv_file:
reader = csv.reader(csv_file)
rows = [row for row in reader]
sorted_rows = sorted(rows, key=lambda row: [row[n] for n in numbers])
writer = csv.writer(sys.stdout)
writer.writerows(sorted_rows)
def sum_timestamps(timestamps):
total_time = 0
for time in timestamps:
total_time += parse_time(time)
return format_time(total_time)
def parse_time(time_string):
sections = time_string.split(':')
if len(sections) == 2:
seconds = int(sections[1])
minutes = int(sections[0])
hours = 0
else:
seconds = int(sections[2])
minutes = int(sections[1])
hours = int(sections[0])
return hours*3600 + minutes*60 + seconds
def format_time(total_seconds):
hours = str(int(total_seconds / 3600))
minutes = str(int(total_seconds / 60) % 60)
seconds = str(total_seconds % 60)
if len(minutes) < 2 and hours != "0":
minutes = "0" + minutes
if len(seconds) < 2:
seconds = "0" + seconds
time = minutes + ":" + seconds
if hours != "0":
time = hours + ":" + time
return time
import re
TIME_RE = re.compile(r'''
^
(?: # Optional hours
( \d+ )
:
)?
( \d+ ) # Minutes
:
( \d+ ) # Seconds
$
''', re.VERBOSE)
def sum_timestamps(timestamps):
total_time = sum(
parse_time(time)
for time in timestamps
)
return format_time(total_time)
def parse_time(time_string):
hours, minutes, seconds = TIME_RE.search(time_string).groups()
if hours is None:
hours = 0
return int(hours)*3600 + int(minutes)*60 + int(seconds)
def format_time(total_seconds):
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
if hours > 0:
return f"{hours}:{minutes:02d}:{seconds:02d}"
else:
return f"{minutes}:{seconds:02d}"
from copy import deepcopy
import unittest
from add import add
class AddTests(unittest.TestCase):
"""Tests for add."""
def test_single_items(self):
self.assertEqual(add([[5]], [[-2]]), [[3]])
def test_two_by_two_matrixes(self):
m1 = [[6, 6], [3, 1]]
m2 = [[1, 2], [3, 4]]
m3 = [[7, 8], [6, 5]]
self.assertEqual(add(m1, m2), m3)
def test_two_by_three_matrixes(self):
m1 = [[1, 2, 3], [4, 5, 6]]
m2 = [[-1, -2, -3], [-4, -5, -6]]
m3 = [[0, 0, 0], [0, 0, 0]]
self.assertEqual(add(m1, m2), m3)
def test_input_unchanged(self):
m1 = [[6, 6], [3, 1]]
m2 = [[1, 2], [3, 4]]
m1_original = deepcopy(m1)
m2_original = deepcopy(m2)
add(m1, m2)
self.assertEqual(m1, m1_original)
self.assertEqual(m2, m2_original)
if __name__ == "__main__":
unittest.main()
import unittest
from all_same import all_same
class AllSameTests(unittest.TestCase):
"""Tests for all_same."""
def test_one_item_number(self):
self.assertTrue(all_same([4]))
self.assertTrue(all_same([0]))
self.assertTrue(all_same([-1]))
def test_one_string(self):
self.assertTrue(all_same(['hello']))
def test_one_none_value(self):
self.assertTrue(all_same([None]))
def test_one_tuple(self):
self.assertTrue(all_same([()]))
self.assertTrue(all_same([(1,)]))
self.assertTrue(all_same([(1, 2)]))
def test_empty_sequence(self):
self.assertTrue(all_same([]))
self.assertTrue(all_same(()))
self.assertTrue(all_same(''))
def test_two_same_item(self):
self.assertTrue(all_same([1, 1]))
self.assertTrue(all_same([0, 0]))
self.assertTrue(all_same(['hello', 'hello']))
self.assertTrue(all_same([-1, -1]))
self.assertTrue(all_same([(1, 2), (1, 2)]))
self.assertTrue(all_same([None, None]))
def test_two_different_items(self):
self.assertFalse(all_same(['hello', 'hi']))
self.assertFalse(all_same([-1, 1]))
self.assertFalse(all_same([-1, 'hi']))
self.assertFalse(all_same([(1, 3), (1, 2)]))
self.assertFalse(all_same(['hello', (4, 5)]))
self.assertFalse(all_same([4, None]))
self.assertFalse(all_same([None, 4]))
def test_many_items(self):
self.assertTrue(all_same([1, 1, 1, 1, 1, 1]))
self.assertFalse(all_same([1, 1, 1, 1, 2, 1]))
self.assertFalse(all_same(['hi', 'hello', 'hey']))
self.assertFalse(all_same(['hello', 'hella', 'hello']))
self.assertTrue(all_same(['hi', 'hi', 'hi', 'hi', 'hi']))
self.assertTrue(all_same(['hello', 'hello', 'hello']))
self.assertTrue(all_same([(1, 2, 3), (1, 2, 3), (1, 2, 3)]))
self.assertFalse(all_same([(1, 2, 3), (1, 2, 3), (1, 4, 3)]))
def test_nonhashable_values(self):
self.assertFalse(all_same([['hi', 'hi'], ['hi', 'hi', 'hi']]))
self.assertTrue(all_same([['hi', 'hi'], ['hi', 'hi']]))
self.assertTrue(all_same([{1: 2}, {1: 2}]))
self.assertFalse(all_same([{1: 2}, {1: 3}]))
def test_nonsequences(self):
numbers = [1, 3, 5, 7, 9]
self.assertTrue(all_same({1}))
self.assertFalse(all_same({1, 2}))
self.assertFalse(all_same(n**2 for n in numbers))
self.assertTrue(all_same(n % 2 for n in numbers))
@unittest.expectedFailure
def test_return_early(self):
self.assertFalse(all_same(n**2 for n in [2, 3, {}]))
from itertools import count
self.assertFalse(all_same(count()))
if __name__ == "__main__":
unittest.main(verbosity=2)
import unittest
from count_words import count_words
class CountWordsTests(unittest.TestCase):
"""Tests for count_words."""
def test_simple_sentence(self):
actual = count_words("oh what a day what a lovely day")
expected = {'oh': 1, 'what': 2, 'a': 2, 'day': 2, 'lovely': 1}
self.assertEqual(actual, expected)
def test_apostrophe(self):
actual = count_words("don't stop believing")
expected = {"don't": 1, 'stop': 1, 'believing': 1}
self.assertEqual(actual, expected)
def test_capitalization(self):
actual = count_words("Oh what a day what a lovely day")
expected = {'oh': 1, 'what': 2, 'a': 2, 'day': 2, 'lovely': 1}
self.assertEqual(actual, expected)
def test_symbols(self):
actual = count_words("Oh what a day, what a lovely day!")
expected = {'oh': 1, 'what': 2, 'a': 2, 'day': 2, 'lovely': 1}
self.assertEqual(actual, expected)
if __name__ == "__main__":
unittest.main()
from contextlib import contextmanager, redirect_stdout
from io import StringIO
import imp
import os
import sys
from textwrap import dedent
from tempfile import NamedTemporaryFile
import unittest
class FixCSVTests(unittest.TestCase):
"""Tests for fix_csv.py"""
maxDiff = None
def test_pipe_file_to_csv_file(self):
old_contents = dedent("""
2012|Lexus|LFA
2009|GMC|Yukon XL 1500
1965|Ford|Mustang
2005|Hyundai|Sonata
1995|Mercedes-Benz|C-Class
""").lstrip()
expected = dedent("""
2012,Lexus,LFA
2009,GMC,Yukon XL 1500
1965,Ford,Mustang
2005,Hyundai,Sonata
1995,Mercedes-Benz,C-Class
""").lstrip()
with make_file(old_contents) as old, make_file("") as new:
output = run_program('fix_csv.py', args=[old, new])
with open(new) as new_file:
new_contents = new_file.read()
self.assertEqual(expected, new_contents)
self.assertEqual("", output)
def test_original_file_is_unchanged(self):
old_contents = dedent("""
2012|Lexus|LFA
2009|GMC|Yukon XL 1500
""").lstrip()
with make_file(old_contents) as old, make_file("") as new:
run_program('fix_csv.py', args=[old, new])
with open(old) as old_file:
contents = old_file.read()
self.assertEqual(old_contents, contents)
@unittest.expectedFailure
def test_delimiter_in_output(self):
old_contents = dedent("""
02|Waylon Jennings|Honky Tonk Heroes (Like Me)
04|Kris Kristofferson|To Beat The Devil
11|Johnny Cash|Folsom Prison Blues
13|Billy Joe Shaver|Low Down Freedom
21|Hank Williams III|Mississippi Mud
22|David Allan Coe|Willie, Waylon, And Me
24|Bob Dylan|House Of The Risin' Sun
""").lstrip()
expected = dedent("""
02,Waylon Jennings,Honky Tonk Heroes (Like Me)
04,Kris Kristofferson,To Beat The Devil
11,Johnny Cash,Folsom Prison Blues
13,Billy Joe Shaver,Low Down Freedom
21,Hank Williams III,Mississippi Mud
22,David Allan Coe,"Willie, Waylon, And Me"
24,Bob Dylan,House Of The Risin' Sun
""").lstrip()
with make_file(old_contents) as old, make_file("") as new:
output = run_program('fix_csv.py', args=[old, new])
with open(new) as new_file:
new_contents = new_file.read()
self.assertEqual(expected, new_contents)
self.assertEqual("", output)
@unittest.expectedFailure
def test_call_with_too_many_files(self):
with make_file("") as old, make_file("") as new:
with self.assertRaises(BaseException):
run_program('fix_csv.py', args=[old, new, old])
def run_program(path, args=[]):
"""
Run program at given path with given arguments.
If raises is specified, ensure the given exception is raised.
"""
old_args = sys.argv
assert all(isinstance(a, str) for a in args)
try:
sys.argv = [path] + args
with redirect_stdout(StringIO()) as output:
try:
if '__main__' in sys.modules:
del sys.modules['__main__']
imp.load_source('__main__', path)
except SystemExit as e:
if e.args != (0,):
raise
del sys.modules['__main__']
return output.getvalue()
finally:
sys.argv = old_args
@contextmanager
def make_file(contents=None):
"""Context manager providing name of a file containing given contents."""
with NamedTemporaryFile(mode='wt', encoding='utf-8', delete=False) as f:
if contents:
f.write(contents)
try:
yield f.name
finally:
os.remove(f.name)
if __name__ == "__main__":
unittest.main(verbosity=2)
import unittest
from get_earliest import get_earliest
class GetEarliestTests(unittest.TestCase):
"""Tests for get_earliest."""
def test_same_month_and_day(self):
newer = "01/27/1832"
older = "01/27/1756"
self.assertEqual(get_earliest(newer, older), older)
def test_february_29th(self):
newer = "02/29/1972"
older = "12/21/1946"
self.assertEqual(get_earliest(newer, older), older)
def test_smaller_month_bigger_day(self):
newer = "03/21/1946"
older = "02/24/1946"
self.assertEqual(get_earliest(older, newer), older)
def test_same_month_and_year(self):
newer = "06/24/1958"
older = "06/21/1958"
self.assertEqual(get_earliest(older, newer), older)
if __name__ == "__main__":
unittest.main(verbosity=2)
from itertools import count
import unittest
from interleave import interleave
class InterleaveTests(unittest.TestCase):
"""Tests for interleave."""
def assertIterableEqual(self, iterable1, iterable2):
self.assertEqual(list(iterable1), list(iterable2))
def test_empty_lists(self):
self.assertIterableEqual(interleave([], []), [])
def test_single_item_each(self):
self.assertIterableEqual(interleave([1], [2]), [1, 2])
def test_two_items_each(self):
self.assertIterableEqual(interleave([1, 2], [3, 4]), [1, 3, 2, 4])
def test_four_items_each(self):
in1 = [1, 2, 3, 4]
in2 = [5, 6, 7, 8]
out = [1, 5, 2, 6, 3, 7, 4, 8]
self.assertIterableEqual(interleave(in1, in2), out)
def test_none_value(self):
in1 = [1, 2, 3, None]
in2 = [4, 5, 6, 7]
out = [1, 4, 2, 5, 3, 6, None, 7]
self.assertIterableEqual(interleave(in1, in2), out)
# To test the Bonus part of this exercise, comment out the following line
@unittest.expectedFailure
def test_non_sequences(self):
in1 = [1, 2, 3, 4]
in2 = (n**2 for n in in1)
out = [1, 1, 2, 4, 3, 9, 4, 16]
self.assertIterableEqual(interleave(in1, in2), out)
if __name__ == "__main__":
unittest.main()
from contextlib import contextmanager, redirect_stdout
from io import StringIO
import imp
import os
import sys
from textwrap import dedent
from tempfile import NamedTemporaryFile
import unittest
class SortByColumnTests(unittest.TestCase):
"""Tests for sort_by_column.py"""
maxDiff = None
def test_sort_by_first_column(self):
contents = dedent("""
2012,Lexus,LFA
2009,GMC,Yukon XL 1500
1965,Ford,Mustang
2005,Hyundai,Sonata
1995,Mercedes-Benz,C-Class
""").lstrip()
expected = dedent("""
1965,Ford,Mustang
1995,Mercedes-Benz,C-Class
2005,Hyundai,Sonata
2009,GMC,Yukon XL 1500
2012,Lexus,LFA
""").lstrip().replace('\n', '\r\n')
with make_file(contents) as csv_file:
output = run_program('sort_by_column.py', args=[csv_file, '0'])
self.assertEqual(expected, output)
def test_sort_by_second_column(self):
contents = dedent("""
2012,Lexus,LFA
2009,GMC,Yukon XL 1500
1965,Ford,Mustang
2005,Hyundai,Sonata
1995,Mercedes-Benz,C-Class
""").lstrip()
expected = dedent("""
1965,Ford,Mustang
2009,GMC,Yukon XL 1500
2005,Hyundai,Sonata
2012,Lexus,LFA
1995,Mercedes-Benz,C-Class
""").lstrip().replace('\n', '\r\n')
with make_file(contents) as csv_file:
output = run_program('sort_by_column.py', args=[csv_file, '1'])
self.assertEqual(expected, output)
def test_original_file_is_unchanged(self):
old_contents = dedent("""
2012,Lexus,LFA
2009,GMC,Yukon XL 1500
""").lstrip()
with make_file(old_contents) as filename:
run_program('sort_by_column.py', args=[filename, '0'])
with open(filename) as csv_file:
new_contents = csv_file.read()
self.assertEqual(old_contents, new_contents)
def test_sorting_with_commas(self):
contents = dedent("""
"Hughes, John",Baby's Day Out
"Hughes, John",The Breakfast Club
"Hughes, Langston",A Dream Deferred
"Hughes, Langston",Dreams
""").lstrip().replace('\n', '\r\n')
expected = dedent("""
"Hughes, Langston",A Dream Deferred
"Hughes, John",Baby's Day Out
"Hughes, Langston",Dreams
"Hughes, John",The Breakfast Club
""").lstrip().replace('\n', '\r\n')
with make_file(contents) as old:
output = run_program('sort_by_column.py', args=[old, '1'])
self.assertEqual(expected, output)
def test_sort_by_one_column_only(self):
contents = dedent("""
11,Johnny Cash,Folsom Prison Blues
13,Billy Joe Shaver,Low Down Freedom
2,Waylon Jennings,Honky Tonk Heroes (Like Me)
2,Hank Williams III,Mississippi Mud
4,Kris Kristofferson,To Beat The Devil
22,David Allan Coe,"Willie, Waylon, And Me"
4,Bob Dylan,House Of The Risin' Sun
""").lstrip().replace('\n', '\r\n')
expected = dedent("""
11,Johnny Cash,Folsom Prison Blues
13,Billy Joe Shaver,Low Down Freedom
2,Waylon Jennings,Honky Tonk Heroes (Like Me)
2,Hank Williams III,Mississippi Mud
22,David Allan Coe,"Willie, Waylon, And Me"
4,Kris Kristofferson,To Beat The Devil
4,Bob Dylan,House Of The Risin' Sun
""").lstrip().replace('\n', '\r\n')
with make_file(contents) as old:
output = run_program('sort_by_column.py', args=[old, '0'])
self.assertEqual(expected, output)
def test_sort_by_multiple_columns(self):
contents = dedent("""
2005,Lexus,LFA
2009,GMC,Yukon XL 1500
1995,Ford,Mustang
2005,Hyundai,Sonata
1995,Mercedes-Benz,C-Class
""").lstrip()
expected = dedent("""
1995,Mercedes-Benz,C-Class
1995,Ford,Mustang
2005,Lexus,LFA
2005,Hyundai,Sonata
2009,GMC,Yukon XL 1500
""").lstrip().replace('\n', '\r\n')
with make_file(contents) as name:
output = run_program('sort_by_column.py', args=[name, '0', '2'])
self.assertEqual(expected, output)
def run_program(path, args=[]):
"""
Run program at given path with given arguments.
If raises is specified, ensure the given exception is raised.
"""
old_args = sys.argv
assert all(isinstance(a, str) for a in args)
try:
sys.argv = [path] + args
with redirect_stdout(StringIO()) as output:
try:
if '__main__' in sys.modules:
del sys.modules['__main__']
imp.load_source('__main__', path)
except SystemExit as e:
if e.args != (0,):
raise
return output.getvalue()
finally:
sys.argv = old_args
@contextmanager
def make_file(contents=None):
"""Context manager providing name of a file containing given contents."""
with NamedTemporaryFile(mode='wt', encoding='utf-8', delete=False) as f:
if contents:
f.write(contents)
try:
yield f.name
finally:
os.remove(f.name)
if __name__ == "__main__":
unittest.main(verbosity=2)
import unittest
from sum_timestamps import sum_timestamps
class SumTimeStampsTests(unittest.TestCase):
"""Tests for sum_timestamps."""
def test_single_timestamp(self):
self.assertEqual(sum_timestamps(['02:01']), '2:01')
self.assertEqual(sum_timestamps(['2:01']), '2:01')
def test_multiple_timestamps(self):
self.assertEqual(sum_timestamps(['02:01', '04:05']), '6:06')
self.assertEqual(sum_timestamps(['9:38', '4:45', '3:52']), '18:15')
def test_many_timestamps(self):
times = [
'3:52', '3:29', '3:23', '4:05', '3:24', '2:29', '2:16', '2:44',
'1:58', '3:21', '2:51', '2:53', '2:51', '3:32', '3:20', '2:40',
'2:50', '3:24', '3:22', '0:42']
self.assertEqual(sum_timestamps(times), '59:26')
def test_no_minutes(self):
self.assertEqual(sum_timestamps(['00:01', '00:05']), '0:06')
self.assertEqual(sum_timestamps(['0:38', '0:15']), '0:53')
# To test the Bonus part of this exercise, comment out the following line
@unittest.expectedFailure
def test_timestamps_over_an_hour(self):
times = [
'3:52', '3:29', '3:23', '4:05', '3:24', '2:29', '2:16', '2:44',
'1:58', '3:21', '2:51', '2:53', '2:51', '3:32', '3:20', '2:40',
'2:50', '3:24', '1:20', '3:22', '3:26', '0:42', '5:20']
self.assertEqual(sum_timestamps(times), '1:09:32')
times2 = [
'50:52', '34:29', '36:23', '47:05', '32:24', '20:29', '22:16',
'23:44', '19:58', '30:21', '24:51', '22:53', '23:51', '34:32',
'36:20', '25:40', '27:50', '39:24', '18:20', '36:22', '4:00',
]
self.assertEqual(sum_timestamps(times2), '10:12:04')
# To test the Bonus part of this exercise, comment out the following line
@unittest.expectedFailure
def test_allow_optional_hour(self):
self.assertEqual(sum_timestamps(['1:02:01', '04:05']), '1:06:06')
self.assertEqual(
sum_timestamps(['9:05:00', '4:45:10', '3:52']),
'13:54:02',
)
if __name__ == "__main__":
unittest.main(verbosity=2)
import unittest
from vector import Vector
class VectorTests(unittest.TestCase):
"""Tests for Vector."""
def test_attributes(self):
v = Vector(1, 2, 3)
self.assertEqual((v.x, v.y, v.z), (1, 2, 3))
def test_equality_and_inequality(self):
self.assertNotEqual(Vector(1, 2, 3), Vector(1, 2, 4))
self.assertEqual(Vector(1, 2, 3), Vector(1, 2, 3))
self.assertFalse(Vector(1, 2, 3) != Vector(1, 2, 3))
v1 = Vector(1, 2, 3)
v2 = Vector(1, 2, 4)
v3 = Vector(1, 2, 3)
self.assertNotEqual(v1, v2)
self.assertEqual(v1, v3)
def test_iterable_vector(self):
x, y, z = Vector(x=1, y=2, z=3)
self.assertEqual((x, y, z), (1, 2, 3))
def test_no_weird_extras(self):
v1 = Vector(1, 2, 3)
v2 = Vector(4, 5, 6)
with self.assertRaises(TypeError):
len(v1)
with self.assertRaises(TypeError):
v1 < v2
with self.assertRaises(TypeError):
v1 > v2
with self.assertRaises(TypeError):
v1 <= v2
with self.assertRaises(TypeError):
v1 >= v2
with self.assertRaises(TypeError):
v1 + (1, 2, 3)
with self.assertRaises(TypeError):
(1, 2, 3) + v1
with self.assertRaises(TypeError):
v1 - (1, 2, 3)
with self.assertRaises(TypeError):
v1 * 'a'
with self.assertRaises(TypeError):
v1 / v2
def test_memory_efficient_attributes(self):
v = Vector(1, 2, 3)
with self.assertRaises(AttributeError):
v.a = 3
with self.assertRaises(AttributeError):
v.__dict__
def test_shifting(self):
v1 = Vector(1, 2, 3)
v2 = Vector(4, 5, 6)
v3 = v2 + v1
v4 = v3 - v1
self.assertEqual((v3.x, v3.y, v3.z), (5, 7, 9))
self.assertEqual((v4.x, v4.y, v4.z), (v2.x, v2.y, v2.z))
def test_scaling(self):
v1 = Vector(1, 2, 3)
v2 = Vector(4, 5, 6)
v3 = v1 * 4
v4 = 2 * v2
self.assertEqual((v3.x, v3.y, v3.z), (4, 8, 12))
self.assertEqual((v4.x, v4.y, v4.z), (8, 10, 12))
if __name__ == "__main__":
unittest.main(verbosity=2)
class Vector:
__slots__ = 'x', 'y', 'z'
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
def __iter__(self):
return iter([self.x, self.y, self.z])
def __eq__(self, other):
return self.x == other.x and self.y == other.y and self.z == other.z
def __ne__(self, other):
return self.x != other.x and self.y != other.y and self.z != other.z
def __add__(self, other):
if not isinstance(other, Vector):
return NotImplemented
return Vector(self.x + other.x, self.y + other.y, self.z + other.z)
def __sub__(self, other):
if not isinstance(other, Vector):
return NotImplemented
return Vector(self.x - other.x, self.y - other.y, self.z - other.z)
def __mul__(self, scalar):
if not isinstance(scalar, (int, float)):
return NotImplemented
return Vector(self.x * scalar, self.y * scalar, self.z * scalar)
def __rmul__(self, scalar):
if not isinstance(scalar, (int, float)):
return NotImplemented
return Vector(self.x * scalar, self.y * scalar, self.z * scalar)
class Vector:
__slots__ = 'x', 'y', 'z'
def __init__(self, x, y, z):
self.x, self.y, self.z = x, y, z
def __iter__(self):
yield from (self.x, self.y, self.z)
def __eq__(self, other):
return tuple(self) == tuple(other)
def __add__(self, other):
if not isinstance(other, Vector):
return NotImplemented
x1, y1, z1 = self
x2, y2, z2 = other
return Vector(x1+x2, y1+y2, z1+z2)
def __sub__(self, other):
if not isinstance(other, Vector):
return NotImplemented
x1, y1, z1 = self
x2, y2, z2 = other
return Vector(x1-x2, y1-y2, z1-z2)
def __mul__(self, scalar):
if not isinstance(scalar, (int, float)):
return NotImplemented
x, y, z = self
return Vector(x*scalar, y*scalar, z*scalar)
def __rmul__(self, scalar):
if not isinstance(scalar, (int, float)):
return NotImplemented
return Vector(self.x * scalar, self.y * scalar, self.z * scalar)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment