Skip to content

Instantly share code, notes, and snippets.

@FlyingFathead
Last active December 13, 2023 19:56
Show Gist options
  • Save FlyingFathead/bb222836d99596141348a3ad152c816c to your computer and use it in GitHub Desktop.
Save FlyingFathead/bb222836d99596141348a3ad152c816c to your computer and use it in GitHub Desktop.
CUDA-optimized code for generating a PyTorch Dataset of Fibonacci primes
# file under: laziness check, December 2023...
import torch
from torch.utils.data import Dataset
from numba import cuda
import numpy as np
# CPU function to generate Fibonacci numbers within uint64 range
def generate_fibonacci_numbers(max_length):
fib_numbers = np.zeros(max_length, dtype=np.uint64)
fib_numbers[1] = 1
for i in range(2, max_length):
next_fib = fib_numbers[i - 1] + fib_numbers[i - 2]
if next_fib > np.iinfo(np.uint64).max:
break # Stop to prevent overflow
fib_numbers[i] = next_fib
return fib_numbers[fib_numbers > 0] # Filter out the zeros
# CUDA kernel to check for primality
@cuda.jit
def check_prime_kernel(numbers, prime_flags, max_length):
idx = cuda.grid(1)
if idx >= max_length:
return
n = numbers[idx]
if n <= 1:
prime_flags[idx] = False
elif n <= 3:
prime_flags[idx] = True
elif n % 2 == 0 or n % 3 == 0:
prime_flags[idx] = False
else:
i = 5
while i * i <= n:
if n % i == 0 or n % (i + 2) == 0:
prime_flags[idx] = False
return
i += 6
prime_flags[idx] = True
class FibonacciPrimeDataset(Dataset):
def __init__(self, max_length):
# Generate Fibonacci numbers on the CPU
fib_numbers_host = generate_fibonacci_numbers(max_length)
actual_length = len(fib_numbers_host) # Actual number of Fibonacci numbers generated
prime_flags_host = np.empty(actual_length, dtype=np.bool_) # Adjust the length
# Copy data to the device
fib_numbers_device = cuda.to_device(fib_numbers_host)
prime_flags_device = cuda.to_device(prime_flags_host)
# Set up the threads and blocks
threads_per_block = 128
blocks_per_grid = (actual_length + (threads_per_block - 1)) // threads_per_block
# Run the kernel to check for primes
check_prime_kernel[blocks_per_grid, threads_per_block](fib_numbers_device, prime_flags_device, actual_length)
# Copy the results back to the host
prime_flags_host = prime_flags_device.copy_to_host()
# Filter to get only the prime Fibonacci numbers
self.prime_fibs = fib_numbers_host[prime_flags_host]
def __len__(self):
return len(self.prime_fibs)
def __getitem__(self, idx):
print(f"Accessing index {idx}", flush=True)
# Correct dtype to torch.int64 since PyTorch does not support uint64
return torch.tensor(self.prime_fibs[idx], dtype=torch.int64)
# Example usage
dataset = FibonacciPrimeDataset(94) # Adjust the number as needed
for i in dataset:
print(i.item(), flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment