Skip to content

Instantly share code, notes, and snippets.

@rmcgibbo
Last active February 19, 2016 09:49
Show Gist options
  • Save rmcgibbo/a1d1b932aae626191ca6 to your computer and use it in GitHub Desktop.
Save rmcgibbo/a1d1b932aae626191ca6 to your computer and use it in GitHub Desktop.
Multiple trajectories + checkpointing with OpenMM
#!/bin/bash
#SBATCH --partition=long
#SBATCH --job-name=<job-name>
#SBATCH --output=<job-name>-%j.stdout
#SBATCH --error=<job-name>-%j.stderr
#SBATCH --time=24:00:00
#SBATCH --qos=normal
#SBATCH --nodes=1
#SBATCH --mail-type=ALL
#SBATCH --mail-user=<your_email>
#SBATCH --ntasks-per-node=10
#SBATCH --gres=gpu:6
export PYTHONUNBUFFERED="true"
module load cuda60/toolkit/6.0.37
cd $SLURM_SUBMIT_DIR
HOSTNAME=`hostname`
for i in `seq 0 5`; do
cd $SLURM_SUBMIT_DIR
mkdir -p $HOSTNAME-$i
cd $HOSTNAME-$i
python ../simulate.py $i ../setup/system.xml ../setup/state.xml ../setup/integrator.xml ../setup/initial.pdb \
> simulate.$SLURM_JOB_ID.$i.stdout \
2> simulate.$SLURM_JOB_ID.$i.stderr &
done
wait
from __future__ import print_function
import os
import re
import glob
import socket
import argparse
from simtk.openmm import app
import simtk.openmm as mm
from simtk import unit
for i in range(mm.Platform.getNumPlatforms()):
print('Platform', mm.Platform.getPlatform(i).getName())
reportInterval = int((100*unit.picosecond) / (2*unit.femtosecond))
runLength = int((1*unit.microsecond) / (2*unit.femtosecond))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('gpuindex')
parser.add_argument('system')
parser.add_argument('state')
parser.add_argument('integrator')
parser.add_argument('topologypdb')
args = parser.parse_args()
with open(args.system) as f:
system = mm.XmlSerializer.deserialize(f.read())
with open(args.integrator) as f:
integrator = mm.XmlSerializer.deserialize(f.read())
with open(args.state) as f:
state = mm.XmlSerializer.deserialize(f.read())
topology = app.PDBFile(args.topologypdb).topology
platform = mm.Platform.getPlatformByName('CUDA')
properties = {'CudaPrecision': 'mixed', 'CudaDeviceIndex': args.gpuindex}
traj_number = 0
simulation = app.Simulation(topology, system, integrator, platform=platform,
platformProperties=properties)
simulation.context.setState(state)
existing_trajectories = glob.glob('trajectory-*.dcd')
if len(existing_trajectories) == 0:
traj_number = 0
else:
last_traj = max(int(re.search('\d+', e).group()) for e in existing_trajectories)
traj_number = last_traj + 1
assert os.path.exists('checkpoint-%08d.chk' % last_traj)
assert os.path.exists('trajectory-%08d.dcd' % last_traj)
assert os.path.exists('statedata-%08d.log' % last_traj)
with open('checkpoint-%08d.chk' % last_traj) as f:
simulation.context.loadCheckpoint(f.read())
simulation.reporters.append(app.CheckpointReporter('checkpoint-%08d.chk' % traj_number, reportInterval))
simulation.reporters.append(app.DCDReporter('trajectory-%08d.dcd' % traj_number, reportInterval))
f = open('statedata-%08d.log' % traj_number, 'w', 0)
simulation.reporters.append(app.StateDataReporter(f, reportInterval,
step=True, time=True, potentialEnergy=True, totalEnergy=True, temperature=True,
density=True, remainingTime=True, speed=True, separator='\t', totalSteps=runLength))
simulation.step(runLength)
if __name__ == '__main__':
main()
@bougui505
Copy link

I didn't know the CheckpointReporter. Is it possible to get a particular frame from the chk file to restart a simulation with?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment