/* ===========================================================================
*
*                            PUBLIC DOMAIN NOTICE
*               National Center for Biotechnology Information
*
*  This software/database is a "United States Government Work" under the
*  terms of the United States Copyright Act.  It was written as part of
*  the author's official duties as a United States Government employee and
*  thus cannot be copyrighted.  This software/database is freely available
*  to the public for use. The National Library of Medicine and the U.S.
*  Government have not placed any restriction on its use or reproduction.
*
*  Although all reasonable efforts have been taken to ensure the accuracy
*  and reliability of the software and data, the NLM and the U.S.
*  Government do not and cannot warrant the performance or results that
*  may be obtained by using this software or data. The NLM and the U.S.
*  Government disclaim all warranties, express or implied, including
*  warranties of performance, merchantability or fitness for any particular
*  purpose.
*
*  Please cite the author in any work or product based on this material.
*
* ===========================================================================*/
#include <ncbi.h>
#include <blast/blast.h>

#define COMMENT_CHR	'#'
#define TOKSTR	" \t\n\r"

BLAST_Error
BlastScoreBlkMatRead(sbp, fp)
	BLAST_ScoreBlkPtr	sbp;
	FILE	*fp;
{
	BLAST_AlphaMapPtr	m1, m2;
	BLAST_AlphabetPtr	printap, a1, a2;
	char	buf[512+3];
	char	temp[512];
	CharPtr	cp;
	BLAST_LetterPtr lp;
	BLAST_Letter		ch;
	BLAST_ScoreMat	matrix;
	BLAST_ScorePtr	m;
	BLAST_Score	score;
	int		a1cnt = 0, a2cnt = 0;
	BLAST_Letter
				/*a1chars[BLAST_ALPHASIZE_MAX],*/
				a2chars[BLAST_SCOREMAT_INDEX_MAX+1];
	long	lineno = 0;
	double	xscore;
	register int	i, j;

	printap = BlastAlphabetFindByID(BLAST_ALPHA_PRINT);
	if (printap == NULL)
		return blast_errno;

	a1 = sbp->a1;
	a2 = sbp->a2;

	if (a1->alphasize > DIM(a2chars) || a2->alphasize > BLAST_ALPHASIZE_MAX)
		return blast_errno = BLAST_ERR_ALPHASIZE;

	if (printap == a1)
		m1 = NULL;
	else {
		m1 = BlastAlphaMapFind(printap, a1);
		if (m1 == NULL)
			return blast_errno;
	}
	if (printap == a1)
		m2 = NULL;
	else {
		m2 = BlastAlphaMapFind(printap, a2);
		if (m2 == NULL)
			return blast_errno;
	}

	matrix = sbp->matrix;
	for (i = a1->minval; i <= (int)a1->maxval; ++i)
		for (j = a2->minval; j <= (int)a2->maxval; ++j)
			matrix[i][j] = BLAST_SCORE_MIN;

	/* Read the residue names for the second alphabet */
	while (Nlm_FileGets(buf, sizeof(buf), fp) != NULL) {
		++lineno;
		if (Nlm_StrChr(buf, '\n') == NULL)
			return blast_errno = BLAST_ERR_FFORMAT;
		if (buf[0] == COMMENT_CHR) {
			/* save the comment line in a linked list */
			*Nlm_StrChr(buf, '\n') = NULLB;
			ValNodeCopyStr(&sbp->comments, 0, buf+1);
			continue;
		}
		if ((cp = Nlm_StrChr(buf, COMMENT_CHR)) != NULL)
			*cp = NULLB;
		lp = (BLAST_LetterPtr)Nlm_StrTok(buf, TOKSTR);
		if (lp == NULL) /* skip blank lines */
			continue;
		do {
			if (m2 != NULL && !BlastAlphaMapTst(m2, ch = *lp))
				return blast_errno = BLAST_ERR_LETTER;
			if (a2cnt >= DIM(a2chars))
				return blast_errno = BLAST_ERR_ALPHASIZE;
			a2chars[a2cnt++] = (m2 != NULL ? BlastAlphaMapChr(m2, ch) : ch);
		} while ((lp = (BLAST_LetterPtr)Nlm_StrTok(NULL, TOKSTR)) != NULL);
		break;
	}

	if (a2cnt <= 1)
		return blast_errno = BLAST_ERR_FFORMAT;

	while (Nlm_FileGets(buf, sizeof(buf), fp) != NULL) {
		++lineno;
		if ((cp = Nlm_StrChr(buf, '\n')) == NULL)
			return blast_errno = BLAST_ERR_FFORMAT;
		if ((cp = Nlm_StrChr(buf, COMMENT_CHR)) != NULL)
			*cp = NULLB;
		if ((lp = (BLAST_LetterPtr)Nlm_StrTok(buf, TOKSTR)) == NULL)
			continue;
		if (m1 != NULL && !BlastAlphaMapTst(m1, ch = *lp))
			return blast_errno = BLAST_ERR_LETTER;
		cp = (CharPtr)lp;
		if ((cp = Nlm_StrTok(NULL, TOKSTR)) == NULL)
			return blast_errno = BLAST_ERR_FFORMAT;
		if (a1cnt >= BLAST_ALPHASIZE_MAX)
			return blast_errno = BLAST_ERR_ALPHASIZE;
		++a1cnt;
		/*a1chars[a1cnt++] =*/ ch = (m1 != NULL ? BlastAlphaMapChr(m1, ch) : ch);
		m = &matrix[ch][0];
		j = 0;
		do {
			if (j >= a2cnt)
				return blast_errno = BLAST_ERR_FFORMAT;
			Nlm_StrCpy(temp, cp);
			if (Nlm_StrICmp(temp, "na") == 0) {
				score = BLAST_SCORE_1MIN;
			}
			else {
				if (sscanf(temp, "%lg", &xscore) != 1)
					return blast_errno = BLAST_ERR_CONV;
				/*xscore = MAX(xscore, BLAST_SCORE_1MIN);*/
				if (xscore > BLAST_SCORE_1MAX || xscore < BLAST_SCORE_1MIN)
					return blast_errno = BLAST_ERR_SCORE_DOMAIN;
				xscore += (xscore >= 0. ? 0.5 : -0.5);
				score = (BLAST_Score)xscore;
			}
			m[a2chars[j++]] = score;
		} while ((cp = Nlm_StrTok(NULL, TOKSTR)) != NULL);
	}

	if (a1cnt <= 1)
		return blast_errno = BLAST_ERR_FFORMAT;

	return BlastScoreBlkSort(sbp);
}

BLAST_Error
BlastScoreBlkMatchWeights(sbp, reward, penalty, dmp1, dmp2, rfp1, rfp2)
	BLAST_ScoreBlkPtr	sbp;
	BLAST_Score	reward, penalty;
	BLAST_DegenMapPtr	dmp1;
	BLAST_DegenMapPtr	dmp2;
	BLAST_ResFreqPtr	rfp1;
	BLAST_ResFreqPtr	rfp2;
{
	char	buf[128];
	BLAST_ScoreMatPtr	matrix;
	BLAST_AlphabetPtr	a1, a2;
	BLAST_AlphaMapPtr	amp;
	BLAST_DegenListPtr	dlp1, dlp2;
	BLAST_Letter	c1, c11, c2, c22;
	double	p, q, x;
	int		i, j, ii, jj;

	if (sbp == NULL || rfp1 == NULL || rfp2 == NULL)
		return blast_errno = BLAST_ERR_INVAL;
	if (BlastScoreChk(penalty, reward) != BLAST_ERR_NONE)
		return blast_errno;

	a1 = sbp->a1;
	a2 = sbp->a2;

	if (dmp1 != NULL && dmp1->ap != rfp1->ap)
		return blast_errno = BLAST_ERR_INVAL;
	if (dmp1 == NULL && a1 != rfp1->ap)
		return blast_errno = BLAST_ERR_INVAL;
	if (dmp2 != NULL && dmp2->ap != rfp2->ap)
		return blast_errno = BLAST_ERR_INVAL;
	if (dmp2 == NULL && a2 != rfp2->ap)
		return blast_errno = BLAST_ERR_INVAL;
	if (dmp1 == NULL && dmp2 == NULL && a1 != a2)
		return blast_errno = BLAST_ERR_INVAL;

	amp = NULL;
	if (a1 != a2 && (amp = BlastAlphaMapFindCreate(a1, a2)) == NULL)
		return blast_errno;

	matrix = &sbp->_matrix0;
	for (i = a1->minval; i <= (int)a1->maxval; ++i)
		for (j = a2->minval; j <= (int)a2->maxval; ++j)
			(*matrix)[i][j] = BLAST_SCORE_MIN;

	for (i = 0; i < a1->alphasize; ++i) {
		c1 = a1->alist[i];
		if (dmp1 != NULL) {
			dlp1 = &dmp1->degen[c1];
			for (j = 0; j < a2->alphasize; ++j) {
				c2 = a2->alist[j];
				p = q = 0.;
				if (dmp2 != NULL) {
					dlp2 = &dmp2->degen[c2];
					for (jj = 0; jj < dlp2->cnt; ++jj) {
						c22 = dlp2->list[jj];
						for (ii = 0; ii < dlp1->cnt; ++ii) {
							c11 = dlp1->list[ii];
							x = rfp1->prob[c11] * rfp2->prob[c22];
							if (amp != NULL) {
								if (!BlastAlphaMapTst(amp, c11)) {
									q += x;
									continue;
								}
								c11 = BlastAlphaMapChr(amp, c11);
							}
							if (c11 == c22)
								p += x;
							else
								q += x;
						}
					}
				}
				else {
					for (ii = 0; ii < dlp1->cnt; ++ii) {
						c11 = dlp1->list[ii];
						x = rfp1->prob[c11] * rfp2->prob[c2];
						if (amp != NULL) {
							if (!BlastAlphaMapTst(amp, c11)) {
								q += x;
								continue;
							}
							c11 = BlastAlphaMapChr(amp, c11);
						}
						if (c11 == c2)
							p += x;
						else
							q += x;
					}
				}
				if (p != 0. || q != 0.)
					(*matrix)[c1][c2] = Nlm_Nint((p * reward + q * penalty) / (p + q));
			}
			continue;
		}
		for (j = 0; j < a2->alphasize; ++j) {
			c2 = a2->alist[j];
			p = q = 0.;
			if (dmp2 != NULL) {
				dlp2 = &dmp2->degen[c2];
				for (jj = 0; jj < dlp2->cnt; ++j) {
					c22 = dlp2->list[jj];
					x = rfp1->prob[c1] * rfp2->prob[c22];
					c11 = c1;
					if (amp != NULL) {
						if (!BlastAlphaMapTst(amp, c11)) {
							q += x;
							continue;
						}
						c11 = BlastAlphaMapChr(amp, c11);
					}
					if (c11 == c22)
						p += x;
					else
						q += x;
				}
			}
			else { /* dmp1 == NULL && dmp2 == NULL */
				x = rfp1->prob[c1] * rfp2->prob[c2];
				c11 = c1;
				if (amp != NULL) {
					if (!BlastAlphaMapTst(amp, c11)) {
						q += x;
						continue;
					}
					c11 = BlastAlphaMapChr(amp, c11);
				}
				if (c11 == c2)
					p = x;
				else
					q = x;
			}
			if (p != 0. || q != 0.)
				(*matrix)[c1][c2] = Nlm_Nint((p * reward + q * penalty) / (p + q));
		}
	}

	sprintf(buf, "%+ld,%+ld", (long)reward, (long)penalty);
	sbp->name = StrSave(buf);

	if (sbp->comments != NULL)
		ValNodeCopyStr(&sbp->comments, 0, "");
	ValNodeCopyStr(&sbp->comments, 0, "  This matrix was created by the function BlastScoreBlkMatchWeights()");
	sprintf(buf, "  in the BLAST function library version %d.", BLAST_VERSION);
	ValNodeCopyStr(&sbp->comments, 0, buf);
	sprintf(buf, "  Match reward = %lg, Mismatch penalty = %lg",
				(double)reward, (double)penalty);
	ValNodeCopyStr(&sbp->comments, 0, buf);

	return BlastScoreBlkSort(sbp);
}

BLAST_Error
BlastScoreBlkSort(sbp)
	BLAST_ScoreBlkPtr	sbp;
{
	BLAST_AlphabetPtr	a1, a2;
	BLAST_ScorePtr	m;
	BLAST_Score	score;
	BLAST_Letter	ch1, t;
	BLAST_LetterPtr	o, p, q;
	int		i, j;

	a1 = sbp->a1; a2 = sbp->a2;

	/* For unspecified scores, the cost will default to the lowest possible */
	for (i=0; i < DIM(sbp->maxcost); ++i)
		sbp->maxcost[i] = BLAST_SCORE_MIN / BLAST_WORDSIZE_MAX;

	sbp->loscore = BLAST_SCORE_1MAX;
	sbp->hiscore = BLAST_SCORE_1MIN;
	for (i=0; i < a1->alphasize; ++i) {
		m = sbp->matrix[ch1 = a1->alist[i]];
		for (j=0; j < a2->alphasize; ++j) {
			score = m[a2->alist[j]];
			if (score <= BLAST_SCORE_MIN || score >= BLAST_SCORE_MAX)
				continue;
			if (sbp->loscore > score)
				sbp->loscore = score;
			if (sbp->hiscore < score)
				sbp->hiscore = score;
		}

		/* bubble-sort residue names by substitution cost for i-th residue */
		Nlm_MemCpy((char *)(o = sbp->order[ch1]), (char *)a2->alist,
					a2->alphasize * sizeof(BLAST_Letter));
		for (q = o+(a2->alphasize-1); q > o; --q) {
			for (p = o; p < q; ++p)
				if (m[*p] < m[*q]) {
					t = *p;
					*p = *q;
					*q = t;
				}
		}
		sbp->maxcost[ch1] = m[*o];
	}
	if (sbp->loscore < BLAST_SCORE_1MIN || sbp->hiscore > BLAST_SCORE_1MAX)
		return blast_errno = BLAST_ERR_SCORE_DOMAIN;

	return BlastScoreChk(sbp->loscore, sbp->hiscore);
}


BLAST_ScoreBlkPtr
BlastScoreBlkNew(a1, a2)
	BLAST_AlphabetPtr	a1, a2;
{
	BLAST_ScoreBlkPtr	sbp;
	BLAST_ScoreMat	matrix;
	BLAST_LetterMat	order;
	size_t	range;

	/*
	Only the range of a2 need be checked because the substitution
	matrix can be allocated as large as necessary in its first dimension.
	*/
	if (a2->maxval > BLAST_SCOREMAT_INDEX_MAX) {
		blast_errno = BLAST_ERR_ALPHAVAL;
		return NULL;
	}

	sbp = (BLAST_ScoreBlkPtr)BlastCalloc(sizeof(*sbp));
	if (sbp == NULL)
		return NULL;
	sbp->a1 = a1; sbp->a2 = a2;

	range = a1->maxval - a1->minval + 1;
	matrix = (BLAST_ScoreMat) BlastMalloc(sizeof(*matrix)*range);
	if (matrix == NULL) {
		BlastScoreBlkDestruct(sbp);
		return NULL;
	}
	sbp->_matrix0 = matrix;
	sbp->matrix = (matrix -= a1->minval);

	range = a1->letter_max - a1->letter_min + 1;
	order = (BLAST_LetterMat) BlastMalloc(sizeof(*order)*range);
	if (order == NULL) {
		BlastScoreBlkDestruct(sbp);
		return NULL;
	}
	sbp->_order0 = order;
	sbp->order = order - a1->letter_min;

	if (a1 != a2) {
		sbp->amp12 = BlastAlphaMapFindCreate(a1, a2);
		sbp->amp21 = BlastAlphaMapFindCreate(a2, a1);
	}

	return sbp;
}

void
BlastScoreBlkDestruct(sbp)
	BLAST_ScoreBlkPtr	sbp;
{
	if (sbp == NULL)
		return;

	if (sbp->matrix != NULL)
		BlastFree(sbp->_matrix0);
	if (sbp->order != NULL)
		BlastFree(sbp->_order0);

	ValNodeFreeData(sbp->comments);

	Nlm_MemSet((CharPtr)sbp, 0, sizeof(*sbp));
	BlastFree(sbp);
}

BLAST_Error
BlastScoreChk(lo, hi)
	BLAST_Score	lo, hi;
{
	if (lo >= 0 || hi <= 0 ||
			lo < BLAST_SCORE_1MIN || hi > BLAST_SCORE_1MAX)
		return blast_errno = BLAST_ERR_SCORE_DOMAIN;
	if (hi - lo > BLAST_SCORE_RANGE_MAX)
		return blast_errno = BLAST_ERR_SCORE_RANGE;
	return BLAST_ERR_NONE;
}

BLAST_ScoreFreqPtr
BlastScoreFreqNew(score_min, score_max)
	BLAST_Score	score_min, score_max;
{
	BLAST_ScoreFreqPtr	sfp;
	BLAST_Score	range;

	if (BlastScoreChk(score_min, score_max) != BLAST_ERR_NONE)
		return NULL;

	sfp = (BLAST_ScoreFreqPtr) BlastCalloc(sizeof(*sfp));
	if (sfp == NULL)
		return NULL;

	range = score_max - score_min + 1;
	sfp->sprob = (double PNTR) BlastCalloc(sizeof(sfp->sprob[0]) * range);
	if (sfp->sprob == NULL) {
		BlastScoreFreqDestruct(sfp);
		return NULL;
	}
	sfp->_sprob0 = sfp->sprob;
	sfp->sprob -= score_min;
	sfp->score_min = score_min;
	sfp->score_max = score_max;
	sfp->obs_min = sfp->obs_max = 0;
	sfp->score_avg = 0.0;
	return sfp;
}

void
BlastScoreFreqDestruct(sfp)
	BLAST_ScoreFreqPtr	sfp;
{
	if (sfp->_sprob0 != NULL)
		BlastFree(sfp->_sprob0);
	Nlm_MemSet((CharPtr)sfp, 0, sizeof(*sfp));
	BlastFree(sfp);
	return;
}

BLAST_Error
BlastScoreFreqCalc(sfp, sbp, rfp1, rfp2)
	BLAST_ScoreFreqPtr	sfp;
	BLAST_ScoreBlkPtr	sbp;
	BLAST_ResFreqPtr	rfp1, rfp2;
{
	BLAST_ScoreMat	matrix;
	BLAST_Score	s, obs_min, obs_max;
	BLAST_AlphabetPtr	a1, a2;
	BLAST_Letter	ch1, ch2;
	double	score_sum, score_avg;
	int		i, j;

	a1 = rfp1->ap;
	a2 = rfp2->ap;
	if (sbp->a1 != a1 || sbp->a2 != a2)
		return blast_errno = BLAST_ERR_ALPHACONFLICT;

	if (sbp->loscore < sfp->score_min || sbp->hiscore > sfp->score_max)
		return blast_errno = BLAST_ERR_SCORE_RANGE;

	for (s = sfp->score_min; s <= sfp->score_max; ++s)
		sfp->sprob[s] = 0.;

	matrix = sbp->matrix;

	for (i = 0; i < sbp->a1->alphasize; ++i) {
		ch1 = sbp->a1->alist[i];
		for (j = 0; j < sbp->a2->alphasize; ++j) {
			ch2 = sbp->a2->alist[j];
			s = matrix[ch1][ch2];
			if (s >= sbp->loscore) {
				sfp->sprob[s] += rfp1->prob[ch1] * rfp2->prob[ch2];
			}
		}
	}

	score_sum = 0.;
	obs_min = obs_max = BLAST_SCORE_MIN;
	for (s = sfp->score_min; s <= sfp->score_max; ++s) {
		if (sfp->sprob[s] > 0.) {
			score_sum += sfp->sprob[s];
			obs_max = s;
			if (obs_min == BLAST_SCORE_MIN)
				obs_min = s;
		}
	}
	sfp->obs_min = obs_min;
	sfp->obs_max = obs_max;

	score_avg = 0.;
	if (score_sum > 0.00001) {
		if (score_sum < (1.-1.e-12) || score_sum > (1.+1.e-12))
			for (s = obs_min; s <= obs_max; ++s) {
				score_avg += s * (sfp->sprob[s] /= score_sum);
			}
		else
			for (s = obs_min; s <= obs_max; ++s) {
				score_avg += s * sfp->sprob[s];
			}
	}
	sfp->score_avg = score_avg;

	return BLAST_ERR_NONE;
}

BLAST_Score
BlastScoreMaxAchievable(sbp, sp, start, len)
	BLAST_ScoreBlkPtr	sbp;
	BLAST_StrPtr	sp;
	BLAST_Coord	start, len;
{
	register BLAST_LetterPtr	lp, lpmax;
	register BLAST_ScorePtr	maxcost;
	register BLAST_Score	sum_max, sum;

	if (sp->lpb != 1) {
		blast_errno = BLAST_ERR_ENCODING;
		return -1;
	}

	if (sp->ap != sbp->a1 || sp->ap != sbp->a2) {
		blast_errno = BLAST_ERR_ALPHACONFLICT;
		return -1;
	}

	if (start + len > sp->len) {
		blast_errno = BLAST_ERR_DOMAIN;
		return -1;
	}

	lp = sp->str + start;
	maxcost = sbp->maxcost;

	lpmax = lp + len;
	sum_max = sum = 0;
	while (lp < lpmax) {
		sum += maxcost[*lp++];
		if (sum > sum_max)
			sum_max = sum;
		else
			if (sum < 0)
				sum = 0;
	}
	return sum_max;
}
