Created
March 26, 2022 18:22
-
-
Save chronos-tachyon/1b78becf3d07dd659b064915033ca39c to your computer and use it in GitHub Desktop.
Script to fix Unix ownership and permissions on a multi-user writable public directory.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# Written by Donald King <chronos@chronos-tachyon.net> | |
# Public Domain. | |
# | |
# ==== CC0 https://creativecommons.org/publicdomain/zero/1.0/ ==== | |
# [To the extent possible under law, I have waived all copyright ] | |
# [and related or neighboring rights to this script. This work is] | |
# [published from the United States of America. ] | |
import argparse | |
import contextlib | |
import datetime | |
import hashlib | |
import math | |
import os | |
import pwd | |
import stat | |
import sys | |
import time | |
import humanize | |
import pytz | |
BLOCK_SIZE = (1 << 20) | |
DIGEST_THRESHOLD = 0 | |
PUB = pwd.getpwnam('pub') | |
UTC = pytz.utc | |
LOCAL_TZ = pytz.timezone('America/Los_Angeles') | |
USER_MD5SUM = 'user.md5sum' | |
USER_SHA1SUM = 'user.sha1sum' | |
USER_SHA256SUM = 'user.sha256sum' | |
USER_MTIME = 'user.mtime' | |
USER_SIZE = 'user.size' | |
LAUNCH_TIME = UTC.localize(datetime.datetime.utcnow()) | |
def raise_it(e): | |
raise e | |
def format_int(it): | |
return '{:d}'.format(it) | |
def format_bytes(it, rate=False): | |
suffix = ('/s' if rate else '') | |
return humanize.naturalsize(it, binary=True) + suffix | |
def format_datetime(it): | |
return it.strftime('%Y-%m-%d %I:%M:%S%p %Z %z') | |
def format_elapsed(seconds): | |
sign = '' | |
if seconds < 0: | |
sign = '-' | |
seconds = -seconds | |
seconds = float(seconds) | |
hours = int(math.floor(seconds / 3600.0)) | |
seconds -= hours * 3600 | |
minutes = int(math.floor(seconds / 60.0)) | |
seconds -= minutes * 60 | |
if hours != 0: | |
return '{}{:d}h{:02d}m{:06.3f}s'.format(sign, hours, minutes, seconds) | |
if minutes != 0: | |
return '{}{:d}m{:06.3f}s'.format(sign, minutes, seconds) | |
return '{}{:.3f}s'.format(sign, seconds) | |
class Printer(contextlib.AbstractContextManager): | |
def __init__(self, path): | |
self.path = path | |
self.printed_header = False | |
def __exit__(self, exc_type, exc_value, traceback): | |
if self.printed_header: | |
print(file=ARGV.logfile, flush=True) | |
def print(self, fmt, *args, **kwargs): | |
if not self.printed_header: | |
now = UTC.localize(datetime.datetime.utcnow()).astimezone(LOCAL_TZ) | |
formatted = '[{}]\npath: {!r}'.format(format_datetime(now), self.path) | |
print(formatted, file=ARGV.logfile, flush=True) | |
self.printed_header = True | |
formatted = fmt.format(*args, **kwargs) if args or kwargs else fmt | |
print(formatted, file=ARGV.logfile, flush=True) | |
def process_item(path): | |
try: | |
st = os.lstat(path) | |
except FileNotFoundError: | |
return | |
with Printer(path) as printer: | |
expect_uid = PUB.pw_uid | |
expect_gid = PUB.pw_gid | |
if stat.S_ISDIR(st.st_mode): | |
expect_mode = 0o02775 | |
elif stat.S_ISREG(st.st_mode) and (st.st_mode & 0o0100) != 0: | |
expect_mode = 0o0775 | |
elif stat.S_ISREG(st.st_mode): | |
expect_mode = 0o0664 | |
else: | |
expect_mode = None | |
if st.st_uid != expect_uid or st.st_gid != expect_gid: | |
printer.print('chown:\n\told={}:{}\n\tnew={}:{}', st.st_uid, st.st_gid, expect_uid, expect_gid) | |
os.lchown(path, expect_uid, expect_gid) | |
if expect_mode is not None: | |
masked_mode = (st.st_mode & 0o7777) | |
if masked_mode != expect_mode: | |
printer.print('chmod:\n\told={:04o}\n\tnew={:04o}', masked_mode, expect_mode) | |
os.chmod(path, expect_mode) | |
if stat.S_ISREG(st.st_mode) and st.st_size >= DIGEST_THRESHOLD: | |
actual_mtime = int(round(st.st_mtime * 1000)) | |
actual_mtime_datetime = UTC.localize(datetime.datetime.utcfromtimestamp(st.st_mtime)) | |
actual_mtime_datetime = actual_mtime_datetime.astimezone(LOCAL_TZ) | |
actual_size = st.st_size | |
if actual_mtime_datetime > LAUNCH_TIME: | |
return | |
with open(path, 'rb') as fp: | |
os.posix_fadvise(fp.fileno(), 0, actual_size, os.POSIX_FADV_SEQUENTIAL) | |
attrs = os.listxattr(fp.fileno()) | |
have_mtime = (USER_MTIME in attrs) | |
have_size = (USER_SIZE in attrs) | |
have_md5 = (USER_MD5SUM in attrs) | |
have_sha1 = (USER_SHA1SUM in attrs) | |
have_sha256 = (USER_SHA256SUM in attrs) | |
if have_mtime and have_size and have_md5 and have_sha1 and have_sha256: | |
stored_mtime = int(os.getxattr(fp.fileno(), USER_MTIME).decode('ascii')) | |
stored_size = int(os.getxattr(fp.fileno(), USER_SIZE).decode('ascii')) | |
if stored_mtime == actual_mtime and stored_size == actual_size: | |
return | |
args = [ | |
format_int(actual_mtime), | |
format_datetime(actual_mtime_datetime), | |
format_int(actual_size), | |
format_bytes(actual_size), | |
] | |
printer.print('cksum:\n\tmtime={!r} ({})\n\tsize={!r} ({})', *args) | |
h0 = hashlib.md5() | |
h1 = hashlib.sha1() | |
h2 = hashlib.sha256() | |
start_time = time.clock_gettime(time.CLOCK_MONOTONIC) | |
measured_size = 0 | |
while True: | |
block = fp.read(BLOCK_SIZE) | |
if not block: | |
break | |
h0.update(block) | |
h1.update(block) | |
h2.update(block) | |
measured_size += len(block) | |
if measured_size != actual_size: | |
args = [ | |
actual_size, | |
measured_size, | |
] | |
printer.print('\terror: size mismatch: was {:d} bytes, now {:d} bytes', *args) | |
return | |
end_time = time.clock_gettime(time.CLOCK_MONOTONIC) | |
elapsed = (end_time - start_time) | |
rate = measured_size / elapsed | |
s0 = h0.hexdigest() | |
s1 = h1.hexdigest() | |
s2 = h2.hexdigest() | |
args = [ | |
s0, | |
s1, | |
s2, | |
format_elapsed(elapsed), | |
format_bytes(rate, rate=True), | |
] | |
printer.print('\tmd5={!r}\n\tsha1={!r}\n\tsha256={!r}\n\telapsed: {} ({})', *args) | |
os.setxattr(fp.fileno(), USER_MTIME, format_int(actual_mtime).encode('ascii')) | |
os.setxattr(fp.fileno(), USER_SIZE, format_int(actual_size).encode('ascii')) | |
os.setxattr(fp.fileno(), USER_MD5SUM, s0.encode('ascii')) | |
os.setxattr(fp.fileno(), USER_SHA1SUM, s1.encode('ascii')) | |
os.setxattr(fp.fileno(), USER_SHA256SUM, s2.encode('ascii')) | |
def process_root(root): | |
for path, dirChildren, fileChildren in os.walk(root, onerror=raise_it): | |
if os.path.basename(path) == 'lost+found': | |
dirChildren[:] = () | |
fileChildren[:] = () | |
continue | |
dirChildren.sort() | |
fileChildren.sort() | |
process_item(path) | |
for child in fileChildren: | |
process_item(os.path.join(path, child)) | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'-l', | |
'--logfile', | |
metavar='FILE', | |
type=argparse.FileType('at', bufsize=1), | |
default=sys.stdout, | |
help='path to logfile to append to; default /dev/stdout', | |
) | |
parser.add_argument( | |
'-r', | |
'--root', | |
metavar='DIR', | |
required=True, | |
default=[], | |
action='append', | |
help='a filesystem root to fix; repeatable', | |
) | |
ARGV = parser.parse_args() | |
try: | |
for root in ARGV.root: | |
process_root(root) | |
except KeyboardInterrupt: | |
sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment