Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Some handy utils for messing with MXCSR (x86-64 SSE FPU control register)
#!/usr/bin/env python
import sys, os
import platform
import ctypes as ct
import mmap
from enum import Enum
import importlib
import functools
import errno
COLOR_RED = "\033[31m"
COLOR_GREEN = "\033[32m"
COLOR_RESET = "\033[0m"
# Ensure we're on x86_64
if platform.machine() != 'x86_64' or sys.maxsize <= 2**32:
raise RuntimeError("This module only works on x86_64")
# Set up a RWX buffer so we can put some assembly into it
_code_buf = mmap.mmap(-1, mmap.PAGESIZE, prot=mmap.PROT_READ | mmap.PROT_WRITE)
_set_mxcsr_asm = (
b"\x0F\xAE\x17" # ldmxcsr [rdi]
b"\xc3" # ret
b"\x90" * 4 # padding
)
_get_mxcsr_asm = (
b"\x0F\xAE\x1F" # stmxcsr [rdi]
b"\xc3" # ret
b"\x90" * 4 # padding
)
# Copy the assembly into the buffer
_code_buf_addr = ct.addressof(ct.c_void_p.from_buffer(_code_buf))
_code_buf.write(_set_mxcsr_asm)
_set_mxcsr_addr = _code_buf_addr
_code_buf.write(_get_mxcsr_asm)
_get_mxcsr_addr = _set_mxcsr_addr+len(_set_mxcsr_asm)
# Make our code buffer read-only after we're done with it.
mprotect = ct.CDLL(None, use_errno=True).mprotect
mprotect.argtypes = [ct.c_void_p, ct.c_size_t, ct.c_int]
mprotect.restype = ct.c_int
if mprotect(_code_buf_addr, mmap.PAGESIZE, mmap.PROT_READ | mmap.PROT_EXEC) != 0:
e = ct.get_errno()
raise OSError("mprotect: " + errno.errorcode[e] + f" ({os.strerror(e)})")
##############################################################################################################################
# Bits of the MXCSR register. Diagram was # +----+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+ #
# converted to ASCII-art from Figure 10-3 # | 15 | 14 13 | 12 | 11 | 10 | 9 | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 | #
# in the Intel 64 and IA-32 Architectures # +----+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+ #
# Software Developer's Manual, Volume 1. # FZ RC PM UM OM ZM DM IM DAZ PE UE OE ZE DE IE #
# ######################################### | | | | | | | | | | | | | | | #
# Flush to Zero -------------------------------' | | | | | | | | | | | | | | #
# Rounding Control -----------------------------------' | | | | | | | | | | | | | #
# Precision Mask --------------------------------------------' | | | | | | | | | | | | #
# Underflow Mask -------------------------------------------------' | | | | | | | | | | | #
# Overflow Mask -------------------------------------------------------' | | | | | | | | | | #
# Divide-by-Zero Mask ------------------------------------------------------' | | | | | | | | | #
# Denormal Operation Mask -------------------------------------------------------' | | | | | | | | #
# Invalid Operation Mask -------------------------------------------------------------' | | | | | | | #
# Denormals Are Zeros ---------------------------------------------------------------------' | | | | | | #
# Precision Flag -------------------------------------------------------------------------------' | | | | | #
# Underflow Flag ------------------------------------------------------------------------------------' | | | | #
# Overflow Flag ------------------------------------------------------------------------------------------' | | | #
# Divide-by-Zero Flag -----------------------------------------------------------------------------------------' | | #
# Denormal Flag ----------------------------------------------------------------------------------------------------' | #
# Invalid Operation Flag ------------------------------------------------------------------------------------------------' #
##############################################################################################################################
class MXCSR_bits(ct.LittleEndianStructure):
_fields_ = [
("IE", ct.c_uint32, 1),
("DE", ct.c_uint32, 1),
("ZE", ct.c_uint32, 1),
("OE", ct.c_uint32, 1),
("UE", ct.c_uint32, 1),
("PE", ct.c_uint32, 1),
("DAZ", ct.c_uint32, 1),
("IM", ct.c_uint32, 1),
("DM", ct.c_uint32, 1),
("ZM", ct.c_uint32, 1),
("OM", ct.c_uint32, 1),
("UM", ct.c_uint32, 1),
("PM", ct.c_uint32, 1),
("RC", ct.c_uint32, 2),
("FZ", ct.c_uint32, 1),
("reserved", ct.c_uint32, 16),
]
class RoundingModes(Enum):
RoundToNearest = 0
RoundDown = 1
RoundUp = 2
RoundTowardsZero = 3
def short(self):
return ["RN", "RD", "RU", "RZ"][self.value]
full_names = [
"Invalid Operation Flag",
"Denormal Flag",
"Divide-by-Zero Flag",
"Overflow Flag",
"Underflow Flag",
"Precision Flag",
"Denormals Are Zeros",
"Invalid Operation Mask",
"Denormal Operation Mask",
"Divide-by-Zero Mask",
"Overflow Mask",
"Underflow Mask",
"Precision Mask",
"Rounding Control",
"Flush to Zero",
]
# Draw FZ and DAZ in red if they are set and output is a terminal
@staticmethod
def _colorize(s, name, value):
if sys.stdout.isatty() and name in ["FZ", "DAZ"] and value:
return COLOR_RED + s + COLOR_RESET
return s
# Custom __str__ method to print MXCSR register
def __str__(self):
bits_set = [(nm if getattr(self, nm) else ' '*len(nm))
for nm,_,_ in self._fields_ if nm not in ['RC', 'reserved']]
bits_set = [self._colorize(s, s, 1) for s in bits_set]
bits_set += [f"RC={self.RoundingModes(self.RC).short()}"]
bits_str = ",".join(bits_set)
return f"MXCSR({bits_str})"
__repr__ = __str__
# Custom __setattr__ method to prevent setting reserved bits
def __setattr__(self, name, value):
if name == "reserved":
raise ValueError("Cannot set reserved bits")
super().__setattr__(name, value)
# A more verbose description
def describe(self):
MAX_NAME_LEN = max([len(n) for n in self.full_names])+1
s = "MXCSR register:"
for field, full_name in zip(self._fields_, self.full_names):
name = field[0]
full_name_col = full_name.ljust(MAX_NAME_LEN)
value = getattr(self, name)
if ('Flag' in full_name or 'Mask' in full_name or
full_name == 'Flush to Zero' or full_name == 'Denormals Are Zeros'):
value_s = 'Set' if value else 'Clear'
elif 'Rounding Control' in full_name:
value_s = self.RoundingModes(value).name
else:
raise ValueError(f"You forgot one: {full_name}")
s += self._colorize(f"\n {full_name_col}: {value_s}", name, value)
return s
class MXCSR(ct.Union):
_fields_ = [
("bits", MXCSR_bits),
("value", ct.c_uint32),
]
# Convenience function to get power-on MXCSR value
RESET_VALUE = 0x1f80
@staticmethod
def initial():
return MXCSR(value=MXCSR.RESET_VALUE)
# Convenience function to create MXCSR from a dictionary
@staticmethod
def from_dict(d):
mxcsr = MXCSR()
for k, v in d.items():
setattr(mxcsr.bits, k, v)
return mxcsr
def __repr__(self):
return f"MXCSR({self.value:#x})"
def __str__(self):
return str(self.bits)
def describe(self):
return self.bits.describe()
_set_mxcsr = ct.CFUNCTYPE(None, ct.POINTER(MXCSR))(_set_mxcsr_addr)
def set_mxcsr(val: MXCSR):
_set_mxcsr(ct.byref(val))
_get_mxcsr = ct.CFUNCTYPE(None, ct.POINTER(MXCSR))(_get_mxcsr_addr)
def get_mxcsr() -> MXCSR:
mxcsr = MXCSR()
_get_mxcsr(ct.byref(mxcsr))
return mxcsr
def ensure_clean_fpu_state(function):
@functools.wraps(function)
def decorator(*args, **kwargs):
old_mxcsr = get_mxcsr()
set_mxcsr(MXCSR.initial())
try:
return function(*args, **kwargs)
finally:
set_mxcsr(old_mxcsr)
return decorator
# Small demo. numpy's finfo will yell loudly if the FZ or DAZ bits are set.
def decorator_demo():
import numpy as np
@ensure_clean_fpu_state
def tricky_numerical_operation_safe():
np.finfo(np.float32)
def tricky_numerical_operation_unsafe():
np.finfo(np.float32)
print(f"MXCSR at power on: {MXCSR.initial()}")
print(f"MXCSR now: {get_mxcsr()}")
print("Importing gevent, which uses ffast-math...")
import gevent
print(f"MXCSR after import: {get_mxcsr()}")
print("Running np.finfo(np.float32) without FPU wrapper (you should see warnings):")
print(COLOR_RED, end='')
tricky_numerical_operation_unsafe()
print(COLOR_RESET, end='')
print("Running np.finfo(np.float32) with FPU wrapper (no warnings):")
tricky_numerical_operation_safe()
print(COLOR_GREEN+"All done!"+COLOR_RESET)
if __name__ == "__main__":
decorator_demo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment