# file survnnet/survnnet.q by R.M. Ripley based on
# file nnet/nnet.q copyright (C) 1994-8 W. N. Venables and B. D. Ripley
#
survnnet <- function(object, ...)
{
  if(is.null(class(object))) class(object) <- data.class(object)
  UseMethod("survnnet")
}

survnnet.formula <- function(formula, data = NULL, weights, ...,
			subset, na.action = na.fail, contrasts=NULL)
{
  class.ind <- function(cl)
  {
    n <- length(cl)
    x <- matrix(0, n, length(levels(cl)))
    x[(1:n) + n * (as.vector(unclass(cl)) - 1)] <- 1
    dimnames(x) <- list(names(cl), levels(cl))
    x
  }
  m <- match.call(expand.dots = FALSE)
  if(is.matrix(eval(m$data, parent.frame())))
      m$data <- as.data.frame(data)
  m$... <- m$contrasts <- NULL
  m[[1]] <- as.name("model.frame")
  m <- eval(m, parent.frame())
  Terms <- attr(m, "terms")
  x <- model.matrix(Terms, m, contrasts)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop= FALSE] # Bias term is used for intercepts
  w <- model.extract(m, weights)
  if(length(w) == 0) w <- rep(1, nrow(x))
  y <- model.extract(m, response)
  if(is.factor(y)) {
    lev <- levels(y)
    counts <- table(y)
    if(any(counts == 0)) {
      warning(paste("group(s)", paste(lev[counts == 0], collapse=" "),
		    "are empty"))
      y <- factor(y, levels=lev[counts > 0])
    }
    if(length(lev) == 2) {
      y <- as.vector(unclass(y)) - 1
      res <- survnnet.default(x, y, w, entropy = TRUE, ...)
      res$lev <- lev
    } else {
      y <- class.ind(y)
      res <- survnnet.default(x, y, w, softmax = TRUE, ...)
      res$lev <- lev
    }
  } else res <- survnnet.default(x, y, w, ...)
  #res$x <- as.data.frame(res$x)
  res$terms <- Terms
  res$coefnames <- dimnames(x)[[2]]
  if (res$model=='hazard') res$coefnames<- c(res$coefnames,'stime')
  res$call <- match.call()
  class(res) <- c("survnnet.formula", "survnnet")
  res
}

survnnet.default <-
function(x, y, weights, size, Wts, mask=rep(1, length(wts)),
	 linout = FALSE, entropy = FALSE, softmax = FALSE, censored = FALSE,
         skip = FALSE,
	 rang=0.7, decay=0, bias.decay=1, maxit=100, Hess = FALSE,
         trace = TRUE, MaxNWts=1000, model=NULL, alpha=NULL, Nintervals=20,
         varWt=10)
{
  net <- NULL
  x <- as.matrix(x)
  y <- as.matrix(y)
  if(any(is.na(x))) stop("missing values in x")
  if(any(is.na(y))) stop("missing values in y")
  if(dim(x)[1] != dim(y)[1]) stop("nrows of x and y must match")
  if(linout && entropy) stop("entropy fit only for logistic units")
  if(softmax) {linout <- TRUE; entropy <- FALSE}
  if(censored) {linout <- TRUE; entropy <- FALSE; softmax <- TRUE}
  if(!is.null(model))
    {
      linout <- TRUE;
      entropy <- FALSE;
      softmax <- FALSE
      if (ncol(y)!=2)
	stop("target must have two columns - survival time followed by status")
      stime <- y[,1]
      y <- y[,2,drop= FALSE]
      if(any(is.na(stime))) stop("missing values in survival times")
      if(dim(x)[1] != length(stime))
        stop("nrows of x and length of stime must match")
      if(is.null(alpha) && model!='exp' && model !='hazard' &&
         model!='lnormvar')
        alpha<- 0.1
      model <- factor(model,levels=c("exp","llog","lnorm","weibull",
                             "hazard","lnormvar"))
      if (model=='hazard')
          x<- cbind(x,stime)
      if (model=='lnormvar')
          y<- cbind(y,rep(1,nrow(y)))
    }
  else
    {
    model <- 0
    stime <- rep(1,nrow(x))
  }
  if(is.null(alpha)) alpha<- 0
  net$n <- c(dim(x)[2], size, dim(y)[2])
  net$nunits <- 1 + sum(net$n)
  net$nconn <- rep(0, net$nunits+1)
  net$conn <- numeric(0)
  net <- norm.survnet(net)
  if(skip) net <- add.survnet(net, seq(1,net$n[1]),
			  seq(1+net$n[1]+net$n[2], net$nunits-1))
  if((nwts <- length(net$conn))==0) stop("No weights to fit")
  if(nwts > MaxNWts)
    stop(paste("Too many (", nwts, ") weights", sep=""))
  nsunits <- net$nunits
  if(linout) nsunits <- net$nunits - net$n[3]
  net$nsunits <- nsunits
  net$decay <- decay
  net$entropy <- entropy
  net$softmax <- softmax
  net$censored <- censored
  net$model <- model
  net$Nintervals<- Nintervals
  net$varWt<- varWt
  if(missing(Wts)) {
    if(rang > 0) wts <- runif(nwts, -rang, rang)
    else wts <- rep(0, nwts)
    if(model!=0) wts[wts==rev(wts[net$conn==0])[1]] <-
      wts[wts==rev(wts[net$conn==0])[1]] - median(log(stime))
  } else wts <- Wts
  if(length(mask) != length(wts)) stop("incorrect length of mask")
  if(length(wts) != nwts) stop("weights vector of incorrect length")
     if(model=='llog'|model=='lnorm'|model=='weibull') mask <- c(mask,1)
  if(trace) {
    cat("# weights: ", length(wts))
    nw <- sum(mask != 0)
    if(nw < length(wts)) cat(" (", nw, " variable)\n",sep="")
    else cat("\n")
  }
  if(length(decay) == 1) decay <- rep(decay, length(wts))
  decay <- ifelse(net$conn==0, decay/bias.decay, decay)
  if(any(decay!=decay[1])) net$decay <- decay
  .C("set_survnet",
     as.integer(net$n),
     as.integer(net$nconn),
     as.integer(net$conn),
     as.double(decay),
     as.integer(nsunits),
     as.integer(entropy),
     as.integer(softmax),
     as.integer(censored),
     as.integer(model),
     as.single(alpha),
     as.integer(Nintervals),
     as.single(varWt),
     PACKAGE="survnnet")
  ntr <- dim(x)[1]
  nout <- dim(y)[2]
  if(missing(weights)) weights <- rep(1, ntr)
  if(length(weights) != ntr || any(weights < 0))
    stop("invalid weights vector")
  z <- .C("set_survtrain",
	  as.integer(ntr),
	  as.single(cbind(x,y)),
	  as.single(weights),
          as.single(stime),
          PACKAGE="survnnet")
  if (model=="llog"|model=='lnorm'|model=='weibull')
   tmp <- .C("survdovm",
             as.integer(length(wts)+1),
             wts=as.double(c(wts,alpha)),
             val=double(1),
             as.integer(maxit),
             as.integer(trace),
             as.integer(mask),
             PACKAGE="survnnet")
   else
     tmp <- .C("survdovm",
               as.integer(length(wts)),
               wts=as.double(wts),
               val=double(1),
               as.integer(maxit),
               as.integer(trace),
               as.integer(mask),
               PACKAGE="survnnet"
	    )
  .C("unset_survtrain", PACKAGE="survnnet")
  net$value <- tmp$val
  net$wts <- tmp$wts
  if (model=='llog'|model=='lnorm'|model=='weibull')
    {
      net$alpha <- as.double(tmp$wts[length(tmp$wts)])
      net$wts <- tmp$wts[1:(length(tmp$wts)-1)]
    }
  tmp <- matrix(.C("survnntest",
		   as.integer(ntr),
		   as.single(cbind(x,y)),
		   tclass = single(ntr*nout),
		   as.double(net$wts),
                   as.integer(0),
                   as.integer(0),
                   single(ntr*nout),
                   double(ntr*Nintervals),
                   PACKAGE="survnnet"
		   )$tclass,  ntr, nout)
  dimnames(tmp) <- list(dimnames(x)[[1]], NULL)
  net$fitted.values <- tmp
  .C("unset_survnet", PACKAGE="survnnet")
  if(entropy) net$lev <- c("0","1")
  if(softmax) net$lev <- dimnames(y)[[2]]
  net$call <- match.call()
  net$x <- unclass(x)
  if (model !=0)
   y <- cbind(stime,y)
  if (model == 'lnormvar')
      y<- y[,-3]
     net$y <- y
 # browser()
  if(Hess) net$Hessian <- survnnet.Hess(net, x, y, weights)
  structure(net, class="survnnet")
}

predict.survnnet.formula <- function(object, newdata, type=c("raw", "class"), loglike=FALSE,stime=NULL, ...)
{
  if(!inherits(object, "survnnet.formula"))
    stop("object not of class survnnet.formula")
  type <- match.arg(type)
  if(missing(newdata))
    switch(type,
	   raw = return(object$fitted.values),
	   class = {
	     if(is.null(object$lev)) stop("inappropriate fit for class")
	     z <- object$fitted.values
	     if(ncol(z) > 1) object$lev[max.col(z)]
	     else object$lev[1 + (z > 0.5)]
	   })
  x <- model.matrix(delete.response(object$terms), newdata)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop= FALSE] # Bias term is used for intercepts
  predict.survnnet(object=object, x=x, type=type,loglike=loglike,stime=stime, ...)
}

predict.survnnet <- function(object, x, type=c("raw","class"),loglike=FALSE,stime=NULL,allHaz=FALSE, ...)
{
  if(!inherits(object, "survnnet")) stop("object not of class survnnet")
  if(is.null(object$model)) object$model <- 0
  if(is.null(object$alpha)) object$alpha <- 1.0
  type <- match.arg(type)
  if(missing(x))
    z <- object$fitted
  else {
    if(is.null(dim(x))) dim(x) <- c(1, length(x))
    if(any(is.na(x))) stop("missing values in x")
    x <- as.matrix(x)
    ntr <- dim(x)[1]
    nout <- object$n[3]
    if(length(object)==10) object$softmax <- FALSE
    .C("set_survnet",
       as.integer(object$n),
       as.integer(object$nconn),
       as.integer(object$conn),
       as.double(object$decay),
       as.integer(object$nsunits),
       as.integer(0),
       as.integer(object$softmax),
       as.integer(object$censored),
       as.integer(object$model),
       as.single(object$alpha),
       as.integer(object$Nintervals),
       as.single(object$varWt),
       PACKAGE="survnnet"
       )
     if (loglike)
        if (object$model!='hazard')
            stop('loglike options only applicable to hazard model')
     if (allHaz)
        if (object$model!='hazard')
            stop('allHaz options only applicable to hazard model')
   if (object$model=='hazard')
      {
      if (is.null(stime))
        stop("For Hazard model must give stime")
      else if (length(stime)!=ntr)
        stop("Length of stime must equal nrow(x)")
      x<- cbind(x,stime)
    }
    else
      stime=rep(1,ntr)
   z <- .C("survnntest",
	    as.integer(ntr),
	    as.single(x),
	    tclass = single(ntr*nout),
	    as.double(object$wts),
            as.integer(loglike),
            as.integer(allHaz),
            like=single(ntr*nout),
            haz=double(ntr*object$Nintervals),
            PACKAGE="survnnet"
	    )
     zzh<- matrix(z$haz,ntr,object$Nintervals)
    zz<- matrix(z$like,ntr,nout)
   z <- matrix(z$tclass, ntr, nout)
    dimnames(z) <- list(dimnames(x)[[1]], dimnames(object$fitted)[[2]])
    .C("unset_survnet", PACKAGE="survnnet")
  }
  switch(type, raw = if(loglike|allHaz) {
                                         list(z=z,zz=zz,zzh=zzh) }else z,
	 class = {
	   if(is.null(object$lev)) stop("inappropriate fit for class")
	   if(ncol(z) > 1) object$lev[max.col(z)]
	   else object$lev[1 + (z > 0.5)]
	 })
}

eval.survnn <- function(wts)
{
  z<- .C("survdfunc",
	 as.double(wts),
	 df=double(length(wts)),
	 fp=as.double(1),
         PACKAGE="survnnet"
	 )
  fp<- z$fp
  attr(fp, "gradient") <- z$df
  fp
}
add.survnet <- function(net, from, to)
{
  nconn <- net$nconn
  conn <- net$conn
  for(i in to){
    ns <- nconn[i+2]
    cadd <- from
    if(nconn[i+1] == ns) cadd<-c(0,from)
    con <- NULL
    if(ns > 1) con <- conn[1:ns]
    con <- c(con, cadd)
    if(length(conn) > ns) con <- c(con, conn[(ns+1):length(conn)])
    for(j in (i+1):net$nunits) nconn[j+1] <- nconn[j+1]+length(cadd)
    conn <- con
  }
  net$nconn <- nconn
  net$conn <- con
  net
}

norm.survnet <- function(net)
{
  n<- net$n; n0 <- n[1]; n1 <- n0+n[2]; n2 <- n1+n[3];
  if(n[2] <= 0) return(net)
  net <- add.survnet(net, 1:n0,(n0+1):n1)
  add.survnet(net, (n0+1):n1, (n1+1):n2)
}

survnnet.Hess <- function(net, x, y, weights)
{
  x <- as.matrix(x)
  y <- as.matrix(y)
  if(dim(x)[1] != dim(y)[1]) stop("dims of x and y must match")
  if(is.null(net$model)) net$model <- 0
  if(is.null(net$alpha)) net$alpha <- 1.0
  if (net$model !=0)
    {
      if (dim(y)[2]!= 2) stop("y must have two columns - time and event")
      stime <- y[,1]
      y <- y[,2,drop= FALSE]
      if (net$model=='lnormvar')
         y<- cbind(y,rep(1,nrow(y)))
    }
    else
  stime <- NULL
  nw <- length(net$wts)
  decay <- net$decay
  if(length(decay) == 1) decay <- rep(decay, nw)
  .C("set_survnet",
     as.integer(net$n),
     as.integer(net$nconn),
     as.integer(net$conn),
     as.double(decay),
     as.integer(net$nsunits),
     as.integer(net$entropy),
     as.integer(net$softmax),
     as.integer(net$censored),
     as.integer(net$model),
     as.single(net$alpha),
     as.integer(net$Nintervals),
     as.single(net$varWt),
     PACKAGE="survnnet"
     )
  ntr <- dim(x)[1]
  nout <- dim(y)[2]
  if(missing(weights)) weights <- rep(1, ntr)
  if(length(weights) != ntr || any(weights < 0))
    stop("invalid weights vector")
  z <- .C("set_survtrain",
	  as.integer(ntr),
	  as.single(cbind(x,y)),
	  as.single(weights),
	  as.single(stime),
          PACKAGE="survnnet"
	  )
  if(!is.na(net$model) && unclass(net$model) > 1 && unclass(net$model) < 5)
    nw <- length(net$wts)+1
    else
      nw <- length(net$wts)
  z <- matrix(.C("survnnHessian",as.double(net$wts),H = double(nw*nw),
                 PACKAGE="survnnet")$H, nw, nw)
  .C("unset_survtrain", PACKAGE="survnnet");
  .C("unset_survnet", PACKAGE="survnnet")
  z
}

print.survnnet <- function(x, ...)
{
  if(!inherits(x, "survnnet")) stop("Not a legitimate survival neural net fit")
  if(length(x)==10) x$softmax <- FALSE
  if(is.null(x$model))x$model<- NA
  if(is.null(x$censored))x$censored<- FALSE
  cat("a ",x$n[1],"-",x$n[2],"-",x$n[3]," network", sep="")
  cat(" with", length(x$wts),"weights\n")
  if(length(x$coefnames))  cat("inputs:", x$coefnames, "\noutput(s):",
                               deparse(formula(x)[[2]]), "\n")
  cat("options were -")
  tconn <- diff(x$nconn)
  if(tconn[length(tconn)] > x$n[2]+1) cat(" skip-layer connections ")
  if(x$nunits > x$nsunits && !x$softmax) cat(" linear output units ")
  if(x$entropy) cat(" entropy fitting ")
  if(x$softmax) cat(" softmax modelling ")
  if(x$censored) cat(" with censoring ")
  if(!is.na(x$model)&& unclass(x$model)>0)
      cat( "model=",as.vector(x$model), sep="")
  if(any(x$decay > 0))
    { if(any(x$decay!=x$decay[1])) cat("\n decay variable ")
    else
      cat("\n decay=", x$decay[1], sep="")
    }
  cat("\n")
  invisible(x)
}

coef.survnnet <- function(object, ...)
{
  wts <- format(round(object$wts,2))
  wm <- c("b", paste("i", seq(length=object$n[1]), sep=""))
  if(object$n[2] > 0)
  wm <- c(wm, paste("h", seq(length=object$n[2]), sep=""))
  if(object$n[3] > 1)  wm <- c(wm,
	  paste("o", seq(length=object$n[3]), sep=""))
  else wm <- c(wm, "o")
  names(wts) <- apply(cbind(wm[1+object$conn], wm[1+rep(1:object$nunits - 1, diff(object$nconn))]),
		      1, function(x)  paste(x, collapse = "->"))
  wts
}

summary.survnnet <- function(object, ...)
{
  class(object) <- "summary.survnnet"
  object
}

print.summary.survnnet <- function(x, ...)
{
  if(length(x)==10) x$softmax <- FALSE
  if(is.null(x$model))x$model<- NA
  if(is.null(x$censored))x$censored<- FALSE
  cat("a ",x$n[1],"-",x$n[2],"-",x$n[3]," network", sep="")
  cat(" with", length(x$wts),"weights\n")
  cat("options were -")
  tconn <- diff(x$nconn)
  if(tconn[length(tconn)] > x$n[2]+1) cat(" skip-layer connections ")
  if(x$nunits > x$nsunits && !x$softmax) cat(" linear output units ")
  if(x$entropy) cat(" entropy fitting ")
  if(x$softmax) cat(" softmax modelling ")
  if(x$censored) cat(" with censoring ")
  if(!is.na(x$model) && unclass(x$model)>0)
      cat( "\nmodel=",as.vector(x$model), sep="")
  if(!is.null(x$alpha)) cat(" alpha=", round(x$alpha,2),sep="")
  if(any(x$decay > 0))
    { if(any(x$decay!=x$decay[1])) cat("\n decay variable ")
    else
      cat("\n decay=", x$decay[1], sep="")
    }
  cat("\n")
  if (x$model=='hazard') cat ('Intervals used=',x$Nintervals,'\n');
  if (x$model=='lnormvar') cat('Variance proportion=',x$varWt,'\n');
  wts <- format(round(x$wts,2))
  wm <- c("b", paste("i", seq(length=x$n[1]), sep=""))
  if(x$n[2] > 0) wm <- c(wm, paste("h", seq(length=x$n[2]), sep=""))
  if(x$n[3] > 1)
    wm <- c(wm, paste("o", seq(length=x$n[3]), sep=""))
  else wm <- c(wm, "o")
  names(wts) <- apply(cbind(wm[1+x$conn], wm[1+rep(1:x$nunits - 1, tconn)]),
		      1, function(x)  paste(x, collapse = "->"))
  lapply(split(wts,rep(1:x$nunits, tconn)),
	 function(x) print(x, quote= FALSE))
  invisible(x)
}

phtnnet <- function(object, ...)
{
  if(is.null(class(object))) class(object) <- data.class(object)
  UseMethod("phtnnet")
}

phtnnet.formula <- function(formula, data = NULL, ...,
			subset, na.action = na.fail, contrasts=NULL)
{
  m <- match.call(expand.dots = FALSE)
  if(is.matrix(eval(m$data, parent.frame())))
      m$data <- as.data.frame(data)
  m$... <- m$contrasts <- NULL
  m[[1]] <- as.name("model.frame")
  m <- eval(m, parent.frame())
  Terms <- attr(m, "terms")
  x <- model.matrix(Terms, m, contrasts)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop= FALSE] # Bias term is used for intercepts
 # w <- model.extract(m, weights)
 # if(length(w) == 0) w <- rep(1, nrow(x))
  y <- model.extract(m, response)
  if(is.factor(y)) {
    lev <- levels(y)
    counts <- table(y)
    if(any(counts == 0)) {
      warning(paste("group(s)", paste(lev[counts == 0], collapse=" "),
		    "are empty"))
      y <- factor(y, levels=lev[counts > 0])
    }
    if(length(lev) == 2)
      y <- as.vector(unclass(y)) - 1
    else
      stop("only two-level factors allowed")
  }
  res <- phtnnet.default(x, y, ...)
  #res$x <- as.data.frame(res$x)
  res$terms <- Terms
  res$coefnames <- dimnames(x)[[2]]
  res$call <- match.call()
  class(res) <- c("phtnnet.formula", "phtnnet","phnnet")
  res
}

phtnnet.default <- function(x, status, size, Wts, mask, skip = FALSE,
                            rang = 0.7, decay = 0, bias.decay=1, breakpts,
                            MaxNWts=1000, maxit = 100, trace = TRUE,  dohaz = FALSE,
                            dohess = FALSE)
{
  net <- NULL
  x <- as.matrix(x)
  status <- as.matrix(status)
  xx <- x
  yy <- status
  if(any(is.na(x)))
    stop("missing values in x")
  if(dim(x)[1] != nrow(status))
    stop("status and x must have equal lengths")
 # need to add column to x- matrix
  if(!missing(breakpts))
    {
      if (ncol(status) !=2)
	stop("must give follow-up time if break points are given")
      zone <- as.numeric(cut(status[,1],breaks=breakpts))-1
      if(any(is.na(zone))) stop("`breakpts' must cover the data")
      if (max(zone)>0)
	RFS <- zone/max(zone)
      else
	RFS <- zone
      x <- cbind(x,RFS)
    }
  else # assume x matrix contains extra column
      zone <- match(x[,ncol(x)],sort(unique(x[,ncol(x)])))-1
  if (ncol(status) ==2)
    { # need sorting
      ind <- rev(order(status[,1]))
      x <- x[ind,, drop = FALSE]
      status <- status[ind,2]
      zone <- zone[ind]
    }
  else
    {
      status <- as.vector(status)
      ind <- 1:length(status)
    }
  net$n <- c(dim(x)[2], size, 1)
  net$nunits <- 1 + sum(net$n)
  net$entropy <- TRUE
  net$nconn <- rep(0, net$nunits + 1)
  net$conn <- numeric(0)
  net <- norm.survnet(net)
  if(skip)
    net <- add.survnet(net, seq(1, net$n[1]), seq(1 + net$n[1] + net$
						 n[2], net$nunits - 1))
  if((nbrwts <- length(net$conn))==0) stop("No weights to fit")
  if(nbrwts > MaxNWts)
    stop(paste("Too many (", nbrwts, ") weights", sep=""))
 # nsunits <- net$nunits
 # if(linout)
    nsunits <- net$nunits - net$n[3]
  net$nsunits <- nsunits
  if (length(decay)==1)
    net$decay <- rep(decay,length(net$conn))
  net$decay <- ifelse(net$conn==0,decay/bias.decay,decay)
  nbrwts <- length(net$conn)
  if(length(net$conn) == 0)
    stop("No weights to fit")
  if(missing(Wts))
    if(rang > 0)
      {
	wts <- runif(nbrwts-1,  - rang, rang)
	#wts <- c(wts[1:(nbrwts - size - 1)], 0,
		# wts[(nbrwts - size):(nbrwts - 1)])
	if(net$n[2]>0)
	  wts <- c(wts[1:((net$n[1]+1)*net$n[2])], 0,
		   wts[((net$n[1]+1)*net$n[2]+1):(nbrwts - 1)])
	  else wts <- c(0,wts)
      }
      else wts <- rep(0, length(net$conn))
    else wts <- Wts
  if(!missing(mask) && length(mask) != length(wts))
    stop("incorrect length of mask")
  if(length(wts) != length(net$conn))
    stop("weights vector of incorrect length")
  if(missing(mask))
    {
      mask <- ifelse(wts==0,0,1)
    }
  nzone <- length(unique(x[,ncol(x)]))
  uzone <- sort(unique(x[,ncol(x)]))
  net$zone[ind] <- zone
  net$nzone <- nzone
  net$uzone <- uzone
  if(any(zone>nzone)) stop("zone out of range")
  if(trace) {
    cat("# weights: ", length(wts))
    nw <- sum(mask != 0)
    if(nw < length(wts))
      cat(" (", nw, " variable)\n", sep = "")
      else cat("\n")
  }
  .C("set_phtnet",
     as.integer(net$n),
     as.integer(net$nconn),
     as.integer(net$conn),
     as.integer(nsunits),
     as.single(net$decay),
     PACKAGE="survnnet"
     )
  ntr <- dim(x)[1]
  nout <- 1
  tmp <- .C("phtdoit",
	    as.integer(ntr),
	    as.single(x),
	    as.integer(status),
	    as.integer(length(wts)),
	    wts = as.double(wts),
	    val = double(1),
	    as.integer(maxit),
	    as.integer(trace),
	    as.integer(mask),
	    haz = double(ntr),
	    hess = double(length(wts) * length(wts)),
	    as.integer(dohaz),
	    as.integer(dohess),
	    as.integer(nzone),
	    as.integer(zone),
	    as.single(uzone),
             PACKAGE="survnnet")
  .C("unset_phttrain", PACKAGE="survnnet")
  net$value <- tmp$val
  net$wts <- tmp$wts
  if(dohaz)
    net$haz <- tmp$haz
  if(dohess) {
    hess <- matrix(tmp$hess, length(wts), length(wts))
    net$hess <- hess + t(hess) - diag(diag(hess))
  }
  tmp <- matrix(.C("pred_phtnnet",
		   as.integer(ntr),
		   as.single(x),
		   as.double(net$wts),
		   pred = single(ntr*nzone),
		   as.integer(nzone),
		   as.single(uzone),
                   PACKAGE="survnnet")$pred, ntr, nzone)
  net$fitted.values <- matrix(0,ntr,nzone)
  net$fitted.values[ind,] <- tmp
  #if (exists('RFS')) net$x <- unclass(cbind(xx,RFS))
  #else
  #  net$x <- unclass(xx)
  #net$y <- yy
  .C("unset_phtnet", PACKAGE="survnnet")
  structure(net, class = c("phtnnet","phnnet"))
}

predict.phtnnet.formula <- function(object, newdata,  ...)
{
  if(!inherits(object, "phtnnet.formula"))
    stop("object not of class phtnnet.formula")
  if(missing(newdata))
      return(object$fitted.values)
  x <- model.matrix(delete.response(object$terms), newdata)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop= FALSE] # Bias term is used for intercepts
  predict.phtnnet(object=object, x=x)
}


predict.phtnnet <- function(object, x, ...)
{
  if(!inherits(object, "phtnnet"))
    stop("object not of class phtnnet")
  if(missing(x))
    z <- object$fitted
  else {
    if(is.null(dim(x)))
      dim(x) <- c(1, length(x))
    if(any(is.na(x)))
      stop("missing values in x")
    x <- as.matrix(x)
    ntr <- dim(x)[1]
    x <- cbind(x,rep(0,ntr))
    nout <- object$n[3]
    .C("set_phtnet",
       as.integer(object$n),
       as.integer(object$nconn),
       as.integer(object$conn),
       as.integer(object$nsunits),
       as.single(object$decay),
       PACKAGE="survnnet")
    z <- .C("pred_phtnnet",
            as.integer(ntr),
            as.single(x),
            as.double(object$wts),
            pred = single(ntr*object$nzone),
	    as.integer(object$nzone),
	    as.single(object$uzone),
            PACKAGE="survnnet")
    z <- matrix(z$pred, ntr,object$nzone)
    dimnames(z) <- list(dimnames(x)[[1]], dimnames(object$fitted)[[2]])
    class(z) <- object$errfn
  }
  z
}

phtnnet.Hess <- function(net, x, status)
{
  x <- cbind(x,rep(0,nrow(x)))
  x <- as.matrix(x)
  status <- as.matrix(status)
  if (ncol(status) ==2)
    { # need sorting
      ind <- rev(order(status[,1]))
      x <- x[ind, , drop = FALSE]
      status <- status[ind,2]
      net$zone <- net$zone[ind]
    }
  else status <- as.vector(status)
  if(dim(x)[1] !=length(status)) stop("dims of x and status must match")
  nw <- length(net$wts)
  .C("set_phtnet",
     as.integer(net$n),
     as.integer(net$nconn),
     as.integer(net$conn),
     as.integer(net$nsunits),
     as.single(net$decay),
     PACKAGE="survnnet"
     )
  ntr <- dim(x)[1]
  z <- .C("set_phttrain",
	  as.integer(ntr),
	  as.single(x),
	  as.integer(status),
          as.integer(net$nzone),
          as.integer(net$zone),
          as.single(net$uzone),
          PACKAGE="survnnet")
  tmp <- matrix(.C("phtfunc2",as.double(net$wts),H = double(nw*nw),
                   PACKAGE="survnnet")$H,nw,nw)
  .C("unset_phttrain", PACKAGE="survnnet")
  .C("unset_phtnet", PACKAGE="survnnet")
  tmp
}


phnnet <- function(object, ...)
{
  if(is.null(class(object))) class(object) <- data.class(object)
  UseMethod("phnnet")
}

phnnet.formula <- function(formula, data = NULL, ...,
			subset, na.action = na.fail, contrasts=NULL)
{
  m <- match.call(expand.dots = FALSE)
  if(is.matrix(eval(m$data, parent.frame())))
      m$data <- as.data.frame(data)
  m$... <- m$contrasts <- NULL
  m[[1]] <- as.name("model.frame")
  m <- eval(m, parent.frame())
  Terms <- attr(m, "terms")
  x <- model.matrix(Terms, m, contrasts)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop= FALSE] # Bias term is used for intercepts
 # w <- model.extract(m, weights)
 # if(length(w) == 0) w <- rep(1, nrow(x))
  y <- model.extract(m, response)
  if(is.factor(y)) {
    lev <- levels(y)
    counts <- table(y)
    if(any(counts == 0)) {
      warning(paste("group(s)", paste(lev[counts == 0], collapse=" "),
		    "are empty"))
      y <- factor(y, levels=lev[counts > 0])
    }
    if(length(lev) == 2)
      y <- as.vector(unclass(y)) - 1
    else
      stop("only two-level factors allowed")
  }
  res <- phnnet.default(x, y, ...)
  #res$x <- as.data.frame(res$x)
  res$terms <- Terms
  res$coefnames <- dimnames(x)[[2]]
  res$call <- match.call()
  class(res) <- c("phnnet.formula", "phnnet")
  res
}

phnnet.default <- function(x, status, size, Wts, mask, skip = FALSE,
		   rang = 0.7, decay=0, bias.decay=1, MaxNWts=1000,
		   maxit = 100, trace = TRUE, dohaz = FALSE, dohess = FALSE)
{
  net <- NULL
  x <- as.matrix(x)
  status <- as.matrix(status)
  xx <- x
  yy <- status
  if (ncol(status) ==2)
    { # need sorting
      ind <- rev(order(status[,1]))
      x <- x[ind,, drop= FALSE]
      status <- status[ind,2]
    }
  else
    {
      status <- as.vector(status)
      ind <- 1:length(status)
     }
  if(any(is.na(x)))
    stop("missing values in x")
  if(dim(x)[1] != length(status))
    stop("status and x must have equal lengths")
  net$n <- c(dim(x)[2], size, 1)
  net$nunits <- 1 + sum(net$n)
  net$entropy <- TRUE
  net$nconn <- rep(0, net$nunits + 1)
  net$conn <- numeric(0)
  net <- norm.survnet(net)
  if(skip)
    net <- add.survnet(net, seq(1, net$n[1]), seq(1 + net$n[1] + net$
						 n[2], net$nunits - 1))
  nsunits <- net$nunits
  #if(linout)
    nsunits <- net$nunits - net$n[3]
  net$nsunits <- nsunits
  if (length(decay)==1)
    net$decay <- rep(decay,length(net$conn))
  net$decay <- ifelse(net$conn==0,decay/bias.decay,decay)
  if((nbrwts <- length(net$conn))==0) stop("No weights to fit")
  if(nbrwts > MaxNWts)
    stop(paste("Too many (", nbrwts, ") weights", sep=""))
  if(length(net$conn) == 0)
    stop("No weights to fit")
  if(missing(Wts))
    if(rang > 0)
      {
	wts <- runif(nbrwts-1,  - rang, rang)
	#wts <- c(wts[1:(nbrwts - size - 1)], 0,
		# wts[(nbrwts - size):(nbrwts - 1)])
if(net$n[2]>0)
	wts <- c(wts[1:((net$n[1]+1)*net$n[2])], 0,
		 wts[((net$n[1]+1)*net$n[2]+1):(nbrwts - 1)])
else wts <- c(0,wts)
      }
      else wts <- rep(0, length(net$conn))
    else wts <- Wts
  if(!missing(mask) && length(mask) != length(wts))
    stop("incorrect length of mask")
  if(length(wts) != length(net$conn))
    stop("weights vector of incorrect length")
  if(missing(mask))
    {
      mask <- ifelse(wts==0,0,1)
    }
  if(trace) {
    cat("# weights: ", length(wts))
    nw <- sum(mask != 0)
    if(nw < length(wts))
      cat(" (", nw, " variable)\n", sep = "")
      else cat("\n")
  }
  .C("set_phtnet",
     as.integer(net$n),
     as.integer(net$nconn),
     as.integer(net$conn),
     as.integer(nsunits),
     as.single(net$decay),
     PACKAGE="survnnet"
     )
  ntr <- dim(x)[1]
  nout <- 1
 tmp <- .C("phdoit",
	    as.integer(ntr),
	    as.single(x),
	    as.integer(status),
	    as.integer(length(wts)),
	    wts = as.double(wts),
	    val = double(1),
	    as.integer(maxit),
	    as.integer(trace),
	    as.integer(mask),
	    haz = double(ntr),
	    hess = double(length(wts) * length(wts)),
	    as.integer(dohaz),
	    as.integer(dohess),
           PACKAGE="survnnet")
  .C("unset_phtrain", PACKAGE="survnnet")
  net$value <- tmp$val
  net$wts <- tmp$wts
  if(dohaz)
    net$haz <- tmp$haz
  if(dohess) {
    hess <- matrix(tmp$hess, length(wts), length(wts))
    net$hess <- hess + t(hess) - diag(diag(hess))
  }
  tmp <- matrix(.C("pred_phnnet",
		   as.integer(ntr),
		   as.single(x),
		   as.double(net$wts),
		   pred = single(ntr),
                   PACKAGE="survnnet")$pred, ntr, nout)
  dimnames(tmp) <- list(dimnames(x)[[1]], "Output")
  #net$x <- unclass(xx)
  #net$y <- yy
  net$fitted.values[ind] <- tmp
  .C("unset_phtnet", PACKAGE="survnnet")
  structure(net, class = c("phnnet"))
}


print.phnnet <- function(x, ...)
{
  if(!inherits(x, "phnnet"))
    stop("Not a legitimate neural net fit")
  cat("a ",x$n[1],"-",x$n[2],"-",x$n[3]," network", sep="")
  cat(" with", length(x$wts),"weights\n")
  if (length(x$coefnames))cat("inputs:", x$coefnames,
      "\noutput(s):", deparse(formula(x)[[2]]), "\n")
  cat("options were -")
  tconn <- diff(x$nconn)
  if(tconn[length(tconn)] > x$n[2]+1) cat(" skip-layer connections ")
  if(x$nunits > x$nsunits) cat(" linear output units ")
  if(x$entropy) cat(" entropy fitting ")
  if(any(x$decay > 0)) cat(" decay=", x$decay, sep=" ")
  cat("\n")
  invisible(x)
}


summary.phnnet <- function(object, ...)
{
  class(object) <- "summary.phnnet"
  object
}

print.summary.phnnet <- function(x,...)
{
        if(!inherits(x, "summary.phnnet"))
                stop("Not a legitimate ph neural net fit")
        cat("a ", x$n[1], "-", x$n[2], "-", x$n[3], " network", sep = "")
        cat(" with", length(x$wts), "weights\n")
        cat("options were -")
        tconn <- diff(x$nconn)
        if(tconn[length(tconn)] > x$n[2] + 1)
                cat(" skip-layer connections ")
        if(x$nunits > x$nsunits)
                cat(" linear output units ")
        if(x$entropy)
                cat(" entropy fitting ")
        if(x$decay[1] > 0)
                cat(" decay=", x$decay, sep = " ")
        cat("\n")
        wts <- format(round(x$wts, 2))
       # names(wts) <- apply(cbind(x$conn, rep(1:x$nunits - 1, tconn)), 1,
       #         function(x)
       # paste(x, collapse = "->"))
       # print(wts, quote = FALSE)
  wm <- c("b", paste("i", seq(length=x$n[1]), sep=""))
  if(x$n[2] > 0) wm <- c(wm, paste("h", seq(length=x$n[2]), sep=""))
  if(x$n[3] > 1)
    wm <- c(wm, paste("o", seq(length=x$n[3]), sep=""))
  else wm <- c(wm, "o")
  names(wts) <- apply(cbind(wm[1+x$conn], wm[1+rep(1:x$nunits - 1, tconn)]),
		      1, function(x)  paste(x, collapse = "->"))
  lapply(split(wts,rep(1:x$nunits, tconn)),
	 function(x) print(x, quote= FALSE))
        invisible()
}
predict.phnnet.formula <- function(object, newdata,  ...)
{
  if(!inherits(object, "phnnet.formula"))
    stop("object not of class phnnet.formula")
  if(missing(newdata))
      return(object$fitted.values)
  x <- model.matrix(delete.response(object$terms), newdata)
  xint <- match("(Intercept)", dimnames(x)[[2]], nomatch=0)
  if(xint > 0) x <- x[, -xint, drop= FALSE] # Bias term is used for intercepts
  predict.phnnet(object=object, x=x)
}


predict.phnnet <- function(object, x, ...)
{
  if(!inherits(object, "phnnet"))
    stop("object not of class phnnet")
  if(missing(x))
    z <- object$fitted
  else {
    if(is.null(dim(x)))
      dim(x) <- c(1, length(x))
    if(any(is.na(x)))
      stop("missing values in x")
    x <- as.matrix(x)
    ntr <- dim(x)[1]
    nout <- object$n[3]
    .C("set_phtnet",
       as.integer(object$n),
       as.integer(object$nconn),
       as.integer(object$conn),
       as.integer(object$nsunits),
       as.single(object$decay),
       PACKAGE="survnnet")
    z <- .C("pred_phnnet",
            as.integer(ntr),
            as.single(x),
            as.double(object$wts),
            pred = single(ntr),
            PACKAGE="survnnet")
    z <- matrix(z$pred, ntr, nout)
    dimnames(z) <- list(dimnames(x)[[1]], dimnames(object$fitted)[[2]])
    class(z) <- object$errfn
  }
  z
}

eval.pnn <- function(wts)
{
        z <- .C("phtfunc",
                as.double(wts),
                df = double(length(wts)),
                fp = as.double(1),
                PACKAGE="survnnet")
        fp <- z$fp
        attr(fp, "gradient") <- z$df
        fp
}

phnnet.Hess <- function(net, x, status)
{
  x <- as.matrix(x)
  status <- as.matrix(status)
  xx <- x
  yy <- status
  if (ncol(status) ==2)
    { # need sorting
      ind <- rev(order(status[,1]))
      x <- x[ind,, drop = FALSE]
      status <- status[ind,2]
    }
  else status <- as.vector(status)
  if(dim(x)[1] !=length(status)) stop("dims of x and status must match")
  nw <- length(net$wts)
  .C("set_phtnet",
     as.integer(net$n),
     as.integer(net$nconn),
     as.integer(net$conn),
     as.integer(net$nsunits),
     as.single(net$decay),
     PACKAGE="survnnet"
     )
  ntr <- dim(x)[1]
  z <- .C("set_phtrain",
	  as.integer(ntr),
	  as.single(x),
	  as.integer(status),
          PACKAGE="survnnet")
  z <- matrix(.C("phtfunc2",as.double(net$wts),H = double(nw*nw),
                 PACKAGE="survnnet")$H,nw,nw)
  .C("unset_phtrain", PACKAGE="survnnet")
  .C("unset_phtnet", PACKAGE="survnnet")
  z
}
