#include "includes.h"
#include "grid.h"
#include "vector.h"

#include <fstream.h>
extern ofstream errorfs;

#include "declare.h"

/*file contains source code for functions called in HMM assuming 
  haplotype data*/

void makenullshare(dgrid &obsprobs, const igrid &data,
		   const igrid &freqallele, const dgrid &freqfreq,
		   const dgrid &condfreq2l, const dgrid &condfreq2r,
		   const dgrid &condfreq2C,
		   const ivector &allelecounts,
		   const int &LE, const int &RE, const int &C)
{
  //this function generates parts of observation probabilities for HMM
  //that do not involve parameters, i.e., conditional allele frequencies 
  int i,j;
  const int numhap=data.dimx();

  for (i=0;i<numhap;i++) {
    
    if (LE<C && C<RE)
      {
	
	//make first column
	if (data(i,LE)==0) obsprobs(2*i+1,LE)=1;
	else obsprobs(2*i+1,LE)=
	       getfreq(freqallele, freqfreq, data(i,LE), LE); 
      
	//make columns LE+1:C-1
	for (j=LE+1;j<C;j++) {
	  if (data(i,j)==0) obsprobs(2*i+1,j)=1;
	  else {
	    if (data(i,j-1)==0) obsprobs(2*i+1,j)=
				  getfreq(freqallele, freqfreq, data(i,j), j);
	    else obsprobs(2*i+1,j)=getcondfreq2l(freqallele, allelecounts, 
						 condfreq2l,
						 j-1, data(i,j-1), 
						 data(i,j)); } }
      
	//make columns C+1:RE-1
	for (j=C+1;j<RE;j++) {
	  if (data(i,j)==0) obsprobs(2*i+1,j)=1;
	  else {
	    if (data(i,j+1)==0) obsprobs(2*i+1,j)=
				  getfreq(freqallele, freqfreq, data(i,j), j);
	    else obsprobs(2*i+1,j)=getcondfreq2r(freqallele, allelecounts, 
						 condfreq2r,
						 j, data(i,j), 
						 data(i,j+1), C); } 
	} 

	//make last column 
	if (data(i,RE)==0) obsprobs(2*i+1,RE)=1;
	else obsprobs(2*i+1,RE)=
	       getfreq(freqallele, freqfreq, data(i,RE), RE); 
      }

    else {
      if (C<LE) {

	//make first column
	if (data(i,C)==0) obsprobs(2*i+1,C)=1;
	else {
	  if (data(i,LE)==0) obsprobs(2*i+1,C) = 
			       getfreq(freqallele, freqfreq, data(i,C), C);
	  else obsprobs(2*i+1,C)=
		 getjointfreq2C(freqallele, condfreq2C, 
				data(i,LE), LE, data(i,C), C); }


	//make column LE:RE-1
	for (j=LE;j<RE;j++) {
	  if (data(i,j)==0) obsprobs(2*i+1,j)=1;
	  else {
	    if (data(i,j+1)==0) obsprobs(2*i+1,j)=
				  getfreq(freqallele, freqfreq, data(i,j), j);
	    else obsprobs(2*i+1,j)=getcondfreq2r(freqallele, allelecounts, 
						 condfreq2r,
						 j, data(i,j), 
						 data(i,j+1), C); } } 
      
	//make column RE
	if (data(i,RE)==0) obsprobs(2*i+1,RE)=1;
	else obsprobs(2*i+1,RE)=
	       getfreq(freqallele, freqfreq, data(i,RE), RE); 
      }
      else {//C>RE
	//make column LE
	if (data(i,LE)==0) obsprobs(2*i+1,LE)=1;
	else obsprobs(2*i+1,LE)=
	       getfreq(freqallele, freqfreq, data(i,LE), LE); 
      
	//make columns LE+1: RE
	for (j=LE+1;j<=RE;j++) {
	  if (data(i,j)==0) obsprobs(2*i+1,j)=1;
	  else {
	    if (data(i,j-1)==0) obsprobs(2*i+1,j)=
				  getfreq(freqallele, freqfreq, data(i,j), j);
	    else obsprobs(2*i+1,j)=getcondfreq2l(freqallele, allelecounts, 
						 condfreq2l,
						 j-1, data(i,j-1), 
						 data(i,j)); } } 

      
	//make column C
	if (data(i,C)==0) obsprobs(2*i+1,C)=1;
	else 
	  {
	    if (data(i,RE)==0) obsprobs(2*i+1,C) = 
				 getfreq(freqallele, freqfreq, data(i,C), C);
	    else obsprobs(2*i+1,C)=
		   getjointfreq2C(freqallele, condfreq2C, 
				  data(i,RE), RE, data(i,C), C); }

      }
    }
  }

  for (i=0;i<numhap;i++) {
    if (C<LE || C>RE) 
      assert(-0.00001<obsprobs(2*i+1,C) && obsprobs(2*i+1,C)<1.00001);
    for (j=LE;j<=RE;j++)
      assert(-0.00001<obsprobs(2*i+1,j) && obsprobs(2*i+1,j)<1.00001); }

}

void makenullshareCnum(dgrid &obsprobsCnum, const igrid &data,
		       const igrid &freqallele, const dgrid &freqfreq,
		       const ivector &allelecounts,
		       const dgrid &jointfreq2, const dvector &condfreq3l,
		       const dvector &condfreq3r, const int &C,
		       const int &LE, const int &RE,
		       const double &m, const ivector &anchap)
{//note: assumes missing values not allowed at center

  //this function generates parts of observation probabilities that
  //do not depend on parameters, used in numerator of alpha and beta
  //objects at C
  int i;
  const int numhap=data.dimx();

  if (LE<C && C<RE) {
  for (i=0;i<numhap;i++) {
    //    obsprobsCnum(8*i  ,0) = obsprobsCnum(8*i  ,1) = 
    //      obsprobsCnum(8*i+4,0) = obsprobsCnum(8*i+1,1) = 1;
    obsprobsCnum(8*i+2,0) = obsprobsCnum(8*i+2,1) = 
      obsprobsCnum(8*i+3,0) = obsprobsCnum(8*i+3,1) = 
      obsprobsCnum(8*i+6,0) = obsprobsCnum(8*i+6,1) = 0;//states imposs.

    if (data(i,C-1)!=0) {
      if (data(i,C+1)!=0) {

	//if is necessary because we may be conditioning on a set of pr.0 
	if (m==0 && data(i,C-1)!=anchap(C-1)) {
	  obsprobsCnum(8*i  ,0)= 0;
	  obsprobsCnum(8*i+1,0)= 0; }
	else {
	  obsprobsCnum(8*i  ,0)= 1;
	  obsprobsCnum(8*i+1,0)= 
	    getfreq(freqallele, freqfreq, data(i,C+1),C+1);}

	obsprobsCnum(8*i+4,0)= 1;
	obsprobsCnum(8*i+5,0)= getfreq(freqallele, freqfreq, data(i,C+1),C+1);
	obsprobsCnum(8*i+7,0)= getjointfreq3(freqallele, allelecounts,
					      condfreq3l, C-1,
					      data(i,C-1), data(i,C),
					      data(i,C+1));

	if (m==0 && data(i,C+1)!=anchap(C+1)) {
	  obsprobsCnum(8*i  ,1)= 0;
	  obsprobsCnum(8*i+4,1)= 0; }
	else {
	  obsprobsCnum(8*i  ,1)= 1; 
	  obsprobsCnum(8*i+4,1)= 
	    getfreq(freqallele, freqfreq, data(i,C-1),C-1); }

	obsprobsCnum(8*i+1,1)= 1;
	obsprobsCnum(8*i+5,1)= getfreq(freqallele, freqfreq, data(i,C-1),C-1);
	obsprobsCnum(8*i+7,1)= getjointfreq3(freqallele, allelecounts,
					     condfreq3r, C-1,
					     data(i,C-1), data(i,C),
					     data(i,C+1)); }
      else {

	if (m==0 && data(i,C-1)!=anchap(C-1)) {
	  obsprobsCnum(8*i  ,0)= 0;
	  obsprobsCnum(8*i+1,0)= 0; }
	else {
	  obsprobsCnum(8*i  ,0)= 1;
	  obsprobsCnum(8*i+1,0)= 1; }

	obsprobsCnum(8*i+4,0)= 1;	
	obsprobsCnum(8*i+5,0)= 1;
	obsprobsCnum(8*i+7,0)= getjointfreq2(freqallele, allelecounts,
					      jointfreq2, C-1,
					      data(i,C-1), data(i,C))/
	  getfreq(freqallele, freqfreq, data(i,C-1),C-1);


	obsprobsCnum(8*i  ,1)= 1; 
	obsprobsCnum(8*i+1,1)= 1;
  	obsprobsCnum(8*i+4,1)= getfreq(freqallele, freqfreq, data(i,C-1),C-1);
	obsprobsCnum(8*i+5,1)= getfreq(freqallele, freqfreq, data(i,C-1),C-1);
	obsprobsCnum(8*i+7,1)= getjointfreq2(freqallele, allelecounts,
					      jointfreq2, C-1,
					      data(i,C-1), data(i,C)); } }
    else {
      if (data(i,C+1)!=0) {
	obsprobsCnum(8*i  ,0)= 1;
	obsprobsCnum(8*i+1,0)= getfreq(freqallele, freqfreq, data(i,C+1),C+1);
	obsprobsCnum(8*i+4,0)= 1;
	obsprobsCnum(8*i+5,0)= getfreq(freqallele, freqfreq, data(i,C+1),C+1);
	obsprobsCnum(8*i+7,0)= getjointfreq2(freqallele, allelecounts,
					      jointfreq2, C,
					      data(i,C), data(i,C+1));

	if (m==0 && data(i,C+1)!=anchap(C+1)) {
	  obsprobsCnum(8*i  ,1)= 0;
	  obsprobsCnum(8*i+4,1)= 0; }
	else {
	  obsprobsCnum(8*i  ,1)= 1; 
	  obsprobsCnum(8*i+4,1)= 1; }

	obsprobsCnum(8*i+1,1)= 1;
	obsprobsCnum(8*i+5,1)= 1;
	obsprobsCnum(8*i+7,1)= getjointfreq2(freqallele, allelecounts,
					      jointfreq2, C,
					      data(i,C), data(i,C+1))/
	  getfreq(freqallele, freqfreq, data(i,C+1),C+1); }
      else {
	obsprobsCnum(8*i  ,0)= 1;
	obsprobsCnum(8*i+1,0)= 1;
	obsprobsCnum(8*i+4,0)= 1;
	obsprobsCnum(8*i+5,0)= 1;
	obsprobsCnum(8*i+7,0)= getfreq(freqallele, freqfreq, data(i,C),C);

	obsprobsCnum(8*i  ,1)= 1; 
	obsprobsCnum(8*i+1,1)= 1;
  	obsprobsCnum(8*i+4,1)= 1;
	obsprobsCnum(8*i+5,1)= 1;
	obsprobsCnum(8*i+7,1)= getfreq(freqallele, freqfreq, data(i,C),C); } }
 
  }
  }
}

void makenullshareCden(dgrid &obsprobsCden, const igrid &data, 
		       const igrid &freqallele,
		       const dgrid &freqfreq, const int &C, 
		       const int &LE, const int &RE)
{
  //makes parts of observation probabilities that do not depend on
  //parameters, for denominators of alpha and beta at C
  int i;
  const int numhap=data.dimx();

  if (LE<C && C<RE) {
    
    for (i=0;i<numhap;i++) {
      
      obsprobsCden(4*i  ,0)= obsprobsCden(4*i  ,1) = 1;
      obsprobsCden(4*i+2,0)= obsprobsCden(4*i+1,1) = 0;
    
      if (data(i,C+1)!=0) {
	obsprobsCden(4*i+1,0)= obsprobsCden(4*i+3,0)= 
	  getfreq(freqallele, freqfreq, data(i,C+1),C+1); }
      else obsprobsCden(4*i+1,0)= obsprobsCden(4*i+3,0)= 1;

      if (data(i,C-1)!=0) {
	obsprobsCden(4*i+2,1)= obsprobsCden(4*i+3,1)= 
	  getfreq(freqallele, freqfreq, data(i,C-1),C-1); }
      else obsprobsCden(4*i+2,1)= obsprobsCden(4*i+3,1)= 1; 
    }
  }
}

void makemutmatch(dvector &mutmatch, dvector &mutnomatch, 
		  const dvector &mut,
		  const ivector &allelecounts, const double &tau, const int &C)
{
  //makes m(l,tau,i,j), mutation probabilities, for i=j and i != j

  int j;
  double temp;
  const int numloc=allelecounts.dim();

  for (j=0;j<numloc;j++) {
    if (j!=C) {
      if (allelecounts(j)>1 && mut(j)>0) {
	temp=pow(1-mut(j),tau);
	mutmatch(j)=temp + (1-temp)/(allelecounts(j));
	mutnomatch(j)=(1-mutmatch(j))/(allelecounts(j)-1); }
      else { mutmatch(j)=1; mutnomatch(j)=0; } }
    else {
      mutmatch(j)=1;
      mutnomatch(j)=0; } }
}
    
void makemodlshare(dgrid &obsprobs, const igrid &data,
		   const dvector &mutmatch, dvector &mutnomatch,
		   const ivector &anchap, 
		   const int &C, const int &LE, const int &RE)
{
  //completes observation probabilities generated in makenullshare;
  //includes components that depend on tau through m(l,tau,i,j)

  int i,j;
  const int numhap=data.dimx();

  for (i=0;i<numhap;i++) {
    if (C<LE || RE<C) obsprobs(2*i,C)= 1;//condition on matching at C
    for (j=LE;j<=RE;j++) {
      if (j!=C) {
	if (data(i,j)==0) obsprobs(2*i,j)=1;
	else{ 
	  if (data(i,j)==anchap(j)) obsprobs(2*i,j)= mutmatch(j);
	  else obsprobs(2*i,j)= mutnomatch(j); } } } }
}

void makemodlshareCnum(dgrid &obsprobsCnum, const igrid &data,
		       const dgrid &obsprobs, const int &C, 
		       const int &LE, const int &RE)
{
  //completes observation probabilities generated in makenullshareCnum;
  //includes components that depend on tau through m(l,tau,i,j)

  //assumes no mutation at center
  int i;
  const int numhap=data.dimx();
  
  if (LE<C && C<RE) {
    for (i=0;i<numhap;i++) {
      
      if (data(i,C-1)!=0) {
	if (data(i,C+1)!=0) {
	  
	  obsprobsCnum(8*i  ,0)= obsprobsCnum(8*i+4,0)= obsprobs(2*i,C+1);
	  obsprobsCnum(8*i  ,1)= obsprobsCnum(8*i+1,1)= obsprobs(2*i,C-1); } 
	
	else {
	  
	  obsprobsCnum(8*i  ,0)= obsprobsCnum(8*i+4,0)= 1;
	  obsprobsCnum(8*i  ,1)= obsprobsCnum(8*i+1,1)= obsprobs(2*i,C-1); } }
      
      else {
	
	if (data(i,C+1)!=0) {
	  
	  obsprobsCnum(8*i  ,0)= obsprobsCnum(8*i+4,0)= obsprobs(2*i,C+1);
	  obsprobsCnum(8*i  ,1)= obsprobsCnum(8*i+1,1)= 1; } 
	
	else {
	  
	  obsprobsCnum(8*i  ,0)= obsprobsCnum(8*i+4,0)= 
	    obsprobsCnum(8*i  ,1)= obsprobsCnum(8*i+1,1)= 1; } } } 
    
  }
}

void makemodlshareCden(dgrid &obsprobsCden, const igrid &data,
		       const dgrid &obsprobs, 
		       const int &C, 
		       const int &LE, const int &RE)
{
  //completes observation probabilities generated in makenullshareCden;
  //includes components that depend on tau through m(l,tau,i,j)

  int i;
  const int numhap=data.dimx();

  if (LE<C && C<RE) {
    for (i=0;i<numhap;i++) {
      if (data(i,C+1)!=0) {
	obsprobsCden(4*i,0)=obsprobs(2*i,C+1); }
      else obsprobsCden(4*i,0)= 1;
      
      if (data(i,C-1)!=0) {
	obsprobsCden(4*i,1)=obsprobs(2*i,C-1); }
      else obsprobsCden(4*i,1)= 1; }
  }
}

void makeinit(dgrid &initstore, const dvector &x, const double &tau, 
	      const double &p, const int &C, const int &LE, const int &RE)
{
  //computes (unconditional) distribution of MC at 4 locations

  //initial distribution of chain on LHS
  if (C<LE) initstore(0,0)=(1-p); //A 
  else initstore(0,0)=(1-p)* exp( -tau * x[LE]);//A
  initstore(1,0)=1-initstore(0,0);//N

  //make unconditional distribution of chain at C
  initstore(0,1)=(1-p);
  initstore(1,1)=p;

  //initial distribution of chain on RHS
  if (RE<C) initstore(0,2)=(1-p); //A 
  else initstore(0,2)=(1-p)* exp( -tau * x[RE]);//A
  initstore(1,2)=1-initstore(0,2);//N

  //unconditional distribution of chain at C-1
  if (RE<C) initstore(0,3)=(1-p)*exp( -tau * x[RE]); //A 
  else initstore(0,3)=(1-p)*exp( -tau * x[C-1]); //A 
  initstore(1,3)=1-initstore(0,3);//N
}


void maketrans(dgrid &transstore, const dvector &x, 
	       const double &tau, const double &p, 
	       const int &C, const int &LE, const int &RE)
{
  //computes transition probabilities in the direction of C
  int i;

  if (LE<C && C<RE) {
    for (i=LE;i<C;i++)
      {
	transstore(0,i)=1;//A->A  
	transstore(1,i)=0;//A->N 
	transstore(3,i)= (1 - (1-p) * exp( -tau * x[i+1]) )/
	  (1 - (1-p) * exp( -tau * x[i]) );//N->N
	transstore(2,i)=1-transstore(3,i);//N->A
      }
    for (i=C;i<RE;i++)
      {
	transstore(0,i)=1;//A<-A  
	transstore(2,i)=0;//N<-A
	transstore(3,i)= (1 - (1-p) * exp( -tau * x[i]) )/
	  (1 - (1-p) * exp( -tau * x[i+1]) );//N<-N
	transstore(1,i)=1-transstore(3,i); //A<-N
      }
  }
  else { 
    if (C<LE) {
      transstore(0,C)=1;//A<-A  
      transstore(2,C)=0;//N<-A
      transstore(3,C)= p/(1 - (1-p) * exp( -tau * x[LE]) );//N<-N
      transstore(1,C)=1-transstore(3,C);//A<-N
      
      for (i=LE;i<RE;i++) {
	transstore(0,i)=1;//A<-A  
	transstore(2,i)=0;//N<-A
	transstore(3,i)= (1 - (1-p) * exp( -tau * x[i]) )/
	  (1 - (1-p) * exp( -tau * x[i+1]) );//N<-N
	transstore(1,i)=1-transstore(3,i); //A<-N
      }
    }
    else {//(RE<C)
      for (i=LE;i<RE;i++) {
	transstore(0,i)=1;//A->A  
	transstore(1,i)=0;//A->N
	transstore(3,i)= (1 - (1-p) * exp( -tau * x[i+1]) )/
	  (1 - (1-p) * exp( -tau * x[i]) );//N->N
	transstore(2,i)=1-transstore(3,i); } //N->A
      
      transstore(0,RE)=1;//A->A  
      transstore(1,RE)=0;//A->N
      transstore(3,RE)= p/(1 - (1-p) * exp( -tau * x[RE]) );//N->N
      transstore(2,RE)=1-transstore(3,RE); //N->A
    }
  }
}

void maketransC(dgrid &transstoreC, const dvector &x, 
		const double &tau, const double &p, 
		const int &C, const int &LE, const int &RE)
{
  //computes transition probabilities for jumps away from C
  if (LE<C)  
    {
      transstoreC(0,0)= exp( -tau * x[C-1]);//A<-A  
      transstoreC(1,0)=0;//A<-N 
      transstoreC(2,0)= 1- transstoreC(0,0);//N<-A
      transstoreC(3,0)=1;//N<-N
    }
  else transstoreC(0,0)=transstoreC(1,0)=
	 transstoreC(2,0)=transstoreC(3,0)= 0;
  
  if (C<RE)  
    {
      transstoreC(0,1)= exp( -tau * x[C+1]);//A->A  
      transstoreC(1,1)=1- transstoreC(0,1);//A->N 
      transstoreC(2,1)=0;//N->A
      transstoreC(3,1)=1;//N->N
    }
  else {
    transstoreC(0,1)=transstoreC(1,1)=
      transstoreC(2,1)=transstoreC(3,1)= 0; }
  
}

void makealpha(dgrid &alpha, const dgrid &initstore, 
	       const dgrid &transstore, const dgrid &transstoreC,
	       const dgrid &obsprobs, const dgrid &obsprobsCnum,
	       const dgrid &obsprobsCden, 
	       const int &C, const int &LE, const int &RE)
{
  //assembles alpha object

  int i,j,k,h,t;
  const int numhap=alpha.dimx()/2;
  int LHSle, LHSre, RHSle, RHSre;

  double sum,sum1;

  //these are endpoints for t, rather than t+1
  if (LE<C && C<RE) {
    LHSle=LE; LHSre=C-2;
    RHSle=C;  RHSre=RE-1; }
  else {
    if (C<LE) {
      LHSle=0,LHSre=-1;
      RHSle=LE;  RHSre=RE-1; }
    else {//RE<C
      LHSle=LE; LHSre=RE-1;
      RHSle=0,RHSre=-1; }
  }

  for (h=0;h<numhap;h++) {

    if (C<LE) {
      //make column(s) C
      for (i=0;i<2;i++) {
	alpha(2*h+i,C)=obsprobs(2*h+i,C); }
      
      //make column(s) LE
      for (j=0;j<2;j++) {
	sum=0;
	for (i=0;i<2;i++) {
	  sum+=
	    alpha(2*h+i,C)*transstore(2*i+j,C)*obsprobs(2*h+j,LE); }
	alpha(2*h+j,LE)=sum; } }
    else {//LE<C
      //make column(s) LE
      for (i=0;i<2;i++) {
	alpha(2*h+i,LE)=initstore(i,0)*obsprobs(2*h+i,LE); 

      } }
    
    //make column(s) LE+1:C-1 or LE+1:RE
    for (t=LHSle;t<=LHSre;t++) {
      for (j=0;j<2;j++) {
	sum=0;
	for (i=0;i<2;i++) {
	  sum+=
	    alpha(2*h+i,t)*transstore(2*i+j,t)*obsprobs(2*h+j,t+1); 
	}
	alpha(2*h+j,t+1)=sum; } }
      
    if (LE<C && C<RE) {
      //make column(s) C
      for (j=0;j<2;j++) {

	//make numerator
	sum=0;
	for (i=0;i<2;i++) {
	  sum1=0;
	  for (k=0;k<2;k++) {
	    sum1+=transstoreC(2*j+k,1)*obsprobsCnum(8*h+4*i+2*j+k,0); 
	  }
	  sum+=alpha(2*h+i,C-1)*transstore(2*i+j,C-1)*sum1; 
	  	  
	}
	alpha(2*h+j,C)=sum;
	
	//make denominator
	sum=0;
	for (k=0;k<2;k++) {
	  sum+=initstore(j,1)*transstoreC(2*j+k,1)*
	    obsprobsCden(4*h+2*j+k,0); }
	
	if (sum==0) {
	  if (alpha(2*h+j,C)!=0) {
	    errorfs << "ERROR: alpha(2*" << h << "+" << j << ",C) = "
		 << "P(A|B) = P(A&B)/P(B) =!0/0" << endl; 
	    exit(1);
	  } }
	else alpha(2*h+j,C)/=sum; 
	
      }
    }

    //make column(s) C+1:RE or LE+1:RE
    for (t=RHSle;t<=RHSre;t++) {
      for (j=0;j<2;j++) {
	sum=0;
	for (i=0;i<2;i++) {
	  sum+=
	    alpha(2*h+i,t)*transstore(2*i+j,t)*obsprobs(2*h+j,t+1); 
	  
	}
	alpha(2*h+j,t+1)=sum; } }

    if (RE<C) {
      //make column(s) C
      for (j=0;j<2;j++) {
	if (initstore(j,1)==0) alpha(2*h+j,C)=0;
	else {
	  sum=0;
	  for (i=0;i<2;i++) {
	    sum+=
	      alpha(2*h+i,RE)*transstore(2*i+j,RE)*obsprobs(2*h+j,C); 
	    
	  }
	  sum/=initstore(j,1);  
	  alpha(2*h+j,C)=sum; } }
    }

  }

  for (h=0;h<alpha.dimx()/2;h++) {
    for (i=0;i<2;i++) {
      if ( !(-0.00001<alpha(2*h+i,C) && alpha(2*h+i,C)<1.00001) ) {
	errorfs << "alpha (h) out of bounds " 
		<< h << " " << alpha(2*h+i,C) << endl; 
	exit(1);
      }
      for (j=LE;j<=RE;j++) {
	if ( !(-0.00001<alpha(2*h+i,j) && alpha(2*h+i,j)<1.00001) ) {
	  errorfs << "alpha (h) out of bounds "
		  << h << " " << alpha(2*h+i,j) << endl; 
	  exit(1);
	}
      }
    }
  }

  for (h=0;h<numhap;h++) {
    for (i=0;i<2;i++) {
      if (C<LE || RE<C) 
	if ( !(-0.0001<=alpha(2*h+i,C) && alpha(2*h+i,C)<=1.0001) ) {
	  errorfs << "ERROR: alpha(2*" << h << "+" << i << ",C) is not between 0 and 1 ("
	       << alpha(2*h+i,C) << ")" << endl;
	  exit(1); }
      for (j=LE;j<=RE;j++)
	if ( !(-0.0001<=alpha(2*h+i,j) && alpha(2*h+i,j)<=1.0001) ) {
	  errorfs << "ERROR: alpha(2*" << h << "+" << i << "," << j <<") is not between 0 and 1 ("
	       << alpha(2*h+i,j) << ")" << endl;
	  exit(1); }
    }
  }

  
}

void makebeta(dgrid &beta, const dgrid &initstore, 
	      const dgrid &transstore, const dgrid &transstoreC,
	      const dgrid &obsprobs, const dgrid &obsprobsCnum,
	      const dgrid &obsprobsCden, 
	      const int &C, const int &LE, const int &RE)
{
  //assembles beta object  
  int i,j,k,h,t;
  const int numhap=beta.dimx()/2;
  int LHSle, LHSre, RHSle, RHSre;

  double sum,sum1,sum2;

  if (LE<C && C<RE) {
    LHSle=LE; LHSre=C-2;
    RHSle=C;  RHSre=RE-1; }
  else {
    if (C<LE) {
      LHSle=0,LHSre=-1;
      RHSle=LE;  RHSre=RE-1; }
    else {//RE<C
      LHSle=LE; LHSre=RE-1;
      RHSle=0,RHSre=-1; }
  }

  for (h=0;h<numhap;h++) {

    if (RE<C) {
      //make column(s) C
      for (i=0;i<2;i++) {
	beta(2*h+i,C)=initstore(i,2); }
    
      //make column(s) RE
      for (i=0;i<2;i++) {
	beta(2*h+i,RE)=1; }//bc condition on matching at C in haps. See notes!
    }
    else {//C<RE
      //make column(s) RE
      for (i=0;i<2;i++) {
	beta(2*h+i,RE)= initstore(i,2); } 
    }
     
    //make column(s) RE-1:C or RE-1:LE
    for (t=RHSre;t>=RHSle;t--) {
      for (i=0;i<2;i++) {
	sum=0;
	for (j=0;j<2;j++) {
	  sum+=
	    beta(2*h+j,t+1)*transstore(2*i+j,t)*obsprobs(2*h+j,t+1); 
	}
	beta(2*h+i,t)=sum; } }
      

    if (LE<C && C<RE) {

      //make column(s) C-1
      for (i=0;i<2;i++) {
	//make numerator
	sum=0;
	for (j=0;j<2;j++) {
	  sum1=0;
	  //make tCjk
	  sum2=0;
	  for (k=0;k<2;k++) {
	    sum2+=initstore(j,1)*transstoreC(2*j+k,1)*
	      obsprobsCden(4*h+2*j+k,0); 
	  }
	  for (k=0;k<2;k++) {
	    if (sum2!=0) { 
	      sum1+=(initstore(j,1)*transstoreC(2*j+k,1)*
		     obsprobsCden(4*h+2*j+k,0)/sum2)*
		//these two lines were transstoreC(2*j+k,1)
		obsprobsCnum(8*h+4*i+2*j+k,1); 
	    }
	  }
	  sum+=beta(2*h+j,C)*transstoreC(2*i+j,0)*sum1; 

	}
	beta(2*h+i,C-1)=sum;
      }

      //make denominators
      for (i=0;i<2;i++) {
	sum=0;
	sum+=initstore(i,3)*obsprobsCden(4*h+2*i  ,1); 

	if (sum==0) {
	  if (beta(2*h+i,C-1)!=0) {
	    errorfs << "ERROR: beta(2*" << h << "+" << i << ",C-1) = "
		 << "P(A|B) = P(A&B)/P(B) =!0/0" << endl; 
	    exit(1);
	  } }
	else beta(2*h+i,C-1)/=sum; 
      }
  
    }
    
    //make column(s) C-2:LE or RE-1:LE
    for (t=LHSre;t>=LHSle;t--) {
      for (i=0;i<2;i++) {
	sum=0;
	for (j=0;j<2;j++) {
	  sum+=
	    beta(2*h+j,t+1)*transstore(2*i+j,t)*obsprobs(2*h+j,t+1); }
	beta(2*h+i,t)=sum; 
      } }
  
    if (C<LE) {
      //make column(s) C
      for (i=0;i<2;i++) {
	sum=0;
	for (j=0;j<2;j++) {
	  sum+=
	    beta(2*h+j,LE)*transstore(2*i+j,C)*obsprobs(2*h+j,LE); }
	beta(2*h+i,C)=sum; } }

  }

  for (h=0;h<numhap;h++) {
    for (i=0;i<2;i++) {
      if (C<LE || RE<C) 
	if ( !(-0.0001<=beta(2*h+i,C) && beta(2*h+i,C)<=1.0001) ) {
	  errorfs << "ERROR: beta(2*" << h << "+" << i << ",C) is not between 0 and 1 ("
	       << beta(2*h+i,C) << ")" << endl;
	  exit(1); }
      for (j=LE;j<=RE;j++)
	if ( !(-0.0001<=beta(2*h+i,j) && beta(2*h+i,j)<=1.0001) ) {
	  errorfs << "ERROR: beta(2*" << h << "+" << i << "," << j <<") is not between 0 and 1 ("
	       << beta(2*h+i,j) << ")" << endl;
	  exit(1); }
    }
  }

}

void makegamma(dgrid &gamma, const dgrid &alpha, const dgrid &beta,
	       const int &C, const int &LE, const int &RE)
{
  //assembles gamma object

  int i,h,t;
  double denom, sum;

  const int numhap=gamma.dimx()/2;

  for (h=0;h<numhap;h++) {
    //calculate likelihood for haplotype
    denom=0;
    for (i=0;i<2;i++) { denom+=alpha(2*h+i,C)*beta(2*h+i,C); }

    if (C<LE || RE<C) {
      for (i=0;i<2;i++) { 
	gamma(2*h+i,C)=( alpha(2*h+i,C)*beta(2*h+i,C) )/denom; } }

    for (t=LE;t<=RE;t++) {
      for (i=0;i<2;i++) { 
	gamma(2*h+i,t)=( alpha(2*h+i,t)*beta(2*h+i,t) )/denom; }


      //check that gamma's sum to 1 at each locus
      sum=0;
      for (i=0;i<2;i++) sum+=gamma(2*h+i,t);
      if ( !(0.999 <sum && sum<1.001) ) {
	errorfs << "ERROR: gamma's do not sum to 1.0 (" << sum << ") for hap. " 
	     << h << "+1 and locus " << t << "+1" << endl;
	exit(1); 
      }
    }
  }

    
  for (h=0;h<gamma.dimx();h++) {
    if (C<LE || RE<C) 
      assert(-0.0001<gamma(h,C) && gamma(h,C)<1.0001);
    for (t=LE;t<=RE;t++)
      assert(-0.0001<gamma(h,t) && gamma(h,t)<1.0001); }

}

void makecstar(dvector &cstar, const dgrid &gamma, 
	       const int &C, const int &LE, const int &RE)
{
  //makes cstar (expected number of haps sharing by descent from
  //ancestral hapl., conditional on model,data)

  int h,t;
  const int numhap=gamma.dimx()/2;
  
  if (C<LE || RE<C) {
    cstar(C)=0;
    for (h=0;h<numhap;h++) {
      cstar(C)+=gamma(2*h,C); } }

  for (t=LE;t<=RE;t++) {
    cstar(t)=0;
    for (h=0;h<numhap;h++) {
      cstar(t)+=gamma(2*h,t); } }
}

void makecstarmut(dvector &cstarmut, const dvector &cstar,
		  const dgrid &gamma,
		  const igrid &data, const ivector &anchap,
		  const dvector &mutmatch,
		  const int &C, const int &LE, const int &RE)
{
  //makes counts of cstar that do not match the ancestral haplotype
  //due to mutation-- these are already included in cstar but must
  //be part of complete data likelihood
  
  int h,t;
  const int numhap=gamma.dimx()/2;

  //assumes no mutation allowed at center
  cstarmut(C)=0;  

  for (t=LE;t<=RE;t++) {
    cstarmut(t)=0;
    if (t!=C) {
      for (h=0;h<numhap;h++) {
	
	if (data(h,t)!=anchap(t)) {
	  if (data(h,t)==0) { 
	    cstarmut(t)+= (1-mutmatch[t])*gamma(2*h,t); }
	  else cstarmut(t)+=gamma(2*h,t); } } } }
}

double bisecmle(const dvector &cstar, const dvector &cstarmut, 
		const dvector &y, 
		const dvector &x,
		const int &C, const int &LE, const int &RE, 
		const dvector &mut, const ivector &allelecounts,
		const double &m)
{
  //uses bisection method to find critical point of complete data
  //likelihood to generate new estimate of tau; here complete data 
  //are cstar and cstarmut

  int k;
  double low,high,tau,likder;
  double part1=0,part2=0;
  low= 1e-10;
  high= 1e10;

  for (k=1; k<=1000; k++)
    {
      assert(high-low>0);
      if ( (high-low)/high < 0.00001 || (high-low) < 0.00001) break;
      tau=(high+low)/2;

      part1=loglikder(tau,y,x,cstar,C,LE,RE);
      if (m>0) part2=mutloglikder(tau,mut,
				  cstar, cstarmut,allelecounts, C, LE, RE);
      else part2=0;

      likder = part1+part2;
      
      if (likder > 0)  low = tau;
      else high = tau;

    }

  return tau;

}//bisecmle
    
double loglikder(double &tau, const dvector &y, const dvector &x,
		 const dvector &cstar, 
		 const int &C, const int &LE, const int &RE)
{
  //computes the derivative w.r.t. tau of the complete data loglik for cstar;
  //used in bisection method

  int i;
  double value=0;

  if (C<LE) {
    value += -cstar[LE] * x[LE] + 
      (cstar[C] - cstar[LE]) * x[LE]/
      (exp(tau * x[LE])-1);
    for (i = LE+1; i<= RE; i++)
      {
	value += -cstar[i] * y[i-1] + 
	  (cstar[i-1] - cstar[i]) * y[i-1]/
	  (exp(tau * y[i-1])-1);
      } }
  else { 
    if (C>RE) {
      value += -cstar[RE] * x[RE] + 
	(cstar[C] - cstar[RE]) * x[RE]/
	(exp(tau * x[RE])-1);
      for (i = RE-1; i>= LE; i--)          //derlik for i<0
	{
	  value += -cstar[i] * y[i] + 
	    (cstar[i+1] - cstar[i]) * y[i]/
	    (exp(tau * y[i])-1);
	} } 
    else { 
      for (i = C-1; i>= LE; i--)          //derlik for i<0
	{
	  value += -cstar[i] * y[i] + 
	    (cstar[i+1] - cstar[i]) * y[i]/
	    (exp(tau * y[i])-1);
	}
  
      for (i = C+1; i<= RE; i++)         //derlik for i>0
	{
	  value += -cstar[i] * y[i-1] + 
	    (cstar[i-1] - cstar[i]) * y[i-1]/
	    (exp(tau * y[i-1])-1);
	  
	} } } 

  return value;
}//loglikder

double mutloglikder(const double &tau, const dvector &mut,//mrescaled 
		    const dvector &cstar, const dvector &cstarmut,
		    const ivector &allelecounts, const int &C, 
		    const int &LE, const int &RE)
{
  //computes the derivative w.r.t. tau of the complete data loglik for 
  //mutation process in cstarmut; code generated by maple
  //used in bisection method

  //note: mut(t) = m*( n(t) + 1 )/n(t)
  //note: assumes no mutation at C

  int t;
  double value=0;

  double t1=0,t2=0,t3=0,t7=0,t10=0,t13=0,t20=0;
  
  if (tau<0.1) return 0;
  else {
   
    for (t = LE; t<=RE; t++) {
      if (t!=C && mut(t)!=0) {
	t1 = 1.0-mut(t);
	t2 = log(t1);
	t3 = pow(t1,1.0*tau);
	t7 = t3*allelecounts(t);
	t10 = allelecounts(t)*cstar[t];
	t13 = pow(t1,2.0*tau);
	t20 = t2/(-1.0+t3)/(t7+1.0-t3)*
	  (-t10*t3+t10*t13+cstar[t]*t3-cstar[t]*t13+t7*cstarmut[t]);
	value += t20; 
      }
    }
    if (! (value<0 || value>=0) ) {
      cerr << "ERROR in mutloglikder in bisecmle; mutation contribution" 
	   << " to derivative of loglik is NaN" << endl;
      exit(1); }
  
    return value; 
  }
}//mutloglikder

double makep(const dvector &cstar, const int &numhap, const int &C)
{
  //gets new estimate of heterogeneity parameter p from cstar and numb. of 
  //haps

  double p=0;
  p=1-cstar[C]/numhap;

  return p;
}

void makeloglik(dvector &loglikvec, const dgrid &alpha, 
		const dgrid &initstore,
		const int &C, const int &RE)
{
  //computes vector of loglikelihoods from alpha and 1 unconditional distr.;
  //each entry corresponds to one haplotype
 
  int i,h,locus;
  double lik;
  const int numhap=alpha.dimx()/2;
  if (C<RE) locus=RE; else locus=C;
  
  for (h=0;h<numhap;h++) {
    loglikvec[h]=0;
    lik=0;
    for (i=0;i<2;i++) {
      lik+=alpha(2*h+i,locus)*initstore(i,2); }

    assert(0.0<lik && lik<=1.00001);
    loglikvec[h]=log(lik);
  }
}