#!/usr/bin/python
# -*- coding: iso-8859-15 -*-

import sys
import copy
import numpy
from numpy import array, zeros, ones, mean, average, tile, nonzero, \
     newaxis, isnan, isfinite
import scipy.sparse
from scipy.sparse import lil_matrix
import time
import munkres

"""
Co-occurrence based Metric for Morphological Analysis (CoMMA)

Version: 1.0
Last modified: 2012-01-05

----------------------------------------------------------------------

Copyright (C) 2011-2012 Sami Virpioja <sami.virpioja@aalto.fi>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

----------------------------------------------------------------------

This program calculates the precision and recall for a predicted
morpheme labels when compared to reference labels of a gold standard
analysis.

The users of this program are requested to refer to the following
article in their scientific publications:

  Sami Virpioja, Ville T. Turunen, Sebastian Spiegler, Oskar Kohonen,
  and Mikko Kurimo. Empirical comparison of evaluation methods for
  unsupervised learning of morphology. Traitement Automatique des
  Langues, 52(2), 2011.

To run the program, the Munkres module by Brian M. Clapper is
required. See

* http://pypi.python.org/pypi/munkres/ or
* http://software.clapper.org/munkres/

There are various options available for the calculations. The versions
described in the article above correspond to the following options:

* CoMMA-B0: --bestalts
* CoMMA-B1: --bestalts --diagonals
* CoMMA-S0: --strictalts
* CoMMA-S1: --strictalts --diagonals

The default (no options given) is CoMMA-S0.

Input files should in a format similar to Morpho Challenge results:
The word and its analyses are separated by a tabular character, any
alternative analyses by a comma and a space, and the labels of the
analyses by single space. Example:

brushes	brush_N +3SG, brush_N +PL

The effect of each word to the overall precision and recall can be
modified by giving numerical weights (--weightFile). The file should
have word and its weight, separated by whitespace, per line.
"""

class Analysis:
    """Stores the analyses of words"""

    def __init__(self):
        self.data = {}
        self.wlist = []
        self.morph_count = {}

    def get_data(self):
        """Return the stored data"""
        return self.data

    def get_words(self):
        """Return the list of words"""
        return self.wlist

    def get_size(self):
        """Return the number of words"""
        return len(self.wlist)

    def get_analysis(self, w):
        """Return the analysis for given word"""
        return self.data[w]

    def load(self, filename, wdict=None):
        """Load segmentations from given file

        Given a dictionary object wdict, load only the words found in it."""
        fobj = open(filename, 'r')
        for line in fobj:
            if line[0] == '#':
                continue
            w, a = line.split("\t")
            if wdict == None or w in wdict:
                self.wlist.append(w)
                self.data[w] = []
                for mstr in a.split(", "):
                    d = {}
                    for m in mstr.split():
                        if m in d:
                            d[m] += 1
                        else:
                            d[m] = 1
                        if m in self.morph_count:
                            self.morph_count[m] += 1
                        else:
                            self.morph_count[m] = 1
                    self.data[w].append(d)

    def equalize(self, ref):
        """Remove words that are not in the given Analysis instance"""
        for w in self.wlist:
            if not w in ref.data:
                del self.data[w]
        self.wlist = copy.copy(ref.wlist)

    def common_morphs(self, w1, w2):
        """The number of common morphs in the analyses of words"""
        cl = []
        for mdict1 in self.data[w1]:
            cll = []
            for mdict2 in self.data[w2]:
                s = 0
                for m in mdict1:
                    if m in mdict2: s += min(mdict1[m], mdict2[m])
                cll.append(s)
            cl.append(cll)
        return cl

    def morph_product(self, w1, w2, wfunc=None):
        """Return wfunc(x) for the number of common morphs x of words"""
        cl = []
        for mdict1 in self.data[w1]:
            cll = []
            for mdict2 in self.data[w2]:
                s = 0.0
                for m in mdict1:
                    if m in mdict2:
                        if wfunc == None:
                            n = 1.0
                        else:
                            n = wfunc(self.morph_count[m])
                        s += mdict1[m] * float(mdict2[m]) / n
                cll.append(s)
            cl.append(cll)
        return cl

    def word_similarity(self, w1, w2, sim_func):
        """Return the similarity of two words for given measure"""
        if sim_func == 'intersection':
            return self.common_morphs(w1, w2)
        elif sim_func == 'product':
            return self.morph_product(w1, w2, None)
        elif sim_func == 'wproduct':
            return self.morph_product(w1, w2, lambda x: x)
        elif sim_func == 'sqwproduct':
            return self.morph_product(w1, w2, lambda x: x**2)
        elif sim_func == 'logwproduct':
            return self.morph_product(w1, w2, lambda x: 1.0+numpy.log(x))
        else:
            raise StandardError("Unknown similarity measure '%s'\n" % sim_func)

def analysis2matrix(analysis, diags=False, sim_func='intersection'):
    """Reduce the analysis to a matrix of co-occurrences

    The best matching alternatives are selected for each word pair."""
    wlist = analysis.get_words()
    n = len(wlist)
    M = lil_matrix((n, n), dtype=float)
    for i in range(n):
        w1 = wlist[i]
        for j in range(i, n):
            if not diags and i == j:
                continue
            w2 = wlist[j]
            cl = analysis.word_similarity(w1, w2, sim_func)
            m = max(map(max, cl))
            M[i,j] = m
            M[j,i] = m
    return M.tocsr()

def recall_eval(Mpre, Mref, saxis=1, weights=None):
    """Calculate the recall of matrix Mpre with respect to matrix Mref

    Precision can be calculated by switching the parameters."""
    diff = Mref-Mpre
    E = (abs(diff) + diff) / 2        # recall errors
    E = E / Mref                      # normalize per word
    tot = (Mref / Mref).sum(saxis)    # words with co-occ morphs
    recv = (tot - E.sum(saxis)) / tot # recall vector
    if recv.size == 1:
        rec = recv[0,0]
        if isnan(rec):
            return 1.0, 0
        else:
            return rec, 1
    i = nonzero(isfinite(recv))
    n = i[0].size
    if n == 0:
        return 1.0, 0
    if weights == None:
        rec = mean(recv[i])
    else:
        rec = average(recv[i], weights=weights[i])
    return rec, n

def opt_alt_match(pre, ref):
    """Calculate recall for the analyses of single word

    Uses Munkres to find the pairing of alternative analyses that
    gives optimal recall. Precision can be calculated by switching the
    parameters."""
    ref_altnum = ref.shape[0]
    pre_altnum = pre.shape[0]
    n = max(ref_altnum, pre_altnum)
    w = [[0 for v in range(n)] for u in range(n)]
    zero_hits = {}
    for i in range(n):
        for j in range(n):
            if i < ref_altnum and j < pre_altnum:
                r, t = recall_eval(pre[j,:], ref[i,:], 1)
                if t == 0:
                    zero_hits[i] = 1
                    w[i][j] = 1.0
                else:
                    w[i][j] = 1.0 - r
            else:
                w[i][j] = 1.0
    m = munkres.Munkres()
    indexes = m.compute(w)
    rsum = 0
    for i, j in indexes:
        if not i in zero_hits:
            rsum += 1.0 - w[i][j]
    nnz = ref_altnum - len(zero_hits)
    if nnz == 0:
        return 1.0, 0, indexes
    return rsum / nnz, nnz, indexes

def recall_eval_amb(Apre, Aref, diags=False, weights=None, 
                    sim_func='intersection'):
    """Evaluation using matching of alternative analyses separately
    for best precision and recall"""
    wlist = Aref.get_words()
    wnum = len(wlist)
    total = 0
    s = 0.0
    for i in range(wnum):
        w = wlist[i]
        ref_altnum = len(Aref.get_analysis(w))
        pre_altnum = len(Apre.get_analysis(w))
        ref_v = lil_matrix((ref_altnum, wnum), dtype=float)
        pre_v = lil_matrix((pre_altnum, wnum), dtype=float)
        for j in range(wnum):
            if not diags and i == j:
                continue
            w2 = wlist[j]
            refsim = Aref.word_similarity(w, w2, sim_func)
            presim = Apre.word_similarity(w, w2, sim_func)
            ref_cl = map(max, refsim)
            for k in range(len(ref_cl)):
                ref_v[k,j] = ref_cl[k]
            pre_cl = map(max, presim)
            for k in range(len(pre_cl)):
                pre_v[k,j] = pre_cl[k]
        ref_v = ref_v.tocsr()
        pre_v = pre_v.tocsr()
        t, n, I = opt_alt_match(pre_v, ref_v)
        if n > 0:
            if weights == None:
                s += t
                total += 1
            else:
                s += weights[i] * t
                total += weights[i]
    if total == 0:
        return 1.0, 0
    return s / total, total

def f_opt_alt_match(pre, ref):
    """Calculate precision and recall for one word (strict matching)

    Uses Munkres to find the pairing of alternative analyses that
    gives optimal F-measure."""
    ref_altnum = ref.shape[0]
    pre_altnum = pre.shape[0]
    n = max(ref_altnum, pre_altnum)
    w = [[0 for v in range(n)] for u in range(n)]
    results = {}
    pre_nnz = 0
    rec_nnz = 0
    for i in range(n):
        if i < pre_altnum and pre[i,:].sum()>0:
            pre_nnz += 1
        if i < ref_altnum and ref[i,:].sum()>0:
            rec_nnz += 1
        for j in range(n):
            if i < ref_altnum and j < pre_altnum:
                pre_r, pre_t = recall_eval(ref[i,:], pre[j,:], 1)
                if pre_t == 0:
                    pre_r = 0.0
                rec_r, rec_t = recall_eval(pre[j,:], ref[i,:], 1)
                if rec_t == 0:
                    rec_r = 0.0
            else:
                pre_r = 0.0
                rec_r = 0.0
            results[(i,j)] = (pre_r, rec_r)
            if pre_r + rec_r == 0:
                f = 0.0
            else:
                f = 2.0*pre_r*rec_r/(pre_r+rec_r)
            w[i][j] = 1.0 - f # cost
    m = munkres.Munkres()
    indexes = m.compute(w)
    pre_rsum = 0
    rec_rsum = 0
    for i, j in indexes:
        pre_rsum += results[(i,j)][0]
        rec_rsum += results[(i,j)][1]
    if pre_nnz == 0:
        pre_rsum = 1.0
    else:
        pre_rsum = pre_rsum / pre_nnz
    if rec_nnz == 0:
        rec_rsum = 1.0
    else:
        rec_rsum = rec_rsum / rec_nnz
    return pre_rsum, rec_rsum, pre_nnz, rec_nnz, indexes

def global_eval_amb(Apre, Aref, diags=False, weights=None, 
                    sim_func='intersection'):
    """Evaluation using strict matching of alternative analyses"""
    wlist = Aref.get_words()
    wnum = len(wlist)
    pre_total = 0
    rec_total = 0
    pre_sum = 0.0
    rec_sum = 0.0
    for i in range(wnum):
        w = wlist[i]
        ref_altnum = len(Aref.get_analysis(w))
        pre_altnum = len(Apre.get_analysis(w))
        ref_v = lil_matrix((ref_altnum, wnum), dtype=float)
        pre_v = lil_matrix((pre_altnum, wnum), dtype=float)
        for j in range(wnum):
            if not diags and i == j:
                continue
            w2 = wlist[j]
            refsim = Aref.word_similarity(w, w2, sim_func)
            presim = Apre.word_similarity(w, w2, sim_func)
            ref_cl = map(max, refsim)
            for k in range(len(ref_cl)):
                ref_v[k,j] = ref_cl[k]
            pre_cl = map(max, presim)
            for k in range(len(pre_cl)):
                pre_v[k,j] = pre_cl[k]
        ref_v = ref_v.tocsr()
        pre_v = pre_v.tocsr()
        pre_t, rec_t, pre_n, rec_n, I = f_opt_alt_match(pre_v, ref_v)
        
        if pre_n > 0:
            if weights == None:
                pre_sum += pre_t
                pre_total += 1
            else:
                pre_sum += weights[i] * pre_t
                pre_total += weights[i]
        if rec_n > 0:
            if weights == None:
                rec_sum += rec_t
                rec_total += 1
            else:
                rec_sum += weights[i] * rec_t
                rec_total += weights[i]
    if pre_total == 0:
        pre_r = 1.0
    else:
        pre_r = pre_sum / pre_total
    if rec_total == 0:
        rec_r = 1.0
    else:
        rec_r = rec_sum / rec_total
    return pre_r, rec_r, pre_total, rec_total


if __name__ == "__main__":
    from optparse import OptionParser
    parser = OptionParser("Usage: %prog [options]")
    parser.add_option("-g", "--goldFile", dest="goldFile",
                      default = None,
                      help="gold standard analysis file")
    parser.add_option("-p", "--predFile", dest="predFile",
                      default = None,
                      help="predicted analysis file")
    parser.add_option("-b", "--bestalts", dest="bestalts",
                      default = False, action = "store_true",
                      help="reduce alternative analyses to best matching"+ 
                      " analysis")
    parser.add_option("-a", "--alternatives", dest="alts",
                      default = False, action = "store_true",
                      help="evaluate each alternative analysis separately"+
                      " for precision and recall")
    parser.add_option("-s", "--strictalts", dest="strictalts",
                      default = True, action = "store_true",
                      help="strict matching of alternative analyses based"+ 
                      " on F-measure (default)")
    parser.add_option("-d", "--diagonal", dest="diag",
                      default = False, action = "store_true",
                      help="include diagonals of similarity matrix")
    parser.add_option("-m", "--measure", dest="similarity",
                      default='intersection', 
                      help="Word similarity measure: "+
                      "intersection (default) or wproduct")
    parser.add_option("-w", "--weightFile", dest="weightFile",
                      default = None,
                      help="read word weights from file")
    (options, args) = parser.parse_args()

    if options.goldFile == None or options.predFile == None:
        parser.print_help()
        sys.exit()

    ts = time.time()

    if options.weightFile != None:
        wdict = {}
        fobj = open(options.weightFile, 'r')
        for line in fobj:
            word, weight = line.split()
            wdict[word] = float(weight)
        fobj.close()
    else:
        wdict = None

    goldstd = Analysis()
    goldstd.load(options.goldFile, wdict)
    predicted = Analysis()
    predicted.load(options.predFile, goldstd.get_data())
    goldstd.equalize(predicted)

    if wdict != None:
        wlist = goldstd.get_words()
        weights = zeros((len(wlist), 1))
        for i in range(len(wlist)):
            word = wlist[i]
            weights[i] = wdict[word]
    else:
        weights = None

    print "# Gold standard file: %s" % options.goldFile
    print "# Predictions file  : %s" % options.predFile
    print "# Evaluation options:"
    print "# - word similarity measure (--measure): %s" % options.similarity
    if options.weightFile != None:
        print "# - word weights loaded from %s" % options.weightFile
    if options.diag:
        print "# - word self-hits included (--diagonal)"
    else:
        print "# - word self-hits excluded"
    if options.alts:
        print "# - alternative analyses evaluated separately (--alternatives)"
        rec, n = recall_eval_amb(predicted, goldstd, options.diag, weights, 
                                 options.similarity)
        print "# Recall based on %s words" % n
        pre, n = recall_eval_amb(goldstd, predicted, options.diag, weights, 
                                 options.similarity)
        print "# Precision based on %s words" % n
    elif options.bestalts:
        print "# - alternative analyses reduced to best matching analysis "+\
            "(--bestalts)"
        goldM = analysis2matrix(goldstd, options.diag, options.similarity)
        predM = analysis2matrix(predicted, options.diag, options.similarity)
        rec, n = recall_eval(predM, goldM, weights=weights)
        print "# Recall based on %s words" % n
        pre, n = recall_eval(goldM, predM, weights=weights)
        print "# Precision based on %s words" % n
    elif options.strictalts:
        print "# - strict matching of alternative analyses (--strictalts)"
        pre, rec, pren, recn = global_eval_amb(predicted, goldstd, 
                                               options.diag, weights, 
                                               options.similarity)
        print "# Recall based on %s words" % recn
        print "# Precision based on %s words" % pren

    te = time.time()
    print "# Evaluation time: %.2fs" % (te-ts)
    print

    f = 2.0/(1.0/pre+1.0/rec)
    print "precision: %s" % pre
    print "recall   : %s" % rec
    print "fmeasure : %s" % f

