Skip to content

Instantly share code, notes, and snippets.

@JohnAtl
Created January 5, 2024 15:25
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JohnAtl/2fd83eeee94af7053b6524064577b90e to your computer and use it in GitHub Desktop.
Save JohnAtl/2fd83eeee94af7053b6524064577b90e to your computer and use it in GitHub Desktop.
Distrobox with Tensorflow and Nvidia support
#!/bin/bash
# Adapted from https://stackoverflow.com/a/47436840
function lib_installed() { /sbin/ldconfig -N -v $(sed 's/:/ /' <<< $LD_LIBRARY_PATH) 2>/dev/null | grep $1; }
function check() { lib_installed $1 && echo "$1 is installed" || echo -e "\nERROR: $1 is NOT installed\n"; }
check libcuda.so
check libcudart
check libcudnn
if ! command -v nvcc &>/dev/null
then
echo -e "\nnvcc is not installed\n"
else
nvcc --version
fi
if ! command -v nvidia-smi &>/dev/null
then
echo -e "\nnvidia-smi is not installed\n"
else
nvidia-smi
fi
echo "Run tensorflow_mnist_test to test GPU funtion for training/testing on the MNIST dataset."
[nvbox]
image=nvcr.io/nvidia/tensorflow:23.12-tf2-py3
init=true
nvidia=true
pull=true
home="~/.local/share/distrobox/nvbox" # Sets an alternate home to not pollute your home as much
# volume="/mnt/nvme:/mnt/nvme:rw /mnt/btrfs:/mnt/btrfs:rw" # Replace with your shared folders
start_now=true
additional_packages="systemd openssh-server"
additional_packages="build-essential clang clang-tools python-is-python3" # random C packages
pre_init_hooks="echo 'Port 2222' | tee -a /etc/ssh/sshd_config" # Set ssh port to 2222
pre_init_hooks="echo 'ListenAddress 127.0.0.1' | tee -a /etc/ssh/sshd_config" # Set listen address to only be localhost.
init_hooks=sudo -u "${USER}" "/usr/bin/cp -r ${HOME}/.ssh ~" # copy over ssh keys
#!/usr/bin/env python3
# https://www.tensorflow.org/datasets/keras_example
try:
import tensorflow as tf
except:
print("Please install tensorflow:")
print("pip install tensorflow")
exit(-1)
try:
import tensorflow_datasets as tfds
except:
print("\nPlease install the tensorflow datasets:")
print(" pip install tensorflow_datasets")
print("You can safely ignore pip dependency errors")
exit(-2)
if len(tf.config.list_physical_devices('GPU')) < 1:
print("No GPU(s) found")
exit(-3)
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment