Skip to content

Instantly share code, notes, and snippets.

@quietvoid
Created January 7, 2021 21:32
Show Gist options
  • Save quietvoid/175d7d16248848e2fb299f9eb1bd2c38 to your computer and use it in GitHub Desktop.
Save quietvoid/175d7d16248848e2fb299f9eb1bd2c38 to your computer and use it in GitHub Desktop.
Removes HDR10+ from an HEVC file
# -*- coding: utf-8 -*-
from argparse import RawTextHelpFormatter
from bitstring import BitStream
from ctypes import *
import threading
import platform
import argparse
import queue
import time
import sys
import os
# ######################################################################
# # Author: yusesope
# # Version: 0.0.4_beta
# # Info: https://www.makemkv.com/forum/viewtopic.php?f=12&t=18602
# # Modified by quietvoid to become a hdr10+ remover
# ######################################################################
class NalUnitType:
NAL_UNIT_CODED_SLICE_TRAIL_N = 0
NAL_UNIT_CODED_SLICE_TRAIL_R = 1
NAL_UNIT_CODED_SLICE_TSA_N = 2
NAL_UNIT_CODED_SLICE_TSA_R = 3
NAL_UNIT_CODED_SLICE_STSA_N = 4
NAL_UNIT_CODED_SLICE_STSA_R = 5
NAL_UNIT_CODED_SLICE_RADL_N = 6
NAL_UNIT_CODED_SLICE_RADL_R = 7
NAL_UNIT_CODED_SLICE_RASL_N = 8
NAL_UNIT_CODED_SLICE_RASL_R = 9
NAL_UNIT_RESERVED_VCL_N10 = 10
NAL_UNIT_RESERVED_VCL_R11 = 11
NAL_UNIT_RESERVED_VCL_N12 = 12
NAL_UNIT_RESERVED_VCL_R13 = 13
NAL_UNIT_RESERVED_VCL_N14 = 14
NAL_UNIT_RESERVED_VCL_R15 = 15
NAL_UNIT_CODED_SLICE_BLA_W_LP = 16
NAL_UNIT_CODED_SLICE_BLA_W_RADL = 17
NAL_UNIT_CODED_SLICE_BLA_N_LP = 18
NAL_UNIT_CODED_SLICE_IDR_W_RADL = 19
NAL_UNIT_CODED_SLICE_IDR_N_LP = 20
NAL_UNIT_CODED_SLICE_CRA = 21
NAL_UNIT_RESERVED_IRAP_VCL22 = 22
NAL_UNIT_RESERVED_IRAP_VCL23 = 23
NAL_UNIT_RESERVED_VCL24 = 24
NAL_UNIT_RESERVED_VCL25 = 25
NAL_UNIT_RESERVED_VCL26 = 26
NAL_UNIT_RESERVED_VCL27 = 27
NAL_UNIT_RESERVED_VCL28 = 28
NAL_UNIT_RESERVED_VCL29 = 29
NAL_UNIT_RESERVED_VCL30 = 30
NAL_UNIT_RESERVED_VCL31 = 31
NAL_UNIT_VPS = 32
NAL_UNIT_SPS = 33
NAL_UNIT_PPS = 34
NAL_UNIT_ACCESS_UNIT_DELIMITER = 35
NAL_UNIT_EOS = 36
NAL_UNIT_EOB = 37
NAL_UNIT_FILLER_DATA = 38
NAL_UNIT_PREFIX_SEI = 39
NAL_UNIT_SUFFIX_SEI = 40
NAL_UNIT_RESERVED_NVCL41 = 41
NAL_UNIT_RESERVED_NVCL42 = 42
NAL_UNIT_RESERVED_NVCL43 = 43
NAL_UNIT_RESERVED_NVCL44 = 44
NAL_UNIT_RESERVED_NVCL45 = 45
NAL_UNIT_RESERVED_NVCL46 = 46
NAL_UNIT_RESERVED_NVCL47 = 47
NAL_UNIT_UNSPECIFIED_48 = 48
NAL_UNIT_UNSPECIFIED_49 = 49
NAL_UNIT_UNSPECIFIED_50 = 50
NAL_UNIT_UNSPECIFIED_51 = 51
NAL_UNIT_UNSPECIFIED_52 = 52
NAL_UNIT_UNSPECIFIED_53 = 53
NAL_UNIT_UNSPECIFIED_54 = 54
NAL_UNIT_UNSPECIFIED_55 = 55
NAL_UNIT_UNSPECIFIED_56 = 56
NAL_UNIT_UNSPECIFIED_57 = 57
NAL_UNIT_UNSPECIFIED_58 = 58
NAL_UNIT_UNSPECIFIED_59 = 59
NAL_UNIT_UNSPECIFIED_60 = 60
NAL_UNIT_UNSPECIFIED_61 = 61
NAL_UNIT_UNSPECIFIED_62 = 62
NAL_UNIT_UNSPECIFIED_63 = 63
NAL_UNIT_INVALID = 64
class NalUnit(object):
def __init__(self):
self._offset = 0
self._global_offset = 0
self._size = 0
self._type = None
self._data = 0
self._nuh_layer_id = 0
self._payload = None
self._raw_data = None
@property
def offset(self):
return self._offset
@offset.setter
def offset(self, val):
self._offset = val
@property
def global_offset(self):
return self._global_offset
@global_offset.setter
def global_offset(self, val):
self._global_offset = val
@property
def size(self):
return self._size
@size.setter
def size(self, val):
self._size = val
@property
def type(self):
return self._type
@type.setter
def type(self, val):
self._type = val
@property
def data(self):
return self._data
@data.setter
def data(self, val):
self._data = val
@property
def nuh_layer_id(self):
return self.nuh_layer_id
@nuh_layer_id.setter
def nuh_layer_id(self, val):
self._nuh_layer_id = val
@property
def raw_data(self):
if not self._raw_data:
self._raw_data = self.remove_emulation_prevention_three_byte()
return self._raw_data
def remove_emulation_prevention_three_byte(self):
NumBytesInNaluData = int(len(self._data) / 8)
rbsp_byte = BitStream()
i = 0
while i < NumBytesInNaluData:
if (i + 2) < NumBytesInNaluData and self._data.peek('bits:24') == "0x000003":
rbsp_byte.append(self._data.read('bits:8'))
rbsp_byte.append(self._data.read('bits:8'))
emulation_prevention_three_byte = self._data.read('bits:8')
i += 3
else:
rbsp_byte.append(self._data.read('bits:8'))
i += 1
return rbsp_byte
class Reader(threading.Thread):
def __init__(
self,
file,
file_size,
):
threading.Thread.__init__(self)
self.file = file
self.file_size = file_size
self.output_queue = queue.Queue()
self.event = threading.Event()
self.event.set()
self.daemon = True
self.active = True
self.chunk_size = 1024 * 1024 * 1
self.global_offset = 0
self.skipped_data = 0
def get_dummy_nalu(self):
dummy_nalu = NalUnit()
dummy_nalu.global_offset = self.file_size
dummy_nalu.type = NalUnitType.NAL_UNIT_ACCESS_UNIT_DELIMITER
return dummy_nalu
def get_list_offsets(self, stream):
list_offsets = list(stream.findall('0x000001', bytealigned=True))
if self.global_offset + len(stream) == self.file_size:
list_offsets.append(len(stream))
return list_offsets
def parse_nalus(self,list_offsets,stream):
for index in range(0,(len(list_offsets) - 1)):
nalu = NalUnit()
nalu.offset = list_offsets[index]
nalu.global_offset = self.global_offset + nalu.offset
nalu.type = stream[(nalu.offset + 25):(nalu.offset + 31)].uint
nalu.size = list_offsets[index + 1] - nalu.offset
if stream[(nalu.offset + nalu.size - 8):(nalu.offset + nalu.size + 24)] == "0x00000001":
nalu.size -= 8
nalu.data = stream[(nalu.offset + 24):(nalu.offset + nalu.size)]
if nalu.type == NalUnitType.NAL_UNIT_PREFIX_SEI:
# Matches ITU-T T.35 SMPTE ST 2094-40
if nalu.data[:24] == "0x4E0104":
self.skipped_data += nalu.size
continue
self.output_queue.put(nalu)
if self.global_offset + len(stream) == self.file_size:
self.output_queue.put(self.get_dummy_nalu())
return True
def read(self,stream):
list_offsets = self.get_list_offsets(stream)
self.parse_nalus(list_offsets,stream)
latest_valid_offset = list_offsets.pop(-1)
return latest_valid_offset
def chunks(self, file_obj):
while True:
data = file_obj.read(self.chunk_size)
if not data:
self.active = False
break
yield data
self.event.wait()
def run(self):
with open(self.file,"rb") as i_f:
unprocessed_data = BitStream("")
while True:
for data in self.chunks(i_f):
stream = BitStream(bytes=data)
stream.prepend(unprocessed_data)
latest_valid_offset = self.read(stream)
unprocessed_data = stream[latest_valid_offset:]
self.global_offset += (len(stream) - len(unprocessed_data))
if not self.active:
break
class AccessUnitAnalysis(threading.Thread):
def __init__(
self,
input_queue,
):
threading.Thread.__init__(self)
self.input_queue = input_queue
self.output_queue = queue.Queue()
self.event = threading.Event()
self.event.set()
self.daemon = True
self.active = True
self.slice_segment_address_length = None
self.total = 0
self._buffer = []
self._starting_nalus = [
NalUnitType.NAL_UNIT_ACCESS_UNIT_DELIMITER,
NalUnitType.NAL_UNIT_VPS,
NalUnitType.NAL_UNIT_SPS,
NalUnitType.NAL_UNIT_PPS,
NalUnitType.NAL_UNIT_PREFIX_SEI
]
def reset_buffer(self,nalu=None):
self._buffer = []
if nalu:
self._buffer.append(nalu)
def add_nalu(self, nalu):
if nalu.type in self._starting_nalus:
if self._buffer and self._buffer[-1].type not in self._starting_nalus:
self.output_queue.put(self._buffer)
self.reset_buffer()
self.total += 1
self._buffer.append(nalu)
def run(self):
while True:
nalu = self.input_queue.get()
self.add_nalu(nalu)
self.event.wait()
class Writer(threading.Thread):
def __init__(
self,
out_file,
out_file_size,
input_queue,
skipped_data,
):
threading.Thread.__init__(self)
self.out_file = out_file
self.out_file_size = out_file_size
self.input_queue = input_queue
self.daemon = True
self.active = True
self.written_data = 0
self.cpp_lib = None
self.skipped_data = skipped_data
def update_written_data(self, nalu):
self.written_data = nalu.global_offset + nalu.size
def write(self,nalu,out_file_handler):
nalu.data.prepend("0x00000001")
nalu.data.tofile(out_file_handler)
def run(self):
with open(self.out_file, 'wb') as out_f:
while True:
access_unit = self.input_queue.get()
if access_unit:
for nalu in access_unit:
out_file_handler = out_f
self.write(nalu,out_file_handler)
self.update_written_data(nalu)
if (self.written_data + self.skipped_data) == self.out_file_size:
self.active = False
break
def progress(list_state,string=None):
display = list()
for state in list_state:
_type = state[0]
current = state[1]
total = state[2]
percentage = round(100.0 * current / float(total), 1)
display.append("{}: {}".format(_type,(str(percentage) + "%").ljust(8," ")))
if string:
display.append("({})".format(string))
sys.stdout.write("{}\r".format("".join(display)))
sys.stdout.flush()
return True
def format_seconds_to_hhmmss(seconds):
hours = seconds // (60 * 60)
seconds %= (60 * 60)
minutes = seconds // 60
seconds %= 60
if hours != 0:
formatted_string = "{:02d}h {:02d}m {:02d}s".format(
int(hours),
int(minutes),
int(seconds)
)
else:
formatted_string = "{:02d}m {:02d}s".format(
int(minutes),
int(seconds)
)
return formatted_string
def main():
parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter)
parser.add_argument('-i', '--input', required=True, help='input hevc file')
parser.add_argument('-o', '--output', required=False, help='output hevc file', default='processed_file.hevc')
parser.add_argument(
'-lbf',
dest='LOAD_BALANCING_FACTOR',
type=int,
default=500,
)
args = parser.parse_args()
start = time.time()
in_file = args.input
in_file_size = os.path.getsize(in_file) * 8
reader = Reader(in_file,in_file_size)
reader.start()
skipped_data = reader.skipped_data
au_analysis = AccessUnitAnalysis(reader.output_queue)
au_analysis.start()
sys.stdout.write("\nWait...\r")
sys.stdout.flush()
while au_analysis.output_queue.qsize() < int(args.LOAD_BALANCING_FACTOR / 2):
time.sleep(0.2)
writer = Writer(
args.output,
in_file_size,
au_analysis.output_queue,
skipped_data,
)
writer.start()
while writer.active:
for qe in [
(au_analysis.output_queue.qsize(),reader.event),
]:
if qe[0] > args.LOAD_BALANCING_FACTOR:
if qe[1].is_set():
qe[1].clear()
elif qe[0] < int(args.LOAD_BALANCING_FACTOR / 2):
if not qe[1].is_set():
qe[1].set()
progress(
[
["PROGRESS",writer.written_data,in_file_size]
]
)
progress(
[
["PROGRESS",writer.written_data,in_file_size]
]
)
writer.join()
print ("\n\nELAPSED TIME: {}".format(format_seconds_to_hhmmss(time.time() - start)))
if __name__ == '__main__':
main()
end = input("Press a button to exit")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment