Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created June 13, 2022 19:55
Show Gist options
  • Save dominicrufa/8c335a0b0f0e66717368fdc027488e2b to your computer and use it in GitHub Desktop.
Save dominicrufa/8c335a0b0f0e66717368fdc027488e2b to your computer and use it in GitHub Desktop.
host:guest force-matching template (to place on gpu)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "790bf47e-8a05-4e49-a2e4-a795032e09ee",
"metadata": {},
"source": [
"# G0\n",
"Let's fit a model for G0!"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "326e140d-7da3-4185-951b-e8fe66858f44",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dominic/anaconda3/envs/aquaregia/lib/python3.10/site-packages/google/colab/data_table.py:30: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" from IPython.utils import traitlets as _traitlets\n",
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"import numpy as np\n",
"from functools import partial\n",
"import haiku as hk\n",
"import functools\n",
"from aquaregia.utils import Array, ArrayTree\n",
"from aquaregia.tfn import DEFAULT_EPSILON"
]
},
{
"cell_type": "markdown",
"id": "34fcb9a1-4df9-4817-8083-cacc3dfa49ee",
"metadata": {},
"source": [
"load data..."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8af05ad6-c13e-4106-9b0f-23d79b0c1c8e",
"metadata": {},
"outputs": [],
"source": [
"from openmm import XmlSerializer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "84d5ddca-5a59-47ec-a26b-63a9bbca9352",
"metadata": {},
"outputs": [],
"source": [
"with open('data_G0/complex.5.decoupled.iter0.xml', 'r') as infile:\n",
" xml_readable = infile.read()\n",
"xml_deserialized = XmlSerializer.deserialize(xml_readable)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "df72eacc-01f9-4351-9cb9-3823e33f4397",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f66a237ae20> >,\n",
" <openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f66a237aee0> >,\n",
" <openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f66a237af70> >,\n",
" <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f66a237b000> >,\n",
" <openmm.openmm.CMMotionRemover; proxy of <Swig Object of type 'OpenMM::CMMotionRemover *' at 0x7f66a237b090> >,\n",
" <openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x7f66a237b120> >]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xml_deserialized.getForces()"
]
},
{
"cell_type": "markdown",
"id": "fcf4ef83-d7dd-4dc9-8fd0-c46cd36cdabe",
"metadata": {},
"source": [
"gather data."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c2fe9b90-cf68-48fb-9349-7fc7276913c5",
"metadata": {},
"outputs": [],
"source": [
"def gather_data(path_iterable):\n",
" \n",
" full_dict = {'partial_charges': [], 'positions': [], 'forces': [], 'indices': [], 'energies': [], 'masses': []}\n",
" for idx, path in enumerate(path_iterable):\n",
" _dict = jnp.load(path, allow_pickle=True)['arr_0'].item()\n",
" for key, val in _dict.items():\n",
" #print(val.shape)\n",
" full_dict[key].append(val)\n",
" \n",
" \n",
" out_dict = {key: None for key in full_dict.keys()}\n",
" for key, val in full_dict.items():\n",
" out_dict[key] = jnp.concatenate([_q[jnp.newaxis, ...] for _q in val])\n",
" \n",
" # reshape positions and forces\n",
" num_files, samples_per_file, num_atoms, three = out_dict['positions'].shape\n",
" _reshape = (num_files*samples_per_file,num_atoms, three)\n",
" for key in ['positions', 'forces']:\n",
" val = out_dict[key]\n",
" out_dict[key] = val.reshape(_reshape)\n",
" \n",
" return out_dict"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ba216cb2-d90f-46d4-8bf5-14d2c389e2d0",
"metadata": {},
"outputs": [],
"source": [
"data = gather_data([f\"data_G0/complex.5.decoupled.iter{i}.npz\" for i in range(12)])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8001e43b-af45-49c6-a452-3c844a728ee5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(150000, 144, 3)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data['positions'].shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "bf166b42-f2e5-408a-9a51-05cf3b42ec2f",
"metadata": {},
"outputs": [],
"source": [
"WIDTH=4\n",
"BATCH_SIZE=16\n",
"N_TRAIN_SAMPLES = int(data['positions'].shape[0])\n",
"N_VALIDATE_SAMPLES = int(data['positions'].shape[0]*.0001)\n",
"N_TRAIN_STEPS = 5000\n",
"LEARNING_RATE = 1e-3\n",
"MAX_L = 1"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d274f12c-f29b-4754-b0d4-8d81f569cf81",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N_VALIDATE_SAMPLES"
]
},
{
"cell_type": "markdown",
"id": "53eb5461-ac7a-4243-8f5b-74b24f21b7b1",
"metadata": {},
"source": [
"make a base energy fn."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "bf4e60d2-dee0-4172-891a-01e01991b32b",
"metadata": {},
"outputs": [],
"source": [
"from aquaregia.openmm import make_canonical_energy_fn\n",
"from aquaregia.utils import get_vacuum_neighbor_list\n",
"import jax_md\n",
"num_particles = data['indices'][0].shape[0]\n",
"displacement_fn, shift_fn = jax_md.space.free()\n",
"vacuum_neighbor_list = get_vacuum_neighbor_list(num_particles)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8b0949e1-47e2-433c-b56e-837374e1d3aa",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dominic/anaconda3/envs/aquaregia/lib/python3.10/site-packages/src-0.1.0-py3.10.egg/aquaregia/openmm.py:569: UserWarning: force CMMotionRemover is not currently handled. Omitting\n",
" warnings.warn(f\"force {force_name} is not currently handled. Omitting\")\n",
"/home/dominic/anaconda3/envs/aquaregia/lib/python3.10/site-packages/src-0.1.0-py3.10.egg/aquaregia/openmm.py:569: UserWarning: force MonteCarloBarostat is not currently handled. Omitting\n",
" warnings.warn(f\"force {force_name} is not currently handled. Omitting\")\n"
]
}
],
"source": [
"u_params, u_fn = make_canonical_energy_fn(system = xml_deserialized, \n",
" displacement_fn=displacement_fn,\n",
" particle_indices = data['indices'][0].tolist(),\n",
" allow_constraints=False,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8538ade9-9bf5-45c9-8774-0bd0baf1a2e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(130.21396248, dtype=float64)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"u_fn(data['positions'][10], vacuum_neighbor_list, u_params)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a3d10734-f079-4be3-a003-757970d750c7",
"metadata": {},
"outputs": [],
"source": [
"from aquaregia.nnp import make_energy_module\n",
"\n",
"\n",
"conv_shapes_dict = {0: {'output_sizes': [WIDTH,WIDTH], 'activation': jax.nn.swish},\n",
" 1: {'output_sizes': [WIDTH,WIDTH], 'activation': jax.nn.swish}}\n",
"tf_mlp_shapes_dict = {0: {0: {'output_sizes': [WIDTH,WIDTH], 'nonlinearity': jax.nn.swish},\n",
" 1: {'output_sizes': [WIDTH,WIDTH], 'nonlinearity': jax.nn.swish}},\n",
" 1: {0: {'output_sizes': [WIDTH,1], 'nonlinearity': jax.nn.swish}}\n",
" }\n",
"SinusoidalBasis_kwargs = {'r_switch': .4,\n",
" 'r_cut': .5,\n",
" 'basis_init' : hk.initializers.Constant(constant=jnp.linspace(0.1, 8, WIDTH))}\n",
"\n",
"if MAX_L == 0:\n",
" del tf_mlp_shapes_dict[0][1]\n",
"\n",
"_energy_fn, _constructor = make_energy_module(max_L = MAX_L, \n",
" num_particles=num_particles,\n",
" conv_shapes_dict=conv_shapes_dict,\n",
" tf_mlp_shapes_dict=tf_mlp_shapes_dict,\n",
" mask_output=False)\n",
"\n",
"energy_fn = hk.without_apply_rng(hk.transform(_energy_fn))\n",
"init_positions = random.normal(random.PRNGKey(25), shape=(num_particles, 3))\n",
"feature_dict = {0:random.normal(random.PRNGKey(253), shape=(num_particles, WIDTH,1)), \n",
" 1: None}\n",
"\n",
"init_params = energy_fn.init(random.PRNGKey(347), init_positions, feature_dict, epsilon = DEFAULT_EPSILON, mask_val = 0., lifting_val = 0.)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b68ba57e-07e6-4ad6-a1fc-8df850d576c2",
"metadata": {},
"outputs": [],
"source": [
"u_prior = partial(u_fn, neighbor_list = vacuum_neighbor_list, parameters = u_params)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "95fa1585-baa0-4ff1-8df0-b0402236d8fc",
"metadata": {},
"outputs": [],
"source": [
"def total_energy(positions, nn_parameters, input_feature_dict):\n",
" prior_e = u_prior(positions)\n",
" implicit_e = energy_fn.apply(nn_parameters, positions, input_feature_dict, DEFAULT_EPSILON, 0., 0.)\n",
" return prior_e + implicit_e\n",
"\n",
"in_f = jnp.repeat(data['partial_charges'][0][..., jnp.newaxis], repeats=WIDTH, axis=1)\n",
"INPUT_FEATURES_DICT = {0: in_f[..., jnp.newaxis], 1: None}\n",
"\n",
"total_energy = functools.partial(total_energy, input_feature_dict = INPUT_FEATURES_DICT)\n",
"total_neg_forces = jax.grad(total_energy)\n",
"prior_neg_forces = jax.grad(u_prior)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "eee36b5b-8884-4b75-a78a-1d635e4a661a",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "50f25b46-d897-48b8-a4dd-8b3f79b3c589",
"metadata": {},
"outputs": [],
"source": [
"indices = np.arange(data['positions'].shape[0])\n",
"np.random.shuffle(indices)\n",
"\n",
"random_positions = Array(data['positions'][indices])\n",
"random_forces = Array(data['forces'][indices])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "b4ce9473-569d-432a-a754-a8ff0cdac621",
"metadata": {},
"outputs": [],
"source": [
"FORCES_MEAN, FORCES_STD = jnp.mean(random_forces), jnp.std(random_forces)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a795fb79-ade1-4a32-8276-37e996bc60ad",
"metadata": {},
"outputs": [],
"source": [
"def singular_loss(nn_energy_params, xs, force_matcher):\n",
" \"\"\"get the force matching loss of a single snapshot; inputs are in canonical units\"\"\"\n",
" nn_forces = -1. * total_neg_forces(xs, nn_energy_params)\n",
" squared_force_differences = jnp.square((nn_forces - force_matcher) / FORCES_STD)\n",
" return jnp.mean(squared_force_differences)\n",
"\n",
"#just batch the singular loss fn\n",
"def batch_loss(nn_energy_params, batch_xs, batch_force_matcher):\n",
" batch_differences = jax.vmap(singular_loss, in_axes=(None, 0,0))(nn_energy_params, batch_xs, batch_force_matcher)\n",
" return jnp.mean(batch_differences)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b662fc43-8629-4c19-90f3-6a82441a13c2",
"metadata": {},
"outputs": [],
"source": [
"def get_train(train_positions,\n",
" validate_positions,\n",
" train_forces, \n",
" validate_forces, \n",
" batch_size,\n",
" optimizer,\n",
" validate_frequency = 100):\n",
" n_train_samples = len(train_positions)\n",
" n_validate_samples = len(validate_positions)\n",
" \n",
" assert train_positions.shape == train_forces.shape\n",
" if validate_positions is not None: \n",
" assert validate_positions.shape == validate_forces.shape\n",
" \n",
" assert batch_size < n_train_samples\n",
" \n",
" init_fun, update_fun, get_params = optimizer\n",
" \n",
" def step(seed, step, opt_state):\n",
" train_seed, val_seed = random.split(seed)\n",
" train_indices = random.choice(train_seed, jnp.arange(n_train_samples, dtype=jnp.int32), shape=(batch_size,), replace=False)\n",
" batch_train_positions, batch_train_forces = train_positions[train_indices], train_forces[train_indices]\n",
" #print(batch_train_positions.shape, batch_train_forces.shape)\n",
" mean_val, grads = jax.value_and_grad(batch_loss)(get_params(opt_state), batch_xs = batch_train_positions, batch_force_matcher = batch_train_forces)\n",
" opt_state = update_fun(step, grads, opt_state)\n",
" return mean_val, opt_state\n",
" \n",
" \n",
" def train(init_params, seed, num_iters):\n",
" import tqdm\n",
" train_values, validate_values = [], []\n",
" opt_state = init_fun(init_params)\n",
" trange = tqdm.trange(num_iters, desc=f\"Bar desc\", leave=True)\n",
" for i in trange:\n",
" run_seed, val_seed, seed = random.split(seed, num=3)\n",
" train_value, opt_state = jax.jit(step)(seed = run_seed, step=i, opt_state = opt_state)\n",
" \n",
" #validation\n",
" if i % validate_frequency == 0:\n",
" _validate_values = jax.jit(jax.vmap(singular_loss, in_axes=(None, 0, 0)))(\n",
" get_params(opt_state), \n",
" validate_positions, \n",
" validate_forces)\n",
" \n",
" # appends\n",
" scaled_train_value = train_value * FORCES_STD**2/100.\n",
" scaled_validate_values = _validate_values * FORCES_STD**2/100.\n",
" train_values.append(scaled_train_value)\n",
" validate_values.append(scaled_validate_values)\n",
" \n",
" # trange\n",
" trange.set_description(f\"test loss: {scaled_train_value}; validate_loss: {jnp.mean(scaled_validate_values)} std {jnp.std(scaled_validate_values)}\")\n",
" trange.refresh() # halt this?\n",
" \n",
" return opt_state, Array(train_values), Array(validate_values)\n",
" \n",
" return train, step "
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "84f6c61d-6c02-4eec-9c56-5efb8fcc21b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dominic/anaconda3/envs/aquaregia/lib/python3.10/site-packages/jax/experimental/optimizers.py:28: FutureWarning: jax.experimental.optimizers is deprecated, import jax.example_libraries.optimizers instead\n",
" warnings.warn('jax.experimental.optimizers is deprecated, '\n"
]
}
],
"source": [
"from jax.experimental.optimizers import adam"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "eec377c2-1184-4a39-8356-6110ecf7c497",
"metadata": {},
"outputs": [],
"source": [
"train, step = get_train(train_positions = random_positions[:N_TRAIN_SAMPLES],\n",
" validate_positions = random_positions[N_TRAIN_SAMPLES: N_TRAIN_SAMPLES + N_VALIDATE_SAMPLES],\n",
" train_forces = random_forces[:N_TRAIN_SAMPLES], \n",
" validate_forces = random_forces[N_TRAIN_SAMPLES: N_TRAIN_SAMPLES + N_VALIDATE_SAMPLES], \n",
" batch_size = BATCH_SIZE,\n",
" optimizer = jax.experimental.optimizers.adam(LEARNING_RATE))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9217d1db-4141-4cfd-905a-32c36db07624",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"test loss: 95.06356365656171; validate_loss: nan std nan: 8%|█████▊ | 423/5000 [27:09<4:16:04, 3.36s/it]"
]
}
],
"source": [
"new_opt_state, trains, validates = train(init_params, random.PRNGKey(342), N_TRAIN_STEPS)"
]
},
{
"cell_type": "markdown",
"id": "cc63eb62-cbc1-4a99-8b83-019e95787806",
"metadata": {},
"source": [
"need to place this on gpu."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4edf3a67-0cde-48ff-a079-99ff6b35803d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment