Gibbs sampling on a normal mixture model

The model we’re trying to fit is \[f(x) = \sum_{k=1}^K p_k \varphi(x; \mu_k, \sigma_k^2)\] where \(K\) is fixed and the unkown parameters are \(\mu = (\mu_1,\ldots, \mu_K) \in \mathbb{R}^k\), \(\sigma^2 = (\sigma^2_1,\ldots, \sigma^2_K) \in \mathbb{R}^{+k}\) and \(p = (p_1,\ldots, p_K) \in (0,1)^k\) such that \(\sum_{k=1}^K p_k = 1\). Here \(\varphi\) is the Gaussian density \[\varphi(x; \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left\{-\frac{1}{2\sigma^2}(x-\mu)^2\right\}\] In this document we’re first generating a synthetic dataset from the model, given a fixed parameter value. We then run a Gibbs sampler to sample from the posterior distribution of the parameters given the dataset.

Let us generate a synthetic dataset from the mixture model. We first fix parameter values and then sample observations given the parameters values.

K <- 2  # number of components
p_0 <- rep(1/K, K)  # weight parameter
mu_0 <- seq(from = -K, to = K, length.out = K)  # mean parameter
sigma2_0 <- rep(1, K)  # variance parameter
N <- 100  # sample size

# First sample the cluster allocation for each data point.
z <- sample(1:K, N, replace = TRUE, prob = p_0)

# Now sample the actual observations conditional on the cluster allocation.
y <- c()  # observations
for (index in 1:K) {
    n_index <- length(z[z == index])
    y <- c(y, rnorm(n_index, mean = mu_0[index], sd = sqrt(sigma2_0[index])))
}

Plot the simulated sample.

y.df <- data.frame(y)
names(y.df) <- "observations"
g <- ggplot(data = y.df, aes(x = observations, y = ..density..))
g <- g + geom_histogram()
print(g)

We now want to sample from the posterior distribution of the parameters given the dataset. Let’s set up the prior “hyper” parameters. The prior is taken to be Dirichlet for \(p\), Normal for \(\mu_k\) and Inverse Gamma for \(\sigma^2_k\). \[\begin{aligned} \pi(p) &\propto \prod_{i=1}^K p_i^{\gamma_k-1}\\ \forall k \in \{1, \ldots, K\} \quad \pi(\mu_k) &\propto \exp\left\{-\frac{1}{2\tau^2}\left(\mu_k - m\right)^2\right\}\\ \forall k \in \{1, \ldots, K\} \quad \pi(\sigma_k^2) &\propto (\sigma_k^2)^{-\alpha - 1}\exp\left\{-\beta \sigma_k^{-2}\right\} \end{aligned}\]

# Hyperparameters for p.
gamma <- rep(1,K)

# Hyperparameters for mu_k.
m <- 0
tau <- 1 

# Hyperparameters for sigma_k^2.
alpha <- 2
beta <- 1 

We need functions to sample from the various conditionals of the posterior distribution.

# Conditional distribution of z given the rest.
zconditional <- function(y, p, mu, sigma2) {
    next_z <- rep(0, N)
    for (i in 1:N) {
        prob <- p * dnorm(y[i], mean = mu, sd = sqrt(sigma2))
        # Note that 'sample' doesn't require 'prob' to be normalized.
        next_z[i] <- sample(1:K, size = 1, prob = prob)
    }
    return(next_z)
}

# Conditional distribution of p given the rest.
pconditional <- function(y, z, mu, sigma2) {
    n_by_cluster <- sapply(X = 1:K, FUN = function(j) sum(z == j))
    next_p <- rgamma(n = K, shape = gamma + n_by_cluster, rate = 1)
    return(next_p/sum(next_p))
}

# Conditional distribution of mu given the rest.
muconditional <- function(y, z, p, sigma2) {
    n_by_cluster <- sapply(X = 1:K, FUN = function(j) sum(z == j))
    sum_y_by_cluster <- sapply(X = 1:K, FUN = function(j) sum(y[z == j]))
    cond_mean <- (tau^(-2) * m + sum_y_by_cluster * sigma2^(-1))/(tau^(-2) + 
        n_by_cluster * sigma2^(-1))
    cond_var <- 1/(tau^(-2) + n_by_cluster * sigma2^(-1))
    next_mu <- rnorm(n = K, mean = cond_mean, sd = sqrt(cond_var))
    return(next_mu)
}

# Conditional distribution of sigma^2 given the rest.
sigma2conditional <- function(y, z, p, mu) {
    n_by_cluster <- sapply(X = 1:K, FUN = function(j) sum(z == j))
    sum_square_by_cluster <- sapply(X = 1:K, FUN = function(j) sum((y[z == j] - 
        mu[j])^2))
    cond_shape <- alpha + n_by_cluster/2
    cond_rate <- beta + sum_square_by_cluster/2
    next_sigma2 <- 1/rgamma(n = K, shape = cond_shape, rate = cond_rate)
    return(next_sigma2)
}

We can now implement the Gibbs sampler, taking as arguments the dataset and a number of iterations. Note that we initialize the chain from the prior distribution, but in principle any starting value would work.

GibbsSampler <- function(y, niterations) {
    # Initialize the Markov chain.
    current_p <- rgamma(n = K, shape = gamma, rate = 1)
    current_p <- current_p/sum(current_p)
    current_mu <- rnorm(n = K, mean = m, sd = tau)
    current_sigma2 <- 1/rgamma(n = K, shape = alpha, rate = beta)
    
    # The following matrices will store the entire trajectory of the Markov
    # chain.
    p_chain <- matrix(rep(current_p, niterations), nrow = niterations, byrow = TRUE)
    mu_chain <- matrix(rep(current_mu, niterations), nrow = niterations, byrow = TRUE)
    sigma2_chain <- matrix(rep(current_sigma2, niterations), nrow = niterations, 
        byrow = TRUE)
    
    # Run the Markov chain.
    for (t in 2:niterations) {
        if (t%%100 == 0) {
            cat("iteration ", t, "/", niterations, "\n")
        }
        current_z <- zconditional(y, current_p, current_mu, current_sigma2)
        current_p <- pconditional(y, current_z, current_mu, current_sigma2)
        current_mu <- muconditional(y, current_z, current_p, current_sigma2)
        current_sigma2 <- sigma2conditional(y, current_z, current_p, current_mu)
        p_chain[t, ] <- current_p
        mu_chain[t, ] <- current_mu
        sigma2_chain[t, ] <- current_sigma2
    }
    
    # Return p, mu, sigma^2 in a list of data frames.
    return(list(p = data.frame(p_chain), mu = data.frame(mu_chain), sigma2 = data.frame(sigma2_chain)))
}

Run the Gibbs sampler.

niterations <- 10000
chain <- GibbsSampler(y, niterations)

Plot histograms of posteriors with overlaid prior probability density functions in red, and a vertical dashed line indicating the parameter values used to generate the dataset.

plot_histogram <- function(df, component) {
    g <- ggplot(data = df, geom = "blank")
    g <- g + geom_histogram(aes_string(x = paste0("X", component), y = "..density.."), 
        binwidth = 0.03)
    return(g)
}
gp1 <- plot_histogram(chain$p, 1) + xlab(expression(p[1])) + xlim(0, 1)
constant <- integrate(function(x) x^(gamma[1] - 1), lower = 0, upper = 1)$value
gp1 <- gp1 + stat_function(fun = function(x) (x^(gamma[1] - 1))/constant, colour = "red")
gp1 <- gp1 + geom_vline(aes(xintercept = p_0[1]), linetype = 2, size = 2, colour = "blue")
gp2 <- plot_histogram(chain$p, 2) + xlab(expression(p[2])) + xlim(0, 1)
constant <- integrate(function(x) x^(gamma[2] - 1), lower = 0, upper = 1)$value
gp2 <- gp2 + stat_function(fun = function(x) (x^(gamma[2] - 1))/constant, colour = "red")
gp2 <- gp2 + geom_vline(aes(xintercept = p_0[2]), linetype = 2, size = 2, colour = "blue")
grid.arrange(gp1, gp2)

gmu1 <- plot_histogram(chain$mu, 1) + xlab(expression(mu[1])) + xlim(-K - 1, 
    K + 1)
gmu1 <- gmu1 + stat_function(fun = function(x) dnorm(x, mean = m, sd = tau), 
    colour = "red")
gmu1 <- gmu1 + geom_vline(aes(xintercept = mu_0[1]), linetype = 2, size = 2, 
    colour = "blue")
gmu2 <- plot_histogram(chain$mu, 2) + xlab(expression(mu[2])) + xlim(-K - 1, 
    K + 1)
gmu2 <- gmu2 + stat_function(fun = function(x) dnorm(x, mean = m, sd = tau), 
    colour = "red")
gmu2 <- gmu2 + geom_vline(aes(xintercept = mu_0[2]), linetype = 2, size = 2, 
    colour = "blue")
grid.arrange(gmu1, gmu2)

gsigma21 <- plot_histogram(chain$sigma2, 1) + xlab(expression(sigma[1]^2)) + 
    xlim(0, var(y))
constant <- integrate(function(x) x^(-alpha - 1) * exp(-beta/x), lower = 0, 
    upper = 10)$value
gsigma21 <- gsigma21 + stat_function(fun = function(x) x^(-alpha - 1) * exp(-beta/x)/constant, 
    colour = "red")
gsigma21 <- gsigma21 + geom_vline(aes(xintercept = sigma2_0[1]), linetype = 2, 
    size = 2, colour = "blue")
gsigma22 <- plot_histogram(chain$sigma2, 2) + xlab(expression(sigma[2]^2)) + 
    xlim(0, var(y))
gsigma22 <- gsigma22 + stat_function(fun = function(x) x^(-alpha - 1) * exp(-beta/x)/constant, 
    colour = "red")
gsigma22 <- gsigma22 + geom_vline(aes(xintercept = sigma2_0[2]), linetype = 2, 
    size = 2, colour = "blue")
grid.arrange(gsigma21, gsigma22)

We can also look at traceplots of the parameters, to see whether the chain looks like it is mixing quickly (as if we were sampling independently from a distribution).

allparameters <- cbind(chain$p, chain$mu, chain$sigma2)
names(allparameters) <- c(paste0("p", 1:K), paste0("mu", 1:K), paste0("sigma2", 
    1:K))
allparameters$iteration <- 1:niterations
# head(all.df.melted)
library(ggthemes)
all.df.melted <- melt(allparameters, id = "iteration")
gtrace <- ggplot(all.df.melted, aes(x = iteration, y = value, colour = variable)) + 
    geom_line()
gtrace <- gtrace + theme(legend.position = "bottom") + scale_color_colorblind()
print(gtrace)