#include "rFunctionLinearExtension.h"
#include "rwrapper.h"

//************************************
//************************************
//************************************

FLERinterface::FLERinterface(std::shared_ptr<POSet> poset, SEXP r_function) : FunctionLinearExtension(poset) {
    __r_function = r_function;
    
    LinearExtension le(__poset->size());
    __poset->FirstLE(le);
    
    int conta_proteggi = 0;
    auto le_r = Proteggi(Rf_allocVector(STRSXP, (long) le.size()), conta_proteggi);
    for(std::uint_fast64_t k = 0; k < le.size(); ++k) {
        auto v = le.getVal(k);
        std::string s = poset->GetEName(v);
        SET_STRING_ELT(le_r, (long) k, Rf_mkChar(s.c_str()));
    }
    auto R_fcall = Proteggi(Rf_lang2(__r_function, R_NilValue), conta_proteggi);
    SETCADR(R_fcall, le_r);
    auto fcall_ans = Proteggi(Rf_allocVector(VECSXP, 1), conta_proteggi);
    SET_VECTOR_ELT(fcall_ans, 0, Rf_eval(R_fcall, R_GlobalEnv));
    auto ans = VECTOR_ELT(fcall_ans, 0);
    if (!Rf_isMatrix(ans)) {
        std::string err_str = "Not a matrix";
        throw_line(err_str);
    }
    
    auto ans_dim = Proteggi(Rf_getAttrib(ans, R_DimSymbol), conta_proteggi);
    std::uint_fast64_t ans_nrow = (std::uint_fast64_t) INTEGER(ans_dim)[0];
    std::uint_fast64_t ans_ncol = (std::uint_fast64_t) INTEGER(ans_dim)[1];
    
    __shape.push_back(ans_nrow);
    __shape.push_back(ans_ncol);
    
    __rows_name.resize(ans_nrow);
    __cols_name.resize(ans_ncol);
    auto ans_dim_names = Proteggi(Rf_getAttrib(ans, R_DimNamesSymbol), conta_proteggi);
    if (ans_dim_names == R_NilValue) {
        for (std::uint_fast64_t k = 0; k < ans_nrow; ++k) {
            __rows_name.at(k) = std::to_string(k + 1);
        }
        for (std::uint_fast64_t k = 0; k < ans_ncol; ++k) {
            __cols_name.at(k) = std::to_string(k + 1);
        }
    } else {
        auto ans_dim_names_row = VECTOR_ELT(ans_dim_names, 0);
        for (std::uint_fast64_t k = 0; k < ans_nrow; ++k) {
            auto v = CHAR(STRING_ELT(ans_dim_names_row, (long) k));
            __rows_name.at(k) = v;
        }
        auto ans_dim_names_col = VECTOR_ELT(ans_dim_names, 1);
        for (std::uint_fast64_t k = 0; k < ans_ncol; ++k) {
            auto v = CHAR(STRING_ELT(ans_dim_names_col, (long) k));
            __cols_name.at(k) = v;
        }
    }
            
    __data.clear();
    if (conta_proteggi > 0) UNPROTECT(conta_proteggi);
}
    
// ***********************************************
// ***********************************************
// ***********************************************

void FLERinterface::operator()(std::shared_ptr<LinearExtension> x) {
    ++__calls;
    
    int conta_proteggi = 0;
    auto le_r = Proteggi(Rf_allocVector(STRSXP, (long) x->size()), conta_proteggi);
    for(std::uint_fast64_t k = 0; k < x->size(); ++k) {
        auto v = x->getVal(k);
        std::string s = __poset->GetEName(v);
        SET_STRING_ELT(le_r, (long) k, Rf_mkChar(s.c_str()));
    }
    auto R_fcall = Rf_lang2(__r_function, R_NilValue);
    SETCADR(R_fcall, le_r);
    auto fcall_ans = Proteggi(Rf_allocVector(VECSXP, 1), conta_proteggi);
    SET_VECTOR_ELT(fcall_ans, 0, Rf_eval(R_fcall, R_GlobalEnv));
    auto ans = VECTOR_ELT(fcall_ans, 0);
    if (!Rf_isMatrix(ans)) {
        std::string err_str = "Not a matrix";
        throw_line(err_str);
    }

    auto ans_dim = Rf_getAttrib(ans, R_DimSymbol);
    std::uint_fast64_t ans_nrow = (std::uint_fast64_t) INTEGER(ans_dim)[0];
    std::uint_fast64_t ans_ncol = (std::uint_fast64_t) INTEGER(ans_dim)[1];
    
    char data_type = 0;
    if (Rf_isLogical(ans)) {
        data_type = 1;
    } else if (Rf_isReal(ans)) {
        data_type = 2;
    } else if (Rf_isInteger(ans)) {
        data_type = 3;
    } else {
        std::string err_str = "Not a numerical matrix";
        throw_line(err_str);
    }
    __data.clear();
    for (std::uint_fast64_t riga = 0; riga < ans_nrow; ++riga) {
        for (std::uint_fast64_t col = 0; col < ans_ncol; ++col) {
            double v = 0.0;
            if (data_type == 1) {
                v = LOGICAL(ans)[riga + ans_nrow * col];
            } else if (data_type == 2) {
                v = REAL(ans)[riga + ans_nrow * col];
            } else if (data_type == 3) {
                v = INTEGER(ans)[riga + ans_nrow * col];
            }
            __data.push_back(std::make_tuple(riga, col, v));
        }
    }
            
    if (conta_proteggi > 0) UNPROTECT(conta_proteggi);
}

