
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "split_stdin.h"

main(int argc, char *argv[]) {
  int num_fields;
  int i, j, prediction, label, correct, incorrect;
  int image[65];
  int confusion[10][10];
  char **response;
  
  // The confusion matrix will count prediction-label pairs
  for (i = 0; i < 10; i++)
    for (j = 0; j < 10; j++)
      confusion[i][j] = 0;

  image[0] = 1;
  prediction = 0;
  correct = 0;
  incorrect = 0;

  for (;;) {
    response = split_stdin(" ", &num_fields);
    if (num_fields == -1) break;
    if (num_fields == 64) {
      // Input the image and make a prediction.

      for (j = 1; j <= 64; j++)
	image[j] = atoi(response[j-1]);
      free_split_stdin(response, num_fields);
      
      prediction = 0;

      printf("%d\n", prediction); fflush(stdout);

    } else if (num_fields == 10) {
      // Input the correct label and record result.
      // Should be updating weights.

      if (0 == strcmp("correct", response[0]))
	correct++;
      else if (0 == strcmp("incorrect", response[0]))
	incorrect++;
      else {
	free_split_stdin(response, num_fields);
	continue;
      }

      label = atoi(response[3]);
      confusion[label][prediction]++;

      // Output error rate every 1000 examples.
      if ((correct + incorrect) % 1000 == 0) {
	printf("user error rate = %d/%d = %g%%\n",
	       incorrect, correct + incorrect,
	       100*(0.0 + incorrect)/(0.0 + correct + incorrect));
	fflush(stdout);
      }

      free_split_stdin(response, num_fields);
    
    } else
      fprintf(stderr, "response has %d fields\n", num_fields);
  }
  
  // Print out confusion matrix

  printf("user confusion matrix, labels are rows, predictions are columns\n");
  printf("user         0     1     2     3     4     5     6     7     8     9\n");
  printf("user   +-------------------------------------------------------------+\n");
  for (i = 0; i < 10; i++) {
    printf("user %d |", i);
    for (j = 0; j < 10; j++)
      printf(" %5d", confusion[i][j]);
    printf(" |\n");
  }
  printf("user   +-------------------------------------------------------------+\n");
  fflush(stdout);
}

