Wednesday, August 4, 2010

Huffman coding

I've just completed small project in C#; Huffman-coding. For no particular reason, other than to keep my brain occupied. There's something oddly satisfying about making code do what it is supposed to. Even if, as in this case, it is not very useful. The code does a poor job of saving disk-space, Serialize() is way to greedy for that. I tried half-heartedly to save space by using a baseclass with only the essential fields for decoding, but it was not all that effective. It did prove to be a useful exercise, though. 

I couldn't find the motivation to code up something smarter, the challenge was to keep track of the bits. Which I think is correct, at least it decodes whatever I throw at it correctly. (Note that "whatever" in this context is limited to 3 random files that happened to have a short path-name on my computer, so I suppose the testing hasn't been all that thorough. Which doesn't matter, really, as it is pretty useless. :))

Huffman-coding is in principle quite simple. Count how many occurrences there are in the input of  each symbol, and assign a short code to the most common ones, and a long one for the least common ones. There are a number of pages describing the exact mechanisms, but in short it builds a binary tree out of the symbols, with a short path from the root of the tree to the common symbols, and a long path to the less common symbols. When decoding, start at the root node, read a bit from the input stream, and take the left child if it is 0, or the right child if it is 1. Get another bit from the inputstream, and keep going until you hit a node with no children, which in this case means you've arrived at the symbol the code represents. Output the symbol, start at the root node, and get another bit from the input. Quite simple, really. 

There's a number of scenarios that will make the code fail. If the input it all a single symbol, and a bunch of others I haven't considered.

It could be interesting to make symbol-size of the input variable. Currently it works on bytes, but in principle it could work on any fixed number of bits. The tree would be a lot larger, so finding a better way to store it on disk than serialize would definitely be required. Also, I suppose the efficiency requires that some symbols are way more prevalent than others, not sure e.g. 16-bit symbol-size would work. Interesting, though.. 

The only reason I'm publishing is that I've recently gotten into mathematics. And writing about stuff makes me think about the subject in a more structured way, aiding my learning. So I've planned a series of posts, and it starts with Huffman-coding. The plan is to move on to Markov-chains, then a hybrid Huffman-Markov coding, then to Hidden Markov Models. If I ever get around to writing it.. 

Anyway, the code. For what it's worth. 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using System.Runtime.InteropServices;

namespace Huffman {
    class Program {
        static StreamReader finput;
        static StreamWriter foutput;

        // This class contains fields required when decoding.
        [Serializable]
        class DecodeNode {
            public byte value;                      // Which symbol the Node represent
            public DecodeNode left = null;          // Left child-node
            public DecodeNode right = null;         // Right child-node
        }

        // This class also contains definitions only required when encoding
        class Node : DecodeNode, IComparable<Node> {
            public ulong count;             // Number of times the symbol occurs in input-file
            public uint code = 0;           // Code for this symbol
            public uint bitcount = 0;       // Number of bits in the encode
            public bool isLeaf = false;     // Flag to signify if it has a value

            // This will allow us to compare the frequencies of other values easily.
            // Ends up with "lowest to highest"
            public int CompareTo(Node other) {
                return (count.CompareTo(other.count));
            }

            // There MUST be a more elegant way of losing the "extra info" in Node before serializing..
            public DecodeNode ToDecodeNode() {
                DecodeNode d = new DecodeNode();

                if (this.isLeaf) {
                    d.value = this.value;
                } else {
                    d.left = ((Node)this.left).ToDecodeNode();
                    d.right = ((Node)this.right).ToDecodeNode();
                }

                return (d);
            }

            // Count number nodes in (sub-) tree, including this. Not required, just info
            public int TreeCount() {
                int i = 1;

                if (left != null)
                    i += ((Node)left).TreeCount();
                if (right != null)
                    i += ((Node)right).TreeCount();

                return (i);
            }
        }

        // Used for the code-table, when encoding
        class CodeEntry {
            public uint code;
            public byte bitcount;
        }

        // Support-method for encode()
        // Recursively traverse the tree, making a encode table as we go along.
        // Assumes maximum 32 bit codes, returns a codetable for use in encoding.
        static CodeEntry[] makecodes(Node v) {

            // Set up encode table
            CodeEntry[] ct = new CodeEntry[256];

            // Only called with this proto from root, therefore code and codelength == 0
            makecodes(ct, v, 0, 0);

            return (ct);
        }

        static void makecodes(CodeEntry[] ctable, Node v, uint code, byte bitcount) {
            if (v.isLeaf) {
                CodeEntry e = new CodeEntry();
                e.bitcount = bitcount;
                e.code = code;

                ctable[v.value] = e;
                return;
            }

            bitcount++;         // Code is 1 bit longer
            code <<= 1;         // Left-shift 1 bit

            // Left child
            if (v.left != null)
                makecodes(ctable, (Node)v.left, code, bitcount);

            // Right child
            code++; // Make last bit 1
            if (v.right != null)
                makecodes(ctable, (Node)v.right, code, bitcount);
        }

        static Node maketree(List<Node> values) {

            // Remove all values where count == 0
            values.RemoveAll(delegate(Node n) { return n.count == 0; });

            // Now build a tree. 
            // Remove the two least common values, add a new node which is NOT a leaf node. Frq = count of node 1 + count of node 2
            // Insert the new node into the list. Do until only one node is left.
            Node node, node1, node2;
            while (values.Count > 1) {
                // Sort the array so that the LEAST common values are "on top". 
                values.Sort();

                // Grab 2 least frequent nodes from the list
                node1 = values[0];
                node2 = values[1];
                values.RemoveRange(0, 2);

                // Create a parent node for these
                node = new Node();
                node.count = node1.count + node2.count;
                node.left = node1;
                node.right = node2;

                // Add new parent node to list
                values.Add(node);
            }

            return (values[0]);
        }

        // For debugging
        static void dumptree(DecodeNode n, string level) {
            level += ".";
            if ((n.left == null) && (n.right == null))
                Console.WriteLine(level + "Val: " + n.value);
            else {
                Console.WriteLine(level);
                dumptree(n.left, level);
                dumptree(n.right, level);
            }
        }

        static void encode() {
            // 500 seems to be a ballpark figure for number of nodes in the tree.
            List<Node> values = new List<Node>(500);

            // Add objects with values, so we can later identify it
            for (int i = 0; i < 256; i++) {
                Node v = new Node();
                v.value = (byte)i;
                v.isLeaf = true;
                values.Add(v);
            }

            Stream fin = finput.BaseStream;
            int b;

            // Count each byte. The values are at their offsets in the list
            while ((b = fin.ReadByte()) != -1)
                values[(byte)b].count++;

            Node root = maketree(values);

            #region Dump tree to disk
            // This is pretty inefficient use of diskspace. Serialize() use a lot of space..

            DecodeNode decoderoot = root.ToDecodeNode();    // Convert to smaller objects, saves diskspace in output.

            // decoderoot contains all nodes, with enough information to recreate the tree at decode-time.
            Stream fout = foutput.BaseStream;
            BinaryFormatter binfmt = new BinaryFormatter();
            binfmt.Serialize(fout, decoderoot);
            decoderoot = null; 
            #endregion

            // Create codes
            CodeEntry[] codeTable = makecodes(root);

            uint buffer = 0;
            int bits_in_buffer = 0;

            fin.Seek(0, 0);                   // Start from beginning

            // Now encode the entire stream.
            CodeEntry e;
            while ((b = fin.ReadByte()) != -1) {

                // Pull relevant entry from the codelist
                e = codeTable[b];

                // Add new code to buffer. Leftshift code so it lands next to  whatever is already in the buffer
                bits_in_buffer += e.bitcount;
                buffer |= (e.code << (32 - bits_in_buffer));

                // Quick check..
                if (bits_in_buffer > 32)
                    Console.WriteLine("Fatal! Buffer overflow in encode()");

                // Write buffer to output
                while (bits_in_buffer >= 8) {
                    fout.WriteByte((byte)(buffer >> 24));   // WriteByte takes the right 8 bits out of an int, so we must right-shift by 24 bits
                    buffer <<= 8;                           // Shift the whole buffer 8 bits to the left
                    bits_in_buffer -= 8;
                }
            }

            // Write out any bits left in the buffer, will be less than 8.
            // There is a slight chance decode() will output an extra character at the end of the output,
            // if a value-node is found at the end of however many padding 0's is written. Can be solved by inserting
            // a virtual "EOF-node" in the tree, or simply writing the filesize into the encoded file and stop
            // decoding when the correct number of bytes has been written. 
            if (bits_in_buffer != 0)
                fout.WriteByte((byte)(buffer >> 24));

            foutput.Close();
            return;
        }

        static void decode() {
            DecodeNode root;

            Stream fout = foutput.BaseStream;
            Stream fin = finput.BaseStream;

            // Recreate the b-tree
            BinaryFormatter binfmt = new BinaryFormatter();
            root = (DecodeNode)binfmt.Deserialize(fin);

            // Read through the whole file, byte by byte, decoding as we go
            int v;
            byte val;

            DecodeNode current = root;
            while ((v = fin.ReadByte()) != -1) {
                val = (byte)v;

                // Do each bit, left to right
                for (int i = 0; i < 8; i++) {

                    // Walk the tree
                    if ((val & 128) == 128)  // Check if MSB is 1
                        current = current.right;
                    else
                        current = current.left;

                    // If we are at a leaf-node, output the value, and reset to root of tree for the next bits.
                    if ((current.left == null) && (current.right == null)) {
                        fout.WriteByte(current.value);
                        current = root;
                    }

                    // Shift next bit into position
                    val <<= 1;
                }
            }

            fout.Close();

            return;
        }

        static void usage() {
            System.Console.WriteLine("Usage:");
            System.Console.WriteLine("Huffman [-c|-d] <in-file> <out-file>");
            System.Console.WriteLine("-c    encode <in-file>, write result in <out-file> ");
            System.Console.WriteLine("-d    decode <in-file>, write result in <out-file> ");
        }

        static void Main(string[] args) {
            if (args.Length != 3) {
                usage();
                return;
            }

            // Open inputstram
            try {
                finput = new StreamReader(args[1]);
            } catch (Exception e) {
                System.Console.WriteLine(args[1] + " is not a valid input-file:\n" + e.Message);
                usage();
                return;
            }

            // Open outputstram
            try {
                foutput = new StreamWriter(args[2]);
            } catch (Exception e) {
                System.Console.WriteLine(args[2] + " is not a valid output-file:\n" + e.Message);
                usage();
                return;
            }

            if (args[0].Equals("-c"))
                encode();
            else if (args[0].Equals("-d"))
                decode();
            else
                usage();
        }
    }
}