Tables and shapes for message-passing algorithms
================================================
The Easy BP package provides a number of MATLAB classes to make
implementation of various message passing algorithms in MATLAB much
simpler. It is provided as a research and teaching tool, focused
specifically on ease-of-use.
The idea is the following: in message-passing algorithms for discrete
graphical models, the basic data structures are potentials and
messages. These are tables of numbers that live on subsets of
variables. For example, in loopy BP for MRFs, we have beliefs
b_ij(x_i,x_j), b_i(x_i)
and messages
m_ji(x_i)
for all edges (ij). b_ij(x_i,x_j) is a table of |X_i||X_j| numbers,
where |X_i| is the cardinality of the domain X_i of the random
variable x_i, while b_i(x_i), m_ji(x_i) are vectors of length |X_i|
associated with variable x_i.
Further, the operations we perform on these data structures are
elementwise operations (e.g. multiplication), or tensor type
operations (sum-prod, max-prod). These operations respect the
correspondence of node indices between tables. For example,
b_ij(x_i,x_j) * b_i(x_i) * b_j(x_j)
should be a table of numbers where the (x_i,x_j) entry is as given
above. If we implement b_ij as a mtrix in MATLAB, and b_i, b_j as
column vectors, then to implement the multiplication above we'd have to
compute:
b_ij .* ( b_i * b_j')
Now imagine that we'd like to compute
b_c(x_c) * b_d(x_d)
or sum_{x_d} b_c(x_c) * b_d(x_d)
where c and d are some subsets of nodes. If we implement these
tables as multidimensional arrays in MATLAB, we'd have to do some
complicated reshape, repmat, and permute operations.
Another useful operation for message-passing algorithms is
normalization. In order for algorithms to be numerically stable,
it is often necessary to normalize the messages so that either the
entries sum to 1, or has maximum value 1. Further, often times
we'd like to keep track of the normalization constant (e.g. for
junction tree the normalization constant gives you the value of the
partition function). However we'd like to keep track of these in
the log domain, such they can often take on values that are too
large or two small for double precision arithmetic.
This package implements shapes and tables, two structures in C with MATLAB
object-oriented interfaces, that will make things much simpler for
the operations we're interested in. The above two examples are simply
done in MATLAB as
b_c .* b_d
and b_c * b_d
which are very succinct and intuitive. Some other examples:
mathematical operations: MATLAB expressions:
max_{x_e} b_c(x_c) * b_d(x_d) maxprod(b_c,b_d,e)
prod_c b_c(x_c) times(B{:}) where B is cell of tables
b_c(x_c) / sum_{x_c} b_c(x_c) b_c ./ sum(b_c,c)
Further, the normalization constants are automatically tracked in the
log domain so we need not worry about numerical overflows and underflows.
I'll describe how to install the system and check out a simple demo,
then describe shapes and tables as MATLAB objects. The C interface
will be described later on. And finally at the end I'll describe
what I'll actually do with this codebase.
Installing and first example
============================
After extracting the files, the directory should contain 6 subdirectories
subdirectories: @index, @shape, @table, CFiles, Networks and Utilities.
CFiles contains the files for the C implementation, while @index, @shape
and @table are the MATLAB object interfaces. Networks contains some code
to generate some simple MRFs, and Utilities contains code to build and test
the system.
First run
initpath
to add Networks and Utilities into your path. Run
buildsystem
in the main directory to compile the object and mex files, which will
automatically be put into @index, @shape and @table.
Run
testsystem
to test the system on a forward pass of a 10 node HMM. This should
produce output like the following:
Test the system by calculating log probability of observations in a
HMM using both brute force and forward-backward (actually just forward
suffices).
Usual MATLAB code (without message normalization):
log probability = -7.486
Brute force:
log probability = -7.486
Forward pass using tables:
log probability = -7.486
The first run uses usual MATLAB code for forward pass, second multiplies
all potentials together into one big table, and third does forward pass.
For more information on the index, shape and tables classes type
help index
help shape
help table
in MATLAB.
C interface and intricacies
===========================
There is a C interface to most of the functions I described above.
See shape.h and table.h in directory cfiles. They are pretty much
self-explanatory.
Most operators have been split into two function, table_operator,
and table_basic_operator. The table_operator functions take in
arguments, and returns a newly created table containing the result.
It does this by allocating the right sized output, and calling the
table_basic_operator function which actually does the computations.
The table_basic_operator function does not actually do the allocating
of memory itself. The idea is that if you were to do an implementation
of some message-passing algorithm in C, since the number and types
of messages and potentials are fixed, you'd allocate memory for
these, and just use the table_basic_operator functions, which
performs the computations for you without the trouble of allocating
and deallocating memory.
The table_basic_operator functions all have a last argument "ITERATOR
tt". See iterator.h for the definition this structure. As
table_basic_operator steps through the entries in each table when
doing its computations, the iterator controls how the entries are
stepped through for the output and the input tables.
Most of the operators are defined in their own C files, which are
mostly trivial. The main bits of the code actually resides in
table_op_1arg.c, table_op_2arg.c, table_op_marg.c, and table_op_tensor.c.
These define macros which are expanded out by the operator C files,
substituting in different bits of code for the different behaviours
of the operators.
table_op_1arg defines the mex function for the unary (1 argument)
operators (zeros, ones, uminus, sumnorm, maxnorm, and maxabsnorm).
The actual code for the unary operators reside in table_basic.c.
uplus is just an m-file since it just returns the input unchanged.
table_op_2arg defines both the C code and mex function for the
binary (2 argument) operators. There are three types of such
operators: comparative (==, ~=, <, >, <=, >=), additive (+, -, max,
min), and multiplicative (.*, ./, .\). These types differ in how
the zz value is treated. .^ is not treated this way and its basic
code is given in table_power.c, since it is trickier to deal with
and there is only one such function (no point defining macros and
expand it out).
table_op_marg defines both the C code and mex function for the
marginalization operators (sum, prod, maxmarg, minmarg). Again
these are split into additive ones (sum, maxmarg, minmarg) and
multiplicative ones (prod).
table_op_tensor defines both the C code and mex function for the
tensor operators. These are split into additive-additive types
(maxsum, minsum) and additive-multiplicative (sumprod, maxprod)
types.
FUTURE WORK
===========
The future work is quite obvious---actually implement the various
message-passing algorithms! With this code this should be very
easy to do so that we may concentrate our efforts on the algorithms
themselves rather than on the implementation. So obvious candidates
are: junction tree, brute force, loopy BP, various region graph
message-passing alsgorithms, Tom Heskes et al's algorithm represented
using primary variables rather than the dual messages and the
convergent versions, treeEP and other discrete EP, Martin Wainwright's
convergent versions.
After that, the next step would be to deal with conditional models
like CRFs, and think of learning such models as well.