#include <ncbi.h>
#include <blast/blast.h>

static void	score_forward PROTO((BLAST_HSPPtr PNTR, size_t));
static void	score_reverse PROTO((BLAST_HSPPtr PNTR, size_t));
static int	isforward PROTO((BLAST_HSPPtr,BLAST_HSPPtr));
static int	isreverse PROTO((BLAST_HSPPtr,BLAST_HSPPtr));

static int	fwdcmp PROTO((BLAST_HSPPtr PNTR, BLAST_HSPPtr PNTR));
static int	revcmp PROTO((BLAST_HSPPtr PNTR, BLAST_HSPPtr PNTR));

/* consist_sum -- determine the consistent sum statistics for each HSP */
void LIBCALL
consist_sum(bcp, hp0, fraction)
	BLAST_ConfigPtr	bcp;
	BLAST_HSPPtr	hp0;
	double		fraction;
{
	BLAST_HSPPtr	hspstack[200];
	BLAST_HSPPtr	PNTR flist0;
	register BLAST_HSPPtr	hp;
	register BLAST_HSPPtr	PNTR flist, PNTR flistmax;
	size_t	i, cnt, cnts[9];
	double	logKN[DIM(cnts)];
	register int	qsign, ssign, sgn;

	if (hp0 == NULL)
		return;

	if (hp0->next == NULL) {
		hp0->n = 1;
		hp0->fwdptr = hp0->revptr = NULL;
		hp0->sumscore = hp0->xscore =
			hp0->score * hp0->kbp->Lambda - hp0->kbp->logK
			- fct_ilog(hp0->q_seg.sp->efflen) - fct_ilog(hp0->s_seg.sp->efflen);
		return;
	}

	/*
	There is one element in cnts[] for every strand combination.
	{-1, 0, +1} X {-1, 0, +1} = 9 combinations
	*/
	cnts[0] = cnts[1] = cnts[2] = cnts[3] = cnts[4] = cnts[5] =
		cnts[6] = cnts[7] = cnts[8] = 0;

	for (hp = hp0; hp != NULL; hp = hp->next) {
		qsign = SIGN(hp->q_seg.frame);
		ssign = SIGN(hp->s_seg.frame);
		sgn = (qsign + 1) * 3 + ssign + 1;
		if (cnts[sgn]++ == 0)
			logKN[sgn] = hp->kbp->logK + fct_ilog(hp->q_seg.sp->efflen)
								+ fct_ilog(hp->s_seg.sp->efflen);
		hp->n = 1;
		hp->fwdptr = hp->revptr = NULL;
		hp->sumscore = hp->xscore = hp->score * hp->kbp->Lambda - logKN[sgn];
	}

	for (sgn = 0; sgn < DIM(cnts); ++sgn) {
		if ((cnt = cnts[sgn]) < 2)
			continue;
		if (cnt <= DIM(hspstack))
			flist0 = hspstack;
		else
			flist0 = (BLAST_HSPPtr PNTR) BlastMalloc(sizeof(*flist0) * cnt);
		/* Convert to a linear list for sorting */
		flist = flist0;
		flistmax = flist + cnt;
		for (hp = hp0; flist < flistmax; hp = hp->next) {
			qsign = SIGN(hp->q_seg.frame);
			ssign = SIGN(hp->s_seg.frame);
			if (sgn != (qsign + 1) * 3 + ssign + 1)
				continue;
			*flist++ = hp;

			hp->cd.tot_score = -HUGE_VAL;
			hp->cd.qbeg = hp->q_seg.offset + fraction * hp->q_seg.len;
			hp->cd.sbeg = hp->s_seg.offset + fraction * hp->s_seg.len;
			hp->cd.qend = hp->q_seg.offset + (1.-fraction) * hp->q_seg.len;
			hp->cd.send = hp->s_seg.offset + (1.-fraction) * hp->s_seg.len;
		}

		HeapSort((CharPtr)flist0, cnt, sizeof(*flist0), (int (*)())fwdcmp);
		score_forward(flist0, cnt);
		for (i = 0; i < cnt; ++i) {
			flist0[i]->n = flist0[i]->cd.n;
			flist0[i]->cd.xscore = flist0[i]->cd.tot_score;
			flist0[i]->cd.tot_score = -HUGE_VAL;
		}

		HeapSort((CharPtr)flist0, cnt, sizeof(*flist0), (int (*)())revcmp);
		score_reverse(flist0, cnt);
		for (i = 0; i < cnt; ++i) {
			flist0[i]->n += flist0[i]->cd.n - 1;
			flist0[i]->sumscore = flist0[i]->cd.xscore + flist0[i]->cd.tot_score
						- flist0[i]->xscore;
		}

		if (flist0 != hspstack)
			BlastFree(flist0);
	}
}

static int
fwdcmp(hpp1, hpp2)
	BLAST_HSPPtr	PNTR hpp1, PNTR hpp2;
{
	register BLAST_HSPPtr	hp1 = *hpp1, hp2 = *hpp2;
	register unsigned long	pos1, pos2;

	pos1 = hp1->cd.qbeg;
	pos2 = hp2->cd.qbeg;
	if (pos1 > pos2)
		return 1;
	if (pos1 < pos2)
		return -1;

	pos1 = hp1->cd.sbeg;
	pos2 = hp2->cd.sbeg;
	if (pos1 > pos2)
		return 1;
	if (pos1 < pos2)
		return -1;
	return 0;
}

static int
isforward(hp1, hp2)
	register BLAST_HSPPtr	hp1, hp2;
{
	return (hp1->cd.qend < hp2->cd.qbeg && hp1->cd.send < hp2->cd.sbeg);
}


static void
score_forward(list, nelem)
	BLAST_HSPPtr	PNTR list;
	size_t	nelem;
{
	register BLAST_HSPPtr	hp1, hp2;
	register unsigned	n1plus1;
	register double	xscore, ts;
	register int		i, j;

	for (i = 0; i < nelem; ++i) {
		hp1 = list[i];
		xscore = hp1->cd.tot_score;
		if (xscore == -HUGE_VAL) {
			xscore = hp1->cd.tot_score = hp1->xscore;
			hp1->cd.n = 1;
		}
		n1plus1 = hp1->cd.n + 1;
		for (j = i + 1; j < nelem; ++j) {
			hp2 = list[j];
			if (!isforward(hp1, hp2))
				continue;
			if ((ts = xscore + hp2->xscore) < hp2->cd.tot_score)
				continue;
			hp2->revptr = hp1;
			hp2->cd.tot_score = ts;
			hp2->cd.n = n1plus1;
		}
	}
}

static int
revcmp(hpp1, hpp2)
	BLAST_HSPPtr	PNTR hpp1, PNTR hpp2;
{
	register BLAST_HSPPtr	hp1 = *hpp1, hp2 = *hpp2;
	register unsigned long	pos1, pos2;

	pos1 = hp1->cd.qend;
	pos2 = hp2->cd.qend;
	if (pos1 < pos2)
		return 1;
	if (pos1 > pos2)
		return -1;

	pos1 = hp1->cd.send;
	pos2 = hp2->cd.send;
	if (pos1 < pos2)
		return 1;
	if (pos1 > pos2)
		return -1;
	return 0;
}

static int
isreverse(hp1, hp2)
	register BLAST_HSPPtr	hp1, hp2;
{
	return (hp1->cd.qbeg > hp2->cd.qend && hp1->cd.sbeg > hp2->cd.send);
}

static void
score_reverse(list, nelem)
	BLAST_HSPPtr	PNTR list;
	size_t	nelem;
{
	register BLAST_HSPPtr	hp1, hp2;
	register unsigned	n1plus1;
	register double	xscore, ts;
	register int		i, j;

	for (i = 0; i < nelem; ++i) {
		hp1 = list[i];
		xscore = hp1->cd.tot_score;
		if (xscore == -HUGE_VAL) {
			xscore = hp1->cd.tot_score = hp1->xscore;
			hp1->cd.n = 1;
		}
		n1plus1 = hp1->cd.n + 1;
		for (j = i + 1; j < nelem; ++j) {
			hp2 = list[j];
			if (!isreverse(hp1, hp2))
				continue;
			if ((ts = xscore + hp2->xscore) < hp2->cd.tot_score)
				continue;
			hp2->fwdptr = hp1;
			hp2->cd.tot_score = ts;
			hp2->cd.n = n1plus1;
		}
	}
}

