Skip to content

Instantly share code, notes, and snippets.

@rajarshi
Created April 2, 2018 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rajarshi/34257786ce41233773f8f09844df8f97 to your computer and use it in GitHub Desktop.
Save rajarshi/34257786ce41233773f8f09844df8f97 to your computer and use it in GitHub Desktop.
import chemaxon.struc.Molecule;
import chemaxon.util.MolHandler;
import gov.nih.ncgc.algo.graph.VFLib2;
import gov.nih.ncgc.descriptor.MolecularFramework;
import gov.nih.ncgc.util.ChemUtil;
import gov.nih.ncgc.util.MolStandardizer;
import org.apache.commons.dbcp2.BasicDataSource;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* A one line summary.
*
* @author Rajarshi Guha
*/
public class GenerateFragment {
static final Logger logger =
Logger.getLogger(GenerateFragment.class.getName());
static private final Molecule DONE = new Molecule();
class FragmentRunner implements Runnable {
String name;
MolecularFramework mf = new MolecularFramework();
PreparedStatement pstm1, pstm2, pstm3;
Connection con;
int count = 0;
FragmentRunner(String name) throws SQLException {
this.name = name;
mf.setGenerateAtomMapping(true);
mf.setAllowBenzene(false);
mf.setNumThreads(1);
mf.setKeepStereo(false);
//mf.setKeepFusedRings(true); // generate only ring systems
}
public void run() {
Thread.currentThread().setName(name);
try {
logger.info(name + ": waiting for queue...");
for (Molecule mol; (mol = queue.take()) != DONE; ) {
process(mol);
}
closeSQL();
logger.info("Fragment thread " + name + " is done!");
} catch (Exception ex) {
logger.log(Level.SEVERE, "Fragment thread " + name, ex);
}
}
void newSQL() throws SQLException {
logger.info(name + ": ** creating new SQL objects; count="
+ count + " **");
con = getConnection();
pstm1 = con.prepareStatement
("insert into fragment_class (smiles,acount,bcount,symmetry,"
+ "complexity,hashkey1,hashkey2,hashkey3) "
+ "values (?,?,?,?,?,?,?,?)", new String[]{"class_id"});
pstm2 = con.prepareStatement
("select class_id,hashkey3 from fragment_class "
+ "where smiles = ?");
pstm3 = con.prepareStatement
("insert into fragment_instances (molregno,chembl_id,class_id"
+ ",smiles,hashkey1,hashkey2,hashkey3,adiff) values"
+ "(?,?,?,?,?,?,?,?)");
}
void checkSQL() throws SQLException {
if (con == null) {
newSQL();
} else {
try {
if (con.isClosed()) {
newSQL();
}
Statement stm = con.createStatement();
stm.close();
} catch (SQLException ex) {
newSQL();
}
}
}
void closeSQL() throws SQLException {
if (pstm1 != null) {
pstm1.close();
}
if (pstm2 != null) {
pstm2.close();
}
if (pstm3 != null) {
pstm3.close();
}
if (con != null) {
con.close();
}
}
void process(Molecule mol) throws Exception {
checkSQL();
if (count % 500 == 0) {
logger.info(name + ": processing " + mol.getName()
+ "; " + count + " fragments processed!");
}
mf.setMolecule(mol);
mf.run();
/*
logger.info(Thread.currentThread().getName()
+": "+mol.getName() + " "+mol);
*/
int i = 1;
for (Enumeration<Molecule> en = mf.getFragments();
en.hasMoreElements(); ++i) {
Molecule f = en.nextElement();
f.setProperty("MOLREGNO", mol.getProperty("MOLREGNO"));
f.setProperty("CHEMBL_ID", mol.getProperty("CHEMBL_ID"));
String smiles = MolStandardizer.canonicalSMILES(f, false);
f.setProperty("SMILES", smiles);
f.setProperty("HASHKEY", f.getName());
f.setName(mol.getName() + "-" + i);
try {
load(f, mol.getAtomCount() - f.getAtomCount());
} catch (Exception ex) {
logger.log(Level.SEVERE, "Failed to load fragment "
+ smiles + " for " + mol.getName(), ex);
}
}
}
void load(Molecule m, int adiff) throws Exception {
String smiles = m.getProperty("SMILES");
long molregno = Long.parseLong(m.getProperty("MOLREGNO"));
String chembl = m.getProperty("CHEMBL_ID");
String hashkey = m.getProperty("HASHKEY");
String[] hk = hashkey.split("-");
String hk1 = hk[0], hk2 = hk[0] + hk[1], hk3 = hk[0] + hk[1] + hk[2];
long classId = 0;
try {
pstm1.setString(1, smiles);
pstm1.setInt(2, m.getAtomCount());
pstm1.setInt(3, m.getBondCount());
{
Molecule f = m.cloneMolecule();
f.aromatize();
int complexity = ChemUtil.complexity(f);
VFLib2 vf = VFLib2.automorphism(f);
int[][] hits = vf.findAll();
pstm1.setInt(4, hits.length);
pstm1.setInt(5, complexity);
}
pstm1.setString(6, hk1);
pstm1.setString(7, hk2);
pstm1.setString(8, hk3);
pstm1.executeUpdate();
ResultSet rset = pstm1.getGeneratedKeys();
if (rset.next()) {
classId = rset.getLong(1);
}
rset.close();
} catch (SQLException ex) {
//ex.printStackTrace();
pstm2.setString(1, smiles);
ResultSet rset = pstm2.executeQuery();
if (rset.next()) {
classId = rset.getLong(1);
String key = rset.getString(2);
if (!key.equals(hk3)) {
logger.warning("Hash key mismatch; expecting " + key
+ " but got " + hk3);
}
}
rset.close();
}
if (classId == 0) {
logger.log(Level.SEVERE, name + ": ** Can't get class id for "
+ molregno + "; " + smiles);
} else {
pstm3.setLong(1, molregno);
pstm3.setString(2, chembl);
pstm3.setLong(3, classId);
pstm3.setString(4, m.toFormat("smiles:q"));
pstm3.setString(5, hk1);
pstm3.setString(6, hk2);
pstm3.setString(7, hk3);
pstm3.setInt(8, adiff);
int r = pstm3.executeUpdate();
if (r > 0) {
if (++count % 1000 == 0) {
logger.info(name + ": " + String.format("%1$7d", count)
+ " " + String.format("%1$6d", classId)
+ " " + hashkey);
//con.commit();
}
}
}
}
}
class FragmentationTask implements Runnable {
public void run() {
try {
doFragmentation();
} catch (SQLException ex) {
logger.log(Level.SEVERE, "Fragmentation task failed!", ex);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
class FragmentStatsTask implements Runnable {
public void run() {
try {
doFragmentStats();
} catch (SQLException ex) {
logger.log(Level.SEVERE, "Fragment stats task failed!", ex);
}
}
}
private BasicDataSource ds;
private BlockingQueue<Molecule> queue;
private BlockingQueue<String> molregnos;
private ExecutorService service;
private int nthreads = 4;
private Future[] tasks;
private AtomicBoolean started = new AtomicBoolean();
public void init(String dbUrl, int qsize, int nthreads) throws IOException {
ds = new BasicDataSource();
ds.setDriverClassName("org.postgresql.Driver");
ds.setUrl(dbUrl);
queue = new ArrayBlockingQueue<Molecule>(qsize);
this.nthreads = nthreads;
service = Executors.newCachedThreadPool();
started.set(false);
BufferedReader reader = new BufferedReader(new FileReader("molregnos.txt"));
String line;
ArrayList<String> lines = new ArrayList<>();
while ((line = reader.readLine()) != null)
lines.add(line.trim());
molregnos = new ArrayBlockingQueue<String>(lines.size());
molregnos.addAll(lines);
logger.info("Read "+molregnos.size()+" molregno's");
}
synchronized Connection getConnection() throws SQLException {
return ds.getConnection();
}
void kickStart() throws Exception {
tasks = new Future[nthreads];
for (int n = 0; n < nthreads; ++n) {
tasks[n] = service.submit
(new GenerateFragment.FragmentRunner("FragmentThread-" + n));
}
service.submit(new GenerateFragment.FragmentationTask());
}
private void run(String connectURI, int nthreads) throws IOException {
init(connectURI, 20000, nthreads);
try {
kickStart();
} catch (Exception ex) {
ex.printStackTrace();
return;
}
int running = tasks.length;
for (Future f : tasks) {
if (f.isDone()) {
--running;
}
}
logger.info("Current time: " + new java.util.Date());
logger.info("Fragment queue: " + queue.size());
logger.info("Queue remaining capacity: " + queue.remainingCapacity());
logger.info("Number of threads: " + nthreads
+ " (" + running + " still active)");
try {
Connection con = getConnection();
Statement stm = con.createStatement();
ResultSet rset = stm.executeQuery
("select count(1) from fragment_class");
if (rset.next()) {
logger.info("Fragment classes: " + rset.getString(1));
}
rset.close();
stm.close();
stm = con.createStatement();
rset = stm.executeQuery("select count(1) from fragment_instances");
if (rset.next()) {
logger.info("Fragment instances: " + rset.getString(1));
}
rset.close();
stm.close();
con.close();
} catch (SQLException ex) {
ex.printStackTrace();
}
}
public static void main(String[] args)
throws IOException {
GenerateFragment gf = new GenerateFragment();
int nthreads = Integer.parseInt(args[0]);
String connectURI = System.getProperty("dburi");
gf.run(connectURI, nthreads);
}
void doFragmentation() throws SQLException, IOException, InterruptedException {
Connection con = getConnection();
Statement stm = con.createStatement();
stm.setQueryTimeout(0);
// ResultSet rset = stm.executeQuery
// ("select molregno, molfile, chembl_id, standard_inchi_key from chembl_id_lookup d, compound_structures b "
// + "where not exists (select 1 from fragment_instances c "
// + "where b.molregno = c.molregno) "
//// + "and exists (select 1 from activities_robustz a\n"
//// + "where b.molregno = a.molregno\n"
//// + "and a.molregno = b.molregno\n"
//// + "and a.standard_units = 'nM'\n"
//// //+"and a.STANDARD_FLAG = 1\n"
//// + "and a.STANDARD_TYPE in ('Ki', 'IC50')\n"
//// + "and a.standard_relation = '=')\n"
// + "and d.entity_type = 'COMPOUND' "
// + "and d.entity_id = b.molregno "
// + "and d.status = 'ACTIVE' "
// );
//
logger.info("Prepare to generate fragments...");
MolHandler mh = new MolHandler();
int count = 0;
ResultSet rset = null;
while (!molregnos.isEmpty()) {
String molregno = molregnos.take();
rset = stm.executeQuery("select * from chembl_id_lookup d, compound_structures b " +
" where b.molregno = " + molregno +
" and d.entity_id = b.molregno");
rset.next();
try {
mh.setMolecule(rset.getString("molfile"));
Molecule mol = mh.getMolecule();
mol.setProperty("MOLREGNO", molregno);
mol.setProperty("CHEMBL_ID", rset.getString("chembl_id"));
mol.setProperty("INCHI_KEY",
rset.getString("standard_inchi_key"));
mol.setName(mol.getProperty("CHEMBL_ID"));
int[][] sssr = mol.getSSSR();
if (sssr.length > 0) {
// block if queue full...
if (++count % 10 == 0) {
logger.info("queuing " + mol.getName() + " "
+ queue.size() + "/"
+ queue.remainingCapacity());
}
queue.put(mol);
}
} catch (Exception ex) {
logger.log(Level.SEVERE,
"Can't process compound " + molregno, ex);
}
}
rset.close();
stm.close();
con.close();
logger.info(count + " compound(s) queued!");
for (int i = 0; i < nthreads; ++i) {
try {
queue.put(DONE);
} catch (InterruptedException ex) {
logger.log(Level.SEVERE, "Queue interrupted", ex);
}
}
// if there are new instances, we need to update snr, apt, and
// instances in fragment_class
if (count > 0) {
logger.info("Waiting for fragmentation jobs to finish...");
// waiting for all threads to finish
for (Future f : tasks) {
try {
f.get();
} catch (Exception ex) {
logger.warning("Fragmentation thread interrupted");
}
}
}
doFragmentStats();
}
void doFragmentStats() throws SQLException {
logger.info("Updating aggregated stats for fragment_class");
Connection con = getConnection();
PreparedStatement pstm = con.prepareStatement
("update fragment_class set snr= ?, apt = ?,instances=? "
+ "where class_id = ?");
PreparedStatement pstm2 = con.prepareStatement
("select a.molfile,a.molregno "
+ "from compound_structures a, fragment_instances b where "
+ "b.class_id = ? and a.molregno = b.molregno");
Statement stm = con.createStatement();
ResultSet rset = stm.executeQuery
("select smiles,class_id "
+ "from fragment_class where snr is null");
MolHandler mh = new MolHandler();
int count = 0;
while (rset.next()) {
String smiles = rset.getString(1);
long classId = rset.getLong(2);
try {
mh.setMolecule(smiles);
Molecule mol = mh.getMolecule();
int acount = mol.getAtomCount();
pstm2.setLong(1, classId);
ResultSet rs = pstm2.executeQuery();
Map<String, int[]> processed =
new HashMap<String, int[]>();
double SNR = 0.;
while (rs.next()) {
String molregno = rs.getString("molregno");
mh.setMolecule(rs.getString("molfile"));
mh.aromatize();
mol = mh.getMolecule();
int adiff = mol.getAtomCount() - acount;
if (!processed.containsKey(molregno)) {
SNR += adiff * adiff;
int[] fp = mh.generateFingerprintInInts(16, 2, 6);
processed.put(molregno, fp);
}
}
rs.close();
int[][] N = processed.values().toArray(new int[0][]);
double APT = 0.;
if (N.length > 1) {
for (int i = 0; i < N.length; ++i) {
int Ni = 0;
for (int k = 0; k < N[i].length; ++k) {
Ni += Integer.bitCount(N[i][k]);
}
for (int j = 0; j < N.length; ++j) {
if (i != j) {
int Nj = 0, Nij = 0;
for (int k = 0; k < N[j].length; ++k) {
Nj += Integer.bitCount(N[j][k]);
Nij += Integer.bitCount
(N[i][k] & N[j][k]);
}
APT += (double) Nij / (Ni + Nj - Nij);
}
}
}
APT /= N.length * (N.length - 1);
if (SNR > 0.) {
SNR = acount / Math.sqrt(SNR / N.length);
} else {
SNR = acount;
}
} else {
SNR = 0.;
APT = 0.;
}
pstm.setDouble(1, SNR);
pstm.setDouble(2, APT);
pstm.setInt(3, N.length);
pstm.setLong(4, classId);
/*
logger.info(classId + ": SNR="+SNR+" APT="+APT
+ " N="+N.length+ " size="+acount);
*/
if (pstm.executeUpdate() > 0) {
if (++count % 1000 == 0) {
logger.info(classId + ": SNR=" + SNR + " APT=" + APT
+ " N=" + N.length + " size=" + acount);
}
}
} catch (Exception ex) {
logger.log(Level.SEVERE, "Update error", ex);
}
}
rset.close();
stm.close();
pstm.close();
pstm2.close();
con.close();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment