//
// Created by Margherita Milani on 22/08/22.
//
#include "lib.h"

//GENERAL
struct image image;
struct patch patch = estimate_patchDim();
int numPatches = (image.width * image.height) / (patch.height * patch.width);
int numSamples = 1;
int embedDim = patch.width * patch.height * image.faces; //256 o 512
int nHead_original = 1;
int numHead = 2;
//estimate_num_head(nHead_original);

//MLP
int MLP_inputDim = embedDim;
int MLP_outputDim = MLP_inputDim;


///GENERAL
void read_image(struct matrix *image1, FILE *f) {

    for (int i = 0; i < image1->row * image1->col * image1->faces; i++) {
        fscanf(f, "%f", &image1->val[i]);
    }
}

void matrix_product_one_face(struct matrix *mat1, struct matrix *mat2, struct matrix *res) {

    res->row = mat1->row;
    res->col = mat2->col;
    res->faces = mat1->faces;
    res->val = (float *) malloc(res->row * res->col * res->faces * sizeof(float));

    for (int i = 0; i < mat1->row; i++) {
        for (int j = 0; j < mat1->col; j++) {
            for (int k = 0; k < mat2->col; k++) {
                res->val[i * mat2->col + k] +=
                        mat1->val[i * mat1->col + j] * mat2->val[j * mat2->col + k];
            }
        }
    }
}

void matrix_product_one_face2(struct matrix *mat1, struct matrix *mat2, struct matrix *res) {
// input = embedDim * numPatch + 1 * numSamples -> every column of input become a face
// W = 3embedDim * embedDim
//res = 3embedDim * numPatch + 1 * numSamples -> 3embedDim * 1 * numPatch+1
    res->row = mat1->row;
    res->col = mat2->col;
    res->faces = mat2->faces;

    res->col = 1;
    res->faces = mat2->col;
    res->val = (float *) malloc(res->row * res->col * res->faces * sizeof(float));
    int h2 = 0;
    for (int h = 0; h < mat2->faces; h++) {
        for (int i = 0; i < mat1->row; i++) {
            for (int j = 0; j < mat1->col; j++) {
                for (int k = 0; k < mat2->col; k++) {

                    res->val[i * res->col + k * (res->row) + h * res->row * res->col] +=
                            mat1->val[i * mat1->col + j] * mat2->val[j * mat2->col + k + h * mat2->row * mat2->col];

                }
            }
            h2++;
        }
    }
}

void print_mat(struct matrix *mat) {

    for (int h = 0; h < mat->faces; h++) {
        for (int i = 0; i < mat->row; i++) {
            for (int j = 0; j < mat->col; j++) {
                printf("%f ", mat->val[i * mat->col + j + h * mat->row * mat->col]);
            }
            printf("\n");
        }
        printf("\n");
    }
}

void rand_matrix(struct matrix *mat) {

    for (int i = 0; i < mat->row * mat->col * mat->faces; i++) {
        mat->val[i] = rand() % 10;
        mat->val[i] = (float) mat->val[i] / (float) 10;
    }
}

void sum_vector_matrix(struct matrix *X, struct matrix *vector) {

    for (int h = 0; h < X->faces; h++) {
        for (int i = 0; i < X->row; i++) {
            for (int j = 0; j < X->col; j++) {
                X->val[i * X->col + j + h * X->row * X->col] =
                        X->val[i * X->col + j + h * X->row * X->col] + vector->val[i];
            }
        }
    }
}

void sum_vector_matrix2(struct matrix *mat, struct matrix *vector) {

    int row_diff = mat->row / vector->row;

    for (int h = 0; h < mat->faces; h++) {
        for (int k = 0; k < row_diff; k++) {
            for (int i = 0; i < vector->row; i++) {
                for (int j = 0; j < mat->col; j++) {

                    mat->val[i * mat->col + j + h * mat->row * mat->col + (k * mat->col * vector->row)] =
                            mat->val[i * mat->col + j + h * mat->row * mat->col + (k * mat->col * vector->row)] +
                            vector->val[i];


                }
            }
        }
    }
}

void scale_matrix(float scale, struct matrix *mat) {
    for (int i = 0; i < mat->col * mat->faces * mat->row; i++) {
        mat->val[i] = mat->val[i] * scale;
    }
}


void matrix_product_per_face(struct matrix *mat1, struct matrix *mat2, struct matrix *res) {

    res->row = mat1->row;
    res->col = mat2->col;
    res->faces = mat1->faces;
    res->val = (float *) malloc(res->row * res->col * res->faces * sizeof(float));

    for (int h = 0; h < mat1->faces; h++) {
        for (int i = 0; i < mat1->row; i++) {
            for (int j = 0; j < mat1->col; j++) {
                for (int k = 0; k < mat2->col; k++) {
                    res->val[i * mat2->col + k + h * mat1->row * mat2->col] +=
                            mat1->val[i * mat1->col + j + h * mat1->row * mat1->col] *
                            mat2->val[j * mat2->col + k + h * mat2->row * mat2->col];
                }
            }
        }
    }
}

void matrix_product_transpose_per_face(struct matrix *mat1, struct matrix *mat2, struct matrix *res) {

    res->row = mat1->row;
    res->col = mat2->col;
    res->faces = mat1->faces;
    res->val = (float *) malloc(res->row * res->col * res->faces * sizeof(float));

    for (int h = 0; h < mat1->faces; h++) {
        for (int i = 0; i < mat1->row; i++) {
            for (int j = 0; j < mat1->col; j++) {
                for (int k = 0; k < mat2->col; k++) {
                    res->val[i * mat2->col + k + h * mat1->row * mat2->col] +=
                            mat1->val[i + j * mat1->row + h * mat1->row * mat1->col] *
                            mat2->val[j * mat2->col + k + h * mat2->row * mat2->col];
                }
            }
        }
    }
}

void fill_matrix_with_zero(struct matrix *mat) {
    for (int i = 0; i < mat->row * mat->col * mat->faces; i++) {
        mat->val[i] = 0;
    }
}

void matrix_product_member_by_member(struct matrix *mat, struct matrix *mult) {
    for (int i = 0; i < mat->row * mat->col * mat->faces; i++) {
        mat->val[i] = mat->val[i] * mult->val[i];
    }
}

void sum_matrix(struct matrix *mat1, struct matrix *mat2) {
    for (int i = 0; i < mat1->col * mat1->faces * mat1->row; i++) {
        mat1->val[i] = mat1->val[i] + mat2->val[i];
    }
}

void copy_matrix_from_1_to_2(struct matrix *mat1, struct matrix *mat2) {
    for (int i = 0; i < mat1->row * mat1->col * mat1->faces; i++) {
        mat2->val[i] = mat1->val[i];
    }
}


/// MLP
void initialize_MLP(struct allMatrix *allMatrix) {

    allMatrix->MLP_layer.W.row = MLP_outputDim;
    allMatrix->MLP_layer.W.col = MLP_inputDim;
    allMatrix->MLP_layer.W.faces = 1;
    allMatrix->MLP_layer.W.val = (float *) malloc(
            allMatrix->MLP_layer.W.row * allMatrix->MLP_layer.W.col * sizeof(float));
    rand_matrix(&allMatrix->MLP_layer.W);


    allMatrix->MLP_layer.b.row = MLP_outputDim;
    allMatrix->MLP_layer.b.col = 1;
    allMatrix->MLP_layer.b.faces = 1;
    allMatrix->MLP_layer.b.val = (float *) malloc(allMatrix->MLP_layer.b.row * sizeof(float));
    rand_matrix(&allMatrix->MLP_layer.b);
}

void initialize_MLP2(struct allMatrix *allMatrix) {

    allMatrix->MLP_layer.W2.row = MLP_inputDim;
    allMatrix->MLP_layer.W2.col = MLP_outputDim;
    allMatrix->MLP_layer.W2.faces = 1;
    allMatrix->MLP_layer.W2.val = (float *) malloc(
            allMatrix->MLP_layer.W2.row * allMatrix->MLP_layer.W2.col * sizeof(float));
    rand_matrix(&allMatrix->MLP_layer.W2);


    allMatrix->MLP_layer.b2.row = MLP_inputDim;
    allMatrix->MLP_layer.b2.col = 1;
    allMatrix->MLP_layer.b2.faces = 1;
    allMatrix->MLP_layer.b2.val = (float *) malloc(allMatrix->MLP_layer.b2.row * sizeof(float));
    rand_matrix(&allMatrix->MLP_layer.b2);
}

void initialize_MLP3(struct allMatrix *allMatrix) {

    allMatrix->MLP_layer.W3.row = 2;
    allMatrix->MLP_layer.W3.col = embedDim;
    allMatrix->MLP_layer.W3.faces = 1;
    allMatrix->MLP_layer.W3.val = (float *) malloc(
            allMatrix->MLP_layer.W3.row * allMatrix->MLP_layer.W3.col * sizeof(float));
    rand_matrix(&allMatrix->MLP_layer.W3);


    allMatrix->MLP_layer.b3.row = 1;
    allMatrix->MLP_layer.b3.col = 2;
    allMatrix->MLP_layer.b3.faces = 1;
    allMatrix->MLP_layer.b3.val = (float *) malloc(allMatrix->MLP_layer.b3.row * sizeof(float));
    rand_matrix(&allMatrix->MLP_layer.b3);
}


/// ATTENTION

void
attention_layer(struct matrix *input, struct matrix *W_qkv, struct matrix *b, struct matrix *Z, struct matrix *W_proj) {

    /// X
    /*printf("X ATT\n");
    print_mat(input);

    /// W_qkv
    printf("W_qkv ATT\n");
    print_mat(W_qkv);*/

    /// res = W_qkv * X
    //res = 3embedDim * 1 * numPatch+1
    struct matrix QKV;
    matrix_product_per_face(W_qkv, input, &QKV);
    /*printf("res ATT\n");
    print_mat(&QKV);*/

    /// b_qkv
    /*printf("VECTOR b ATT\n");
    print_mat(&allMatrix.att.b_qkv);*/

    /// QKV = (W_qkv * X + b_qkv)
    sum_vector_matrix2(&QKV, b);
    /*printf("QKV\n");
    print_mat(&allMatrix.QKV_att.init);*/

    /// split matrix QKV in Q,K,V
    int splitSize = QKV.row / 3;
    /// create Q
    //(embedDim / numHead) * 1 * numHead * numPatch+1;
    struct matrix Q;
    split_QKV(0, splitSize, &QKV, &Q, splitSize);
    /*printf("Q\n");
    print_mat(&Q);*/

    /// create K
    struct matrix K;
    split_QKV(splitSize, 2 * splitSize, &QKV, &K, splitSize);
    /*printf("K\n");
    print_mat(&K);*/

    /// create V
    struct matrix V;
    split_QKV(2 * splitSize, 3 * splitSize, &QKV, &V, splitSize);
    /*printf("V\n");
    print_mat(&V);*/


    /// resKQ = K (transposed) * Q
    //1 * (embedDim / numHead) * numhead * numPatches+1  x (embedDim / numHead) * 1 * numHead * numPatch+1
    //= 1 * 1 * numHead * numPatches + 1
    int tmp = K.row;
    K.row = K.col;
    K.col = tmp;
    //printf("K reverse\n");
    //print_mat(&K.init);
    struct matrix resQK;
    //
    matrix_product_transpose_per_face(&K, &Q, &resQK);
    /*printf("K trasposto * Q = \n");
    print_mat(&resQK);*/

    /// scale resKQ
    //scale_matrix(3.74, &resQK);
    /*printf("scale K * Q\n");
    print_mat(&resQK);*/

    /// softmax
    struct matrix create_for_softmax;
    soft_max(&resQK, &create_for_softmax);
    /*printf("create for softmax\n");
    print_mat(&create_for_softmax);
    printf("SOFTMAX\n");
    print_mat(&resQK);*/
/*
    /// DropOut
    //dovrei creare ps simile ad att: io lo tengo uguale
    //struct matrix ps = resQK;
    apply_dropout(&allMatrix->resQK.init, allMatrix->att.attention_DropOutProb);
    /*printf("DROPOUT 1\n");
    print_mat(&resQK.init);*/

    /// Y = V * RES_QK
    // (embedDim / numHead) * 1 * numHead * numPatch+1   x   1 * 1 * numHead * numPatch + 1
    //= embedDim/numHead * 1 * numHead * numPatch + 1*/
    struct matrix Y;
    matrix_product_per_face(&V, &resQK, &Y);
    /*printf("V\n");
    print_mat(&V.init);
    printf("RES QK\n");
    print_mat(&resQK.init);

     */
    /*printf("Y = (V * QK) \n");
    print_mat(&Y);*/

    //printf("Z (Y reshape)\n");*/

    /// transpose Y
    // = embedDim/NumHead * numPatch + 1 *  numHead
    // = (embedDim) * 1 * numPatches + 1
    Y.row = Y.row * Y.faces;
    Y.faces = 1;

    /*printf("Z (Y reshape)\n");
    print_mat(&allMatrix.Y_att.init);
    print_mat(&Y.init);*/

    /// W proj
    /*printf("W_proj\n");
    print_mat(W_proj);*/

    /// Z = W_proj * Y (transposed)
    matrix_product_one_face(W_proj, &Y, Z);
    /*printf("Z (W proj * Y)\n");
    print_mat(Z);*/
/*
    /// apply dropout
    //dovrei creare ps simile a Z ma tengo Z
    apply_dropout(&allMatrix->Z_att, allMatrix->att.attention_projDropOutProb);
    //printf("DROPOUT Z\n");
    //print_mat(&Z_att);
    */
    free(QKV.val);
    free(Q.val);
    free(K.val);
    free(V.val);
    free(Y.val);
    free(resQK.val);
    free(create_for_softmax.val);

}

void initialize_attention_layer(struct allMatrix *allMatrix) {

    //W_qkv
    allMatrix->att.W_qkv.row = embedDim * 3;
    allMatrix->att.W_qkv.col = embedDim;
    allMatrix->att.W_qkv.faces = 1;
    allMatrix->att.W_qkv.val = (float *) malloc(
            allMatrix->att.W_qkv.row * allMatrix->att.W_qkv.col * allMatrix->att.W_qkv.faces * sizeof(float));
    rand_matrix(&allMatrix->att.W_qkv);
    FILE *W;
    W = fopen("/Users/margheritamilani/desktop/uni/tesi/W_qkv.txt", "r");
    read_image(&allMatrix->att.W_qkv, W);

    //b_qkv
    allMatrix->att.b_qkv.row = embedDim;
    allMatrix->att.b_qkv.col = 1;
    allMatrix->att.b_qkv.faces = 1;
    allMatrix->att.b_qkv.val = (float *) malloc(
            allMatrix->att.b_qkv.row * allMatrix->att.b_qkv.col * allMatrix->att.b_qkv.faces * sizeof(float));
    memset(allMatrix->att.b_qkv.val, 0,
           allMatrix->att.b_qkv.row * allMatrix->att.b_qkv.col * allMatrix->att.b_qkv.faces * sizeof(float));


    allMatrix->att.W_proj.row = embedDim;
    allMatrix->att.W_proj.col = embedDim;
    allMatrix->att.W_proj.faces = 1;
    allMatrix->att.W_proj.val = (float *) malloc(
            allMatrix->att.W_proj.row * allMatrix->att.W_proj.col * sizeof(float));
    rand_matrix(&allMatrix->att.W_proj);

}

void split_QKV(int start, int fine, struct matrix *QKV, struct matrix *split,
               int splitSize) {

    split->row = splitSize / numHead;
    split->col = QKV->col;
    split->faces = numHead;
    split->val = (float *) malloc(
            split->row * split->col * split->faces * sizeof(float));

    for (int h = 0; h < split->faces; h++) {
        for (int i = 0; i < splitSize / numHead; i++) {
            for (int j = 0; j < split->col; j++) {

                split->val[i * split->col + j + h * split->row * split->col] = QKV->val[
                        (i + start) * split->col + j + h * split->row * split->col];
            }
        }
    }
}


void create_matrix_for_softmax(struct matrix *mat, struct matrix *create) {

    create->row = 1;
    create->col = mat->col;
    create->faces = mat->faces;
    create->val = (float *) malloc(create->row * create->col * create->faces * sizeof(float));

    for (int h = 0; h < mat->faces; h++) {
        for (int j = 0; j < mat->col; j++) {
            for (int i = 0; i < mat->row; i++) {
                create->val[j + h * create->col * create->row] += pow(M_E, mat->val[i + j * mat->row +
                                                                                    h * mat->row * mat->col]);
            }
        }
    }
    /*for (int h = 0; h < mat->faces; h++) {
            for (int i = 0; i < mat->row; i++) {
                for (int j = 0; j < mat->col; j++) {
                create->val[i + h * create->col * create->row ] += pow(M_E, mat->val[i * mat->row + j  + h * mat->row * mat->col]);
            }
        }
    }*/
}

void soft_max(struct matrix *mat, struct matrix *create) {
    create_matrix_for_softmax(mat, create);
    for (int h = 0; h < mat->faces; h++) {
        for (int j = 0; j < mat->col; j++) {
            for (int i = 0; i < mat->row; i++) {
                //create->val[j + h * create->col] += pow(M_E, mat->val[i + j * mat->row + h * mat->row * mat->col]);
                double g = pow(M_E, mat->val[i + j * mat->row + h * mat->row * mat->col]);
                g /= create->val[j + h * create->col * create->row];
                mat->val[i + j * mat->row + h * mat->row * mat->col] = g;
            }
        }
    }
}

void apply_dropout(struct matrix *mat, double probability) {

    struct matrix finale;
    finale.row = mat->row;
    finale.col = mat->col;
    finale.faces = mat->faces;
    finale.val = (float *) malloc(finale.row * finale.col * finale.faces * sizeof(float));
    for (int h = 0; h < mat->faces; h++) {
        for (int i = 0; i < mat->row; i++) {
            for (int j = 0; j < mat->col; j++) {
                finale.val[i * mat->col + j + h * mat->row * mat->col] = mat->val[i * mat->col + j +
                                                                                  h * mat->row * mat->col];
            }
        }
    }

    for (int h = 0; h < mat->faces; h++) {
        for (int i = 0; i < mat->row; i++) {
            for (int j = 0; j < mat->col; j++) {
                if (mat->val[i * mat->col + j + h * mat->row * mat->col] < probability) {
                    mat->val[i * mat->col + j + h * mat->row * mat->col] = 0;
                } else {
                    mat->val[i * mat->col + j + h * mat->row * mat->col] = 1;
                }
                mat->val[i * mat->col + j + h * mat->row * mat->col] /= probability;
            }
        }
    }
    /*printf("FINALE \n");
    print_mat(&finale);

    printf("MASK\n");
    print_mat(mat);*/
    matrix_product_member_by_member(mat, &finale);
}

/// CLS_TOKEN
void initialize_CLS_token(struct allMatrix *allMatrix) {

    allMatrix->token.cls_token.row = embedDim;
    allMatrix->token.cls_token.col = 1;
    allMatrix->token.cls_token.faces = numSamples;
    allMatrix->token.cls_token.val = (float *) malloc(allMatrix->token.cls_token.row * sizeof(float));
    fill_matrix_with_zero(&allMatrix->token.cls_token);;

}

void cat_matrix_on2dim(struct matrix *mat1, struct matrix *mat2, struct matrix *concat) {

    //print_mat(mat2);

    concat->row = mat1->row;
    concat->col = numPatches + 1;
    concat->faces = mat1->faces;
    concat->val = (float *) malloc(concat->faces * concat->row * concat->col * sizeof(float));
    int index = 0;
    int count = 0;
    int i2 = 0;
    int i = 0;
    for (int h = 0; h < mat2->faces; h++) {
        for (int j = 0; j < mat2->col; j++) {
            while (i < mat2->row) {

                if (count == concat->col || (count == 0 && i == 0 && j == 0 && h == 0)) {
                    concat->val[index] = mat1->val[h];
                    i = 0;

                } else {
                    concat->val[index] = mat2->val[h * mat2->row * mat2->col + j +
                                                   (i) * mat2->col];
                    i++;
                }
                if (count < concat->col) {
                    count++;
                } else {
                    count = 1;
                }
                index++;
            }
            i = 0;

        }
    }

    mat2->row = embedDim;
    mat2->col = numPatches;
    mat2->faces = numSamples;

}

/// PATCH EMBEDDING
void initialize_patch_embedding_layer(struct allMatrix *allMatrix) {

    allMatrix->embeddingLayer.W.init.row = patch.width;
    allMatrix->embeddingLayer.W.init.col = patch.height;
    allMatrix->embeddingLayer.W.faces1 = image.faces;
    allMatrix->embeddingLayer.W.faces2 = embedDim;
    allMatrix->embeddingLayer.W.init.faces = allMatrix->embeddingLayer.W.faces1 * allMatrix->embeddingLayer.W.faces2;
    allMatrix->embeddingLayer.W.init.val = (float *) malloc(
            allMatrix->embeddingLayer.W.init.row * allMatrix->embeddingLayer.W.init.col *
            allMatrix->embeddingLayer.W.init.faces * sizeof(float));
    FILE *W;
    W = fopen("/Users/margheritamilani/desktop/uni/tesi/filter.txt", "r");
    read_image(&allMatrix->embeddingLayer.W.init, W);
    //rand_matrix(&allMatrix->embeddingLayer.W.init);

    allMatrix->input_patch_embedding.row = image.width;
    allMatrix->input_patch_embedding.col = image.height;
    allMatrix->input_patch_embedding.faces = image.faces;
    allMatrix->input_patch_embedding.val = (float *) malloc(
            allMatrix->input_patch_embedding.row * allMatrix->input_patch_embedding.col *
            allMatrix->input_patch_embedding.faces * sizeof(float));
    FILE *neo;
    neo = fopen("/Users/margheritamilani/desktop/uni/tesi/imageProva.txt", "r");
    read_image(&allMatrix->input_patch_embedding, neo);
    //rand_matrix(&allMatrix->input_patch_embedding.init);
    //print_mat(&allMatrix->input_patch_embedding);
    //print_mat(&allMatrix->embeddingLayer.W.init);

}

//no padding
void convolution(struct matrix *mat1, struct matrix4D *filter, struct matrix *res) {

    res->row = sqrt(numPatches);
    res->col = sqrt(numPatches);
    res->faces = filter->faces2;
    res->val = (float *) malloc(res->row * res->col * res->faces * sizeof(float));

    int kCenterX = filter->init.row / 2;
    int kCenterY = filter->init.col / 2;

    for (int h2 = 0; h2 < filter->faces2; h2++) {
        for (int h = 0; h < mat1->faces; h++) {
            for (int i = 0; i < mat1->row; ++i) {
                for (int j = 0; j < mat1->col; ++j) {
                    for (int m = 0; m < filter->init.row; ++m) {
                        for (int n = 0; n < filter->init.col; ++n) {

                            int ii = (i * 2 + (m - kCenterY) + 1);
                            int jj = (j * 2 + (n - kCenterX) + 1);

                            if (ii >= 0 && ii < mat1->row && jj >= 0 && jj < mat1->col) {

                                res->val[i * res->col + j +
                                         h2 * res->row * res->col] +=
                                        mat1->val[ii * mat1->col + jj +
                                                  h * mat1->row * mat1->col] *
                                        filter->init.val[m * filter->init.col + n +
                                                         h * filter->init.row * filter->init.col +
                                                         h2 * filter->init.row * filter->init.col *
                                                         filter->faces1];
                            }
                        }
                    }
                }
            }
        }
    }
}


/// POS EMBEDDING
void initialize_pos_embedding_layer(struct allMatrix *allMatrix) {

    allMatrix->pos_embed_layer.pos_embed.row = embedDim;
    allMatrix->pos_embed_layer.pos_embed.col = numPatches + 1;
    allMatrix->pos_embed_layer.pos_embed.faces = numSamples;
    allMatrix->pos_embed_layer.pos_embed.val = (float *) malloc(
            allMatrix->pos_embed_layer.pos_embed.row * allMatrix->pos_embed_layer.pos_embed.col *
            allMatrix->pos_embed_layer.pos_embed.faces *
            sizeof(float));
    fill_matrix_with_zero(&allMatrix->pos_embed_layer.pos_embed);
}


/// GELU
void GELU(struct matrix *mat) {

    for (int i = 0; i < mat->row; i++) {

        double X = mat->val[i];

        X = 0.5 * X * (
                1 + tanh(sqrt(2 / M_PI)) * (X + 0.044715 * pow(X, 3))
        );

        mat->val[i] = X;
    }
}

/// ESTIMATE
void create_vector_for_estimate(int size, struct matrix *vect) {

    int count1 = 0;
    for (int i = 1; i <= size; i++) {
        if (size % i == 0) {
            count1++;
        }
    }
    vect->row = count1;
    vect->col = 1;
    vect->faces = 1;
    vect->val = (float *) malloc(vect->row * sizeof(float));

    int count2 = 0;
    for (int i = 1; i <= size; i++) {
        if (size % i == 0) {
            vect->val[count2] = i;
            count2++;
        }
    }
}

struct patch estimate_patchDim() {

    struct matrix divRow;
    create_vector_for_estimate(image.height, &divRow);
    struct matrix divCol;
    create_vector_for_estimate(image.width, &divCol);

    struct matrix patchSizeComb;
    patchSizeComb.row = divCol.row + 1;
    patchSizeComb.col = divRow.row;
    patchSizeComb.faces = 1;
    patchSizeComb.val = (float *) malloc(patchSizeComb.row * patchSizeComb.col * sizeof(float));

    for (int i = 0; i < patchSizeComb.row; i++) {
        for (int j = 0; j < patchSizeComb.col; j++) {
            patchSizeComb.val[j] = image.width / divRow.val[j];
            patchSizeComb.val[(1 + i) * patchSizeComb.col + j] = image.height / divCol.val[i];
        }
    }

    int patchSizeComb_half_col = patchSizeComb.col;
    patchSizeComb_half_col = patchSizeComb_half_col / 2;

    int patchSizeComb_half_row = patchSizeComb.row;
    patchSizeComb_half_row = patchSizeComb_half_row / 2;

    int bestChoice1 = patchSizeComb.val[(patchSizeComb_half_col)];
    int bestChoice2 = patchSizeComb.val[((patchSizeComb_half_row) + 1) * patchSizeComb.col +
                                        (patchSizeComb_half_col)];
    struct patch patch1;
    patch1.width = bestChoice1;
    patch1.height = bestChoice2;
    free(divRow.val);
    free(divCol.val);
    free(patchSizeComb.val);
    return patch1;
}


int estimate_num_head(int nHead_original) {

    struct matrix vector_head;
    create_vector_for_estimate(embedDim, &vector_head);

    int numHead = nHead_original;
    for (int i = 0; i < vector_head.row; i++) {
        if (vector_head.val[i] <= nHead_original) {
            numHead = vector_head.val[i];
        }
    }
    free(vector_head.val);
    return numHead;
}


/// FORWARD ONLY_CLS
void forward_only_cls(struct matrix *X, struct matrix *Z) {
    Z->row = X->row;
    Z->col = 1;
    Z->faces = X->faces;
    Z->val = (float *) malloc(Z->row * Z->col * Z->faces * sizeof(float));
    for (int h = 0; h < Z->faces; h++) {
        for (int i = 0; i < Z->row; i++) {
            Z->val[i * Z->col + h * Z->row * Z->col] = X->val[i * X->col + h * X->row * X->col];
        }
    }
}

/// NORMALIZATION LAYER
void initialize_normalization_layer(struct allMatrix *allMatrix) {

    allMatrix->average.row = patch.width * patch.height;
    allMatrix->average.col = numPatches + 1;
    allMatrix->average.faces = numSamples;
    allMatrix->average.val = (float *) malloc(allMatrix->average.row * allMatrix->average.col * sizeof(float));

    allMatrix->deviation.row = patch.width * patch.height;
    allMatrix->deviation.col = numPatches + 1;
    allMatrix->deviation.faces = numSamples;
    allMatrix->deviation.val = (float *) malloc(
            allMatrix->deviation.row * allMatrix->deviation.col * sizeof(float));

}

void calculate_average(struct matrix *image, float *avg) {
// embedDim * numPatches + 1 * numSamples
    *avg = 0;
    int tot_pixel = image->row * image->col;
    for (int i = 0; i < image->row * image->col; i++) {

        *avg += image->val[i]; //3015
    }
    *avg = *avg / (float) (tot_pixel);
    //printf(" AVG %lf\n", *avg);
}

void calculate_deviation(struct matrix *image, float *dev, float *avg) {
// embedDim * numPatches + 1 * numSamples
    *dev = 0;
    int tot_pixel = image->row * image->col;
    for (int i = 0; i < image->row * image->col; i++) {

        *dev += pow(image->val[i] - *avg, 2);
    }
    *dev = (float) *dev / (float) (tot_pixel);
    //printf(" DEV %lf\n", *dev);
}


void calculate_normalization(struct matrix *image, float *avg, float *dev) {
    float x = 0;
    for (int i = 0; i < image->row * image->col * image->faces; i++) {
        x = image->val[i] - (*avg);
        image->val[i] = x / sqrt(*dev + pow(10, -5));
    }
}


/// ADDITION
void duplicate_matrix(struct matrix *mat1, struct matrix *mat2) {
    mat2->row = mat1->row;
    mat2->col = mat1->col;
    mat2->faces = mat1->faces;
    mat2->val = (float *) malloc(mat2->row * mat2->col * mat2->faces * sizeof(float));

    for (int i = 0; i < mat1->row * mat1->col * mat1->faces; i++) {
        mat2->val[i] = mat1->val[i];
    }
}


/// ALL
void initialize_all_matrix(struct allMatrix *allMatrix) {

    initialize_MLP(allMatrix);
    initialize_MLP2(allMatrix);
    initialize_MLP3(allMatrix);
    initialize_attention_layer(allMatrix);
    initialize_CLS_token(allMatrix);
    initialize_patch_embedding_layer(allMatrix);
    initialize_pos_embedding_layer(allMatrix);
    initialize_normalization_layer(allMatrix);
}

void freeAll(struct allMatrix *allMatrix) {

    free(allMatrix->MLP_layer.W.val);
    free(allMatrix->MLP_layer.b.val);
    free(allMatrix->att.W_qkv.val);
    free(allMatrix->att.b_qkv.val);
    free(allMatrix->att.W_proj.val);
    free(allMatrix->Z_CLS.val);
    free(allMatrix->input_patch_embedding.val);
    free(allMatrix->Z_patch_embedding.val);
    free(allMatrix->Z_forward_only.val);
    //free(allMatrix->deviation.val);
    //free(allMatrix->average.val);
    free(allMatrix->MLP_layer.W2.val);
    free(allMatrix->MLP_layer.b2.val);
    free(allMatrix->MLP_layer.W3.val);
    free(allMatrix->MLP_layer.b3.val);
    free(allMatrix->res3_MLP.val);

}