Skip to content

Instantly share code, notes, and snippets.

@SpotlightKid
Created June 25, 2020 15:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save SpotlightKid/bf2dfa1bb80d844b9c8a71e08fae7f56 to your computer and use it in GitHub Desktop.
Save SpotlightKid/bf2dfa1bb80d844b9c8a71e08fae7f56 to your computer and use it in GitHub Desktop.
Python WAV file decoder
# -*- coding: utf-8 -*-
__all_ = [
'Error',
'FmtChunk',
'ParsingError',
'SmplChunk',
'UnsupportedCompressionError',
'WavChunk',
'WavFile'
]
import logging
import struct
from chunk import Chunk
try:
basestring
except NameError:
basestring = str
# module globals
log = logging.getLogger(__name__)
WAVE_FORMAT_PCM = 0x0001
KNOWN_CHUNKS = [
b'cue ',
b'data',
b'fact',
b'fmt ',
b'inst',
b'list',
b'plst',
b'smpl',
b'wavl',
]
# sub-chunks of 'list'
# 'ltxt',
# 'note',
# 'labl',
FORMAT_TAGS = {
0: 'Unknown',
1: 'PCM/uncompressed',
2: 'Microsoft ADPCM',
6: 'ITU G.711 a-law',
7: 'ITU G.711 u-law',
17: 'IMA ADPCM',
20: 'ITU G.723 ADPCM',
49: 'GSM 6.10',
64: 'ITU G.721 ADPCM',
80: 'MPEG',
0xFFFF: 'Experimental',
}
LOOP_TYPE_FORWARD = 0
LOOP_TYPE_ALTERNATE = 1
LOOP_TYPE_REVERSE = 2
# classes
class Error(Exception):
"""General error."""
pass
class ParseError(Error):
pass
class UnsupportedCompressionError(Error):
pass
class WavChunk(Chunk):
"""Base class for chunks in a WAVE RIFF file.
Sub-classes chunk.Chunk but offers more convenient property-based access
to chunk data. Attributes:
- name: four-character chunk tag name
- size: length of chunk data
- data: raw chunk data
Specialized sub-classes for specific chunk types may add more attributes
for parsed chunk data.
Getting the string value of an instance (e.g. via 'str()' or 'print'),
yields the binary chunk data including tag and size fields and appopriate
data padding.
"""
fourcc = b''
_fieldnames = ()
_pack_format = ''
def __init__(self, file, name=None):
self.closed = False
# whether to align to word (2-byte) boundaries
self.align = True
self.file = file
if name is None:
self.chunkname = file.read(4)
if len(self.chunkname) < 4:
raise EOFError
else:
self.name = name
try:
self.chunksize = struct.unpack('<L', file.read(4))[0]
except struct.error:
raise EOFError
self.size_read = 0
try:
self.offset = self.file.tell()
except (AttributeError, IOError):
self.seekable = False
self._data = self.read()
else:
self.seekable = True
self._data = None
@property
def name(self):
"""Four-character chunk tag."""
return self.chunkname
@name.setter
def name(self, name):
if len(name) > 4:
raise ValueError("Chunk tag name length must be 4 characters.")
self.chunkname = name + b' ' * max(0, 4 - len(name))
@property
def size(self):
if self._data is None:
return self.chunksize
else:
return len(self.data)
@property
def data(self):
if self._data is None:
log.debug("Reading data from %s", self.__class__.__name__)
self.seek(0)
self._data = self.read()
return self._data
def __repr__(self):
return (" ".join(["%02X" % c if isinstance(c, int) else ord(c)
for c in self.data[:100]]) +
(" [...]" if len(self.data) > 100 else ""))
def __str__(self):
fmt = '<4sL%is' % len(self.data)
packed_size = struct.pack('<L', len(self.data))
log.debug("Data size: %i (%r)", len(self.data), packed_size)
return struct.pack(fmt, self.name, len(self.data), self.data) + (
'\0' if len(self.data) % 2 else '')
def __getattr__(self, name):
log.debug("%s.__getattr__(%r) called.", self.__class__.__name__, name)
# attribute access triggers deferred parsing of chunk data
if self._data is None:
self._parse()
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
def _parse(self):
log.debug("%s._parse() called.", self.__class__.__name__)
try:
self.__dict__.update(_unpack_to_dict(self._pack_format, self.data,
0, *self._fieldnames))
except struct.error:
raise ParseError("Invalid data in '%s' chunk." % self.fourcc)
class FmtChunk(WavChunk):
fourcc = 'fmt '
_pack_format = '<hhllh'
_fieldnames = (
'format_tag',
'channels',
'samples_per_sec',
'avg_bytes_per_sec',
'block_align')
def _parse(self):
WavChunk._parse(self)
if self.format_tag == WAVE_FORMAT_PCM:
self.bits_per_sample = struct.unpack('<h', self.data[14:16])[0]
self.compressed = False
else:
self.compressed = True
if self.format_tag not in FORMAT_TAGS:
log.warn('Unknown format tag: %r', self.format_tag)
@property
def comp_name(self):
return FORMAT_TAGS.get(self.format_tag, '<unsupported>')
@property
def sample_width(self):
if self.format_tag == WAVE_FORMAT_PCM:
return (self.bits_per_sample + 7) // 8
else:
raise UnsupportedCompressionError("Can't determine sample width "
"for %s data compression format.", self.comp_name)
@property
def frame_size(self):
return self.channels * self.sample_width
class SmplChunk(WavChunk):
"""Represents a 'smpl' chunk with information for samplers."""
fourcc = 'smpl'
_pack_format = '<9l'
_loop_pack_format = '<6l'
_fieldnames = (
'manufacturer',
'product',
'sample_period',
'midi_unity_note',
'midi_pitch_fraction',
'smpte_format',
'smpte_offset',
'sample_loops',
'sampler_data')
_loop_fieldnames = (
'cue_point_id',
'type',
'start',
'end',
'fraction',
'play_count')
def _parse(self):
WavChunk._parse(self)
self.loops = []
for i in range(self.sample_loops):
self.loops.append(
_unpack_to_dict(self._loop_pack_format, self.data,
struct.calcsize(self._pack_format), *self._loop_fieldnames))
class ListChunk(WavChunk):
"""Represents a 'list' chunk, which has a type and contains sub-chunks.
The list type is available through the 'type_id' attribute, the list of
sub-chunks through the 'subchunks' attribute. Each list item is a tuple
with the four-character chunk tag as the first item and the raw chunk data
(as a byte string) as the second.
"""
_fieldnames = ('type_id',)
_pack_format = '<4s'
def _parse(self):
WavChunk._parse(self)
pos = 4
self.subchunks = []
while pos < len(self.data):
tag = self.data[pos:pos+4]
size = struct.unpack_from('<l', self.data, pos + 4)[0]
self.subchunks.append((tag, self.data[pos+8:pos+8+size]))
pos += 8 + size + (1 if size % 2 else 0)
class CueChunk(WavChunk):
"""Represents a 'cue ' chunk with the list of cue points."""
fourcc = 'cue '
_pack_format = '<l'
_cue_pack_format = '<2l4s3l'
_fieldnames = ('num_cue_points')
_loop_fieldnames = (
'id',
'position',
'data_chunk_id',
'chunk_start',
'block_start',
'sample_offset')
def _parse(self):
WavChunk._parse(self)
self.cue_points = []
for i in range(self.num_cue_points):
self.cue_points.append(
_unpack_to_dict(self._cue_pack_format, self.data,
struct.calcsize(self._pack_format), *self._cue_fieldnames))
class WavFile(object):
"""WAV file reader."""
def __init__(self, wavfile):
self._i_opened_the_file = False
if isinstance(wavfile, basestring):
self.filename = wavfile
self.file = open(self.filename, 'rb')
self._i_opened_the_file = True
else:
self.file = wavfile
try:
self.filename = wavfile.name
except AttributeError:
self.filename = None
try:
self._riff = Chunk(self.file, align=True, bigendian=False)
riff_name = self._riff.getname()
if riff_name != b'RIFF':
raise ValueError("First chunk name != 'RIFF' (value '%s')" %
riff_name)
except (EOFError, ValueError):
raise ParseError("%s: Invalid/missing RIFF tag or chunk size." %
self.filename)
if self._riff.read(4) != b'WAVE':
raise Error("%s: not a WAVE file" % self.filename)
# dict to store chunk by chunk name (four-cc tag)
self.chunks = dict()
# we keep an extra list of chunks to maintain chunk position
self._chunklist = []
while True:
try:
chunk = chunk_factory(self._riff)
except EOFError:
break
if chunk.name == b'data' and b'fmt ' not in self.chunks:
log.warn("Encountered 'data' chunk before 'fmt ' chunk.")
if chunk.name in KNOWN_CHUNKS:
if chunk.name in self.chunks:
log.warn("Ignoring extra '%s' chunk at %i bytes."
% (chunk.name, self._riff.tell()))
else:
self.chunks[chunk.name] = chunk
else:
self.chunks.setdefault(chunk.name, []).append(chunk)
self._chunklist.append(chunk)
chunk.skip()
if b'fmt ' not in self.chunks or b'data' not in self.chunks:
raise ParseError("'fmt ' chunk and/or 'data' chunk missing.")
def close(self):
if self._i_opened_the_file:
try:
self.file.close()
except:
pass
__del__ = close
def __repr__(self):
s = []
for chunk in self:
s.append("Chunk '%s': size %i (%r)\n" % (
chunk.name.decode('ascii'), chunk.size, chunk))
return "".join(s)
def __str__(self):
data = b"".join(str(chunk) for chunk in self)
packed_size = struct.pack('<l', len(data) + 4)
log.debug("Data size: %i (%r)", len(data), packed_size)
return b"RIFF" + packed_size + b"WAVE" + data
def __iter__(self):
"""Make object useable as an iterator which yields each RIFF chunk.
The original position of each chunk in the source file is kept except
for the 'fmt ' chunk, which is always returned first.
"""
yield self.chunks[b'fmt ']
for chunk in self.chunks.get(b'LIST', []):
if chunk.type_id == b'INFO':
yield chunk
for chunk in self._chunklist:
if chunk.name == b'fmt ':
continue
if chunk.name == b'LIST' and chunk.type_id == b'INFO':
continue
yield chunk
@property
def fmt(self):
try:
return self.chunks[b'fmt ']
except KeyError:
raise "'fmt ' chunk not found."
@property
def loops(self):
try:
return self.chunks[b'smpl'].loops
except KeyError:
return []
@property
def cue_points(self):
try:
return self.chunks[b'cue '].loops
except KeyError:
return []
@property
def info(self):
try:
for chunk in self.chunks.get(b'LIST', []):
if chunk.type_id == b'INFO':
return dict((key, val.rstrip('\0'))
for key, val in chunk.subchunks)
except KeyError:
pass
return dict()
def raw_frames(self):
data = self.chunks[b'data'].data
size = len(data)
fs = self.fmt.frame_size
pos = 0
while pos <= size - fs:
yield data[pos:pos+fs]
pos += fs
# utility functions
def _unpack_to_dict(format, data, offset=0, *names):
return dict(zip(names, struct.unpack_from(format, data, offset)))
_chunk_registry = {
b'cue ': CueChunk,
b'fmt ': FmtChunk,
b'smpl': SmplChunk,
b'list': ListChunk,
b'LIST': ListChunk,
None: WavChunk
}
def chunk_factory(file):
fourcc = file.read(4)
if len(fourcc) < 4:
raise EOFError
return _chunk_registry.get(fourcc, _chunk_registry[None])(file, name=fourcc)
if __name__ == '__main__':
import sys
logging.basicConfig(level=logging.DEBUG)
if len(sys.argv) >= 2:
wav = WavFile(sys.argv[1])
else:
wav = WavFile(sys.stdin)
print(repr(wav))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment