/* nag_rand_subsamp_xyw (g05pwc) Example Program.
 *
 * Copyright 2017 Numerical Algorithms Group.
 *
 * Mark 26.1, 2017.
 */
/* Pre-processor includes */
#include <stdio.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagg02.h>
#include <nagg05.h>

int main(void)
{
  /* Integer scalar and array declarations */
  Integer fn, fp, i, ip, pdx, lstate, m,
         n, nn, np, nt, nv, obs_val, pred_val, subid,
         tn, tp, j, pdv, rank, max_iter, print_iter, nsamp, samp;
  Integer exit_status = 0, lseed = 1;
  Integer *isx = 0, *state = 0;
  Integer seed[1];

  /* NAG structures and types */
  NagError fail;
  Nag_Link link;
  Nag_IncludeMean mean;
  Nag_BaseRNG genid;
  Nag_Distributions errfn;
  Nag_Boolean vfobs;
  Nag_DataByObsOrVar sordx;

  /* Double scalar and array declarations */
  double ex_power, dev, eps, tol, df, scale;
  double *b = 0, *cov = 0, *eta = 0, *pred = 0, *se = 0, *seeta = 0,
         *sepred = 0, *v = 0, *offset = 0, *wt = 0, *x = 0, *y = 0, *t = 0;

  /* Character scalar and array declarations */
  char clink[40], cmean[40], cgenid[40];

  /* Initialize the error structure */
  INIT_FAIL(fail);

  printf("nag_rand_subsamp_xyw (g05pwc) Example Program Results\n\n");

  /* Skip heading in data file */
  scanf("%*[^\n] ");

  /* Set variables required by the regression (g02gbc) ... */

  /* Read in the type of link function, whether a mean is required */
  /* and the problem size */
  scanf("%39s%39s%" NAG_IFMT "%" NAG_IFMT "%*[^\n] ", clink, cmean, &n, &m);
  link = (Nag_Link) nag_enum_name_to_value(clink);
  mean = (Nag_IncludeMean) nag_enum_name_to_value(cmean);

  /* Set storage order for g05pwc */
  /* (pick the one required by g02gbc and g02gpc) */
  sordx = Nag_DataByObs;

  pdx = m;
  if (!(x = NAG_ALLOC(pdx * n, double)) ||
      !(y = NAG_ALLOC(n, double)) ||
      !(t = NAG_ALLOC(n, double)) || !(isx = NAG_ALLOC(m, Integer)))
  {
    printf("Allocation failure\n");
    exit_status = -1;
    goto END;
  }

  /* This example is not using an offset or weights */
  offset = 0;
  wt = 0;

  /* Read in data */
  for (i = 0; i < n; i++) {
    for (j = 0; j < m; j++) {
      scanf("%lf", &x[i * pdx + j]);
    }
    scanf("%lf%lf%*[^\n] ", &y[i], &t[i]);
  }

  /* Read in variable inclusion flags */
  for (j = 0; j < m; j++) {
    scanf("%" NAG_IFMT "", &isx[j]);
  }
  scanf("%*[^\n] ");

  /* Read in control parameters for the regression */
  scanf("%" NAG_IFMT "%lf%lf%" NAG_IFMT "%*[^\n] ", &print_iter, &eps,
        &tol, &max_iter);

  /* Calculate IP */
  for (ip = 0, i = 0; i < m; i++)
    ip += (isx[i] > 0);
  if (mean == Nag_MeanInclude)
    ip++;
  /* ... End of setting variables required by the regression */

  /* Set variables required by data sampling routine (g05pwc) ... */

  /* Read in the base generator information and seed */
  scanf("%39s%" NAG_IFMT "%" NAG_IFMT "%*[^\n] ", cgenid, &subid, &seed[0]);
  genid = (Nag_BaseRNG) nag_enum_name_to_value(cgenid);

  /* Initial call to g05kfc to get size of STATE array */
  lstate = 0;
  nag_rand_init_repeatable(genid, subid, seed, lseed, state, &lstate,
                           NAGERR_DEFAULT);

  /* Allocate state array */
  if (!(state = NAG_ALLOC(lstate, Integer)))
  {
    printf("Allocation failure\n");
    exit_status = -1;
    goto END;
  }

  /* Initialize the generator to a repeatable sequence using g05kfc */
  nag_rand_init_repeatable(genid, subid, seed, lseed, state, &lstate,
                           NAGERR_DEFAULT);

  /* Read in the size of the training set required */
  scanf("%" NAG_IFMT "%*[^\n] ", &nt);

  /* Read in the number of sub-samples we will use */
  scanf("%" NAG_IFMT "%*[^\n] ", &nsamp);
  /* ... End of setting variables required by data sampling routine */

  /* Set variables required by prediction routine (g02gpc) ... */

  /* Regression is performed using g02gbc so error structure is binomial */
  errfn = Nag_Binomial;

  /* This example does not use the predicted standard errors, so */
  /* it doesn't matter what VFOBS is set to */
  vfobs = Nag_FALSE;
  /* The error and link being used in the linear model don't use scale */
  /* and ex_power so they can be set to anything */
  ex_power = 0.0;
  scale = 0.0;
  /* ... End of setting variables required by prediction routine */

  /* Calculate the size of the validation dataset */
  nv = n - nt;

  /* Allocate arrays */
  pdv = n;
  if (!(b = NAG_ALLOC(ip, double)) ||
      !(se = NAG_ALLOC(ip, double)) ||
      !(cov = NAG_ALLOC(ip * (ip + 1) / 2, double)) ||
      !(v = NAG_ALLOC(n * pdv, double)) ||
      !(eta = NAG_ALLOC(nv, double)) ||
      !(seeta = NAG_ALLOC(nv, double)) ||
      !(pred = NAG_ALLOC(nv, double)) || !(sepred = NAG_ALLOC(nv, double)))

  {
    printf("Allocation failure\n");
    exit_status = -1;
    goto END;
  }

  /* Initialize counts */
  tp = tn = fp = fn = 0;

  /* Loop over each sample */
  for (samp = 1; samp <= nsamp; samp++)
  {
    /* Use g05pwc to split the data into training and validation datasets */
    nag_rand_subsamp_xyw(nt, n, m, sordx, x, pdx, y, t, state, &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_rand_subsamp_xyw (g05pwc).\n%s\n", fail.message);
      exit_status = 1;
      goto END;
    }

    /* Call g02gbc to fit generalized linear model, with Binomial */
    /* errors to training data */
    nag_glm_binomial(link, mean, nt, x, pdx, m, isx, ip, y, t, wt,
                     offset, &dev, &df, b, &rank, se, cov, v, pdv,
                     tol, max_iter, print_iter, "", eps, &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_glm_binomial (g02gbc).\n%s\n", fail.message);
      exit_status = 1;
      goto END;

    }

    /* Call g02gpc to predict the response for the observations in the */
    /* validation dataset */
    /* We want to start passing X and T at the (NT+1)th observation, */
    /* These start at (i,j)=(nt+1,1), hence the (nt*pdx+0)th element */
    /* of X and the nt'th element of T */
    nag_glm_predict(errfn, link, mean, nv, &x[nt * pdx], pdx, m, isx, ip,
                    &t[nt], offset, wt, scale, ex_power, b, cov, vfobs, eta,
                    seeta, pred, sepred, &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_glm_predict (g02gpc).\n%s\n", fail.message);
      exit_status = 1;
      goto END;
    }

    /* Count the true/false positives/negatives */
    for (i = 0; i < nv; i++) {
      obs_val = (Integer) y[nt + i];
      pred_val = (pred[i] >= 0.5 ? 1 : 0);
      if (obs_val) {
        /* Positive */
        if (pred_val) {
          /* True positive */
          tp++;
        }
        else {
          /* False Negative */
          fn++;
        }
      }
      else {
        /* Negative */
        if (pred_val) {
          /* False positive */
          fp++;
        }
        else {
          /* True negative */
          tn++;
        }
      }
    }
  }

  /* Display results */
  np = tp + fn;
  nn = fp + tn;
  printf("                       Observed\n");
  printf("             --------------------------\n");
  printf(" Predicted | Negative  Positive   Total\n");
  printf(" --------------------------------------\n");
  printf(" Negative  | %5" NAG_IFMT "     %5" NAG_IFMT "     %5" NAG_IFMT
         "\n", tn, fn, tn + fn);
  printf(" Positive  | %5" NAG_IFMT "     %5" NAG_IFMT "     %5" NAG_IFMT
         "\n", fp, tp, fp + tp);
  printf(" Total     | %5" NAG_IFMT "     %5" NAG_IFMT "     %5" NAG_IFMT
         "\n", nn, np, nn + np);
  printf("\n");

  if (np != 0) {
    printf(" True Positive Rate (Sensitivity): %4.2f\n",
           (double) tp / (double) np);
  }
  else {
    printf(" True Positive Rate (Sensitivity): No positives in data\n");
  }
  if (nn != 0) {
    printf(" True Negative Rate (Specificity): %4.2f\n",
           (double) tn / (double) nn);
  }
  else {
    printf(" True Negative Rate (Specificity): No negatives in data\n");
  }

END:

  NAG_FREE(isx);
  NAG_FREE(state);
  NAG_FREE(b);
  NAG_FREE(cov);
  NAG_FREE(eta);
  NAG_FREE(pred);
  NAG_FREE(se);
  NAG_FREE(seeta);
  NAG_FREE(sepred);
  NAG_FREE(t);
  NAG_FREE(x);
  NAG_FREE(y);
  NAG_FREE(v);
  NAG_FREE(offset);
  NAG_FREE(wt);

  return (exit_status);
}