Last active
February 21, 2024 02:45
-
-
Save dsehnal/b06f5555fa9145da69fe69abfeab6eaf to your computer and use it in GitHub Desktop.
BinaryCIF parser in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# BinaryCIF Parser | |
# Copyright (c) 2021 David Sehnal <david.sehnal@gmail.com>, licensed under MIT. | |
# | |
# Resources: | |
# - https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008247 | |
# - https://github.com/molstar/BinaryCIF & https://github.com/molstar/BinaryCIF/blob/master/encoding.md | |
# | |
# Implementation based on Mol*: | |
# - https://github.com/molstar/molstar/blob/master/src/mol-io/common/binary-cif/encoding.ts | |
# - https://github.com/molstar/molstar/blob/master/src/mol-io/common/binary-cif/decoder.ts | |
# - https://github.com/molstar/molstar/blob/master/src/mol-io/reader/cif/binary/parser.ts | |
from typing import Any, Dict, List, Optional, TypedDict, Union | |
import msgpack | |
import numpy as np | |
class EncodingBase(TypedDict): | |
kind: str | |
class EncodedData(TypedDict): | |
encoding: List[EncodingBase] | |
data: bytes | |
class EncodedColumn(TypedDict): | |
name: str | |
data: EncodedData | |
mask: Optional[EncodedData] | |
class EncodedCategory(TypedDict): | |
name: str | |
rowCount: int | |
columns: List[EncodedColumn] | |
class EncodedDataBlock(TypedDict): | |
header: str | |
categories: List[EncodedCategory] | |
class EncodedFile(TypedDict): | |
version: str | |
encoder: str | |
dataBlocks: List[EncodedDataBlock] | |
def _decode(encoded_data: EncodedData) -> Union[np.ndarray, List[str]]: | |
result = encoded_data["data"] | |
for encoding in encoded_data["encoding"][::-1]: | |
if encoding["kind"] in _decoders: | |
result = _decoders[encoding["kind"]](result, encoding) # type: ignore | |
else: | |
raise ValueError(f"Unsupported encoding '{encoding['kind']}'") | |
return result # type: ignore | |
class DataTypes: | |
Int8 = 1 | |
Int16 = 2 | |
Int32 = 3 | |
Uint8 = 4 | |
Uint16 = 5 | |
Uint32 = 6 | |
Float32 = 32 | |
Float64 = 33 | |
_dtypes = { | |
DataTypes.Int8: "i1", | |
DataTypes.Int16: "i2", | |
DataTypes.Int32: "i4", | |
DataTypes.Uint8: "u1", | |
DataTypes.Uint16: "u2", | |
DataTypes.Uint32: "u4", | |
DataTypes.Float32: "f4", | |
DataTypes.Float64: "f8", | |
} | |
def _get_dtype(type: int) -> str: | |
if type in _dtypes: | |
return _dtypes[type] | |
raise ValueError(f"Unsupported data type '{type}'") | |
class ByteArrayEncoding(EncodingBase): | |
type: int | |
class FixedPointEncoding(EncodingBase): | |
factor: float | |
srcType: int | |
class IntervalQuantizationEncoding(EncodingBase): | |
min: float | |
max: float | |
numSteps: int | |
srcType: int | |
class RunLengthEncoding(EncodingBase): | |
srcType: int | |
srcSize: int | |
class DeltaEncoding(EncodingBase): | |
origin: int | |
srcType: int | |
class IntegerPackingEncoding(EncodingBase): | |
byteCount: int | |
isUnsigned: bool | |
srcSize: int | |
class StringArrayEncoding(EncodingBase): | |
dataEncoding: List[EncodingBase] | |
stringData: str | |
offsetEncoding: List[EncodingBase] | |
offsets: bytes | |
def _decode_byte_array(data: bytes, encoding: ByteArrayEncoding) -> np.ndarray: | |
return np.frombuffer(data, dtype="<" + _get_dtype(encoding["type"])) | |
def _decode_fixed_point(data: np.ndarray, encoding: FixedPointEncoding) -> np.ndarray: | |
return np.array(data, dtype=_get_dtype(encoding["srcType"])) / encoding["factor"] | |
def _decode_interval_quantization( | |
data: np.ndarray, encoding: IntervalQuantizationEncoding | |
) -> np.ndarray: | |
delta = (encoding["max"] - encoding["min"]) / (encoding["numSteps"] - 1) | |
return ( | |
np.array(data, dtype=_get_dtype(encoding["srcType"])) * delta + encoding["min"] | |
) | |
def _decode_run_length(data: np.ndarray, encoding: RunLengthEncoding) -> np.ndarray: | |
return np.repeat( | |
np.array(data[::2], dtype=_get_dtype(encoding["srcType"])), repeats=data[1::2] | |
) | |
def _decode_delta(data: np.ndarray, encoding: DeltaEncoding) -> np.ndarray: | |
result = np.array(data, dtype=_get_dtype(encoding["srcType"])) | |
if encoding["origin"]: | |
result[0] += encoding["origin"] | |
return np.cumsum(result, out=result) | |
def _decode_integer_packing_signed( | |
data: np.ndarray, encoding: IntegerPackingEncoding | |
) -> np.ndarray: | |
upper_limit = 0x7F if encoding["byteCount"] == 1 else 0x7FFF | |
lower_limit = -upper_limit - 1 | |
n = len(data) | |
output = np.zeros(encoding["srcSize"], dtype="i4") | |
i = 0 | |
j = 0 | |
while i < n: | |
value = 0 | |
t = data[i] | |
while t == upper_limit or t == lower_limit: | |
value += t | |
i += 1 | |
t = data[i] | |
value += t | |
output[j] = value | |
i += 1 | |
j += 1 | |
return output | |
def _decode_integer_packing_unsigned( | |
data: np.ndarray, encoding: IntegerPackingEncoding | |
) -> np.ndarray: | |
upper_limit = 0xFF if encoding["byteCount"] == 1 else 0xFFFF | |
n = len(data) | |
output = np.zeros(encoding["srcSize"], dtype="i4") | |
i = 0 | |
j = 0 | |
while i < n: | |
value = 0 | |
t = data[i] | |
while t == upper_limit: | |
value += t | |
i += 1 | |
t = data[i] | |
value += t | |
output[j] = value | |
i += 1 | |
j += 1 | |
return output | |
def _decode_integer_packing( | |
data: np.ndarray, encoding: IntegerPackingEncoding | |
) -> np.ndarray: | |
if len(data) == encoding["srcSize"]: | |
return data | |
if encoding["isUnsigned"]: | |
return _decode_integer_packing_unsigned(data, encoding) | |
else: | |
return _decode_integer_packing_signed(data, encoding) | |
def _decode_string_array(data: bytes, encoding: StringArrayEncoding) -> List[str]: | |
offsets = _decode( | |
EncodedData(encoding=encoding["offsetEncoding"], data=encoding["offsets"]) | |
) | |
indices = _decode(EncodedData(encoding=encoding["dataEncoding"], data=data)) | |
str = encoding["stringData"] | |
strings = [""] | |
for i in range(1, len(offsets)): | |
strings.append(str[offsets[i - 1] : offsets[i]]) # type: ignore | |
return [strings[i + 1] for i in indices] # type: ignore | |
_decoders = { | |
"ByteArray": _decode_byte_array, | |
"FixedPoint": _decode_fixed_point, | |
"IntervalQuantization": _decode_interval_quantization, | |
"RunLength": _decode_run_length, | |
"Delta": _decode_delta, | |
"IntegerPacking": _decode_integer_packing, | |
"StringArray": _decode_string_array, | |
} | |
################################################################################## | |
class CifValueKind: | |
Present = 0 | |
# Expressed in CIF as `.` | |
NotPresent = 1 | |
# Expressed in CIF as `?` | |
Unknown = 2 | |
class CifField: | |
def __getitem__(self, idx: int) -> Union[str, float, int, None]: | |
if self._value_kinds and self._value_kinds[idx]: | |
return None | |
return self._values[idx] | |
def __len__(self): | |
return self.row_count | |
@property | |
def values(self): | |
""" | |
A numpy array of numbers or a list of strings. | |
""" | |
return self._values | |
@property | |
def value_kinds(self): | |
""" | |
value_kinds represent the presence or absence of particular "CIF value". | |
- If the mask is not set, every value is present: | |
- 0 = Value is present | |
- 1 = . = value not specified | |
- 2 = ? = value unknown | |
""" | |
return self._value_kinds | |
def __init__( | |
self, | |
name: str, | |
values: Union[np.ndarray, List[str]], | |
value_kinds: Optional[np.ndarray], | |
): | |
self.name = name | |
self._values = values | |
self._value_kinds = value_kinds | |
self.row_count = len(values) | |
class CifCategory: | |
def __getattr__(self, name: str) -> Any: | |
return self[name] | |
def __getitem__(self, name: str) -> Optional[CifField]: | |
if name not in self._field_cache: | |
return None | |
if not self._field_cache[name]: | |
self._field_cache[name] = _decode_column(self._columns[name]) | |
return self._field_cache[name] | |
def __contains__(self, key: str): | |
return key in self._columns | |
def __init__(self, category: EncodedCategory, lazy: bool): | |
self.field_names = [c["name"] for c in category["columns"]] | |
self._field_cache = { | |
c["name"]: None if lazy else _decode_column(c) for c in category["columns"] | |
} | |
self._columns: Dict[str, EncodedColumn] = { | |
c["name"]: c for c in category["columns"] | |
} | |
self.row_count = category["rowCount"] | |
self.name = category["name"][1:] | |
class CifDataBlock: | |
def __getattr__(self, name: str) -> Any: | |
return self.categories[name] | |
def __getitem__(self, name: str) -> CifCategory: | |
return self.categories[name] | |
def __contains__(self, key: str): | |
return key in self.categories | |
def __init__(self, header: str, categories: Dict[str, CifCategory]): | |
self.header = header | |
self.categories = categories | |
class CifFile: | |
def __getitem__(self, index_or_name: Union[int, str]): | |
""" | |
Access a data block by index or header (case sensitive) | |
""" | |
if isinstance(index_or_name, str): | |
return ( | |
self._block_map[index_or_name] | |
if index_or_name in self._block_map | |
else None | |
) | |
else: | |
return ( | |
self.data_blocks[index_or_name] | |
if index_or_name < len(self.data_blocks) | |
else None | |
) | |
def __len__(self): | |
return len(self.data_blocks) | |
def __contains__(self, key: str): | |
return key in self._block_map | |
def __init__(self, data_blocks: List[CifDataBlock]): | |
self.data_blocks = data_blocks | |
self._block_map = {b.header: b for b in data_blocks} | |
def _decode_column(column: EncodedColumn) -> CifField: | |
values = _decode(column["data"]) | |
value_kinds = _decode(column["mask"]) if column["mask"] else None # type: ignore | |
return CifField(name=column["name"], values=values, value_kinds=value_kinds) # type: ignore | |
def loads(data: Union[bytes, EncodedFile], lazy=True) -> CifFile: | |
""" | |
- data: msgpack encoded blob or EncodedFile object | |
- lazy: | |
- True: individual columns are decoded only when accessed | |
- False: decode all columns immediately | |
""" | |
file: EncodedFile = data if isinstance(data, dict) and "dataBlocks" in data else msgpack.loads(data) # type: ignore | |
data_blocks = [ | |
CifDataBlock( | |
header=block["header"], | |
categories={ | |
cat["name"][1:]: CifCategory(category=cat, lazy=lazy) | |
for cat in block["categories"] | |
}, | |
) | |
for block in file["dataBlocks"] | |
] | |
return CifFile(data_blocks=data_blocks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import msgpack | |
import bcif | |
import requests | |
print("mmCIF test") | |
data = requests.get("https://models.rcsb.org/1tqn.bcif").content | |
parsed = bcif.loads(data, lazy=False) | |
atom_site = parsed["1TQN"].atom_site | |
entity = parsed[0]["entity"] | |
label_comp_id = atom_site.label_comp_id | |
Cartn_x = atom_site["Cartn_x"] | |
print("id" in entity) # test if field is present in category | |
print("atom_site" in parsed[0]) # test if category is present data block | |
print(entity.field_names) | |
print(atom_site.row_count) | |
print(label_comp_id[0]) | |
print(label_comp_id.values[-1]) | |
print(len(label_comp_id.values)) | |
print(len(label_comp_id)) | |
print(Cartn_x.values[0:10]) | |
print(atom_site["label_alt_id"].value_kinds[0] == bcif.CifValueKind.NotPresent) | |
# print([[f"_{c.name}.{f}" for f in c.field_names] for c in parsed[0].categories.values()]) | |
print("Volume Data test") | |
data = requests.get( | |
"https://ds.litemol.org/x-ray/1tqn/box/-22.367,-33.367,-21.634/-7.106,-10.042,-0.937?detail=1" | |
).content | |
parsed = bcif.loads(msgpack.loads(data)) | |
print([b.header for b in parsed.data_blocks]) | |
print(parsed[1].categories.keys()) | |
print(parsed["FO-FC"].categories.keys()) | |
print(parsed[1]["volume_data_3d_info"].field_names) | |
print(parsed[1]["volume_data_3d"].row_count) | |
print(parsed[1]["volume_data_3d"]["values"].values[0:10]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment