#include "includes.h"
#include "grid.h"
#include "vector.h"

#include <fstream.h>
extern ofstream errorfs;

#include "declare.h"

/* This file contains source code for functions called before the
   estimation routines (some may not be used) */

//source code
void countalleles(ivector &allelecounts, const igrid &freqallele)
{
  //this function yields a vector of the number of alleles at each locus

  int i,j, ind=0;
  const int numloc=allelecounts.dim();
  int maxnum=freqallele.dimx();
  for (j=0;j<numloc;j++) {
    ind=0;
    allelecounts(j)=0;
    for (i=0;i<maxnum;i++) {
      if(freqallele(i,j)==0) {
	allelecounts(j)=i;
	ind=1; 
	break; } }
    if(ind==0) allelecounts(j)=maxnum; 
    if(allelecounts(j)<=0) errorfs << "WARNING: count 0 alleles at locus" 
				<< j << " +1 by function countalleles()"; }
}


double getfreq(const igrid &freqallele, const dgrid &freqfreq, 
	       const int &Lallele, const int &L)
{
  //this function retrieves an allele frequency

  int i;
  int row=-1;

  //find location of Lallele in column L
  for (i=0;i<freqallele.dimx();i++)
    {
      if (freqallele(i,L) == Lallele)
	{
	  row=i;
	  break;
	}
    }
  if (row==-1)
    errorfs << "Error in getting allele frequency;" 
	 << " Allele " << Lallele 
  	 << " not found at locus (" << L << " + 1) column"  
	 << " by getfreq()" << endl;
  
  if ( !(-0.0001<freqfreq(row,L) && 1.0001>freqfreq(row,L)) )
    errorfs << "WARNING: freq. retrieved from freqfreq is <=0" << endl;

  return freqfreq(row,L);

}//getfreq


double getjointfreq2C(const igrid &freqallele, const dgrid &jointfreq2C, 
		      const int &Lallele, const int &L, const int &callele,
		      const int &C)
{
  //this function retrieves a haplotype frequency

  int i;
  int rowL=-1,rowC=-1;

  //find location of Lallele in column L
  for (i=0;i<freqallele.dimx();i++)
    {
      if (freqallele(i,L) == Lallele)
	{
	  rowL=i;
	  break;
	}
    }
  if (rowL==-1)
    errorfs << "Error in getting allele frequency; Allele " << Lallele 
  	 << " not found at locus (" << L << " + 1) column"  
	 << " by getjointfreq2C()" << endl;
  //find location of callele in column C
  for (i=0;i<freqallele.dimx();i++)
    {
      if (freqallele(i,C) == callele)
	{
	  rowC=i;
	  break;
	}
    }
  if (rowC==-1)
    errorfs << "Error in getting allele frequency; Allele " << callele 
  	 << " not found at locus (" << C << " + 1) column" 
	 << " by getjointfreq2C()" << endl;

  assert(jointfreq2C(rowC*freqallele.dimx()+rowL,L)>0);

  return jointfreq2C(rowC*freqallele.dimx()+rowL,L);
}

double getjointfreq2(const igrid &freqallele, const ivector &allelecounts,
		     const dgrid &jointfreq2, const int &ilocus, 
		     const int &iallele, const int &jallele)
{
  //this function retrieves a 2-locus haplotype frequency

  int k;
  int ipos=-1, jpos=-1;
  //find positions for iallele and jallele
  for (k=0;k<allelecounts(ilocus);k++) {
    if (freqallele(k,ilocus)==iallele) { ipos=k; break; } }
  for (k=0;k<allelecounts(ilocus+1);k++) {
    if (freqallele(k,ilocus+1)==jallele) { jpos=k; break; } }
  if (ipos==-1) errorfs << "WARNING: Allele " << iallele << " not found at locus "
		     << ilocus << " +1 by getjointfreq2()" << endl;
  if (jpos==-1) errorfs << "WARNING: Allele " << jallele << " not found at locus "
		     << ilocus+1 << " +1 by getjointfreq2()" << endl;

  if ( !(-0.0001<jointfreq2(ipos*allelecounts(ilocus+1)+jpos,ilocus) &&
	 1.0001>jointfreq2(ipos*allelecounts(ilocus+1)+jpos,ilocus)) )
    errorfs << "WARNING: freq. retrieved from jointfreq2 is <=0" << endl;
  
  return jointfreq2(ipos*allelecounts(ilocus+1)+jpos,ilocus);
}

double getcondfreq2l(const igrid &freqallele, const ivector &allelecounts,
		     const dgrid &condfreq2l, const int &ilocus, 
		     const int &iallele, const int &jallele)
{
  //this function retrieves a conditional allele frequency

  int k;
  int ipos=-1, jpos=-1;
  //find positions for iallele and jallele
  for (k=0;k<allelecounts(ilocus);k++) {
    if (freqallele(k,ilocus)==iallele) { ipos=k; break; } }
  for (k=0;k<allelecounts(ilocus+1);k++) {
    if (freqallele(k,ilocus+1)==jallele) { jpos=k; break; } }
  if (ipos==-1) errorfs << "WARNING: Allele " << iallele << " not found at locus "
		     << ilocus << " +1 by getcondfreq2l()" << endl;
  if (jpos==-1) errorfs << "WARNING: Allele " << jallele << " not found at locus "
		     << ilocus+1 << " +1 by getcondfreq2l()" << endl;

  if ( !(-0.0001<condfreq2l(ipos*allelecounts(ilocus+1)+jpos,ilocus) &&
	 1.0001>condfreq2l(ipos*allelecounts(ilocus+1)+jpos,ilocus)) )
    errorfs <<"WARNING: freq. retrieved from condfreq2l is <=0" << endl;

  return condfreq2l(ipos*allelecounts(ilocus+1)+jpos,ilocus);
}

double getcondfreq2r(const igrid &freqallele, const ivector &allelecounts,
		     const dgrid &condfreq2r, const int &ilocus, 
		     const int &iallele, const int &jallele, const int &C)
{
  //this function retrieves a conditional allele frequency

  int k;
  int ipos=-1, jpos=-1;
  //find positions for iallele and jallele
  for (k=0;k<allelecounts(ilocus);k++) {
    if (freqallele(k,ilocus)==iallele) { ipos=k; break; } }
  for (k=0;k<allelecounts(ilocus+1);k++) {
    if (freqallele(k,ilocus+1)==jallele) { jpos=k; break; } }
  if (ipos==-1) errorfs << "WARNING: Allele " << iallele << " not found at locus "
		     << ilocus << " +1 by getcondfreq2r()" << endl;
  if (jpos==-1) errorfs << "WARNING: Allele " << jallele << " not found at locus "
		     << ilocus+1 << " +1 by getcondfreq2r()" << endl;

  if ( !(-0.0001<condfreq2r(ipos*allelecounts(ilocus+1)+jpos,ilocus-C)&&
	 1.0001>condfreq2r(ipos*allelecounts(ilocus+1)+jpos,ilocus-C)) )
    errorfs << "WARNING: freq. retrieved from condfreq2r is <=0" << endl;

  return condfreq2r(ipos*allelecounts(ilocus+1)+jpos,ilocus-C);
}

double getjointfreq3(const igrid &freqallele, const ivector &allelecounts,
		     const dvector &jointfreq3, const int &ilocus, 
		     const int &iallele, const int &jallele, 
		     const int &kallele)
{
  //this function retrieves a 3-locus haplotype frequency

  int k;
  int ipos=-1, jpos=-1, kpos=-1;
  //find positions for iallele and jallele and kallele
  for (k=0;k<allelecounts(ilocus);k++) {
    if (freqallele(k,ilocus)==iallele) { ipos=k; break; } }
  for (k=0;k<allelecounts(ilocus+1);k++) {
    if (freqallele(k,ilocus+1)==jallele) { jpos=k; break; } }
  for (k=0;k<allelecounts(ilocus+2);k++) {
    if (freqallele(k,ilocus+2)==kallele) { kpos=k; break; } }
  if (ipos==-1) errorfs << "WARNING: Allele " << iallele << " not found at locus "
		     << ilocus << " +1 by getjointfreq3()" << endl;
  if (jpos==-1) errorfs << "WARNING: Allele " << jallele << " not found at locus "
		     << ilocus+1 << " +1 by getjointfreq3()" << endl;
  if (kpos==-1) errorfs << "WARNING: Allele " << kallele << " not found at locus "
		     << ilocus+2 << " +1 by getjointfreq3()" << endl;

  if ( !(-0.0001<jointfreq3(ipos*allelecounts(ilocus+1)*allelecounts(ilocus+2)+
			    jpos*allelecounts(ilocus+2)+kpos) &&
	 1.0001>jointfreq3(ipos*allelecounts(ilocus+1)*allelecounts(ilocus+2)+
			   jpos*allelecounts(ilocus+2)+kpos)) )
    errorfs << "WARNING: freq. retrieved from jointfreq3 is <=0" << endl;
  
  return jointfreq3(ipos*allelecounts(ilocus+1)*allelecounts(ilocus+2)+
		    jpos*allelecounts(ilocus+2)+kpos);
}

double getcondfreq3F(const igrid &freqallele, const ivector &allelecounts,
		     const dgrid &condfreq3F, const int &C, 
		     const int &iallele, const int &kallele,
		     const int &cond)
{
  //this function retrieves a 3-locus conditional haplotype frequency

  int k;
  int ipos=-1, kpos=-1;
  //find positions for iallele and kallele
  for (k=0;k<allelecounts(C-1);k++) {
    if (freqallele(k,C-1)==iallele) { ipos=k; break; } }
  for (k=0;k<allelecounts(C+1);k++) {
    if (freqallele(k,C+1)==kallele) { kpos=k; break; } }
  if (ipos==-1) errorfs << "WARNING: Allele " << iallele << " not found at locus "
		     << C-1 << " +1 by getcondfreq3F()" << endl;
  if (kpos==-1) errorfs << "WARNING: Allele " << kallele << " not found at locus "
		     << C+1 << " +1 by getcondfreq3F()" << endl;
  
  return condfreq3F(ipos*allelecounts(C+1)+kpos,cond);
}

void makesmallerdataset(igrid &augdata, const igrid &data, 
			const int &C_loc, const int &C_all,
			const int &data_type, const int &debug)
{
  //this function removes haplotype and genotypes that do not
  //possess at least one copy of center allele; to include these in
  //the estimation wastes computations

  int i,j;
  const int numhap=data.dimx();
  const int numgen=data.dimx()/2;

  int count=0; 
  if (data_type==1) {
    for (i=0;i<numhap;i++) { if (data(i,C_loc)==C_all) count+=1; } 
    augdata.init(count,data.dimy() );
    count=0;
    for (i=0;i<numhap;i++) { 
      if (data(i,C_loc)==C_all) {
	for (j=0;j<data.dimy();j++) {
	  augdata(count,j)=data(i,j); }
	count+=1; }
    }
  }
  else {//data_type==2 
    for (i=0;i<numgen;i++) { 
      if (data(2*i  ,C_loc)==C_all || data(2*i+1,C_loc)==C_all) count+=1; } 
    augdata.init(2*count,data.dimy() );
    count=0;
    for (i=0;i<numgen;i++) { 
      if (data(i,C_loc)==C_all || data(2*i+1,C_loc)==C_all) {
	for (j=0;j<data.dimy();j++) {
	  augdata(2*count  ,j)=data(2*i  ,j);
	  augdata(2*count+1,j)=data(2*i+1,j); }
	count+=1; }
    }
  }

  if (debug) {
    cout << "augdata: " << augdata.dimx() << " " << augdata.dimy() << endl;
    for (i=0;i<augdata.dimx();i++) {
      for (j=0;j<augdata.dimy();j++) {
	cout << augdata(i,j) << " "; }
      cout << endl; }
  }

}