# file MASS/qda.q
# copyright (C) 1994-2000 W. N. Venables and B. D. Ripley
#
invisible(
setClass("qda", representation(prior = "named",
                               counts = "named",
                               means = "matrix",
                               scaling = "array",
                               ldet = "numeric",
                               lev = "character",
                               N = "integer",
                               call = "call")
         ))

qda <- function(x, y, ...)
{
    stop("qda not implemented for class ", class(x))
}

qda.formula <- function(x, y, ...)
{
    data <- as.data.frame(y)
    m <- match.call(expand.dots = F)
    m$... <- NULL
    dots <- list(...)
    m$formula <- m$x
    m$data <- m$y
    m$x <- m$y <- NULL
    if(hasArg(subset)) m$subset <- dots$subset
    if(hasArg(na.action)) m$na.action <- dots$na.action
    m[[1]] <- as.name("model.frame")
    m <- eval(m, sys.parent())
    Terms <- attr(m, "terms")
    y <- model.extract(m, "response")
    x <- model.matrix(Terms, m)
    xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
    if(xint > 0) x <- x[, -xint, drop=F]
    res <- qda1(x = x, grouping = y, ...)
    if(class(res) == "qda") {
        Call <- match.call()
        Call$x <- as.call(attr(Terms, "formula"))
        res@call <- Call
    }
    res
}


# qda.data.frame <- function(x, y, ...)
# {
#     x <- as.matrix(x)
#     callGeneric()
# }

qda.data.frame <- function(x, y, ...)
{
    res <- qda1(x = as.matrix(x), grouping = y, ...)
    res@call <- match.call()
    res
}

# qda.Matrix <- function(x, y, ...)
# {
#   x <- data.matrix(x)
#   callGeneric()
# }

qda.matrix <- function(x, y, ...)
{
    x <- as(x, "matrix")
    dots <- list(...)
    if(hasArg(subset)) {
        subset <- dots$subset
        x <- x[subset, , drop = F]
        y <- y[subset]
    }
    if(hasArg(na.action)) {
        na.action <- dots$na.action
        dfr <- na.action(data.frame(g = y, x = x))
        y <- dfr$g
        x <- dfr$x
    }
    res <- qda1(x=x, grouping=y, ...)
    res@call <- match.call()
    res
}

setMethod("qda", "formula", qda.formula)
setMethod("qda", "matrix", qda.matrix)
#setMethod("qda", "Matrix", qda.Matrix)
setMethod("qda", "data.frame", qda.data.frame)

qda1 <-
  function(x, grouping, prior = proportions,
           method = c("moment", "mle", "mve", "t"),
           CV = F, nu = 5, ...)
{
    if(is.null(dim(x))) stop("x is not a matrix")
    n <- nrow(x)
    p <- ncol(x)
    if(n != length(grouping)) stop("nrow(x) and length(grouping) are different")
    g <- as.factor(grouping)
    lev <- levels(g)
    counts <- as.vector(table(g))
    names(counts) <- lev
    if(any(counts < p+1)) stop("some group is too small for qda")
    proportions <- counts/length(g)
    ng <- length(proportions)
# allow for supplied prior
    if(any(prior < 0) || round(sum(prior), 5) != 1) stop("invalid prior")
    if(length(prior) != ng) stop("prior is of incorrect length")
    names(prior) <- lev
# means by group (rows) and variable (columns)
    group.means <- tapply(x, list(rep(g, ncol(x)), col(x)), mean)
    scaling <- array(dim=c(p,p,ng))
    ldet <- numeric(ng)
    method <- match.arg(method)
    if(CV && !(method == "moment" || method == "mle"))
        stop(paste("Cannot use leave-one-out CV with method", method))
    for (i in 1:ng){
        if(method == "mve") {
            cX <- cov.mve(x[oldUnclass(g) == i, ], , F)
            group.means[i,] <- cX$center
            sX <- svd(cX$cov, nu=0)
            scaling[, , i] <- sX$v %*% diag(sqrt(1/sX$d),,p)
            ldet[i] <- sum(log(sX$d))
        } else if(method == "t") {
            if(nu <= 2) stop("nu must exceed 2")
            m <- counts[i]
            X <- x[oldUnclass(g) == i, ]
            w <- rep(1, m)
            repeat {
                w0 <- w
                W <- scale(X, center=group.means[i, ], scale=F)
                sX <- svd(sqrt((1 + p/nu) * w/m) * W, nu=0)
                W <- W %*% sX$v %*% diag(1/sX$d,, p)
                w <- 1/(1 + drop(W^2 %*% rep(1, p))/nu)
                #         print(summary(w))
                group.means[i,] <- colSums(w*X)/sum(w)
                if(all(abs(w - w0) < 1e-2)) break
            }
            qx <- qr(sqrt(w)*scale(X, center=group.means[i, ], scale=F))
            if(qx$rank < p) stop(paste("Rank deficiency in group", lev[i]))
            qx <- qx$qr* sqrt((1 + p/nu)/m)
            scaling[, , i] <- backsolve(qx[1:p,  ], diag(p))
            ldet[i] <- 2*sum(log(abs(diag(qx))))
        } else {
            if(method == "moment") nk <- counts[i] - 1 else nk <- counts[i]
            X <- scale(x[oldUnclass(g) == i, ], center=group.means[i, ], scale=F)/sqrt(nk)
            qx <- qr(X)
            if(qx$rank < p) stop(paste("Rank deficiency in group", lev[i]))
            qx <- qx$qr
            scaling[, , i] <- backsolve(qx[1:p, ], diag(p))
            ldet[i] <- 2*sum(log(abs(diag(qx))))
        }
    }
    if(CV) {
        NG <- if(method == "mle") 0 else 1
        dist <- matrix(0, n, ng)
        Ldet <- matrix(0, n, ng)
        for(i in 1:ng) {
            dev <- ((x - matrix(group.means[i,  ], nrow(x),
                                p, byrow = T)) %*% scaling[,,i])
            dist[, i] <- rowSums(dev^2)
            Ldet[, i] <- ldet[i]
        }
        nc <- counts[g]
        ind <- cbind(1:n, g)
        fac <- 1 - nc/(nc-1)/(nc-NG) * dist[ind]
        fac[] <- pmax(fac, 1e-10)  # possibly degenerate dsn
        Ldet[ind] <- log(fac) + p * log((nc-NG)/(nc-1-NG)) + Ldet[ind]
        dist[ind] <- dist[ind] * (nc^2/(nc-1)^2) * (nc-1-NG)/(nc-NG) / fac
        dist <- 0.5 * dist + 0.5 * Ldet - matrix(log(prior), n, ng, byrow=T)
        dist <- exp(-(dist - min(dist, na.rm=T)))
        posterior <- dist/drop(dist %*% rep(1, length(prior)))
        cl <- max.col(posterior)
        levels(cl) <- lev
        oldClass(cl) <- "factor"
        dimnames(posterior) <- list(dimnames(x)[[1]], lev)
        return(list(class = cl, posterior = posterior))
    }
    if(is.null(dimnames(x)))
        dimnames(scaling) <- list(NULL, as.character(1:p), lev)
    else {
        dimnames(scaling) <- list(dimnames(x)[[2]], as.character(1:p), lev)
        dimnames(group.means)[[2]] <- dimnames(x)[[2]]
    }
    res <- new("qda")
    res@prior <- prior
    res@counts <- counts
    res@means <- group.means
    res@scaling <- scaling
    res@ldet <- ldet
    res@lev <- lev
    res@N <- n
    res
}

predict.qda <- function(object, newdata, prior = object@prior,
			method = c("plug-in", "predictive", "debiased",
                          "looCV"), ...)
{
    if(!inherits(object, "qda")) stop("object not of class qda")
    method <- match.arg(method)
    if(method == "looCV" && !missing(newdata))
        stop("Cannot have leave-one-out CV with newdata")
    if(is.null(mt <- object@call$method)) mt <- "moment"
    if(method == "looCV" && !(mt == "moment" || mt == "mle"))
        stop(paste("Cannot use leave-one-out CV with method", mt))
    if((missing(newdata) ||!is(newdata, "model.matrix")) &&
       is.form(form <- object@call[[2]])) { #
    # formula fit
        if(missing(newdata)) newdata <- model.frame(object)
        else newdata <- model.frame(delete.response(terms(form)), newdata,
                                    na.action=function(x) x)
        x <- model.matrix(delete.response(terms(form)), newdata)
        xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
        if(xint > 0) x <- x[, -xint, drop=F]
        if(method == "looCV") g <- model.extract(newdata, "response")
    } else { #
    # matrix or data-frame fit
        if(missing(newdata)) {
            if(!is.null(sub <- object@call$subset)) {
                newdata <- eval(parse(text=paste(deparse(object@call$x),"[",
                                      deparse(sub),",]")), sys.parent())
                g <- eval(parse(text=paste(deparse(object$call[[3]]),"[",
                                deparse(sub),"]")), sys.parent())
            } else {
                newdata <- eval(object@call$x, sys.parent())
                g <- eval(object$call[[3]], sys.parent())
            }
            if(!is.null(nas <- object@call$na.action)) {
                df <- data.frame(g = g, X = newdata)
                df <- eval(call(nas, df))
                g <- df$g
                newdata <- df$X
            }
            g <- as.factor(g)
        }
        if(is.null(dim(newdata)))
            dim(newdata) <- c(1, length(newdata))  # a row vector
        x <- as.matrix(newdata)		# to cope with dataframes
    }
    p <- ncol(object@means)
    if(ncol(x) != p) stop("wrong number of variables")
    if(length(dimnames(x)[[2]]) > 0 &&
       any(dimnames(x)[[2]] != dimnames(object@means)[[2]]))
        warning("Variable names in newdata do not match those in object")
    ngroup <- length(object@prior)
    dist <- matrix(0, nrow = nrow(x), ncol = ngroup)
    if(method == "plug-in") {
        for(i in 1:ngroup) {
            dev <- ((x - matrix(object@means[i,  ], nrow(x),
                                ncol(x), byrow = T)) %*% object@scaling[,,i])
            dist[, i] <- 0.5 * rowSums(dev^2) + 0.5 * object@ldet[i] - log(object@prior[i])
        }
#        dist <- exp( -(dist - min(dist, na.rm=T)))
        dist <- exp( -(dist - apply(dist, 1, min, na.rm=T)))
    } else if(method == "looCV") {
        n <- nrow(x)
        NG <- 1
        if(mt == "mle") NG <- 0
        ldet <- matrix(0, n, ngroup)
        for(i in 1:ngroup) {
            dev <- ((x - matrix(object@means[i,  ], nrow(x), p, byrow = T))
                    %*% object@scaling[,,i])
            dist[, i] <- rowSums(dev^2)
            ldet[, i] <- object@ldet[i]
        }
        nc <- object@counts[g]
        ind <- cbind(1:n, g)
        fac <- 1 - nc/(nc-1)/(nc-NG) * dist[ind]
        fac[] <- pmax(fac, 1e-10)  # possibly degenerate dsn
        ldet[ind] <- log(fac) + p * log((nc-NG)/(nc-1-NG)) + ldet[ind]
        dist[ind] <- dist[ind] * (nc^2/(nc-1)^2) * (nc-1-NG)/(nc-NG) / fac
        dist <- 0.5 * dist + 0.5 * ldet -
            matrix(log(object@prior), n, ngroup, byrow=T)
        dist <- exp( -(dist - apply(dist, 1, min, na.rm=T)))
    } else if(method == "debiased") {
        for(i in 1:ngroup) {
            nk <- object@counts[i]
            Bm <- p * log((nk-1)/2) - sum(digamma(0.5 * (nk - 1:ngroup)))
            dev <- ((x - matrix(object@means[i,  ], nrow = nrow(x),
                                ncol = ncol(x), byrow = T)) %*% object@scaling[,,i])
            dist[, i] <- 0.5 * (1 - (p-1)/(nk-1)) * rowSums(dev^2) +
                0.5 * object@ldet[i] - log(object@prior[i]) + 0.5 * Bm - p/(2*nk)
        }
        dist <- exp( -(dist - apply(dist, 1, min, na.rm=T)))
    } else {
        N <- object@N
        for(i in 1:ngroup) {
            nk <- object@counts[i]
            dev <- ((x - matrix(object@means[i,  ], nrow = nrow(x),
                                ncol = ncol(x), byrow = T))
                    %*% object@scaling[,,i])
            dev <- 1 + rowSums(dev^2)/(nk+1)
            dist[, i] <- object@prior[i] * exp(-object@ldet[i]/2) *
                dev^(-nk/2) * (1 + nk)^(-p/2)
        }
    }
    posterior <- dist/drop(dist %*% rep(1, length(object@prior)))
    cl <- max.col(posterior)
    levels(cl) <- object@lev
    oldClass(cl) <- "factor"
    dimnames(posterior) <- list(dimnames(x)[[1]], object@lev)
    list(class = cl, posterior = posterior)
}

print.qda <- function(x, ...)
{
    if(!is.null(cl <- x@call)) {
        nm <- names(cl)
        names(cl)[match(c("formula", "data", "x", "y"), nm, 0)] <- ""
        cat("Call:\n")
        dput(cl)
    }
    cat("\nPrior probabilities of groups:\n")
    print(x@prior, ...)
    cat("\nGroup means:\n")
    print(x@means, ...)
    invisible(x)
}

model.frame.qda <-  model.frame.lda

show.qda <- function(object) print.qda(object)
setMethod("show", "qda", show.qda)

formula.qda <- function(object) object@call[[2]]

update.qda <- update.lda
