Skip to content

Instantly share code, notes, and snippets.

@omaraflak
Created December 25, 2021 14:03
Show Gist options
  • Save omaraflak/d0e5271b4b936a61a5bb3a42ac6304af to your computer and use it in GitHub Desktop.
Save omaraflak/d0e5271b4b936a61a5bb3a42ac6304af to your computer and use it in GitHub Desktop.
Tree fractal
# geometry/point.py
import math
from dataclasses import dataclass
from typing import Optional
@dataclass
class Point:
x: float = 0
y: float = 0
def __add__(self, p: 'Point') -> 'Point':
return Point(self.x + p.x, self.y + p.y)
def __sub__(self, p: 'Point') -> 'Point':
return Point(self.x - p.x, self.y - p.y)
def __mul__(self, f: float) -> 'Point':
return Point(self.x * f, self.y * f)
def __truediv__(self, f: float) -> 'Point':
return self * (1 / f)
def __neg__(self) -> 'Point':
return Point.origin() - self
def length(self) -> float:
return (self.x ** 2 + self.y ** 2) ** 0.5
def rotate(self, radians: float, about: Optional['Point'] = None) -> 'Point':
if not about:
about = Point.origin()
centered = self - about
rotated = Point(
centered.x * math.cos(radians) - centered.y * math.sin(radians),
centered.x * math.sin(radians) + centered.y * math.cos(radians)
)
shifted = rotated + about
return shifted
def orthognal(self) -> 'Point':
return Point(-self.y, self.x)
def unit(self) -> 'Point':
return self / self.length()
def to_tuple(self) -> tuple[float, float]:
return (self.x, self.y)
@staticmethod
def distance(p1: 'Point', p2: 'Point') -> float:
return (p2 - p1).length()
@staticmethod
def middle(p1: 'Point', p2: 'Point') -> float:
return (p2 + p1) / 2
@staticmethod
def deg2rad(degrees: float) -> float:
return degrees * math.pi / 180
@staticmethod
def rad2deg(radians: float) -> float:
return radians * 180 / math.pi
@staticmethod
def origin() -> 'Point':
return Point()
@staticmethod
def x_unit() -> 'Point':
return Point(1, 0)
@staticmethod
def y_unit() -> 'Point':
return Point(0, 1)
# geometry/segment.py
from dataclasses import dataclass
from geometry.point import Point
@dataclass
class Segment:
start: Point
end: Point
def length(self) -> float:
return Point.distance(self.start, self.end)
def middle(self) -> Point:
return Point.middle(self.start, self.end)
def rotate_about_start(self, radians: float) -> 'Segment':
return Segment(self.start, self.end.rotate(radians, self.start))
def rotate_about_end(self, radians: float) -> 'Segment':
return Segment(self.start.rotate(radians, self.end), self.start)
def as_vector(self) -> Point:
return self.end - self.start
def unit_vector(self) -> Point:
return self.as_vector().unit()
def unit_segment(self) -> 'Segment':
return Segment(self.start, self.start + self.unit_vector())
def to_tuple(self) -> tuple[tuple[float, float], tuple[float, float]]:
return (self.start.to_tuple(), self.end.to_tuple())
# ./tree.py
import math
from matplotlib import pyplot as plt, cm, collections
from geometry.point import Point
from geometry.segment import Segment
def generate_segments(segment: Segment) -> list[Segment]:
vector = segment.as_vector()
length = segment.length()
i = segment.unit_vector()
j = i.orthognal()
alpha = 0.5
beta = 0.4
gamma = 0.3
theta = math.pi / 4
r1 = 1 / 3
r2 = 5 / 7
return [
Segment(segment.start + vector * r1, segment.start + vector * r1 + (i * math.cos(theta) + j * math.sin(theta)) * length * alpha),
Segment(segment.start + vector * r1, segment.start + vector * r1 + (i * math.cos(theta) - j * math.sin(theta)) * length * alpha),
Segment(segment.start + vector * r2, segment.start + vector * r2 + (i * math.cos(theta) + j * math.sin(theta)) * length * beta),
Segment(segment.start + vector * r2, segment.start + vector * r2 + (i * math.cos(theta) - j * math.sin(theta)) * length * beta),
Segment(segment.end, segment.end + vector * gamma)
]
def generate_tree(iterations: int, initial_segment: Segment) -> list[tuple[Segment, float]]:
queue = [(initial_segment, 1)]
result = []
while queue:
segment, it = queue.pop()
result.append((segment, cm.viridis(it / iterations)))
if it < iterations:
for seg in generate_segments(segment):
queue.append((seg, it + 1))
return result
def plot_segments(segments: list[tuple[Segment, float]]):
collection = collections.LineCollection(
segments=[s[0].to_tuple() for s in segments],
colors=[s[1] for s in segments],
linewidths=2
)
_, ax = plt.subplots()
ax.add_collection(collection)
ax.set_aspect("equal")
ax.margins(0.1)
plt.show()
if __name__ == "__main__":
plot_segments(generate_tree(6, Segment(Point(), Point(0, 1))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment