
#include<iostream>
#include<cstring>
#include<cstdio>
#include<map>
#include<vector>
#include<string>
#include<sstream>
#include<ctime>
#include<cmath>
#include<cstdlib>
#include<algorithm>
#include<fstream>
using namespace std;


#define pi 3.1415926535897932384626433832795

double alpha = 1.0;
bool L1_flag=1;
int entity_cnt[3000], tri_check[10100];

//normal distribution
double rand(double min, double max)
{
    return min+(max-min)*rand()/(RAND_MAX+1.0);
}
double normal(double x, double miu,double sigma)
{
    return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma));
}
double randn(double miu,double sigma, double min ,double max)
{
    double x,y,dScope;
    do{
        x=rand(min,max);
        y=normal(x,miu,sigma);
        dScope=rand(0.0,normal(miu,miu,sigma));
    }while(dScope>y);
    return x;
}

double sqr(double x)
{
    return x*x;
}

double vec_len(vector<double> &a)
{
	double res=0;
    for (int i=0; i<a.size(); i++)
		res+=a[i]*a[i];
	res = sqrt(res);
	return res;
}

string version;
char buf[100000],buf1[100000];
int relation_num,entity_num;
map<string,int> relation2id,entity2id;
map<int,string> id2entity,id2relation;


map<int,map<int,int> > left_entity,right_entity;
map<int,double> left_num,right_num;

class Train{

public:
	map<pair<int,int>, map<int,int> > ok;
    	void add(int x,int y,int z)
    	{
		// x : left entity
		// y : right entity
		// z : relation
        	fb_h.push_back(x); // left entity
        	fb_r.push_back(z); // relation
        	fb_l.push_back(y); // right entity
        	ok[make_pair(x,z)][y]=1;
    	}	
    	void run(int n_in,double rate_in,double margin_in,int method_in,char* edge_cnt)
    	{
    	    	n = n_in; //dimension size
    	    	rate = rate_in; // learning rate
        	margin = margin_in; // margin
        	method = method_in; // method
			
        	relation_vec.resize(relation_num);
		for (int i=0; i<relation_vec.size(); i++)
			relation_vec[i].resize(n);
        	entity_vec.resize(entity_num);
		for (int i=0; i<entity_vec.size(); i++)
			entity_vec[i].resize(n);

        	relation_tmp.resize(relation_num);
		for (int i=0; i<relation_tmp.size(); i++)
			relation_tmp[i].resize(n);
        	entity_tmp.resize(entity_num);
		for (int i=0; i<entity_tmp.size(); i++)
			entity_tmp[i].resize(n);

// INITIALIZATION of each entity / relation vector
// All embedings for entities and relationships are first initialized following the random procedure proposed in:"Understanding the difficulty of training deep feedforward neural networks" (AISTATS 2010)
        	for (int i=0; i<relation_num; i++)
        	{
        	    for (int ii=0; ii<n; ii++)
        	        relation_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
        	}
        	for (int i=0; i<entity_num; i++)
        	{
        	    for (int ii=0; ii<n; ii++)
        	        entity_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
		    // e = e / |e|
        	    norm(entity_vec[i]);
        	}
		// Broyden-Fletcher-Goldfarb-Shanno algorithm
        	bfgs(edge_cnt);	
    	}

private:
    	int n,method;
    	double res; //loss function value
    	double count,count1; //loss function gradient
    	double rate,margin;
    	double belta;
    	vector<int> fb_h,fb_l,fb_r;
    	vector<vector<int> > feature;
    	vector<vector<double> > relation_vec,entity_vec;
    	vector<vector<double> > relation_tmp,entity_tmp;

    	double norm(vector<double> &a)
    	{
	// normalize to unit vector
    	    double x = vec_len(a);
    	    if (x>1)
    	    for (int ii=0; ii<a.size(); ii++)
    	            a[ii]/=x;
    	    return 0;
    	}

    	int rand_max(int x)
    	{
	// return random integer below x
    	    int res = (rand()*rand())%x;
    	    while (res<0)
    	        res+=x;
    	    return res;
    	}
	
	void bfgs(char* edge_cnt)
	{
		// Broyden-Fletcher-Goldfarb-Shanno algorithmn
        	res=0;
		
		int nbatches=100; // batch cnt
        	int nepoch = 1000; // epoch size
        	int batchsize = fb_h.size()/nbatches; // batch size
        	for (int epoch=0; epoch<nepoch; epoch++)
            	{
            		res=0;
             		for (int batch = 0; batch<nbatches; batch++)
             		{
             			relation_tmp = relation_vec;
            			entity_tmp = entity_vec;
             			for (int k=0; k<batchsize; k++)
             			{
					// random index
					int i = rand_max(fb_h.size()); // triplet size
					// random entity
					int j = rand_max(entity_num); // entity size	
								
					// replacing head or tail with equal probability
                	            	double pr = 500;
					if (rand()%1000 < (int)(alpha*1000)) {
						// replace entity
						if (rand() % 1000 < pr)
						{
							// make negative sample - replace right entity
							while (ok[make_pair(fb_h[i],fb_r[i])].count(j) > 0)
								j = rand_max(entity_num);
							train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]);
						}
						else
						{
							// make negative sample - replace left entity
							while (ok[make_pair(j,fb_r[i])].count(fb_l[i]) > 0)
								j = rand_max(entity_num);
							train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]);
						}
						norm(relation_tmp[fb_r[i]]);
                				norm(entity_tmp[fb_h[i]]);
                				norm(entity_tmp[fb_l[i]]);
                				norm(entity_tmp[j]);
					} else {
						// replace relation
						int fake_rel = fb_r[i];
						while (fake_rel == fb_r[i])
							fake_rel = rand() % 3;
						train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],fb_l[i],fake_rel);
						norm(relation_tmp[fb_r[i]]);
						norm(relation_tmp[fake_rel]);
						norm(entity_tmp[fb_h[i]]);
						norm(entity_tmp[fb_l[i]]);
					}
                          	}
		            	relation_vec = relation_tmp;
		            	entity_vec = entity_tmp;
             		}
                	cout<<"epoch:"<<epoch<<' '<<res<<endl;
			// File Write
			char filename_2[50], filename_3[50];
			strcpy(filename_2, "relation2vec");
			strcpy(filename_3, "entity2vec");

                	FILE* f2 = fopen(filename_2,"w");
                	FILE* f3 = fopen(filename_3,"w");
                	for (int i=0; i<relation_num; i++)
                	{
                    		for (int ii=0; ii<n; ii++)
                        		fprintf(f2,"%.6lf\t",relation_vec[i][ii]);
                    		fprintf(f2,"\n");
                	}	
                	for (int i=0; i<entity_num; i++)
                	{
                    		for (int ii=0; ii<n; ii++)
                        		fprintf(f3,"%.6lf\t",entity_vec[i][ii]);
                    		fprintf(f3,"\n");
                	}
                	fclose(f2);
                	fclose(f3);
            	}

    }

    double res1;
    double calc_sum(int e1,int e2,int rel)
    {
    	// n is dimension size
        double sum=0;
        if (L1_flag)
        	for (int ii=0; ii<n; ii++)
            	sum+=fabs(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
        else
	// sqr(x) = x*x
        	for (int ii=0; ii<n; ii++)
            	sum+=sqr(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
        return sum;
    }

    void gradient(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
    {
        for (int ii=0; ii<n; ii++)
        {
            double x = 2*(entity_vec[e2_a][ii]-entity_vec[e1_a][ii]-relation_vec[rel_a][ii]);
            if (L1_flag){
            	if (x>0)
            		x=1;
            	else
            		x=-1;
	    }
            relation_tmp[rel_a][ii] -= -1*rate*x;
            entity_tmp[e1_a][ii] -= -1*rate*x;
            entity_tmp[e2_a][ii] += -1*rate*x;

            x = 2*(entity_vec[e2_b][ii]-entity_vec[e1_b][ii]-relation_vec[rel_b][ii]);
            if (L1_flag){
            	if (x>0)
            		x=1;
            	else
            		x=-1;
	    }
            relation_tmp[rel_b][ii] -= rate*x;
            entity_tmp[e1_b][ii] -= rate*x;
            entity_tmp[e2_b][ii] += rate*x;
        }
    }

    void train_kb(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
    {
    // e1 is positive triplet
    // e2 is negative triplet
    // train_kb : calculate loss (margin + pos_ - neg_)
        double sum1 = calc_sum(e1_a,e2_a,rel_a);
        double sum2 = calc_sum(e1_b,e2_b,rel_b);
        if (sum1+margin>sum2)
        {
        	res += margin + sum1 - sum2;
        	gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
        }
    }
};

Train train;
void prepare(char* filename, int entity_count)
{

	FILE* f1 = fopen(filename, "r");
	char temp_buf[100000];
	vector<int> entity_list, relation_list;
	vector<int> left_entity_list, right_entity_list, rel_list;
	
	int ind_cnt = 0;

	while (fscanf(f1,"%s",buf)==1)
	{
		strcpy(temp_buf, buf);
		char *ch = strtok(temp_buf, ",");
		int id_cnt = 0, e_l, e_r, r;
		while (ch != NULL){
			id_cnt ++;
			if (id_cnt == 1) {
				e_l = atoi(ch);
				entity_list.push_back(e_l);
			} else if (id_cnt == 2) {
				e_r = atoi(ch);	
				entity_list.push_back(e_r);
			} else if (id_cnt == 3){
				if (ch[0] == 'n')
					r = 0;
				else if (ch[0] == 's')
					r = 1;
				else
					r = 2;
				relation_list.push_back(r);
			}
			ch = strtok(NULL, ",");
		}
		left_entity[r][e_l] ++;
		right_entity[r][e_r] ++;
		entity_cnt[e_l] ++;
		entity_cnt[e_r] ++;
		train.add(e_l,e_r,r);
	}
	//entity_num = *max_element(entity_list.begin(), entity_list.end()) + 1;
	//relation_num = *max_element(relation_list.begin(), relation_list.end()) + 1;
	entity_num = entity_count;
	relation_num = 3;
	printf("%d %d\n",entity_num, relation_num);


    	/*
We classified the relationships into four classes by computing, for each relationship l, the averaged number of heads h (respect. tails t) appearing in the data set, given a pair (l, t) (respect. a pair (h, l)).

If the average number was below 1.5 then the argument was labeled as 1 and MANY otherwise.
	*/
	for (int i=0; i<relation_num; i++)
    	{
    		double sum1=0,sum2=0;
    		for (map<int,int>::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++)
    		{
    			sum1++;
    			sum2+=it->second;
    		}
    		left_num[i]=sum2/sum1;
    	}
	for (int i=0; i<relation_num; i++)
	{
	    	double sum1=0,sum2=0;
	    	for (map<int,int>::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++)
	    	{
	    		sum1++;
	    		sum2+=it->second;
	    	}
	    	right_num[i]=sum2/sum1;
	}
	cout<<"relation_num = "<<relation_num<<endl;
	cout<<"entity_num = "<<entity_num<<endl;
	fclose(f1);
}

int ArgPos(char *str, int argc, char **argv) {
  int a;
  for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
    if (a == argc - 1) {
      printf("Argument missing for %s\n", str);
      exit(1);
    }
    return a;
  }
  return -1;
}

int t_rand(int i) { return rand()%i; }

int main(int argc,char**argv)
{
    // ./Train_TransE ../../data/graph_437_v3.txt 404 0 1 437
    // 0 : test_0 or test_1 or test_2 or test_3 or test_4
    // 1 : srand(1) or srand(2) or srand(3)

    srand((unsigned) time(NULL));
    int k = 1; // change 1 2 3 4 5
    // method 0 : -unif
    // method 1 : -bern
    int method = 1;
    int n = 100;
    double rate = 0.001;
    double margin = 1;
    int i;
    if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]);
    if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]);
    if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]);
    cout<<"size = "<<n<<endl;
    cout<<"learing rate = "<<rate<<endl;
    cout<<"margin = "<<margin<<endl;
    if (method)
        version = "bern";
    else
        version = "unif";
    cout<<"method = "<<version<<endl;
      
    prepare(argv[1], atoi(argv[2])); 
    train.run(n,rate,margin,method,argv[3]);
}


