Skip to content

Instantly share code, notes, and snippets.

@ericmjl
Created July 21, 2020 18:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericmjl/fd9ccd067244bab514dafc4d6eb56cee to your computer and use it in GitHub Desktop.
Save ericmjl/fd9ccd067244bab514dafc4d6eb56cee to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction\n",
"\n",
"I'm debugging [this issue](https://github.com/ElArkk/jax-unirep/issues/59).\n",
"\n",
"To reproduce this notebook, activate an isolated Python environment, then install jax-unirep from this branch:\n",
"\n",
"```\n",
"pip install git+https://github.com/ElArkk/jax-unirep.git@oscillating-59\n",
"```\n",
"\n",
"In that branch, all of the logging statements have been added in."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:matplotlib.pyplot:Loaded backend module://ipykernel.pylab.backend_inline version unknown.\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from jax_unirep import evotune, fit\n",
"from jax_unirep.utils import load_params_1900\n",
"import logging\n",
"\n",
"logger = logging.getLogger(\"NOTEBOOK\")\n",
"logger.setLevel(logging.DEBUG)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "56459e27f57a42e48a8e77e6e09a5e48",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fcf9bd8ee93147f1adc6501b96a47bff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:root:{8}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "367112df144048fc935ed698992366b9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='evotuning pairs', max=3.0, style=ProgressStyle(descriptio…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Length-batching done: 1 unique lengths, with average length 3.0, max length 3 and min length 3.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3bad89fead76408a9f43fb922b6aa744",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Iteration', max=10.0, style=ProgressStyle(description_wid…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Starting epoch 1\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "14a3b0c3d22440e2941d47133cdcd11f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:root:Input shape: (3, 9, 10)\n",
"DEBUG:root:Output shape: (3, 9, 25)\n",
"DEBUG:layers.py:apply_fun:id of params: 5201263984\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/1)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5201263984\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/1)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/1)>\n",
"DEBUG:absl:Compiling evotune_loss for args (ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[1900]), ShapedArray(float32[1900]), ShapedArray(float32[7600]), ShapedArray(float32[1900,7600]), ShapedArray(float32[1900,1900]), ShapedArray(float32[10,1900]), ShapedArray(float32[10,7600]), ShapedArray(float32[1900,25]), ShapedArray(float32[25]), ShapedArray(float32[3,9,10]), ShapedArray(float32[3,9,25])).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 0: Estimated average training-set loss: 0.48966851830482483. Weights dumped.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"created directory at temp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:root:Input shape: (3, 9, 10)\n",
"DEBUG:root:Output shape: (3, 9, 25)\n",
"DEBUG:layers.py:apply_fun:id of params: 5185297376\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=2/1)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5185297376\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=2/1)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=2/1)>\n",
"DEBUG:absl:Compiling step for args (ShapedArray(int32[], weak_type=True), ShapedArray(float32[3,9,10]), ShapedArray(float32[3,9,25]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[1900]), ShapedArray(float32[1900]), ShapedArray(float32[1900]), ShapedArray(float32[1900]), ShapedArray(float32[1900]), ShapedArray(float32[1900]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[7600]), ShapedArray(float32[1900,7600]), ShapedArray(float32[1900,7600]), ShapedArray(float32[1900,7600]), ShapedArray(float32[1900,1900]), ShapedArray(float32[1900,1900]), ShapedArray(float32[1900,1900]), ShapedArray(float32[10,1900]), ShapedArray(float32[10,1900]), ShapedArray(float32[10,1900]), ShapedArray(float32[10,7600]), ShapedArray(float32[10,7600]), ShapedArray(float32[10,7600]), ShapedArray(float32[1900,25]), ShapedArray(float32[1900,25]), ShapedArray(float32[1900,25]), ShapedArray(float32[25]), ShapedArray(float32[25]), ShapedArray(float32[25])).\n",
"INFO:evotuning:Starting epoch 2\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a6661fa7eda42ecad5c01d279eb2f63",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 1: Estimated average training-set loss: 0.4730319082736969. Weights dumped.\n",
"INFO:evotuning:Starting epoch 3\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2cae77517da448f395c597e25154eaa1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 2: Estimated average training-set loss: 0.45402786135673523. Weights dumped.\n",
"INFO:evotuning:Starting epoch 4\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cbe33d65d83d46e1913716074d7d8a1d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 3: Estimated average training-set loss: 0.4271482229232788. Weights dumped.\n",
"INFO:evotuning:Starting epoch 5\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1139557764b74f96b1b81a1e254a1d09",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 4: Estimated average training-set loss: 0.3994535207748413. Weights dumped.\n",
"INFO:evotuning:Starting epoch 6\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "511c3ddc9577408fb9137b64e875f15e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 5: Estimated average training-set loss: 0.3723492920398712. Weights dumped.\n",
"INFO:evotuning:Starting epoch 7\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ee53defecd440658c5d9beb403f297a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 6: Estimated average training-set loss: 0.34603431820869446. Weights dumped.\n",
"INFO:evotuning:Starting epoch 8\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9c48e90f36d046a0976e9b4d65a022e2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 7: Estimated average training-set loss: 0.3213047683238983. Weights dumped.\n",
"INFO:evotuning:Starting epoch 9\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c79232ef41d4e4b946382299f21d108",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 8: Estimated average training-set loss: 0.29967546463012695. Weights dumped.\n",
"INFO:evotuning:Starting epoch 10\n",
"DEBUG:root:Calculating average loss.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6ac286ee48334e2b807648d26ae5b58a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Avg loss on dataset length 3', max=1.0, style=ProgressSty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:evotuning:Epoch 9: Estimated average training-set loss: 0.27984508872032166. Weights dumped.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"sequences = [\"MKLVIPJ\", \"MMLVIKJP\", \"MKLVIJJ\"]\n",
"params = fit(params=None, sequences=sequences, n_epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([2.896047 , 2.7695258, 2.985471 , ..., 3.035842 , 3.0082836,\n",
" 2.9999993], dtype=float32)"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params[0][\"b\"]"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2.897026 , 2.7702985, 2.9864933, ..., 3.0359561, 3.0091407,\n",
" 2.999047 ], dtype=float32)"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"load_params_1900()[\"b\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what the problematic behaviour is. `h_avg` should take on identical values at every re-run of `get_reps`.\n",
"\n",
"Instead, the value oscillates between two states. On the odd loops, they are one value, and on the even loops, they are another. See this in action below."
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::NO CUSTOM PARAMS:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 4434875784\n",
"DEBUG:featurize.py:params is None, loading default params.\n",
"DEBUG:featurize.py:ID of params after checking None: 5162782800\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59ac1451c9fe44c18d1a65624043f672",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5162782800\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5162782800\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5162782800\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1098.8822021484375\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1088.96533203125\n",
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::CALL 0 WITH CUSTOM:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 5203122928\n",
"DEBUG:featurize.py:params were loaded from user.\n",
"DEBUG:featurize.py:ID of params after checking None: 5203122928\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b29c4ad9d3b4baa9f4f30b83319aa13",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5203122928\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5203122928\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5203122928\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1147.36279296875\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1100.1588134765625\n",
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::CALL 1 WITH CUSTOM:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 5203122928\n",
"DEBUG:featurize.py:params were loaded from user.\n",
"DEBUG:featurize.py:ID of params after checking None: 5203122928\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "962d1d9279c14433b6d932707c7159b7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5203122928\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5203122928\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5203122928\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1133.589599609375\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1099.611328125\n",
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::CALL 2 WITH CUSTOM:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 5203122928\n",
"DEBUG:featurize.py:params were loaded from user.\n",
"DEBUG:featurize.py:ID of params after checking None: 5203122928\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b9d4f4a82e0545918477f25470a7eef5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5203122928\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5203122928\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5203122928\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1147.36279296875\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1100.1588134765625\n",
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::CALL 3 WITH CUSTOM:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 5203122928\n",
"DEBUG:featurize.py:params were loaded from user.\n",
"DEBUG:featurize.py:ID of params after checking None: 5203122928\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c7db1b7412a7437282c95413312f92ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5203122928\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5203122928\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5203122928\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1133.589599609375\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1099.611328125\n",
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::CALL 4 WITH CUSTOM:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 5203122928\n",
"DEBUG:featurize.py:params were loaded from user.\n",
"DEBUG:featurize.py:ID of params after checking None: 5203122928\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "572d4af4694d45cf95f4f0e5df1752de",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5203122928\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5203122928\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5203122928\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1147.3629150390625\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1100.1588134765625\n",
"DEBUG:NOTEBOOK:\n",
"DEBUG:NOTEBOOK::::::::CALL 5 WITH CUSTOM:::::::\n",
"DEBUG:featurize.py:ID of params passed into get_reps: 5203122928\n",
"DEBUG:featurize.py:params were loaded from user.\n",
"DEBUG:featurize.py:ID of params after checking None: 5203122928\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cc9f23c9ae5484184398164639870b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='right-padding sequences', max=3.0, style=ProgressStyle(de…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DEBUG:featurize.py:rep_same_lengths:ID of params: 5203122928\n",
"DEBUG:featurize.py:rep_same_lengths:sum of embedded_seqs: 46.749168395996094\n",
"DEBUG:layers.py:apply_fun:id of params: 5203122928\n",
"DEBUG:layers.py:apply_fun:sum of inputs: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch ID of params: 5203122928\n",
"DEBUG:layers.py:mlstm1900_batch sum of batch: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:layers.py:mlstm1900_batch sum of h_final: Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_final: 1133.589599609375\n",
"DEBUG:featurize.py:rep_same_lengths:sum of h_avg: 1099.611328125\n"
]
}
],
"source": [
"from jax_unirep import get_reps\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# params = load_params(folderpath=\"20200707/iter_1/\")[0]\n",
"mut_seq = ['PROTEIN', 'PROTEI', 'PROTEIN']\n",
"\n",
"logger.debug(\"\")\n",
"logger.debug(\":::::::NO CUSTOM PARAMS:::::::\")\n",
"mut_rep = get_reps(mut_seq)[0]\n",
"# mut_rep_rerun = get_reps(mut_seq)[0]\n",
"\n",
"for i in range(6):\n",
" logger.debug(\"\")\n",
" logger.debug(f\":::::::CALL {i} WITH CUSTOM:::::::\")\n",
" h_avg = get_reps(mut_seq, params=params[0])[0]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This _oscillatory_ behaviour indicates to me that there's some cache-ing that might be going wrong.\n",
"\n",
"I've spent a few hours trying to narrow down the bug. The main hypothesis I've tried is to check that the ID of the params being loaded are correct, as I wasn't sure whether they were being willy-nilly loaded elsewhere, but this turned out to be a dead-end: the IDs of the params passed in stays consistent over time, which means that the same params object is being used over loop runs. \n",
"\n",
"At this point, the parts that I don't know how to introspect into are the jax parts, i.e. vmap and lax.scan.\n",
"\n",
"This is where I might have to ping for help with the jax devs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some source code surfaced up here for convenience that might help:"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mmLSTM1900_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlax_numpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlax_numpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlax_numpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlax_numpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlax_numpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mSource:\u001b[0m \n",
"\u001b[0;32mdef\u001b[0m \u001b[0mmLSTM1900_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n",
"\u001b[0;34m LSTM layer implemented according to UniRep,\u001b[0m\n",
"\u001b[0;34m found here:\u001b[0m\n",
"\u001b[0;34m https://github.com/churchlab/UniRep/blob/master/unirep.py#L43,\u001b[0m\n",
"\u001b[0;34m for a batch of data.\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m This function processes a single embedded sequence,\u001b[0m\n",
"\u001b[0;34m passed in as a two dimensional array,\u001b[0m\n",
"\u001b[0;34m with number of rows being number of sequence positions,\u001b[0m\n",
"\u001b[0;34m and the number of columns being the embedding of each sequence letter.\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m :param params: All weights and biases for a single\u001b[0m\n",
"\u001b[0;34m mLSTM1900 RNN cell.\u001b[0m\n",
"\u001b[0;34m :param batch: One sequence embedded in a (n, 10) matrix,\u001b[0m\n",
"\u001b[0;34m where `n` is the number of sequences\u001b[0m\n",
"\u001b[0;34m :returns:\u001b[0m\n",
"\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mh_t\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"wmh\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mc_t\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"wmh\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"mlstm1900_batch ID of params: {id(params)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mstep_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmLSTM1900_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"mlstm1900_batch sum of batch: {np.sum(batch)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mh_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mstep_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_t\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"mlstm1900_batch sum of h_final: {np.sum(h_final)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mh_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFile:\u001b[0m ~/github/software/jax-unirep/jax_unirep/layers.py\n",
"\u001b[0;31mType:\u001b[0m function\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from jax_unirep.layers import mLSTM1900_batch, mLSTM1900\n",
"from jax_unirep.featurize import rep_same_lengths\n",
"\n",
"mLSTM1900_batch??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Does `lax.scan` do some caching?**"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mmLSTM1900\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0moutput_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1900\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mW_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0mvariance_scaling\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mlocals\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x1281b6560\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0mnormal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mlocals\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x1281b6ef0\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mSource:\u001b[0m \n",
"\u001b[0;32mdef\u001b[0m \u001b[0mmLSTM1900\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1900\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mglorot_normal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n",
"\u001b[0;34m mLSTM cell from the UniRep paper, stax compatible\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m This function works on a per-sequence basis,\u001b[0m\n",
"\u001b[0;34m meaning that mapping over batches of sequences\u001b[0m\n",
"\u001b[0;34m needs to happen outside this function, like this:\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m .. code-block:: python\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m def apply_fun_vmapped(x):\u001b[0m\n",
"\u001b[0;34m return apply_fun(params=params, inputs=x)\u001b[0m\n",
"\u001b[0;34m h_final, c_final, outputs = vmap(apply_fun_vmapped)(emb_seqs)\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m It returns the average hidden, final hidden and final cell states\u001b[0m\n",
"\u001b[0;34m of the mlstm.\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minit_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n",
"\u001b[0;34m Initialize parameters for mLSTM1900\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m output_dim:\u001b[0m\n",
"\u001b[0;34m mlstm cell size -> (1900,)\u001b[0m\n",
"\u001b[0;34m input_shape:\u001b[0m\n",
"\u001b[0;34m one embedded sequence -> (n_letters, 10)\u001b[0m\n",
"\u001b[0;34m output_shape:\u001b[0m\n",
"\u001b[0;34m one sequence in 1900 dims -> (n_letters, 1900)\u001b[0m\n",
"\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0minput_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk4\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mwmx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwmh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mW_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mW_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mW_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dim\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mW_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dim\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk4\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mgmx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgmh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mb_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_dim\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"wmx\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mwmx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"wmh\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mwmh\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"wx\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mwx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"wh\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mwh\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"gmx\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mgmx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"gmh\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mgmh\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"gx\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mgx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"gh\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mgh\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"b\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0moutput_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mapply_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"apply_fun:id of params: {id(params)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"apply_fun:sum of inputs: {np.sum(inputs)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmLSTM1900_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minit_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFile:\u001b[0m ~/github/software/jax-unirep/jax_unirep/layers.py\n",
"\u001b[0;31mType:\u001b[0m function\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mLSTM1900??"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mrep_same_lengths\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mSource:\u001b[0m \n",
"\u001b[0;32mdef\u001b[0m \u001b[0mrep_same_lengths\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mseqs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n",
"\u001b[0;34m This function generates representations of protein sequences that have the same length,\u001b[0m\n",
"\u001b[0;34m by passing them through the UniRep mLSTM.\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m :param seqs: A list of same length sequences as strings.\u001b[0m\n",
"\u001b[0;34m If passing only a single sequence, it also needs to be passed inside a list.\u001b[0m\n",
"\u001b[0;34m :returns: A tuple of np.arrays containing the reps.\u001b[0m\n",
"\u001b[0;34m Each `np.array` has shape (n_sequences, 1900).\u001b[0m\n",
"\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmLSTM1900\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0membedded_seqs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_embeddings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"rep_same_lengths:ID of params: {id(params)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"rep_same_lengths:sum of embedded_seqs: {np.sum(embedded_seqs)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mapply_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapply_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mh_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapply_fun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membedded_seqs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"rep_same_lengths:sum of h_final: {np.sum(h_final)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mh_avg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"rep_same_lengths:sum of h_avg: {np.sum(h_avg)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh_avg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFile:\u001b[0m ~/github/software/jax-unirep/jax_unirep/featurize.py\n",
"\u001b[0;31mType:\u001b[0m function\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"rep_same_lengths??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Could it be that `vmap` does some caching?**"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mget_reps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mseqs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mDict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mSource:\u001b[0m \n",
"\u001b[0;32mdef\u001b[0m \u001b[0mget_reps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mseqs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mDict\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"\u001b[0m\n",
"\u001b[0;34m This function generates representations of protein sequences using the\u001b[0m\n",
"\u001b[0;34m 1900 hidden-unit mLSTM model with pre-trained weights from the UniRep\u001b[0m\n",
"\u001b[0;34m paper (https://github.com/churchlab/UniRep).\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m Each element of the output 3-tuple is a `np.array`\u001b[0m\n",
"\u001b[0;34m of shape (n_input_sequences, 1900):\u001b[0m\n",
"\u001b[0;34m - `h_avg`: Average hidden state of the mLSTM over the whole sequence.\u001b[0m\n",
"\u001b[0;34m - `h_final`: Final hidden state of the mLSTM\u001b[0m\n",
"\u001b[0;34m - `c_final`: Final cell state of the mLSTM\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m You should not use this function\u001b[0m\n",
"\u001b[0;34m if you want to do further JAX-based computations\u001b[0m\n",
"\u001b[0;34m on the output vectors!\u001b[0m\n",
"\u001b[0;34m In that case, the `DeviceArray` futures returned by `mLSTM1900`\u001b[0m\n",
"\u001b[0;34m should be passed directly into the next step\u001b[0m\n",
"\u001b[0;34m instead of converting them to `np.array`s.\u001b[0m\n",
"\u001b[0;34m The conversion to `np.array`s is done\u001b[0m\n",
"\u001b[0;34m in the dispatched `rep_x_lengths` functions\u001b[0m\n",
"\u001b[0;34m to force python to wait with returning the values\u001b[0m\n",
"\u001b[0;34m until the computation is completed.\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m The keys of the ``params`` dictionary must be:\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m b, gh, gmh, gmx, gx, wh, wmh, wmx, wx\u001b[0m\n",
"\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m :param seqs: A list of sequences as strings or a single string.\u001b[0m\n",
"\u001b[0;34m :param params: A dictionary of mLSTM1900 weights.\u001b[0m\n",
"\u001b[0;34m :returns: A 3-tuple of `np.array`s containing the reps.\u001b[0m\n",
"\u001b[0;34m Each `np.array` has shape (n_sequences, 1900).\u001b[0m\n",
"\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"ID of params passed into get_reps: {id(params)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"params is None, loading default params.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_params_1900\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"params were loaded from user.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"ID of params after checking None: {id(params)}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# Check that params have correct keys and shapes\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mvalidate_mLSTM1900_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# If single string sequence is passed, package it into a list\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mseqs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# Make sure list is not empty\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mSequenceLengthsError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot pass in empty list of sequences.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# Differentiate between two cases:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# 1. All sequences in the list have the same length\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# 2. There are sequences of different lengths in the list: we right-pad before calculating reps\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mseq_lengths\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mseqs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq_lengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mseqs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mright_pad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq_lengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mh_avg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrep_same_lengths\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseqs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mh_avg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh_final\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_final\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFile:\u001b[0m ~/github/software/jax-unirep/jax_unirep/featurize.py\n",
"\u001b[0;31mType:\u001b[0m function\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"get_reps??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax-unirep",
"language": "python",
"name": "jax-unirep"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment