A FSharp implementation of Huffman compression algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Huffman | |
open System | |
open System.IO | |
type bit = bool | |
type path = bit list | |
type BinaryTreeNode = | |
| Leaf of byte * frequency:int | |
| Branch of left:BinaryTreeNode option * right:BinaryTreeNode option * frequency:int | |
member __.Switch(b:bit) = | |
match __ with | |
| Branch (_, right, _) when b -> right | |
| Branch (left, _, _) when not b-> left | |
| _ -> None | |
static member Empty = Leaf(0uy,0) | |
member private __.cost = | |
lazy | |
match __ with | |
| Leaf (_, f) -> f | |
| Branch(_,_,f) -> f | |
member __.Cost() = __.cost.Value | |
type BitWriter(stream:Stream) = | |
let buffer = ref 0uy | |
let len = ref 0 | |
let flush() = | |
while !len < 8 do | |
buffer := !buffer <<< 1 | |
buffer := !buffer ||| 0uy | |
len := !len + 1 | |
stream.WriteByte(!buffer) | |
stream.Flush() | |
buffer := 0uy | |
len := 0 | |
let mustFlush() = | |
!len >= 8 | |
member __.Flush() = | |
if mustFlush() then flush() | |
member __.Close() = | |
if !len > 0 then flush() | |
member __.Write(b:bit) = | |
let v = if b then 1uy else 0uy | |
buffer := ((!buffer) <<< 1) ||| v | |
len := !len + 1 | |
__.Flush() | |
member __.Write(bits:bit list) = | |
for b in bits do __.Write b | |
interface IDisposable with | |
member __.Dispose() = | |
__.Close() | |
type BitReader(stream:Stream) = | |
let buffer = ref 0uy | |
let len = ref 0 | |
let position = ref 0L | |
let loadBuffer() = | |
if stream.Position >= stream.Length | |
then buffer := 0uy | |
len := 8 | |
let by = stream.ReadByte() | |
if by = -1 | |
then buffer := 0uy | |
buffer := byte by | |
let readBit() = | |
len := !len - 1 | |
let mask = 1 <<< !len | |
let v = !buffer &&& (byte mask) | |
position := !position + 1L | |
v >= 1uy | |
let peekBit() = | |
let mask = 1 <<< (!len - 1) | |
let v = !buffer &&& (byte mask) | |
v >= 1uy | |
member __.End with get() = stream.Position >= stream.Length && !len <= 0 | |
member __.Position with get() = position | |
member __.Read() = | |
if __.End | |
then None | |
else | |
if !len <= 0 | |
then loadBuffer() | |
Some (readBit()) | |
member __.Peek() = | |
if __.End | |
then None | |
else | |
if !len <= 0 | |
then loadBuffer() | |
Some (peekBit()) | |
type BinaryTree (root:BinaryTreeNode) = | |
member __.Root with get() = root | |
member __.GetPath (data:byte) = | |
let rec scan (node:BinaryTreeNode) (d:byte) (p:path) = | |
let scanBranch (left:BinaryTreeNode option) (right:BinaryTreeNode option) (cp:path) = | |
let scanChildren parent v = | |
match parent with | |
| Some(Leaf(l,f)) -> scan (Leaf(l,f)) d (cp @ [v]) | |
| Some n -> scan n d (cp @ [v]) | |
| None -> None | |
scanChildren left false | |
|> function | |
| Some r -> Some r | |
| None -> scanChildren right true | |
match node with | |
| Leaf (b,_) when b = data -> Some p | |
| Leaf _ -> None | |
| Branch (left, right, _) -> scanBranch left right p | |
scan root data [] | |
member __.GetByte(p:path) = | |
let rec loop (bits:path) (current:BinaryTreeNode) = | |
match bits, current with | |
| true :: tail, Branch(_, Some r, _) -> loop tail r | |
| false :: tail, Branch(Some l,_, _) -> loop tail l | |
| _, Leaf(d, _) -> d | |
| _ -> failwith "invalid path" | |
loop p root | |
let buildTree (frequencies: (byte*int) list) = | |
let sort (tree:BinaryTreeNode list) = | |
tree |> List.sortBy (fun i -> i.Cost()) | |
let rec loop (tree:BinaryTreeNode list) = | |
match sort tree with | |
| left::right::[] -> | |
Branch(Some left, Some right, left.Cost() + right.Cost()) | |
| left::right::tail -> | |
let branch = Branch(Some left, Some right, left.Cost() + right.Cost()) | |
loop (branch :: tail) | |
| [single] -> single | |
| [] -> failwith "invalid operation" | |
frequencies | |
|> Seq.map Leaf | |
|> List.ofSeq | |
|> loop | |
|> BinaryTree | |
let getFrequencies (l:byte list) = | |
l |> Seq.groupBy(fun c -> c) | |
|> Seq.map (fun (c,l) -> c, (List.ofSeq l).Length) | |
|> Seq.toList | |
let getFrequencies' (stream:Stream) = | |
let rec loop (acc:(byte*int) list) = | |
let i = stream.ReadByte() | |
if i < 0 | |
then acc | |
else | |
let b = byte i | |
match acc |> List.tryFind(fun (v,_) -> v = b) with | |
| Some (_,c) -> | |
acc | |
|> List.filter(fun (v,_) -> v <> b) | |
|> List.append [(b, (c+1))] | |
| None -> acc |> List.append [(b, 1)] | |
|> loop | |
stream.Position <- 0L | |
let f = loop [] | |
stream.Position <- 0L | |
f | |
let serializeTree stream (frequencies: (byte*int) list) = | |
let binaryWriter = new BinaryWriter(stream) | |
binaryWriter.Write(frequencies.Length) | |
for b,c in frequencies do | |
binaryWriter.Write b | |
binaryWriter.Write c | |
binaryWriter.Flush() | |
stream.Flush() | |
let deserializeFrequencies stream = | |
let binaryReader = new BinaryReader(stream) | |
let count = binaryReader.ReadInt32() | |
[0..(count-1)] | |
|> Seq.map (fun _ -> binaryReader.ReadByte(), binaryReader.ReadInt32()) | |
|> Seq.toList | |
let deserializeTree stream = | |
stream |> deserializeFrequencies |> buildTree | |
let compress (output:Stream) (bytes:byte list) = | |
use writer = new BitWriter(output) | |
let frequencies = getFrequencies bytes | |
serializeTree output frequencies | |
let tree = buildTree frequencies | |
for b in bytes do | |
let p = tree.GetPath b | |
match p with | |
| Some bits -> bits |> List.iter writer.Write | |
| None -> raise <| ArithmeticException "compression failed" | |
writer.Close() | |
let compressStream (output:Stream) (input:Stream) = | |
use writer = new BitWriter(output) | |
let frequencies = getFrequencies' input | |
serializeTree output frequencies | |
let tree = buildTree frequencies | |
while input.Position < input.Length do | |
let b = byte <| input.ReadByte() | |
let p = tree.GetPath b | |
match p with | |
| Some bits -> bits |> List.iter writer.Write | |
| None -> raise <| ArithmeticException "compression failed" | |
writer.Close() | |
let rec readPath (reader:BitReader) (node:BinaryTreeNode) acc = | |
let sb = reader.Peek() | |
let next() = reader.Read() |> ignore | |
match reader.End, sb, node with | |
| false, Some b, Branch(Some(Leaf(d,_)), _, _) when not b -> | |
next(); Some d | |
| false, Some b, Branch(Some l, _, _) when not b -> | |
next() | |
readPath reader l acc | |
| false, Some b, Branch(_, Some(Leaf(d,_)), _) when b -> | |
next(); Some d | |
| false, Some b, Branch(_, Some r, _) when b -> | |
next() | |
readPath reader r acc | |
| false, Some true, Leaf(d,_) -> next(); Some d | |
| false, Some false, Leaf(d,_) -> Some d | |
| true, _, _ -> None | |
| _ -> failwith "corrupted stream" | |
let decompress stream = | |
let tree = deserializeTree stream | |
let reader = BitReader stream | |
seq { | |
while not reader.End do | |
match readPath reader tree.Root [] with | |
| Some b -> yield b | |
| None -> () | |
} |> Seq.toList | |
let decompressInStream stream (output:Stream) = | |
let tree = deserializeTree stream | |
let reader = BitReader stream | |
while not reader.End do | |
match readPath reader tree.Root [] with | |
| Some b -> | |
output.WriteByte b | |
| None -> () | |
output.Flush() | |
let strToBytes (s:string) = | |
s.ToCharArray() | |
|> Seq.map byte | |
|> Seq.toList | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#load "Huffman.fs" | |
open System | |
open System.IO | |
open Huffman | |
// compress a byte list | |
let file1 = File.ReadAllBytes(Path.Combine(__SOURCE_DIRECTORY__, "..\\assets\\photo.raw")) |> Seq.toList | |
let file2 = File.OpenWrite(Path.Combine(__SOURCE_DIRECTORY__, "..\\assets\\photo.huffman")) | |
compress file2 file1 | |
file2.Close() | |
// compress a FileStream | |
#time | |
let page1 = File.OpenRead(Path.Combine(__SOURCE_DIRECTORY__, "..\\assets\\page1.html")) | |
let output2 = File.OpenWrite(Path.Combine(__SOURCE_DIRECTORY__, "..\\assets\\page1.huffman")) | |
page1 |> compressStream output2 | |
page1.Close() | |
output2.Close() | |
#time | |
// decompress a FileStream | |
#time | |
let huff1 = File.OpenRead(Path.Combine(__SOURCE_DIRECTORY__, "..\\assets\\page1.huffman")) | |
let page1_1 = File.OpenWrite(Path.Combine(__SOURCE_DIRECTORY__, "..\\assets\\page1_1.html")) | |
decompressInStream huff1 page1_1 | |
huff1.Close() | |
page1_1.Close() | |
#time | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace FSandbox.Tests | |
open NUnit.Framework | |
open System.IO | |
open Huffman | |
[<TestFixture>] | |
type ``Check if huffman compression does not alter data``() = | |
[<Test>] | |
member __.``when computing frequencies`` () = | |
let f = | |
"aaabbffppppeee kk aa" | |
|> strToBytes | |
|> getFrequencies | |
|> List.map (fun (b,c) -> char b , c) | |
|> dict | |
Assert.AreEqual(f.Item 'a', 5) | |
Assert.AreEqual(f.Item 'b', 2) | |
Assert.AreEqual(f.Item 'f', 2) | |
Assert.AreEqual(f.Item 'p', 4) | |
Assert.AreEqual(f.Item 'e', 3) | |
Assert.AreEqual(f.Item 'k', 2) | |
[<Test>] | |
member ``when``.``computing frequencies from stream`` () = | |
let f = | |
"aaabbffppppeee kk aa" | |
|> strToBytes | |
|> Seq.toArray | |
|> fun bytes -> new MemoryStream(bytes) | |
|> getFrequencies' | |
|> List.map (fun (b,c) -> char b , c) | |
|> dict | |
Assert.AreEqual(f.Item 'a', 5) | |
Assert.AreEqual(f.Item 'b', 2) | |
Assert.AreEqual(f.Item 'f', 2) | |
Assert.AreEqual(f.Item 'p', 4) | |
Assert.AreEqual(f.Item 'e', 3) | |
Assert.AreEqual(f.Item 'k', 2) | |
[<Test>] | |
member __.``when serializing and deserializing frequencies`` () = | |
let f = | |
"aaabbffppppeee kk aa" | |
|> strToBytes | |
|> getFrequencies | |
let memory = new MemoryStream() | |
serializeTree memory f | |
memory.Position <- 0L | |
let f2 = deserializeFrequencies memory | |
Assert.AreEqual(f, f2) | |
[<Test>] | |
member __.``checking path consistence`` () = | |
let tree1 = | |
"aaabbffppppeee kk aa" | |
|> strToBytes | |
|> getFrequencies | |
|> buildTree | |
let pa = tree1.GetPath(byte 'a') //Some({[True; False]}) | |
let pb = tree1.GetPath(byte 'b') //Some({[False; False; False]}) | |
let pf = tree1.GetPath(byte 'f') //Some({[False; False; True]}) | |
let pp = tree1.GetPath(byte 'p') //Some({[False; True]}) | |
let pe = tree1.GetPath(byte 'e') //Some({[True; True; False]}) | |
let pk = tree1.GetPath(byte 'k') //Some({[True; True; True; True]}) | |
let ps = tree1.GetPath(byte ' ') //Some({[True; True; True; False]}) | |
Assert.AreEqual(byte 'a', tree1.GetByte(pa.Value)) | |
Assert.AreEqual(byte 'b', tree1.GetByte(pb.Value)) | |
Assert.AreEqual(byte 'f', tree1.GetByte(pf.Value)) | |
Assert.AreEqual(byte 'p', tree1.GetByte(pp.Value)) | |
Assert.AreEqual(byte 'e', tree1.GetByte(pe.Value)) | |
Assert.AreEqual(byte ' ', tree1.GetByte(ps.Value)) | |
Assert.AreEqual(byte 'k', tree1.GetByte(pk.Value)) | |
[<Test>] | |
member __.``when compressing in stream`` () = | |
let data = "aaabbffppppeee kk aa" |> strToBytes | |
let tree1 = data |> getFrequencies |> buildTree | |
let pa1 = tree1.GetPath(byte 'a') | |
let pb1 = tree1.GetPath(byte 'b') | |
let pf1 = tree1.GetPath(byte 'f') | |
let pp1 = tree1.GetPath(byte 'p') | |
let pe1 = tree1.GetPath(byte 'e') | |
let pk1 = tree1.GetPath(byte 'k') | |
let memory = new MemoryStream() | |
compress memory data | |
memory.Position <- 0L | |
let tree = deserializeTree memory | |
let pa = tree.GetPath(byte 'a') | |
let pb = tree1.GetPath(byte 'b') | |
let pf = tree1.GetPath(byte 'f') | |
let pp = tree1.GetPath(byte 'p') | |
let pe = tree1.GetPath(byte 'e') | |
let pk = tree1.GetPath(byte 'k') | |
Assert.AreEqual(pa1, pa) //10 | |
Assert.AreEqual(pb1, pb) | |
Assert.AreEqual(pf1, pf) | |
Assert.AreEqual(pp1, pp) | |
Assert.AreEqual(pe1, pe) | |
Assert.AreEqual(pk1, pk) | |
Assert.AreEqual(tree.Root, tree1.Root) | |
let tochar (b:byte option) = | |
match b with | Some c -> char c | None -> '0' | |
let reader = BitReader memory | |
Assert.AreEqual('a', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('a', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('a', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('b', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('b', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('f', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('f', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('p', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('p', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('p', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('p', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('e', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('e', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('e', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual(' ', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('k', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('k', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual(' ', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('a', tochar(readPath reader tree.Root [])) | |
Assert.AreEqual('a', tochar(readPath reader tree.Root [])) | |
[<Test>] | |
member __.``when compressing and decompressing in a stream`` () = | |
let data = "aaabbffppppeee kk aa" |> strToBytes | |
let memory = new MemoryStream() | |
compress memory data | |
memory.Position <- 0L | |
let data2 = decompress memory | |
Assert.AreEqual(data, data2) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment