//
// Created by Margherita Milani on 22/09/22.
//
#include "lib.h"

int main() {


    struct allMatrix allMatrix;
    initialize_all_matrix(&allMatrix);

    time_t begin = time(NULL);

    //// INPUT & PATCH EMBEDDING LAYER
    //PATCH EMBEDDING
    convolution(&allMatrix.input_patch_embedding, &allMatrix.embeddingLayer.W, &allMatrix.Z_patch_embedding);

    //CLS_TOKEN
    cat_matrix_on2dim(&allMatrix.token.cls_token, &allMatrix.Z_patch_embedding,
                      &allMatrix.Z_CLS); //embedDim * numPatches+1 * numSamples

    //POS EMBEDDING
    sum_matrix(&allMatrix.Z_CLS, &allMatrix.pos_embed_layer.pos_embed); //embedDim * numPatches+1 * numSamples

    //copy for addition layer
    duplicate_matrix(&allMatrix.Z_CLS, &allMatrix.copy);


    //// TRANSFORMER BLOCKS x 4
    for (int i = 0; i < 4; i++) {

        //NORMALIZATION LAYER
        float dev;
        float avg;
        calculate_average(&allMatrix.Z_CLS, &avg);
        calculate_deviation(&allMatrix.Z_CLS, &dev, &avg);
        calculate_normalization(&allMatrix.Z_CLS, &avg, &dev);


        //ATTENTION LAYER
        attention_layer(&allMatrix.Z_CLS, &allMatrix.att.W_qkv, &allMatrix.att.b_qkv, &allMatrix.Z_att,
                        &allMatrix.att.W_proj);

        // FF
        sum_matrix(&allMatrix.Z_att, &allMatrix.copy);
        duplicate_matrix(&allMatrix.Z_att, &allMatrix.result_first_add);

        //NORMALIZATION LAYER
        calculate_average(&allMatrix.Z_att, &avg);
        calculate_deviation(&allMatrix.Z_att, &dev, &avg);
        calculate_normalization(&allMatrix.Z_att, &avg, &dev);

        //MLP
        matrix_product_one_face(&allMatrix.MLP_layer.W, &allMatrix.Z_att, &allMatrix.res_MLP);
        sum_vector_matrix(&allMatrix.res_MLP, &allMatrix.MLP_layer.b);

        //GELU
        GELU(&allMatrix.res_MLP);

        //MLP
        matrix_product_one_face(&allMatrix.MLP_layer.W2, &allMatrix.res_MLP, &allMatrix.res2_MLP);
        sum_vector_matrix(&allMatrix.res2_MLP, &allMatrix.MLP_layer.b2);

        //FF
        sum_matrix(&allMatrix.res2_MLP, &allMatrix.result_first_add);

        if (i < 5) {
            // OUTPUT = INPUT for the next iteration
            copy_matrix_from_1_to_2(&allMatrix.res2_MLP, &allMatrix.Z_CLS);
            copy_matrix_from_1_to_2(&allMatrix.res2_MLP, &allMatrix.copy);
            free(allMatrix.res2_MLP.val);
            free(allMatrix.result_first_add.val);

        }

        //FREE
        free(allMatrix.res_MLP.val);
        free(allMatrix.Z_att.val);
    }



    /// CLASSIFICATION
    //NORMALIZATION LAYER
    float avg = 0;
    float dev = 0;
    calculate_average(&allMatrix.res2_MLP, &avg);
    calculate_deviation(&allMatrix.res2_MLP, &dev ,&avg);
    calculate_normalization(&allMatrix.res2_MLP, &avg, &dev);

    //FORWARD ONLY CLS
    forward_only_cls(&allMatrix.res2_MLP, &allMatrix.Z_forward_only);

    //FULLY CONNECTED (MLP)
    matrix_product_one_face(&allMatrix.MLP_layer.W3, &allMatrix.Z_forward_only, &allMatrix.res2_MLP);
    sum_vector_matrix(&allMatrix.res2_MLP, &allMatrix.MLP_layer.b2);

    //SOFTMAX
    soft_max(&allMatrix.res2_MLP, &allMatrix.create_for_softmax);

    printf("RESULT\n");
    print_mat(&allMatrix.res2_MLP);

    /// TIME
    time_t end = time(NULL);
    printf("The elapsed time is %d seconds", (end - begin));

}