Skip to content

Instantly share code, notes, and snippets.

@sritchie

sritchie/setup.py

Last active Aug 21, 2020
Embed
What would you like to do?
from setuptools import find_packages, setup
# This follows the style of Jaxlib installation here:
# https://github.com/google/jax#pip-installation
PYTHON_VERSION = "cp37"
CUDA_VERSION = "cuda101" # alternatives: cuda90, cuda92, cuda100, cuda101
PLATFORM = "linux_x86_64" # alternatives: linux_x86_64
BASE_URL = "https://storage.googleapis.com/jax-releases"
def jax_artifact(version, gpu=False):
if gpu:
prefix = f"{BASE_URL}/{CUDA_VERSION}/jaxlib"
wheel_suffix = f"{PYTHON_VERSION}-none-{PLATFORM}.whl"
location = f"{prefix}-{version}-{wheel_suffix}"
return f"jaxlib @ {location}"
return f"jaxlib=={version}"
def readme():
try:
with open('README.md') as rf:
return rf.read()
except FileNotFoundError:
return None
JAXLIB_VERSION = "0.1.43"
JAX_VERSION = "0.1.62"
REQUIRED_PACKAGES = [
"pg8000>=1.16.1"
"uv-metrics>=0.4.2",
"fs",
"fs-gcsfs",
f"jax=={JAX_VERSION}",
]
setup(
name='my_project',
version="0.0.1",
cmdclass=with_versioneer(lambda v: v.get_cmdclass(), {}),
description='Getting it done.',
long_description=readme(),
author='Sam Ritchie',
author_email='samritchie@google.com',
url='https://github.com/google/caliban',
packages=find_packages(exclude=('tests', 'docs')),
install_requires=REQUIRED_PACKAGES,
extras_require={
"cpu": [jax_artifact(JAXLIB_VERSION, gpu=False)],
"gpu": [jax_artifact(JAXLIB_VERSION, gpu=True)],
},
include_package_data=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.