# file nnet/nnet.q copyright (C) 1994-9 W. N. Venables and B. D. Ripley
#
invisible(
setClass("nnet", representation(n="integer", nunits="integer",
                                nconn="integer", conn="integer",
                                nsunits="integer", decay="numeric",
                                entropy="logical", softmax="logical",
                                censored="logical", value="numeric",
                                wts="numeric", fitted="matrix",
                                residuals="matrix",
                                coefnames="character", lev="character",
                                Hessian="matrix",
                                call="call")
         ))

nnet <- function(x, data, ...)
{
  stop("nnet not implemented for class ", class(x))
}

nnet.matrix <- function(x, data, ...)
{
  res <- nnet1(x=x, y=data, ...)
  res@call <- match.call()
  res
}

nnet.formula <- function(x, data, ...)
{
  class.ind <- function(cl)
  {
    n <- length(cl)
    x <- matrix(0, n, length(levels(cl)))
    x[(1:n) + n * (as.vector(oldUnclass(cl)) - 1)] <- 1
    dimnames(x) <- list(names(cl), levels(cl))
    x
  }
  if(!missing(data) && !is.null(data))
      data <- as.data.frame(data)
  m <- match.call(expand.dots = F)
  m$... <- NULL
  dots <- list(...)
  names(m)[2] <- "formula"
  if(hasArg(weights)) m$weights <- dots$weights
  if(hasArg(subset)) m$subset <- dots$subset
  if(hasArg(na.action)) m$na.action <- dots$na.action
  m[[1]] <- as.name("model.frame")
  m <- eval(m, sys.parent())
  Terms <- attr(m, "terms")
  if(!hasArg(contrasts)) contrasts <- NULL
  else  contrasts <- dots$contrasts
  x <- model.matrix(Terms, m, contrasts)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop=F] # Bias term is used for intercepts
  w <- model.extract(m, "weights")
  if(length(w) == 0) w <- rep(1, nrow(x))
  y <- model.extract(m, "response")
  if(is.factor(y)) {
    lev <- levels(y)
    counts <- table(y)
    if(any(counts == 0)) {
      warning(paste("group(s)", paste(lev[counts == 0], collapse=" "),
		    "are empty"))
      y <- factor(y, levels = lev[counts > 0])
    }
    if(length(lev) == 2) {
      y <- as.vector(oldUnclass(y)) - 1
      res <- nnet1(x, y, w, entropy=T, ...)
      res@lev <- lev
    } else {
      y <- class.ind(y)
      res <- nnet1(x, y, w, softmax=T, ...)
      res@lev <- lev
    }
  } else res <- nnet1(x, y, w, ...)
  res@coefnames <- dimnames(x)[[2]]
  Call <- match.call()
  Call$x <- as.call(attr(Terms, "formula"))
  res@call <- Call
  res
}

setMethod("nnet", "formula", nnet.formula)
setMethod("nnet", "matrix", nnet.matrix)
setMethod("nnet", "Matrix", nnet.matrix)
setMethod("nnet", "data.frame", nnet.matrix)

nnet1 <-
function(x, y, weights, size, Wts, mask=rep(T, length(wts)),
	 linout=F, entropy=F, softmax=F, censored=F, skip=F,
	 rang=0.7, decay=0, maxit=100, Hess=F, trace=T,
         MaxNWts=1000, abstol=1.0e-4, reltol=1.0e-8, ...)
{
  net <- new("nnet")
  x <- as.matrix(x)
  y <- as.matrix(y)
  if(any(is.na(x))) stop("missing values in x")
  if(any(is.na(y))) stop("missing values in y")
  if(dim(x)[1] != dim(y)[1]) stop("nrows of x and y must match")
  if(linout && entropy) stop("entropy fit only for logistic units")
  if(softmax) {
    linout <- T
    entropy <- F
  }
  if(censored) {
    linout <- T
    entropy <- F
    softmax <- T
  }
  net@n <- c(dim(x)[2], size, dim(y)[2])
  net@nunits <- 1 + sum(net@n)
  net@nconn <- rep(0, net@nunits+1)
  net@conn <- numeric(0)
  net <- norm.net(net)
  if(skip) net <- add.net(net, seq(1,net@n[1]),
			  seq(1+net@n[1]+net@n[2], net@nunits-1))
  if((nwts <- length(net@conn))==0) stop("No weights to fit")
  if(nwts > MaxNWts)
    stop(paste("Too many (", nwts, ") weights", sep=""))
  nsunits <- net@nunits
  if(linout) nsunits <- net@nunits - net@n[3]
  net@nsunits <- nsunits
  net@decay <- decay
  net@entropy <- entropy
  net@softmax <- softmax
  net@censored <- censored
  if(missing(Wts))
    if(rang > 0) wts <- runif(nwts, -rang, rang)
    else wts <- rep(0, nwts)
  else wts <- Wts
  if(length(wts) != nwts) stop("weights vector of incorrect length")
  if(length(mask) != length(wts)) stop("incorrect length of mask")
  if(trace) {
    cat("# weights: ", length(wts))
    nw <- sum(mask != 0)
    if(nw < length(wts)) cat(" (", nw, " variable)\n",sep="")
    else cat("\n")
  }
  if(length(decay) == 1) decay <- rep(decay, length(wts))
  if(!is.loaded(symbol.C("VR_set_net")))
    stop("Compiled code has not been dynamically loaded")
  .C("VR_set_net",
     net@n, net@nconn, net@conn, decay, nsunits, entropy, softmax, censored)
  ntr <- dim(x)[1]
  nout <- dim(y)[2]
  if(missing(weights)) weights <- rep(1, ntr)
  if(length(weights) != ntr || any(weights < 0))
    stop("invalid weights vector")
  Z <- as.single(cbind(x,y))
  tmp <- .C("VR_dovm",
            ntr, Z, weights,
	    length(wts), wts=wts, val=double(1),
	    maxit, trace, mask, abstol, reltol)
  net@value <- tmp$val
  net@wts <- tmp$wts
  tmp <- matrix(.C("VR_nntest",
		   ntr, Z, tclass = single(ntr*nout), net@wts
		   )$tclass,  ntr, nout)
  dimnames(tmp) <- list(dimnames(x)[[1]], dimnames(y)[[2]])
  net@fitted <- tmp
  tmp <- y - tmp
  dimnames(tmp) <- list(dimnames(x)[[1]], dimnames(y)[[2]])
  net@residuals <- tmp
  .C("VR_unset_net")
  if(entropy) net@lev <- c("0","1")
  if(softmax) net@lev <- dimnames(y)[[2]]
  net@call <- match.call()
  if(Hess) net@Hessian <- nnet.Hess(net, x, y, weights)
  net
}

is.form <- function(x) is.call(x) && (x[[1]] == "~")

predict.nnet <- function(object, newdata, type=c("raw","class"), ...)
{
  if(!inherits(object, "nnet")) stop("object not of class nnet")
  type <- match.arg(type)
  if(missing(newdata)) z <- fitted(object)
  else {
    if(!is(newdata, "model.matrix") && is.form(form <- object@call[[2]])) {#
      # formula fit
      newdata <- as.data.frame(newdata)
      rn <- row.names(newdata)
# work hard to predict NA for rows with missing data
      Terms <- delete.response(terms(form))
      m <- model.frame(Terms, newdata, na.action = na.omit)
      keep <- match(row.names(m), rn)
      x <- model.matrix(Terms, m)
      xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
      if(xint > 0) x <- x[, -xint, drop=F] # Bias term is used for intercepts
    } else { #
      # matrix ...  fit
      if(is.null(dim(newdata)))
        dim(newdata) <- c(1, length(newdata))# a row vector
      x <- as.matrix(newdata)		# to cope with dataframes
      if(any(is.na(x))) stop("missing values in x")
      keep <- 1:nrow(x)
      rn <- dimnames(x)[[1]]
    }
    ntr <- nrow(x)
    nout <- object@n[3]
    if(!is.loaded(symbol.C("VR_set_net")))
      stop("Compiled code has not been dynamically loaded")
    .C("VR_set_net",
       object@n, object@nconn, object@conn, object@decay, object@nsunits,
       object@entropy, object@softmax, object@censored)
    z <- matrix(NA, nrow(newdata), nout,
                dimnames = list(rn, dimnames(object@fitted)[[2]]))
    z[keep, ] <- matrix(.C("VR_nntest",
                           ntr, x, tclass = single(ntr*nout),
                           object@wts
                        )$tclass, ntr, nout)
    .C("VR_unset_net")
  }
  switch(type, raw = z,
	 class = {
	   if(is.null(object@lev)) stop("inappropriate fit for class")
	   if(ncol(z) > 1) object@lev[max.col(z)]
	   else object@lev[1 + (z > 0.5)]
	 })
}

eval.nn <- function(wts)
{
  z <- .C("VR_dfunc",
	 wts, df = double(length(wts)), fp = double(1))
  fp <- z$fp
  attr(fp, "gradient") <- z$df
  fp
}

add.net <- function(net, from, to)
{
  nconn <- net@nconn
  conn <- net@conn
  for(i in to){
    ns <- nconn[i+2]
    cadd <- from
    if(nconn[i+1] == ns) cadd <- c(0,from)
    con <- NULL
    if(ns > 1) con <- conn[1:ns]
    con <- c(con, cadd)
    if(length(conn) > ns) con <- c(con, conn[(ns+1):length(conn)])
    for(j in (i+1):net@nunits) nconn[j+1] <- nconn[j+1]+length(cadd)
    conn <- con
  }
  net@nconn <- nconn
  net@conn <- con
  net
}

norm.net <- function(net)
{
  n <- net@n; n0 <- n[1]; n1 <- n0+n[2]; n2 <- n1+n[3];
  if(n[2] <= 0) return(net)
  net <- add.net(net, 1:n0,(n0+1):n1)
  add.net(net, (n0+1):n1, (n1+1):n2)
}

which.is.max <- function(x)
{
  y <- seq(along=x)[x == max(x)]
  if(length(y) > 1) sample(y,1) else y
}

nnet.Hess <- function(net, x, y, weights)
{
  x <- as.matrix(x)
  y <- as.matrix(y)
  if(dim(x)[1] != dim(y)[1]) stop("dims of x and y must match")
  nw <- length(net@wts)
  decay <- net@decay
  if(length(decay) == 1) decay <- rep(decay, nw)
  if(!is.loaded(symbol.C("VR_set_net")))
    stop("Compiled code has not been dynamically loaded")
  .C("VR_set_net",
     net@n, net@nconn, net@conn, decay, net@nsunits,
     net@entropy, net@softmax, net@censored)
  ntr <- dim(x)[1]
  nout <- dim(y)[2]
  if(missing(weights)) weights <- rep(1, ntr)
  if(length(weights) != ntr || any(weights < 0))
    stop("invalid weights vector")
  Z <- as.single(cbind(x,y))
  z <- matrix(.C("VR_nnHessian", ntr, Z, weights,
                 net@wts, H = single(nw*nw))$H,  nw, nw)
  .C("VR_unset_net")
  z
}

class.ind <- function(cl)
{
  n <- length(cl)
  cl <- as.factor(cl)
  x <- matrix(0, n, length(levels(cl)) )
  x[(1:n) + n*(oldUnclass(cl)-1)] <- 1
  dimnames(x) <- list(names(cl), levels(cl))
  x
}

setMethod("show", "nnet", function(object) print.nnet(object))

formula.nnet <- function(object) object@call[[2]]
print.nnet <- function(x, ...)
{
  if(!inherits(x, "nnet")) stop("Not legitimate a neural net fit")
  cat("a ",x@n[1],"-",x@n[2],"-",x@n[3]," network", sep="")
  cat(" with", length(x@wts),"weights\n")
  if(length(x@coefnames))  cat("inputs:", x@coefnames, "\noutput(s):",
                               deparse(formula(x)[[2]]), "\n")
  cat("options were -")
  tconn <- diff(x@nconn)
  if(tconn[length(tconn)] > x@n[2]+1) cat(" skip-layer connections ")
  if(x@nunits > x@nsunits && !x@softmax) cat(" linear output units ")
  if(x@entropy) cat(" entropy fitting ")
  if(x@softmax) cat(" softmax modelling ")
  if(x@decay[1] > 0) cat(" decay=", x@decay[1], sep="")
  cat("\n")
  invisible(x)
}

coef.nnet <- function(object, ...)
{
  wts <- object@wts
  wm <- c("b", paste("i", seq(length=object@n[1]), sep=""))
  if(object@n[2] > 0)
  wm <- c(wm, paste("h", seq(length=object@n[2]), sep=""))
  if(object@n[3] > 1)  wm <- c(wm,
	  paste("o", seq(length=object@n[3]), sep=""))
  else wm <- c(wm, "o")
  names(wts) <- apply(cbind(wm[1+object@conn],
                            wm[1+rep(1:object@nunits - 1, diff(object@nconn))]),
                      1, function(x)  paste(x, collapse = "->"))
  wts
}

summary.nnet <- function(object, ...)
{
  class(object) <- c("summary.nnet", class(object))
  object
}

setOldClass(c("summary.nnet", "nnet"))

print.summary.nnet <- function(x, ...)
{
  cat("a ",x@n[1],"-",x@n[2],"-",x@n[3]," network", sep="")
  cat(" with", length(x@wts),"weights\n")
  cat("options were -")
  tconn <- diff(x@nconn)
  if(tconn[length(tconn)] > x@n[2]+1) cat(" skip-layer connections ")
  if(x@nunits > x@nsunits && !x@softmax) cat(" linear output units ")
  if(x@entropy) cat(" entropy fitting ")
  if(x@softmax) cat(" softmax modelling ")
  if(x@decay[1] > 0) cat(" decay=", x@decay[1], sep="")
  cat("\n")
  wts <- format(round(coef.nnet(x),2))
  lapply(split(wts,rep(1:x@nunits, tconn)),
	 function(x) print(x, quote=F))
  invisible(x)
}

invisible({
  setInterface("VR_set_net", "C",
               classes = c(rep("integer", 3), "numeric", rep("integer", 4)),
               copy = rep(F, 8)
               )
  setInterface("VR_dovm", "C",
               classes = c("integer", "single", "single",
               "integer", "numeric", "numeric",
                 rep("integer", 3), rep("numeric", 2)),
               copy = c(F, F, F, F, T, T, F, F, F, F, F)
               )
  setInterface("VR_nntest", "C",
               classes = c("integer", rep("single", 2), "numeric"),
               copy = c(F, F, T, F)
               )
  setInterface("VR_dfunc", "C",
               classes = rep("numeric", 3),
               copy = c(F, T, T)
               )
  setInterface("VR_nnHessian", "C",
               classes = c("integer", "single", "single", "numeric", "single"),
               copy = c(F, F, F, F, T)
               )
             })

model.frame.nnet <-
function(formula, data = NULL, na.action = NULL, ...)
{
  oc <- formula@call
  oc[[1]] <- as.name("model.frame")
  names(oc)[2:3] <- c("formula", "data")
  m <- match(names(oc), c("formula", "data", "na.action", "subset"))
  for(i in rev((seq(along=m)[is.na(m)])[-1])) oc[[i]] <- NULL
  if(length(data)) {
    oc$data <- substitute(data)
    eval(oc, sys.parent())
  }
  else eval(oc, list())
}

fitted.nnet <- function(object, ...) object@fitted
residuals.nnet <- function(object, ...) object@residuals
