#!/usr/bin/env python

"""
Code for running the experiments described at
http://compprag.christopherpotts.net/swda-clausetyping.html

The main method re-runs all the experiments sequentially, printing
summaries of the results to standard output.
"""

__author__ = "Christopher Potts"
__copyright__ = "Copyright 2011, Christopher Potts"
__credits__ = []
__license__ = "Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License: http://creativecommons.org/licenses/by-nc-sa/3.0/"
__version__ = "1.0"
__maintainer__ = "Christopher Potts"
__email__ = "See the author's website"

######################################################################
		
import re, csv
from collections import defaultdict
from operator import itemgetter
from random import shuffle
from nltk.tree import Tree
from swda import CorpusReader, Utterance

######################################################################

def root_is_labeled_SQ(tree):
    """Return true if the root of tree is an SQ node of some kind, else False."""
    if tree.node.startswith('SQ'):
        return True
    else: 
        return False

def has_leftmost_aux_daughter(tree):
    """Return True if tree has an aux-tree daughter that is at least before any NP or SBAR nodes."""
    daughters = [daughter for daughter in tree if re.search(r'(NP$|SBAR$|^V|MD)', daughter.node)]
    if daughters and is_aux_tree(daughters[0]):
        return True
    else:
        return False

def is_aux_tree(tree):
    """Return True if argument tree is of the form (V*|MD aux), else False."""
    verbal_re = re.compile(r'(^V|MD)', re.I)
    aux_re = re.compile(r'^(Is|Are|Was|Were|Have|Has|Had|Can|Could|Shall|Should|Will|Would|May|Might|Must|Do|Does|Did|Wo)$', re.I)
    if is_preterminal(tree) and verbal_re.search(tree.node) and aux_re.search(tree[0]):
        return True
    else: 
        return False

def is_preterminal(tree):
    """Return True if tree is of the form (PARENT CHILD), else False."""
    if isinstance(tree, Tree) and len(list(tree.subtrees())) == 1:
        return True
    else:
        return False    

def is_polar_interrogative(tree):
    """Return True if tree is rooted at an SQ node and has a leftmost aux, else False."""
    if root_is_labeled_SQ(tree) and has_leftmost_aux_daughter(tree):
        return True
    else:
        return False

######################################################################

def inspect_trees(characterizing_function, output_filename, num=100):
    """
    Print num trees that characterizing_function says True to, and num
    trees that characterizing_function says False to.

    Arguments:
    characterizing_function - a function that says True or False for trees
    output_filename (str) -- filename for the output
    num (int) -- the number of randomly selected trees to print out (default: 100)

    Value: stores the trees in output_filename
    """    
    yes_trees = []
    no_trees = []
    for utt in CorpusReader('swda').iter_utterances(display_progress=True):
        if utt.tree_is_perfect_match():
            tree = utt.trees[0]
            if characterizing_function(tree):
                yes_trees.append(tree)
            else:
                no_trees.append(tree)
    s = ''
    for trees, label in ((yes_trees, 'True'), (no_trees, 'False')):
        shuffle(trees)
        for t in trees[:num]:
            s += '======================================================================\n'
            s += label + "\n"
            s += t.pprint() + "\n"
    open(output_filename, 'w').write(s)
           
######################################################################            
 
def inversion_and_qy_to_csv(tree_classifying_function, output_filename):
    """
    Runs a clause-typing experiment, by relating dialog act tags to
    the values given by the supplied function.  The results are put
    into CSV format:

        Tag, PolarIntClauseType
        t1   func_value2
        t2   func_value2
        ...

    Arguments:
    tree_classifying_function -- any function from nltk.tree.Tree
                                 objects with outputs that can be
                                 sensibly stringified
    output_filename (str) -- the output CSV filename
    """
    rows = []
    for utt in CorpusReader('swda').iter_utterances(display_progress=True):
        if utt.tree_is_perfect_match():
            # Capitalize this string so that R treats it as a boolean:
            is_polar_int = str(tree_classifying_function(utt.trees[0])).upper()
            rows.append([utt.act_tag, is_polar_int])
    csvwriter = csv.writer(open(output_filename, 'w'))
    csvwriter.writerow(['Tag', 'PolarIntClauseType'])
    csvwriter.writerows(rows)

######################################################################
    
def inspect_false_negatives(output_filename, num=50):
    """
    Identify false negatives and print num randomly selected instances
    to output_filename for inspection.
    """
    trees = []
    for utt in CorpusReader('swda').iter_utterances(display_progress=False):
        if utt.tree_is_perfect_match():
            if utt.act_tag == 'qy' and not is_polar_interrogative(utt.trees[0]):
                trees.append(utt.trees[0])
    shuffle(trees)
    s = ""
    for tree in trees[: num]:        
        s += '======================================================================\n'
        s += tree.pprint() + '\n'
    open('swda-clausetyping-fn-trees.txt', 'w').write(s)

######################################################################    

if __name__ == '__main__':
    """Runs the main experiment and generates the evaluation files."""
    
    print "Test the polar interrogatives function by writing 100 True and 100 False trees to swda-clausetyping-polar-int-test.txt:"        
    inspect_trees(is_polar_interrogative, 'swda-clausetyping-polar-int-test.txt', num=100)

    print "Generate a CSV file swda-clausetyping-results.csv of the evaluation results (especially useful studying the pattern of false positives):"
    inversion_and_qy_to_csv(is_polar_interrogative, 'swda-clausetyping-results.csv')
    
    print "Inspect the false negative trees:"    
    inspect_false_negatives('swda-clausetyping-fn-trees.txt', num=50)

    print "Experiment complete; let the analysis begin..."
    


