Last active
December 13, 2023 19:56
-
-
Save FlyingFathead/bb222836d99596141348a3ad152c816c to your computer and use it in GitHub Desktop.
CUDA-optimized code for generating a PyTorch Dataset of Fibonacci primes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# file under: laziness check, December 2023... | |
import torch | |
from torch.utils.data import Dataset | |
from numba import cuda | |
import numpy as np | |
# CPU function to generate Fibonacci numbers within uint64 range | |
def generate_fibonacci_numbers(max_length): | |
fib_numbers = np.zeros(max_length, dtype=np.uint64) | |
fib_numbers[1] = 1 | |
for i in range(2, max_length): | |
next_fib = fib_numbers[i - 1] + fib_numbers[i - 2] | |
if next_fib > np.iinfo(np.uint64).max: | |
break # Stop to prevent overflow | |
fib_numbers[i] = next_fib | |
return fib_numbers[fib_numbers > 0] # Filter out the zeros | |
# CUDA kernel to check for primality | |
@cuda.jit | |
def check_prime_kernel(numbers, prime_flags, max_length): | |
idx = cuda.grid(1) | |
if idx >= max_length: | |
return | |
n = numbers[idx] | |
if n <= 1: | |
prime_flags[idx] = False | |
elif n <= 3: | |
prime_flags[idx] = True | |
elif n % 2 == 0 or n % 3 == 0: | |
prime_flags[idx] = False | |
else: | |
i = 5 | |
while i * i <= n: | |
if n % i == 0 or n % (i + 2) == 0: | |
prime_flags[idx] = False | |
return | |
i += 6 | |
prime_flags[idx] = True | |
class FibonacciPrimeDataset(Dataset): | |
def __init__(self, max_length): | |
# Generate Fibonacci numbers on the CPU | |
fib_numbers_host = generate_fibonacci_numbers(max_length) | |
actual_length = len(fib_numbers_host) # Actual number of Fibonacci numbers generated | |
prime_flags_host = np.empty(actual_length, dtype=np.bool_) # Adjust the length | |
# Copy data to the device | |
fib_numbers_device = cuda.to_device(fib_numbers_host) | |
prime_flags_device = cuda.to_device(prime_flags_host) | |
# Set up the threads and blocks | |
threads_per_block = 128 | |
blocks_per_grid = (actual_length + (threads_per_block - 1)) // threads_per_block | |
# Run the kernel to check for primes | |
check_prime_kernel[blocks_per_grid, threads_per_block](fib_numbers_device, prime_flags_device, actual_length) | |
# Copy the results back to the host | |
prime_flags_host = prime_flags_device.copy_to_host() | |
# Filter to get only the prime Fibonacci numbers | |
self.prime_fibs = fib_numbers_host[prime_flags_host] | |
def __len__(self): | |
return len(self.prime_fibs) | |
def __getitem__(self, idx): | |
print(f"Accessing index {idx}", flush=True) | |
# Correct dtype to torch.int64 since PyTorch does not support uint64 | |
return torch.tensor(self.prime_fibs[idx], dtype=torch.int64) | |
# Example usage | |
dataset = FibonacciPrimeDataset(94) # Adjust the number as needed | |
for i in dataset: | |
print(i.item(), flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment