Skip to content

Instantly share code, notes, and snippets.

@jannschu
Created July 16, 2010 15:00
Show Gist options
  • Save jannschu/478460 to your computer and use it in GitHub Desktop.
Save jannschu/478460 to your computer and use it in GitHub Desktop.
PKZIP explode/decompress in Haskell
{-
- Copyright (C) 2010 Ovillo <http://github.com/jannschu/ovillo>
- Thanks to Ben Rudiak-Gould on comp.compression for a description
- and Mark Adler for an intresting implementation with "blast.c".
-
- This file is part of Ovillo.
-
- Ovillo is free software: you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- Ovillo is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
-
- You should have received a copy of the GNU General Public License
- along with Ovillo. If not, see <http://www.gnu.org/licenses/>.
-}
module PKZIP (decompress, decompressFromFile) where
import qualified Data.ByteString.Lazy as B
import Data.Word (Word8)
import Data.Bits (testBit, shiftL, shiftR, (.&.), (.|.))
import Data.List (isPrefixOf, partition)
import qualified Data.Map as Map
bitList :: B.ByteString -> [Bool]
bitList byteString
| B.null byteString = []
| otherwise = case read8Bit byteString of
(bits, newByteString) -> bits ++ (bitList newByteString)
read8Bit :: B.ByteString -> ([Bool], B.ByteString)
read8Bit byteString = (getBits (B.head byteString), B.tail byteString)
where getBits byte = map ((testBit) byte) [0 .. 7]
decompress :: B.ByteString -> Either String B.ByteString
decompress bytes
| hasValidHeader bytes =
sumZipData literalFlag windowSize zippedData B.empty
| otherwise = Left "unexpected header"
where
literalFlag = B.index bytes 0
windowSize = B.index bytes 1
zippedData = bitList (B.drop 2 bytes)
hasValidHeader bytes = B.length bytes >= 2 && literalFlag <= 1 &&
windowSize >= 4 && windowSize <= 6
-- no ZIP file! Just pure compressed data
decompressFromFile :: String -> IO (Either String B.ByteString)
decompressFromFile path = do
file <- B.readFile path
return (decompress file)
sumZipData :: Word8 -> Word8 -> [Bool] -> B.ByteString -> Either String B.ByteString
sumZipData 0 windowSize (False:rest) result = -- literal value, 8-bit binary
case (splitAt 8 rest) of
(char, newRest) ->
if length char == 8 then let newList = fromIntegral (bitsToWord char) `B.cons` result in
sumZipData 0 windowSize newRest newList
else Left "not enough bits for literal"
sumZipData 1 windowSize (False:rest) result = -- literal value, ASCII encoded
case search rest literalTree of
Just (value, newRest) -> let newList = (fromIntegral value) `B.cons` result in
sumZipData 1 windowSize newRest newList
Nothing -> Left "invalid literal"
sumZipData f windowSize (True:stream) result = -- (size,offset) pair
case (readSize stream) ==> readOffset of
Left a -> Left a
Right (size, offset, newStream) ->
if size == 519 then Right (B.reverse result)
else let newResult = getFromDict size (offset + 1) result
in sumZipData f windowSize newStream newResult
where
getFromDict size offset dict = let
bits = B.take (fromIntegral offset) dict
full = B.concat(replicate (size `quot` offset) bits)
half = B.drop (fromIntegral (offset - (size `rem` offset))) bits
in half `B.append` (full `B.append` dict)
readSize stream =
case search stream sizeTree of
Nothing -> Left "invalid length value in pair"
Just (n, newStream) ->
let base = sizeBase Map.! n
extraBitsLength = sizeExtra Map.! n
in case splitAt extraBitsLength newStream of
(extraBits, offsetStream) ->
if length extraBits /= extraBitsLength then
Left "not enough bits for length in pair"
else
Right (base + (bitsToWord extraBits), offsetStream)
readOffset value stream
| value == 519 = Right (0, stream)
| otherwise =
case search stream distanceTree of
Nothing -> Left "invalid distance value"
Just (n, newStream) ->
let (extraOffset, lastStream) = splitAt extraBits newStream
in Right (n + (bitsToWord extraOffset), lastStream)
where extraBits = fromIntegral (if value == 2 then 2 else windowSize)
sumZipData _ _ _ _ = Left "unexpected end"
(==>) :: Either a (b, c) -> (b -> c -> Either a (d, c)) -> Either a (b, d, c)
(==>) (Left a) _ = Left a
(==>) (Right (value, restStream)) f =
case f value restStream of
Left a -> Left a
Right (secondValue, resultStream) -> Right (value, secondValue, resultStream)
bitsToWord :: [Bool] -> Int
bitsToWord bits = sum (zipWith shiftBit bits [0..])
where shiftBit a b = if a then 1 `shiftL` b else 0
----------------------------------------------------------------------
data Tree a = Node (Tree a) (Tree a)
| Value a
| Empty
deriving (Show, Eq)
search :: [Bool] -> Tree a -> Maybe (a, [Bool])
search value binaryTree = searchInTree binaryTree value
where searchInTree (Value i) rest = Just (i, rest)
searchInTree Empty _ = Nothing
searchInTree (Node one two) [] = Nothing
searchInTree (Node one two) (x:xs) =
searchInTree (if x then one else two) xs
literalTree = createBinaryTree 4 13 255 15 [173, 44, 45, 12, 61, 12, 45, 12,
45, 12, 13, 252, 252, 252, 253, 253, 253, 44, 27, 10, 24, 7, 8, 53, 10, 6,
21, 6, 5, 7, 11, 5, 38, 5, 38, 5, 12, 8, 12, 9, 11, 8, 11, 25, 8, 9, 7, 38,
11, 7, 22, 7, 6, 9, 11, 6, 24, 7, 5, 22, 7, 6, 12, 11, 9, 7, 11, 12, 24,
23, 8, 55, 6, 7, 8, 7, 6, 7, 9, 8, 23, 8, 10, 12, 10, 12, 8, 10, 4, 76, 13,
188, 7, 28, 7, 8, 124, 11]
distanceTree = createBinaryTree 2 8 63 3 [248, 151, 247, 230, 53, 20, 2]
sizeTree = createBinaryTree 2 7 16 3 [23, 38, 53, 36, 35, 2]
sizeBase = Map.fromList (zip [0..15] [3, 2, 4, 5, 6, 7, 8, 9, 10, 12, 16, 24, 40, 72, 136, 264])
sizeExtra = Map.fromList (zip [0..15] [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8])
createBinaryTree :: Int -> Int -> Int -> Int -> [Int] -> Tree Int
createBinaryTree minCodeLen maxCodeLen number startCode list
| number < 0 || null list = Empty
| otherwise =
buildBinaryTree $ Map.fromList $ expandGroups $
groupByLength $ (createLengthList list [])
where
createLengthList :: [Int] -> [Int] -> [Int]
createLengthList [] result = result
createLengthList (x:xs) result =
let newResult = (replicate ((x `shiftR` 4) + 1) (x .&. 15)) ++ result
in createLengthList xs newResult
groupByLength :: [Int] -> [[Int]]
groupByLength list = foldWithCodeLen [maxCodeLen, (maxCodeLen - 1)..minCodeLen] (zip [0..number] list) []
where
foldWithCodeLen :: [Int] -> [(a, Int)] -> [[a]] -> [[a]]
foldWithCodeLen (n:rest) list result =
case partition (((==) n) . snd) list of
(match,restLens) -> let matchedSymbols = map fst match
in foldWithCodeLen rest restLens (matchedSymbols:result)
foldWithCodeLen [] _ result = result
expandGroups :: [[Int]] -> [(Int, [Bool])]
expandGroups groups = mapCode [minCodeLen..maxCodeLen] groups startCode []
where
mapCode [] _ _ result = result
mapCode (codeLen:cls) (group:gs) lastNum result =
let
(num, newGroup) = foldl foldGroup (lastNum, []) group
in mapCode cls gs (nextLen num) (newGroup ++ result)
where
foldGroup (code, res) symbol =
(nextCode code , (symbol, getBits code codeLen):res)
getBits byte n = map ((testBit) byte) [(n - 1), (n - 2)..0]
nextCode code = code - 1
nextLen code = (code `shiftL` 1) .|. 1
buildBinaryTree :: Map.Map Int [Bool] -> Tree Int
buildBinaryTree symbolMap = buildTree (Map.elems symbolMap) []
where
buildTree library parent =
Node (handle searchLeft True) (handle searchRight False)
where
handle [] _ = Empty
handle [value] _ = findValue value
handle more add = buildTree more (parent ++ [add])
searchLeft = filter (isPrefixOf (parent ++ [True])) library
searchRight = filter (isPrefixOf (parent ++ [False])) library
findValue val = case Map.keys (Map.filter (\a -> a == val) symbolMap) of
(x:_) -> Value x
[] -> error (show val)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment