//
// Created by Margherita Milani on 22/08/22.
//

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <cmath>
#include <time.h>

#define _USE_MATH_DEFINES


//GENERAL
extern int numPatches;
extern int numSamples;
extern int embedDim;
extern int nHead_original;
extern int numHead;

//MLP
extern int MLP_inputDim;
extern int MLP_outputDim;

//ATTENTION
extern double attention_DropOutProb;
extern double attention_projDropOutProb;



extern struct image {
    int width = 50;
    int height = 50;
    int faces = 116;
};


extern struct patch {
    int width;
    int height;
};


struct matrix {
    int row;
    int col;
    int faces;
    float *val;
};

struct matrix4D {
    struct matrix init;
    int faces1;
    int faces2;
};

struct MLP_layer {
    struct matrix W;
    struct matrix b; //vector
    struct matrix W2;
    struct matrix b2;
    struct matrix W3;
    struct matrix b3;
};

struct attention_layer {
    struct matrix W_qkv;
    struct matrix b_qkv; //vector
    struct matrix W_proj;
    double headDimension = numHead / embedDim;
    double scale =  3.7;
            //pow(headDimension, -0.5);
    double attention_DropOutProb = 0.4;
    double attention_projDropOutProb = 0.4;
};

struct CLS_token {
    struct matrix cls_token;
};

struct patch_embedding_layer {
    struct matrix4D W;
};

struct pos_embedding_layer {
    struct matrix pos_embed;
};

///ALL MATRIX
struct allMatrix {
    //MLP
    struct MLP_layer MLP_layer;
    struct matrix res_MLP;
    struct matrix res2_MLP;
    struct matrix res3_MLP;
    //ATTENTION
    struct attention_layer att;
    struct matrix QKV_att;
    struct matrix Q;
    struct matrix V;
    struct matrix K;
    struct matrix resQK;
    struct matrix Y_att;
    struct matrix Z_att;

    //CLS TOKEN
    struct CLS_token token;
    struct matrix Z_CLS;

    //PATCH EMBEDDING
    struct matrix input_patch_embedding;
    struct patch_embedding_layer embeddingLayer;
    struct matrix Z_patch_embedding;

    //POS EMBEDDING
    struct pos_embedding_layer pos_embed_layer;

    //CLS FORWARD ONLY
    struct matrix Z_forward_only;

    //NORMALIZATION LAYER
    struct matrix deviation;
    struct matrix average;


    //ADDITION
    struct matrix copy;
    struct matrix result_first_add;

    //SOFTMAX
    struct matrix create_for_softmax;
};


//GENERAL

void read_image(struct matrix *image1, FILE *f);

void matrix_product_one_face(struct matrix *mat1, struct matrix *mat2, struct matrix *res);

void matrix_product_one_face2(struct matrix *mat1, struct matrix *mat2, struct matrix *res);

void print_mat(struct matrix *mat);

void rand_matrix(struct matrix *mat);

void sum_vector_matrix(struct matrix *mat, struct matrix *vector);

void sum_vector_matrix2(struct matrix *mat, struct matrix *vector);

void scale_matrix(float scale, struct matrix *mat);

void matrix_product_per_face(struct matrix *mat1, struct matrix *mat2, struct matrix *res);

void fill_matrix_with_zero(struct matrix *mat);

void matrix_product_member_by_member(struct matrix mat, struct matrix *mult);

void initialize_all_matrix(struct allMatrix *allMatrix);

void copy_matrix_from_1_to_2(struct matrix *mat1, struct matrix *mat2);

void duplicate_matrix(struct matrix *mat1, struct matrix *mat2);



//MLP
void initialize_MLP(struct allMatrix *allMatrix);
void initialize_MLP2(struct allMatrix *allMatrix);
void initialize_MLP3(struct allMatrix *allMatrix);

//ATTENTION LAYER
void attention_layer(struct matrix *input, struct matrix *W_qkv,struct matrix *b,struct matrix *Z, struct matrix *W_proj);

void split_QKV(int partenza, int fine, struct matrix *QKV, struct matrix *split,
               int splitSize);

void create_matrix_for_softmax(struct matrix *mat, struct matrix *create);

void soft_max(struct matrix *mat, struct matrix *create);

void apply_dropout(struct matrix *mat, double probability);

void attention_layer(struct matrix *input, struct matrix *W_qkv,struct matrix *b,struct matrix *Z, struct matrix *W_proj);

//CLS TOKEN
void initialize_CLS_token(struct allMatrix *allMatrix);

void cat_matrix_on2dim(struct matrix *mat1, struct matrix *mat2, struct matrix *concat);


//FORWARD ONLY CLS
void initialize_forward_only_CLS(struct allMatrix *allMatrix);

void forward_only_cls(struct matrix *X, struct matrix *Z);


//PATCH EMBEDDING LAYER
void initialize_patch_embedding_layer(struct allMatrix *allMatrix);

void convolution(struct matrix *mat1, struct matrix4D *filter, struct matrix *res);


//POS EMBEDDING LAYER
void initialize_pos_embedding_layer(struct matrix *X, struct matrix *Z);

void sum_matrix(struct matrix *mat1, struct matrix *pos_embed);

//GELU
void GELU(struct matrix *mat);


//ESTIMATE
void create_vector_for_estimate(int size, struct matrix *vect);

struct patch estimate_patchDim();

int estimate_num_head(int nHead_original);


//NORMALIZATION LAYER
void calculate_average(struct matrix *image, float *avg);

void calculate_deviation(struct matrix *image, float *dev, float *avg);

void calculate_normalization(struct matrix *image, float *avg, float *dev);;


