# file MASS/lda.q
# copyright (C) 1994-2002 W. N. Venables and B. D. Ripley
#
invisible(setClass("lda", representation(prior = "named",
                               counts = "named",
                               means = "matrix",
                               scaling = "matrix",
                               lev = "character",
                               svd = "numeric",
                               N = "integer",
                               call = "call")
         ))
lda <- function(x, y, ...)
{
    stop("lda not implemented for class ", class(x))
}

lda.formula <- function(x, y, ...)
{
    m <- match.call(expand.dots = F)
    m$... <- NULL
    dots <- list(...)
    m$formula <- m$x
    m$data <- m$y
    m$x <- m$y <- NULL
    if(hasArg(subset)) m$subset <- dots$subset
    if(hasArg(na.action)) m$na.action <- dots$na.action
    m[[1]] <- as.name("model.frame")
    m <- eval(m, sys.parent())
    Terms <- attr(m, "terms")
    y <- model.extract(m, "response")
    x <- model.matrix(Terms, m)
    xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
    if(xint > 0) x <- x[, -xint, drop=F]
    res <- lda1(x = x, grouping = y, ...)
    if(class(res) == "lda") {
        Call <- match.call()
        Call$x <- as.call(attr(Terms, "formula"))
        res@call <- Call
    }
    res
}

## this is broken: it mis-matches the call.
# lda.data.frame <- function(x, y, ...)
# {
#     x <- as.matrix(x)
#     callGeneric()
# }

lda.data.frame <- function(x, y, ...)
{
    res <- lda1(x = as.matrix(x), grouping = y, ...)
    res@call <- match.call()
    res
}

# lda.Matrix <- function(x, y, ...)
# {
#     res <- lda1(x = data.matrix(x), grouping = y, ...)
#     res@call <- match.call()
#     res
# }

lda.matrix <- function(x, y, ...)
{
    x <- as(x, "matrix")
    dots <- list(...)
    if(hasArg(subset)) {
        subset <- dots$subset
        x <- x[subset, , drop = F]
        y <- y[subset]
    }
    if(hasArg(na.action)) {
        na.action <- dots$na.action
        dfr <- na.action(data.frame(g = y, x = x))
        y <- dfr$g
        x <- dfr$x
    }
    res <- lda1(x = x, grouping = y, ...)
    res@call <- match.call()
    res
}

setMethod("lda", "formula", lda.formula)
setMethod("lda", "matrix", lda.matrix)
#setMethod("lda", "Matrix", lda.Matrix)
setMethod("lda", "data.frame", lda.data.frame)

lda1 <-
  function(x, grouping, prior = proportions, tol = 1.0e-4,
           method = c("moment", "mle", "mve", "t"),
           CV = FALSE, nu = 5, ...)
{
    if(is.null(dim(x))) stop("x is not a matrix")
    n <- nrow(x)
    p <- ncol(x)
    if(n != length(grouping))
        stop("nrow(x) and length(grouping) are different")
    g <- as.factor(grouping)
    lev <- lev1 <- levels(g)
    counts <- as.vector(table(g))
    if(any(counts == 0)) {
        warning(paste("group(s)", paste(lev[counts == 0], collapse=" "),
                      "are empty"))
        lev1 <- lev[counts > 0]
        g <- factor(g, levels=lev1)
        counts <- as.vector(table(g))
    }
    proportions <- counts/n
    ng <- length(proportions)
    if(any(prior < 0) || round(sum(prior), 5) != 1) stop("invalid prior")
    if(length(prior) != ng) stop("prior is of incorrect length")
    names(prior) <- names(counts) <- lev1
    method <- match.arg(method)
    if(CV && !(method == "moment" || method == "mle"))
        stop(paste("Cannot use leave-one-out CV with method", method))
    group.means <- tapply(x, list(rep(g, p), col(x)), mean)
    f1 <- sqrt(diag(var(x - group.means[g,  ])))
    if(any(f1 < tol))
        stop(paste("variable(s)",
                   paste(format((1:p)[f1 < tol]), collapse = " "),
                   "appear to be constant within groups"))
    # scale columns to unit variance before checking for collinearity
    scaling <- diag(1/f1,,p)
    if(method == "mve") {
        # adjust to "unbiased" scaling of covariance matrix
        cov <- n/(n-ng) * cov.rob((x - group.means[g,  ]) %*% scaling)$cov
        sX <- svd(cov, nu = 0)
        rank <- sum(sX$d > tol^2)
        if(rank < p) warning("variables are collinear")
        scaling <- scaling %*% sX$v[, 1:rank] %*%
            diag(sqrt(1/sX$d[1:rank]),,rank)
    } else if(method == "t") {
        if(nu <= 2) stop("nu must exceed 2")
        w <- rep(1, n)
        repeat {
            w0 <- w
            X <- x - group.means[g, ]
            sX <- svd(sqrt((1 + p/nu)*w/n) * X, nu=0)
            X <- X %*% sX$v %*% diag(1/sX$d,, p)
            w <- 1/(1 + drop(X^2 %*% rep(1, p))/nu)
            print(summary(w))
            group.means <- tapply(w*x, list(rep(g, p), col(x)), sum)/
                rep(tapply(w, g, sum), p)
            if(all(abs(w - w0) < 1e-2)) break
        }
        X <-  sqrt(nu/(nu-2)*(1 + p/nu)/n * w) * (x - group.means[g,  ]) %*% scaling
        X.s <- svd(X, nu = 0)
        rank <- sum(X.s$d > tol)
        if(rank < p) warning("variables are collinear")
        scaling <- scaling %*% X.s$v[, 1:rank] %*% diag(1/X.s$d[1:rank],,rank)
    } else {
        if(method == "moment") fac <- 1/(n-ng) else fac <- 1/n
        X <- sqrt(fac) * (x - group.means[g,  ]) %*% scaling
        X.s <- svd(X, nu = 0)
        rank <- sum(X.s$d > tol)
        if(rank < p) warning("variables are collinear")
        scaling <- scaling %*% X.s$v[, 1:rank] %*% diag(1/X.s$d[1:rank],,rank)
    }
    # now have variables scaled so that W is the identity
    if(CV) {
        x <- x %*% scaling
        dm <- group.means %*% scaling
        K <- if(method == "moment") ng else 0
        dist <- matrix(0, n, ng)
        for(i in 1:ng) {
            dev <- x - matrix(dm[i,  ], n, p, byrow = T)
            dist[, i] <- rowSums(dev^2)
        }
        ind <- cbind(1:n, g)
        nc <- counts[g]
        cc <- nc/((nc-1)*(n-K))
        dist2 <- dist
        for(i in 1:ng) {
            dev <- x - matrix(dm[i,  ], n, p, byrow = T)
            dev2 <- x - dm[g, ]
            tmp <- rowSums(dev*dev2)
            dist[, i] <- (n-1-K)/(n-K) * (dist2[, i] +  cc*tmp^2/(1 - cc*dist2[ind]))
        }
        dist[ind] <- dist2[ind] * (n-1-K)/(n-K) * (nc/(nc-1))^2 /
            (1 - cc*dist2[ind])
        dist <- 0.5 * dist - matrix(log(prior), n, ng, byrow=T)
        dist <- exp(-(dist - min(dist, na.rm=T)))
        cl <- max.col(dist)
        levels(cl) <- lev
        oldClass(cl) <- "factor"
        #  convert to posterior probabilities
        posterior <- dist/drop(dist %*% rep(1, length(prior)))
        dimnames(posterior) <- list(dimnames(x)[[1]], lev1)
        return(list(class = cl, posterior = posterior))
    }
    xbar <- colSums(prior %*% group.means)
    if(method == "mle") fac <-  1/ng else fac <- 1/(ng - 1)
    X <- sqrt((n * prior)*fac) * scale(group.means, center=xbar, scale=F) %*% scaling
    X.s <- svd(X, nu = 0)
    rank <- sum(X.s$d > tol * X.s$d[1])
    scaling <- scaling %*% X.s$v[, 1:rank]
    if(is.null(dimnames(x)))
        dimnames(scaling) <- list(NULL, paste("LD", 1:rank, sep = ""))
    else {
        dimnames(scaling) <- list(dimnames(x)[[2]], paste("LD", 1:rank, sep = ""))
        dimnames(group.means)[[2]] <- dimnames(x)[[2]]
    }
    res <- new("lda")
    res@prior <- prior
    res@counts <- counts
    res@means <- group.means
    res@scaling <- scaling
    res@lev <- lev
    res@svd <- X.s$d[1:rank]
    res@N <- n
    res
}

predict.lda <- function(object, newdata, prior = object@prior, dimen,
			method = c("plug-in", "predictive", "debiased"), ...)
{
    if(!inherits(object, "lda")) stop("object not of class lda")
    if((missing(newdata) ||!is(newdata, "model.matrix")) &&
       is.form(form <- object@call[[2]])) { #
    # formula fit
        if(missing(newdata)) newdata <- model.frame(object)
        else newdata <- model.frame(delete.response(terms(form)), newdata,
                                    na.action=function(x) x)
        x <- model.matrix(delete.response(terms(form)), newdata)
        xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
        if(xint > 0) x <- x[, -xint, drop=F]
    } else { #
    # matrix or data-frame fit
        if(missing(newdata)) {
            if(!is.null(sub <- object@call$subset))
                newdata <- eval(parse(text=paste(deparse(object@call$x),"[",
                                      deparse(sub),",]")), sys.parent())
            else newdata <- eval(object@call$x, sys.parent())
            if(!is.null(nas <- object@call$na.action))
                newdata <- eval(call(nas, newdata))
        }
        if(is.null(dim(newdata)))
            dim(newdata) <- c(1, length(newdata))  # a row vector
        x <- as.matrix(newdata)		# to cope with dataframes
    }

    if(ncol(x) != ncol(object@means)) stop("wrong number of variables")
    if(length(dimnames(x)[[2]]) > 0 &&
      any(dimnames(x)[[2]] != dimnames(object@means)[[2]]))
         warning("Variable names in newdata do not match those in object")
    ng <- length(prior)
#   remove overall means to keep distances small
    means <- colSums(prior*object@means)
    scaling <- object@scaling
    x <- scale(x, center=means, scale=F) %*% scaling
    dm <- scale(object@means, center=means, scale=F) %*% scaling
    method <- match.arg(method)
    if(missing(dimen)) dimen <- length(object@svd)
    else dimen <- min(dimen, length(object@svd))
    N <- object@N
    if(method == "plug-in") {
        dm <- dm[, 1:dimen, drop=F]
        dist <- matrix(0.5 * rowSums(dm^2) - log(prior), nrow(x),
                       length(prior), byrow = T) - x[, 1:dimen, drop=F] %*% t(dm)
#        dist <- exp( -(dist - min(dist, na.rm=T)))
        dist <- exp( -(dist - apply(dist, 1, min, na.rm=T)))
    } else if (method == "debiased") {
        dm <- dm[, 1:dimen, drop=F]
        dist <- matrix(0.5 * rowSums(dm^2), nrow(x), ng, byrow = T) -
            x[, 1:dimen, drop=F] %*% t(dm)
        dist <- (N - ng - dimen - 1)/(N - ng) * dist -
            matrix(log(prior) - dimen/object@counts , nrow(x), ng, byrow=T)
#        dist <- exp( -(dist - min(dist, na.rm=T)))
        dist <- exp( -(dist - apply(dist, 1, min, na.rm=T)))
    } else {                            # predictive
        dist <- matrix(0, nrow = nrow(x), ncol = ng)
        p <- ncol(object@means)
        # adjust to ML estimates of covariances
        X <- x * sqrt(N/(N-ng))
        for(i in 1:ng) {
            nk <- object@counts[i]
            dev <- scale(X, center=dm[i, ], scale=F)
            dev <- 1 + rowSums(dev^2) * nk/(N*(nk+1))
            dist[, i] <- prior[i] * (nk/(nk+1))^(p/2) * dev^(-(N - ng + 1)/2)
        }
    }
    posterior <- dist / drop(dist %*% rep(1, ng))
    cl <- max.col(posterior)
    levels(cl) <- object@lev
    oldClass(cl) <- "factor"
    dimnames(posterior) <- list(dimnames(x)[[1]], names(prior))
    list(class = cl, posterior = posterior, x = x[, 1:dimen, drop=F])
}

show.lda <- function(object) print.lda(object)

print.lda <- function(x, ...)
{
    if(!is.null(cl <- x@call)) {
        nm <- names(cl)
        names(cl)[match(c("formula", "data", "x", "y"), nm, 0)] <- ""
        cat("Call:\n")
        dput(cl)
    }
    cat("\nPrior probabilities of groups:\n")
    print(x@prior, ...)
    cat("\nGroup means:\n")
    print(x@means, ...)
    cat("\nCoefficients of linear discriminants:\n")
    print(x@scaling, ...)
    svd <- x@svd
    names(svd) <- dimnames(x@scaling)[[2]]
    if(length(svd) > 1) {
        cat("\nProportion of trace:\n")
        print(round(svd^2/sum(svd^2), 4), ...)
    }
    invisible(x)
}

plot.lda <- function(x, panel = panel.lda, ..., cex=0.7,
                     dimen, abbrev = F,
                     xlab = "LD1", ylab = "LD2")
{
    panel.lda <- function(x, y, ...) {
        text(x, y, as.character(g.lda), cex=tcex, ...)
    }
    if(is.form(form <- x@call[[2]])) { #
    # formula fit
        data <- model.frame(x)
        X <- model.matrix(form, data)
        g <- model.extract(data, "response")
        xint <- match("(Intercept)", dimnames(X)[[2]], nomatch=0)
        if(xint > 0) X <- X[, -xint, drop=F]
    } else { #
    # matrix or data-frame fit
        xname <- x@call$x
        gname <- x@call[[3]]
        if(!is.null(sub <- x@call$subset)) {
            X <- eval(parse(text=paste(deparse(xname),"[", deparse(sub),",]")),
                      sys.parent())
            g <- eval(parse(text=paste(deparse(gname),"[", deparse(sub),"]")),
                      sys.parent())
        } else {
            X <- eval(xname, sys.parent())
            g <- eval(gname, sys.parent())
        }
        if(!is.null(nas <- x@call$na.action)) {
            df <- data.frame(g = g, X = X)
            df <- eval(call(nas, df))
            g <- df$g
            X <- df$X
        }
    }
    if(abbrev) levels(g) <- abbreviate(levels(g), abbrev)
    assign("g.lda", g, frame=1)
    assign("tcex", cex, frame=1)
    means <- colMeans(x@means)
    X <- scale(X, center=means, scale=F) %*% x@scaling
    if(!missing(dimen) && dimen < ncol(X)) X <- X[, 1:dimen, drop=F]
    if(ncol(X) > 2) {
        pairs(X, panel=panel, ...)
    } else if(ncol(X) == 2)  {
        eqscplot(X[, 1:2], xlab=xlab, ylab=ylab, type="n", ...)
        panel(X[, 1], X[, 2], ...)
    } else ldahist(X[,1], g, xlab=xlab, ...)
    invisible(NULL)
}

ldahist <-
function(data, g, nbins = 25, h, x0 = -h/1000, breaks,
	 xlim = range(breaks), ymax = 0, width,
         type = c("histogram", "density", "both"), sep = (type != "density"),
         col = if(is.trellis()) trellis.par.get("bar.fill")$col else 2,
	 xlab = deparse(substitute(data)), bty = "n", ...)
{
    type <- match.arg(type)
    data <- data[!is.na(data)]
    g <- g[!is.na(data)]
    counts <- table(g)
    groups <- names(counts)[counts > 0]
    if(missing(breaks)) {
        if(missing(h)) h <- diff(pretty(data, nbins))[1]
        first <- floor((min(data) - x0)/h)
        last <- ceiling((max(data) - x0)/h)
        breaks <- x0 + h * c(first:last)
    }
    if(type=="histogram" || type=="both") {
        if(any(diff(breaks) <= 0)) stop("breaks must be strictly increasing")
        if(min(data) < min(breaks) || max(data) > max(breaks))
            stop("breaks do not cover the data")
        est <- vector("list", length(groups))
        for (grp in groups){
            bin <- cut(data[g==grp], breaks, include.lowest = T)
            est1 <- tabulate(bin, length(levels(bin)))
            est1 <- est1/(diff(breaks) * length(data[g==grp]))
            ymax <- max(ymax, est1)
            est[[grp]] <- est1
        }
    }
    if(type=="density" || type == "both"){
        xd <- vector("list", length(groups))
        for (grp in groups){
            if(missing(width)) width <- width.SJ(data[g==grp])
            xd1 <- density(data[g==grp], n=200, width=width,
                           from=xlim[1], to=xlim[2])
            ymax <- max(ymax, xd1$y)
            xd[[grp]] <- xd1
        }
    }
    if(!sep) plot(xlim, c(0, ymax), type = "n", xlab = xlab, ylab = "",
                  bty = bty)
    else {
        oldpar <- par(mfrow=c(length(groups), 1))
        on.exit(par(oldpar))
    }
    for (grp in groups) {
        if(sep) plot(xlim, c(0, ymax), type = "n",
                     xlab = paste("group", grp), ylab = "", bty = bty)
        if(type=="histogram" || type=="both") {
            n <- length(breaks)
            polygon(rbind(breaks[-1], breaks[ -n], breaks[ -n], breaks[-1], NA),
                    rbind(0, 0, est[[grp]], est[[grp]], NA), col = col,
                    border = 1, ...)
        }
        if(type=="density" || type == "both") lines(xd[[grp]])
    }
    invisible()
}

pairs.lda <- function(x, labels = dimnames(x)[[2]], panel = panel.lda,
                      dimen, abbrev = F, ..., cex = 0.7,
                      type = c("std", "trellis"))
{
    panel.lda <- function(x,y, ...) {
        text(x, y, as.character(g.lda), cex=tcex, ...)
    }
    type <- match.arg(type)
    if(is.form(form <- x@call[[2]])) { #
    # formula fit
        data <- model.frame(x)
        X <- model.matrix(form, data)
        g <- model.extract(data, "response")
        xint <- match("(Intercept)", dimnames(X)[[2]], nomatch=0)
        if(xint > 0) X <- X[, -xint, drop=F]
    } else { #
    # matrix or data-frame fit
        xname <- x@call$x
        gname <- x@call[[3]]
        if(!is.null(sub <- x@call$subset)) {
            X <- eval(parse(text=paste(deparse(xname),"[", deparse(sub),",]")),
                      sys.parent())
            g <- eval(parse(text=paste(deparse(gname),"[", deparse(sub),"]")),
                      sys.parent())
        } else {
            X <- eval(xname, sys.parent())
            g <- eval(gname, sys.parent())
        }
        if(!is.null(nas <- x@call$na.action)) {
            df <- data.frame(g = g, X = X)
            df <- eval(call(nas, df))
            g <- df$g
            X <- df$X
        }
    }
    g <- as.factor(g)
    if(abbrev) levels(g) <- abbreviate(levels(g), abbrev)
    assign("g.lda", g, frame=1)
    assign("tcex", cex, frame=1)
    means <- colMeans(x@means)
    X <- scale(X, center=means, scale=F) %*% x@scaling
    if(!missing(dimen) && dimen < ncol(X)) X <- X[, 1:dimen]
    if(type=="std") pairs.default(X, panel=panel, ...)
    else {
        print(splom(~X, groups = g, panel=panel.superpose,
                    key = list(
                    text=list(levels(g)),
                    points = Rows(trellis.par.get("superpose.symbol"),
                    seq(along=levels(g))),
                    columns = min(5, length(levels(g)))
                    )
                    ))
    }
    invisible(NULL)
}

model.frame.lda <-
function(formula, data = NULL, na.action = NULL, ...)
{
    oc <- formula@call
    oc$formula <- oc$x
    oc$data <- oc$y
    oc$x <- oc$y <- NULL
    oc$prior <- oc$tol <- oc$method <- oc$CV <- oc$nu <- NULL
    oc[[1]] <- as.name("model.frame")
    if(length(data)) {
        oc$data <- substitute(data)
        eval(oc, sys.parent())
    }
    else eval(oc, list())
}


setMethod("show", "lda", show.lda)

update.lda <- function(object, formula, ..., evaluate = T, class)
{
    newcall <- object@call
    tempcall <- match.call(expand.dots = F)$...
    if(!missing(formula))
        newcall[[2]] <- as.vector(update.formula(object, formula, evaluate = T))
    else {
        nc <- names(sys.call())
        if(length(nc) && any(pmatch(nc, "formula", 0))) newcall$formula <- NULL
    }
    if(length(tempcall) > 1) {
        def <- getFunction(newcall[[1]])
        def$formula <- NULL
        TT <- match.call(def, tempcall)
        if((ndots <- length(TT)) < length(tempcall)) {
            nt <- pmatch(names(tempcall), names(def)[ - length(def)])
            # kill some args
            nt <- names(def)[nt]
            nT <- names(TT)
            for(i in nt[is.na(match(nt, nT))]) newcall[[i]] <- NULL
        }
        if(ndots > 1) {
            ndots <- names(TT)[-1]
            newcall[ndots] <- TT[ndots]
        }
    }
    if(evaluate)
        eval(newcall, sys.parent())
    else newcall
}

formula.lda <- function(object) object@call[[2]]

is.form <- function(x) is.call(x) && (x[[1]] == "~")

coef.lda <- function(object, ...) object@scaling
