Skip to content

Instantly share code, notes, and snippets.

@dsehnal
Last active February 21, 2024 02:45
Show Gist options
  • Save dsehnal/b06f5555fa9145da69fe69abfeab6eaf to your computer and use it in GitHub Desktop.
Save dsehnal/b06f5555fa9145da69fe69abfeab6eaf to your computer and use it in GitHub Desktop.
BinaryCIF parser in Python
# BinaryCIF Parser
# Copyright (c) 2021 David Sehnal <david.sehnal@gmail.com>, licensed under MIT.
#
# Resources:
# - https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008247
# - https://github.com/molstar/BinaryCIF & https://github.com/molstar/BinaryCIF/blob/master/encoding.md
#
# Implementation based on Mol*:
# - https://github.com/molstar/molstar/blob/master/src/mol-io/common/binary-cif/encoding.ts
# - https://github.com/molstar/molstar/blob/master/src/mol-io/common/binary-cif/decoder.ts
# - https://github.com/molstar/molstar/blob/master/src/mol-io/reader/cif/binary/parser.ts
from typing import Any, Dict, List, Optional, TypedDict, Union
import msgpack
import numpy as np
class EncodingBase(TypedDict):
kind: str
class EncodedData(TypedDict):
encoding: List[EncodingBase]
data: bytes
class EncodedColumn(TypedDict):
name: str
data: EncodedData
mask: Optional[EncodedData]
class EncodedCategory(TypedDict):
name: str
rowCount: int
columns: List[EncodedColumn]
class EncodedDataBlock(TypedDict):
header: str
categories: List[EncodedCategory]
class EncodedFile(TypedDict):
version: str
encoder: str
dataBlocks: List[EncodedDataBlock]
def _decode(encoded_data: EncodedData) -> Union[np.ndarray, List[str]]:
result = encoded_data["data"]
for encoding in encoded_data["encoding"][::-1]:
if encoding["kind"] in _decoders:
result = _decoders[encoding["kind"]](result, encoding) # type: ignore
else:
raise ValueError(f"Unsupported encoding '{encoding['kind']}'")
return result # type: ignore
class DataTypes:
Int8 = 1
Int16 = 2
Int32 = 3
Uint8 = 4
Uint16 = 5
Uint32 = 6
Float32 = 32
Float64 = 33
_dtypes = {
DataTypes.Int8: "i1",
DataTypes.Int16: "i2",
DataTypes.Int32: "i4",
DataTypes.Uint8: "u1",
DataTypes.Uint16: "u2",
DataTypes.Uint32: "u4",
DataTypes.Float32: "f4",
DataTypes.Float64: "f8",
}
def _get_dtype(type: int) -> str:
if type in _dtypes:
return _dtypes[type]
raise ValueError(f"Unsupported data type '{type}'")
class ByteArrayEncoding(EncodingBase):
type: int
class FixedPointEncoding(EncodingBase):
factor: float
srcType: int
class IntervalQuantizationEncoding(EncodingBase):
min: float
max: float
numSteps: int
srcType: int
class RunLengthEncoding(EncodingBase):
srcType: int
srcSize: int
class DeltaEncoding(EncodingBase):
origin: int
srcType: int
class IntegerPackingEncoding(EncodingBase):
byteCount: int
isUnsigned: bool
srcSize: int
class StringArrayEncoding(EncodingBase):
dataEncoding: List[EncodingBase]
stringData: str
offsetEncoding: List[EncodingBase]
offsets: bytes
def _decode_byte_array(data: bytes, encoding: ByteArrayEncoding) -> np.ndarray:
return np.frombuffer(data, dtype="<" + _get_dtype(encoding["type"]))
def _decode_fixed_point(data: np.ndarray, encoding: FixedPointEncoding) -> np.ndarray:
return np.array(data, dtype=_get_dtype(encoding["srcType"])) / encoding["factor"]
def _decode_interval_quantization(
data: np.ndarray, encoding: IntervalQuantizationEncoding
) -> np.ndarray:
delta = (encoding["max"] - encoding["min"]) / (encoding["numSteps"] - 1)
return (
np.array(data, dtype=_get_dtype(encoding["srcType"])) * delta + encoding["min"]
)
def _decode_run_length(data: np.ndarray, encoding: RunLengthEncoding) -> np.ndarray:
return np.repeat(
np.array(data[::2], dtype=_get_dtype(encoding["srcType"])), repeats=data[1::2]
)
def _decode_delta(data: np.ndarray, encoding: DeltaEncoding) -> np.ndarray:
result = np.array(data, dtype=_get_dtype(encoding["srcType"]))
if encoding["origin"]:
result[0] += encoding["origin"]
return np.cumsum(result, out=result)
def _decode_integer_packing_signed(
data: np.ndarray, encoding: IntegerPackingEncoding
) -> np.ndarray:
upper_limit = 0x7F if encoding["byteCount"] == 1 else 0x7FFF
lower_limit = -upper_limit - 1
n = len(data)
output = np.zeros(encoding["srcSize"], dtype="i4")
i = 0
j = 0
while i < n:
value = 0
t = data[i]
while t == upper_limit or t == lower_limit:
value += t
i += 1
t = data[i]
value += t
output[j] = value
i += 1
j += 1
return output
def _decode_integer_packing_unsigned(
data: np.ndarray, encoding: IntegerPackingEncoding
) -> np.ndarray:
upper_limit = 0xFF if encoding["byteCount"] == 1 else 0xFFFF
n = len(data)
output = np.zeros(encoding["srcSize"], dtype="i4")
i = 0
j = 0
while i < n:
value = 0
t = data[i]
while t == upper_limit:
value += t
i += 1
t = data[i]
value += t
output[j] = value
i += 1
j += 1
return output
def _decode_integer_packing(
data: np.ndarray, encoding: IntegerPackingEncoding
) -> np.ndarray:
if len(data) == encoding["srcSize"]:
return data
if encoding["isUnsigned"]:
return _decode_integer_packing_unsigned(data, encoding)
else:
return _decode_integer_packing_signed(data, encoding)
def _decode_string_array(data: bytes, encoding: StringArrayEncoding) -> List[str]:
offsets = _decode(
EncodedData(encoding=encoding["offsetEncoding"], data=encoding["offsets"])
)
indices = _decode(EncodedData(encoding=encoding["dataEncoding"], data=data))
str = encoding["stringData"]
strings = [""]
for i in range(1, len(offsets)):
strings.append(str[offsets[i - 1] : offsets[i]]) # type: ignore
return [strings[i + 1] for i in indices] # type: ignore
_decoders = {
"ByteArray": _decode_byte_array,
"FixedPoint": _decode_fixed_point,
"IntervalQuantization": _decode_interval_quantization,
"RunLength": _decode_run_length,
"Delta": _decode_delta,
"IntegerPacking": _decode_integer_packing,
"StringArray": _decode_string_array,
}
##################################################################################
class CifValueKind:
Present = 0
# Expressed in CIF as `.`
NotPresent = 1
# Expressed in CIF as `?`
Unknown = 2
class CifField:
def __getitem__(self, idx: int) -> Union[str, float, int, None]:
if self._value_kinds and self._value_kinds[idx]:
return None
return self._values[idx]
def __len__(self):
return self.row_count
@property
def values(self):
"""
A numpy array of numbers or a list of strings.
"""
return self._values
@property
def value_kinds(self):
"""
value_kinds represent the presence or absence of particular "CIF value".
- If the mask is not set, every value is present:
- 0 = Value is present
- 1 = . = value not specified
- 2 = ? = value unknown
"""
return self._value_kinds
def __init__(
self,
name: str,
values: Union[np.ndarray, List[str]],
value_kinds: Optional[np.ndarray],
):
self.name = name
self._values = values
self._value_kinds = value_kinds
self.row_count = len(values)
class CifCategory:
def __getattr__(self, name: str) -> Any:
return self[name]
def __getitem__(self, name: str) -> Optional[CifField]:
if name not in self._field_cache:
return None
if not self._field_cache[name]:
self._field_cache[name] = _decode_column(self._columns[name])
return self._field_cache[name]
def __contains__(self, key: str):
return key in self._columns
def __init__(self, category: EncodedCategory, lazy: bool):
self.field_names = [c["name"] for c in category["columns"]]
self._field_cache = {
c["name"]: None if lazy else _decode_column(c) for c in category["columns"]
}
self._columns: Dict[str, EncodedColumn] = {
c["name"]: c for c in category["columns"]
}
self.row_count = category["rowCount"]
self.name = category["name"][1:]
class CifDataBlock:
def __getattr__(self, name: str) -> Any:
return self.categories[name]
def __getitem__(self, name: str) -> CifCategory:
return self.categories[name]
def __contains__(self, key: str):
return key in self.categories
def __init__(self, header: str, categories: Dict[str, CifCategory]):
self.header = header
self.categories = categories
class CifFile:
def __getitem__(self, index_or_name: Union[int, str]):
"""
Access a data block by index or header (case sensitive)
"""
if isinstance(index_or_name, str):
return (
self._block_map[index_or_name]
if index_or_name in self._block_map
else None
)
else:
return (
self.data_blocks[index_or_name]
if index_or_name < len(self.data_blocks)
else None
)
def __len__(self):
return len(self.data_blocks)
def __contains__(self, key: str):
return key in self._block_map
def __init__(self, data_blocks: List[CifDataBlock]):
self.data_blocks = data_blocks
self._block_map = {b.header: b for b in data_blocks}
def _decode_column(column: EncodedColumn) -> CifField:
values = _decode(column["data"])
value_kinds = _decode(column["mask"]) if column["mask"] else None # type: ignore
return CifField(name=column["name"], values=values, value_kinds=value_kinds) # type: ignore
def loads(data: Union[bytes, EncodedFile], lazy=True) -> CifFile:
"""
- data: msgpack encoded blob or EncodedFile object
- lazy:
- True: individual columns are decoded only when accessed
- False: decode all columns immediately
"""
file: EncodedFile = data if isinstance(data, dict) and "dataBlocks" in data else msgpack.loads(data) # type: ignore
data_blocks = [
CifDataBlock(
header=block["header"],
categories={
cat["name"][1:]: CifCategory(category=cat, lazy=lazy)
for cat in block["categories"]
},
)
for block in file["dataBlocks"]
]
return CifFile(data_blocks=data_blocks)
import msgpack
import bcif
import requests
print("mmCIF test")
data = requests.get("https://models.rcsb.org/1tqn.bcif").content
parsed = bcif.loads(data, lazy=False)
atom_site = parsed["1TQN"].atom_site
entity = parsed[0]["entity"]
label_comp_id = atom_site.label_comp_id
Cartn_x = atom_site["Cartn_x"]
print("id" in entity) # test if field is present in category
print("atom_site" in parsed[0]) # test if category is present data block
print(entity.field_names)
print(atom_site.row_count)
print(label_comp_id[0])
print(label_comp_id.values[-1])
print(len(label_comp_id.values))
print(len(label_comp_id))
print(Cartn_x.values[0:10])
print(atom_site["label_alt_id"].value_kinds[0] == bcif.CifValueKind.NotPresent)
# print([[f"_{c.name}.{f}" for f in c.field_names] for c in parsed[0].categories.values()])
print("Volume Data test")
data = requests.get(
"https://ds.litemol.org/x-ray/1tqn/box/-22.367,-33.367,-21.634/-7.106,-10.042,-0.937?detail=1"
).content
parsed = bcif.loads(msgpack.loads(data))
print([b.header for b in parsed.data_blocks])
print(parsed[1].categories.keys())
print(parsed["FO-FC"].categories.keys())
print(parsed[1]["volume_data_3d_info"].field_names)
print(parsed[1]["volume_data_3d"].row_count)
print(parsed[1]["volume_data_3d"]["values"].values[0:10])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment