import java.util.*;

public class EntryPoint{

    //required objects
    static DataReader reader = null;

    //Dataset constants for toy data
    static final String TOY_DATASET_PATH =
            "/home/rahul/data/src/ml/dataset/toy/ex3.csv";
    static final int TOY_ROW_COUNT = 47;
    static final int TOY_COL_COUNT = 3;
    static final String TOY_DATA_TYPE = "CSV";
    static final int[] TOY_NUMERIC_COL_IDXS =
            {0,1};//started with 0 for convenience
    static final int[] TOY_NOMINAL_COL_IDXS = {};//0 based
    static final int TOY_LABEL_COL_IDX = 2;// 0 based
    //Dataset constants for housing data
    static final String HOUSING_DATASET_PATH =
            "/home/rahul/data/src/ml/dataset/hw1/housing/housing_train.txt";
    static final String HOUSING_TESTSET_PATH =
            "/home/rahul/data/src/ml/dataset/hw1/housing/housing_test.txt";
    static final int HOUSING_ROW_COUNT = 433;
    static final int HOUSING_TEST_ROW_COUNT = 74;
    static final int HOUSING_COL_COUNT = 14;
    static final String HOUSING_DATA_TYPE = "SSV";
    static final int[] HOUSING_NUMERIC_COL_IDXS =
            {0,1,2,4,5,6,7,8,9,10,11,12};//started with 0 for convenience
    static final int[] HOUSING_NOMINAL_COL_IDXS = {3};//0 based
    static final int HOUSING_LABEL_COL_IDX = 13;// 0 based
    //Dataset constants for spambase data---------------------------------------
    static final String SPAMBASE_DATASET_PATH =
            "/home/rahul/data/src/ml/dataset/hw1/spam/spambase.data";
    static final int SPAMBASE_ROW_COUNT = 4601;
    static final int SPAMBASE_COL_COUNT = 58;
    static final String SPAMBASE_DATA_TYPE = "CSV";
    static final int[] SPAMBASE_NUMERIC_COL_IDXS =
            {0,1,2,3,4,5,6,7,8,9,10,
                    11,12,13,14,15,16,17,18,19,20,
                    21,22,23,24,25,26,27,28,29,30,
                    31,32,33,34,35,36,37,38,39,40,
                    41,42,43,44,45,46,47,48,49,50,
                    51,52,53,54,55,56};//started with 0 for convenience
    static final int[] SPAMBASE_NOMINAL_COL_IDXS = {};//0 based
    static final int SPAMBASE_LABEL_COL_IDX = 57;// 0 based

    /*--------------------------------------------------------------------------
      Functions */
    public static void main(String[] args) throws Exception {
        try {
            //int dbSelector = Integer.parseInt(args[0]);
            //int regSelector = Integer.parseInt(args[1]);
            int dbSelector = 1;
            int regSelector = 1;
            reader = new DataReader();
            System.out.println("");
            double[][] data = null; double[][] dataFromFile = null;
            DataProcessor processor = new DataProcessor();
            if(dbSelector == -1){
                data = reader.readFile(TOY_DATASET_PATH,
                        TOY_DATA_TYPE,
                        TOY_ROW_COUNT,
                        TOY_COL_COUNT);
                Dataset toyData = getToyDataset(data);
                if(regSelector == 0){
                    Regressor linearPredictor = new Regressor(toyData);
                    linearPredictor.buildLinearRegressionModel();
                } else {
                    LogisticRegressor predictor = new LogisticRegressor
                            (toyData, true);
                    predictor.buildLogisticRegressionModel();
                }
                return;
            }
            if(dbSelector == 0 ){
                data = reader.readFile(HOUSING_DATASET_PATH,
                        HOUSING_DATA_TYPE,
                        HOUSING_ROW_COUNT,
                        HOUSING_COL_COUNT);
                double[][] trainSet = reader.readFile(HOUSING_DATASET_PATH,
                        HOUSING_DATA_TYPE,
                        HOUSING_ROW_COUNT,
                        HOUSING_COL_COUNT);
                Dataset housingData = getHousingDataset(data);
                if(regSelector == 0){
                    Regressor linearPredictor = new Regressor(housingData);
                    linearPredictor.buildLinearRegressionModel();
                } else {
                    LogisticRegressor predictor = new LogisticRegressor
                            (housingData, true);
                    predictor.buildLogisticRegressionModel();
                }
                Log.write("Reading test dataset...");
                double[][] testdata = reader.readFile(HOUSING_TESTSET_PATH,
                        HOUSING_DATA_TYPE,
                        HOUSING_TEST_ROW_COUNT,
                        HOUSING_COL_COUNT);
                Log.write("");
                Log.write("");
            } else {
                dataFromFile = reader.readFile(SPAMBASE_DATASET_PATH,
                        SPAMBASE_DATA_TYPE,
                        SPAMBASE_ROW_COUNT,
                        SPAMBASE_COL_COUNT);
                //normalize dataset
                Regressor normalizer =
                        new Regressor(getSpamDataset(dataFromFile));
                //normalize whole dataset before making any folds
                data = normalizer.normalizeDataset();
                int foldCount = 10;ArrayList<double[][]> folds = null;
                int foldRecordCount = data.length / foldCount;
                int allFoldsRecordCount = 0; double testMSESum = 0.0;
                double trainMSESum = 0.0;
                for(int k = 0; k< foldCount; k++){
                    int testStartIDx = foldRecordCount * k;
                    Log.write("");
                    Log.write("Test set Start Index: "+ testStartIDx);
                    int testEndIDx = testStartIDx;
                    if(k == foldCount - 1){
                        testEndIDx += data.length - allFoldsRecordCount;
                        Log.write("Test set End Index: "+ testEndIDx);
                    } else{
                        testEndIDx += foldRecordCount;
                        Log.write("Test set End Index: "+ (testEndIDx - 1));
                    }
                    allFoldsRecordCount += foldRecordCount;
                    folds = processor.splitTrainTest(data, testStartIDx,
                            testEndIDx);
                    double[][] trainset = folds.get(0);
                    double[][] trainsetForPrediction = folds.get(0);
                    double[][] testset = folds.get(1);

                    Dataset spamTrainingSet = getSpamSubset(trainset);
                    if(regSelector == 0){
                        //build model
                        Regressor linearPredictor = new Regressor
                                (spamTrainingSet);
                        double[] linearModel =
                                linearPredictor.buildLinearRegressionModel();
                        //get predictions for training set
                        Dataset spamTrainingSetForPrediction =
                                getSpamSubset(trainsetForPrediction);
                        double[][] trainingPredictions =
                                Regressor.predict(spamTrainingSetForPrediction,
                                        linearModel);
                        Log.write("MSE for training set is:" +
                                Regressor.computeMSE(trainingPredictions));
                        //compute roc stats for training set
                        Log.write("ROC Stats for training set are: ");
                        Log.write("");
                        PredictorStatsCalculator.computeROCStats
                                (trainingPredictions,
                                        null,
                                        new int[]{trainingPredictions.length},
                                        false,true, false);
                        //get predictions for test set
                        Dataset spamTestset = getSpamSubset(testset);
                        double[][] predictions =
                                Regressor.predict(spamTestset, linearModel);
                        //compute roc stats for test set
                        Log.write("ROC Stats for test set are: ");
                        Log.write("");
                        Log.write("MSE for test set is:" +
                                Regressor.computeMSE(predictions));
                        PredictorStatsCalculator.computeROCStats(predictions,
                                null, new int[]{predictions.length}, false,
                                true, false);
                    } else {
                        //build model
                        LogisticRegressor predictor = new LogisticRegressor
                                (spamTrainingSet);
                        double[] model = predictor.buildLogisticRegressionModel
                                ();
                        //get predictions for training set
                        Dataset spamTrainingSetForPrediction =
                                getSpamSubset(trainsetForPrediction);
                        double[][] trainingPredictions =
                                LogisticRegressor.predict(spamTrainingSetForPrediction,
                                        model);
                        Log.write("MSE for training set is:" +
                                LogisticRegressor.computeMSE(
                                        trainingPredictions));
                        //compute roc stats for training set
                        Log.write("ROC Stats for training set are: ");
                        Log.write("");
                        PredictorStatsCalculator.computeROCStats
                                (trainingPredictions,
                                        null,
                                        new int[]{trainingPredictions.length},
                                        false,true, false);
                        //get predictions for test set
                        Dataset spamTestset = getSpamSubset(testset);
                        double[][] predictions =
                                LogisticRegressor.predict(spamTestset, model);
                        //compute roc stats for test set
                        Log.write("ROC Stats for test set are: ");
                        Log.write("");
                        Log.write("MSE for test set is:" +
                                LogisticRegressor.computeMSE(predictions));
                        PredictorStatsCalculator.computeROCStats(predictions,
                                null, new int[]{predictions.length}, false,
                                true, false);

                    }

                } // cross validation loop ends here
                Log.write("");
                Log.write("");
            } //main if-else ends here
            //-------------------------------------------------------
            //test statements----------------------------------------
            //Log.writeToFile(data, "train.csv", ",");
            System.out.println("");
            //System.out.print(data[432][2]); // dataset data
            System.out.println("");
        } catch(Exception e){
            System.out.println("Runtime Error occurred: ");
            System.out.println(e);
            throw e;
        }
    }

    private static Dataset getHousingDataset(double[][] data){
        Dataset housing = new Dataset();
        housing.filePath = HOUSING_DATASET_PATH;
        housing.rowCount = HOUSING_ROW_COUNT;
        housing.colCount = HOUSING_COL_COUNT;
        housing.numericColIDXs = HOUSING_NUMERIC_COL_IDXS;
        housing.nominalColIDXs = HOUSING_NOMINAL_COL_IDXS;
        housing.labelColIDx = HOUSING_LABEL_COL_IDX;
        housing.data = data;
        return housing;
    }

    private static Dataset getSpamSubset(double[][] data){
        Dataset spamSubset = new Dataset();
        spamSubset.filePath = SPAMBASE_DATASET_PATH;
        spamSubset.rowCount = data.length;
        spamSubset.colCount = SPAMBASE_COL_COUNT;
        spamSubset.numericColIDXs = SPAMBASE_NUMERIC_COL_IDXS;
        spamSubset.nominalColIDXs = SPAMBASE_NOMINAL_COL_IDXS;
        spamSubset.labelColIDx = SPAMBASE_LABEL_COL_IDX;
        spamSubset.data = data;
        return spamSubset;
    }

    private static Dataset getSpamDataset(double[][] data){
        Dataset spam = new Dataset();
        spam.filePath = SPAMBASE_DATASET_PATH;
        spam.rowCount = SPAMBASE_ROW_COUNT;
        spam.colCount = SPAMBASE_COL_COUNT;
        spam.numericColIDXs = SPAMBASE_NUMERIC_COL_IDXS;
        spam.nominalColIDXs = SPAMBASE_NOMINAL_COL_IDXS;
        spam.labelColIDx = SPAMBASE_LABEL_COL_IDX;
        spam.data = data;
        return spam;
    }
    private static Dataset getToyDataset(double[][] data){
        Dataset toy = new Dataset();
        toy.filePath = TOY_DATASET_PATH;
        toy.rowCount = TOY_ROW_COUNT;
        toy.colCount = TOY_COL_COUNT;
        toy.numericColIDXs = TOY_NUMERIC_COL_IDXS;
        toy.nominalColIDXs = TOY_NOMINAL_COL_IDXS;
        toy.labelColIDx = TOY_LABEL_COL_IDX;
        toy.data = data;
        return toy;
    }
}
