#!/usr/bin/env python

"""
The Team Banana Wugs model for predicting Implicit coherence relations
in the Penn Discourse Treebank 2.

For background, see http://compprag.christopherpotts.net/.

This page implements the model for predicting Implicit coherence
relations that was developed by Team Banana Wugs in class on July 29
and then implemented by Chris.

FEATURE FUNCTIONS:

The feature function names all finish with _feature or _features. Each
one takes a datum as input and returns a dict mapping feature names
to values, which can be boolean, int, or float.

These feature functions depend on review_functions.py and
swda_experiment_clausetyping.py.

The function banana_wugs_feature_function() pools all these
dictionaries into a single function on datum instances. This is the
``theory'' of PDTB data for predicting Implicit relations.

DEMO:

The bottom of the file defines a main method that displays simplified
text versions of the training set examples along with the features
determined by banana_wugs_feature_function(). Thus, running

python pdtb_competition_team_banana_wugs.py

will show you what the functions are saying about the training
examples.  For this to work, you need to have the training file
'pdtb-competition-implicit-train-indices.pickle' in the current
directory along with this file.
"""

######################################################################

__author__ = "Christopher Potts and the Banana Wugs team"
__copyright__ = "Copyright 2011, Christopher Potts"
__credits__ = []
__license__ = "Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License: http://creativecommons.org/licenses/by-nc-sa/3.0/"
__version__ = "1.0"
__maintainer__ = "Christopher Potts"
__email__ = "See the author's website"

######################################################################

import os
import re
from math import log
from itertools import product
import nltk.tree
try:
    import pdtb
except:
    raise Exception('Please get pdtb.py from the course website and put it in the current directory.')
try:
    import review_functions # for get_all_imdb_scores()
except:
    raise Exception('Please get review_functions.py from the course website and put it in the current directory.')
try:
    import swda_experiment_clausetyping # for is_preterminal()
except:
    raise Exception('Please get swda_experiment_clausetyping.py from the course website and put it in the current directory.')
import pdtb_competition_model_to_csv

######################################################################

# Sentiment scoring for arg_sentiment_features().
review_scorer_filename = 'imdb-words-assess.csv'
if not os.path.exists(review_scorer_filename):
    raise Exception('Please get %s from the course website and put it in the current directory.' % review_scorer_filename)
REVIEW_SCORER = review_functions.get_all_imdb_scores(review_scorer_filename)

NEGATIONS = set(['never', 'not', "n't", 'neither'])

######################################################################

def negation_pairs_feature(datum):
    """    
    Banana Wugs: return negation sets (Pos, Pos), (Neg, Neg), (Pos, Neg), (Neg, Pos)
    
    CP interpretation: return the one-membered dictionary {(val1,
    val2) = True} where val1 is the return value of negated() applied
    to the tokenized words in Arg1 and val2 is the return value of
    negated() applied to the tokenized words in Arg2. Boolean-valued.

    {neg,non-neg} x {neg,non-neg} -> True    
    """    
    def negated(words):
        """Return neg if words contains a member of NEGATIONS, else non-neg."""
        words = map(str.lower, words)
        if set(words) & NEGATIONS: # Python treats emptyset as False, non-empty as True.
            return 'neg'
        else:
            return 'non-neg'
    arg1_val = negated(datum.arg1_words())
    arg2_val = negated(datum.arg2_words())
    return {(arg1_val, arg2_val): True}

def sentential_negation_feature(datum):
    """
    Banana Wugs: return T if the sentential negation is imbalanced, else F

    CP interpretation: define sentential negation as a verbal node
    with a negation daughter.
    
    Sentential_Negation_Balanced -> {True, False}
    """
    def sentential_negation(trees):
        """
        Return True if there is a tree t in trees such that t contains
        a verbal node whose daughter is the preterminal for a word in
        NEGATIONS. Return False if no such t is found.
        """
        pred_re = re.compile(r'^V', re.I)
        for tree in trees:
            for subtree in tree:
                 if isinstance(subtree, nltk.tree.Tree) and pred_re.search(subtree.node):
                    for daught in subtree:
                        if swda_experiment_clausetyping.is_preterminal(daught) and daught[0] in NEGATIONS:
                            return True
        return False
    arg1_val = sentential_negation(datum.Arg1_Trees)
    arg2_val = sentential_negation(datum.Arg2_Trees)
    return {'Sentential_Negation_Balanced': arg1_val == arg2_val}

def constituent_negation_feature(datum):
    """
    Banana Wugs: return T if the constituent negation is imbalanced, else F

    CP interpretation: define constituent negation as a nonverbal node
    with a negation daughter.
    
    Constituent_Negation_Balanced -> {True, False}
    """
    def constituent_negation(trees):
        """
        Return True if there is a tree t in trees such that t has
        contains a non-verbal node whose daughter is the preterminal
        for a word in NEGATIONS. Return False if no such t is found.
        """
        pred_re = re.compile(r'^V', re.I)
        for tree in trees:
            for subtree in tree:
                if isinstance(subtree, nltk.tree.Tree) and not pred_re.search(subtree.node):
                    for daught in subtree:
                        if swda_experiment_clausetyping.is_preterminal(daught) and daught[0] in NEGATIONS:
                            return True
        return False
    arg1_val = constituent_negation(datum.Arg1_Trees)
    arg2_val = constituent_negation(datum.Arg2_Trees)
    return {'Constituent_Negation_Balanced': arg1_val == arg2_val}

def arg_sentiment_features(datum, threshold=0.1):
    """
    Banana Wugs: return sum of sentiment value for both arguments
    
    CP interpretation: for each Arg, the value is the sum of the
    coefficients above threshold for Arg1

    Arg1_Coef_Sum -> float
    Arg2_Coef_Sum -> float
    """
    def arg_sentiment_value(arg_pos, threshold=0.1):
        """
        Sum the IMDB logistic regression coefficients for Category for the
        lemmas in arg_pos. threshold: a p-value threshold (default: 0.1).
        """
        score = 0.0
        for lem in arg_pos:
            if lem in REVIEW_SCORER and REVIEW_SCORER[lem]['P'] <= threshold:
                score += REVIEW_SCORER[lem]['Coef']
        return score
    feats = {}
    feats['Arg1_Coef_Sum'] = arg_sentiment_value(datum.arg1_pos(lemmatize=True))
    feats['Arg2_Coef_Sum'] = arg_sentiment_value(datum.arg2_pos(lemmatize=True))
    return feats

def normalized_shared_token_count_feature(datum):    
    """
    Banana Wugs: normalized number of shared tokens
    
    CP interpretation: lemmatize the words for each argument and turn
    them into a set (this collapses multiple tokens of the same word
    into one). The value is the intersection cardinality divided by
    the union cardinality.

    Normalized_Shared_Token_Count -> float
    """
    arg1_words = set(datum.arg1_words(lemmatize=True))
    arg2_words = set(datum.arg2_words(lemmatize=True))
    return {'Normalized_Shared_Token_Count': float(len(arg1_words & arg2_words)) / float(len(arg1_words | arg2_words))}

def embedded_clause_features(datum):
    """
    Banana Wugs: attitude predicate (embedded CP) returns true

    CP interpretation: for each Arg, create a feature iff it has a VP
    node with an S-BAR daughter:

    Arg1_Clausal_Complement -> True
    Arg2_Clausal_Complement -> True

    Might create just one or neither of these.    
    """
    def clausal_complement(arg_trees):
        """Returns True if at least one of the arg_trees has a V node that has a daughter S-BAR, else False."""
        for tree in arg_trees:
            for subtree in tree.subtrees():
                if subtree.node.startswith('V'):
                    for daught in subtree:
                        if isinstance(daught, nltk.tree.Tree) and daught.node.startswith('S-BAR'):
                            return True
        return False
    feats = {}
    if clausal_complement(datum.Arg1_Trees):
        feats['Arg1_Clausal_Complement'] = True
    if clausal_complement(datum.Arg2_Trees):        
        feats['Arg2_Clausal_Complement'] = True
    return feats

def tree_height_features(datum):
    """
    Banana Wugs: max tree height for both arguments

    CP interpretation: for each Arg, create a feature whose
    value is  max height for its trees.

    Arg1_Max_Tree_Height -> int
    Arg2_Max_Tree_Height -> int
    """
    def max_tree_height(trees):
        return max([tree.height() for tree in trees])
    feats = {}
    feats['Arg1_Max_Tree_Height'] = max_tree_height(datum.Arg1_Trees)
    feats['Arg2_Max_Tree_Height'] = max_tree_height(datum.Arg2_Trees)
    return feats
    
def tree_height_ratio_feature(datum):
    """
    Banana Wugs: ratio of argument tree heights

    CP interpretation: create a single feature with value
    max_tree_height(datum.Arg1_Trees) / max_tree_height(datum.Arg2_Trees)

    Tree_Height_Ratio -> float
    """
    def max_tree_height(trees):
        return max([tree.height() for tree in trees])
    r = float(max_tree_height(datum.Arg1_Trees)) / float(max_tree_height(datum.Arg2_Trees))
    return {'Tree_Height_Ratio':r}

def s_node_count(trees):
    """Count the number of 'S'-initial nodes. This is not a feature function, but rather
    a helper for clause_count_features() and clause_count_ratio_feature()."""
    c = 0
    for tree in trees:
        for subtree in tree.subtrees():
            if subtree.node.lower().startswith('s'):
                c += 1
    return c

def clause_count_features(datum):
    """
    Banana Wugs: clause counts for each argument

    CP interpretation: same as above, with clause-counts defined by s_node_count():

    Arg1_Clause_Count -> int
    Arg2_Clause_Count -> int
    """
    feats = {}
    feats['Arg1_Clause_Count'] = s_node_count(datum.Arg1_Trees)
    feats['Arg2_Clause_Count'] = s_node_count(datum.Arg2_Trees)
    return feats

def clause_count_ratio_feature(datum):
    """
    Banana Wugs: an arg with just a matrix clause has a value of 1,
    and arg with an embedded/relative clause in addition to the matrix
    clause has a value of 2, etc.  Output the ratio of these counts
    for the two clauses.

    CP interpretation: generate a single feature that is the
    S-initial-node count of Arg1 divided by the S-initial-node count
    of Arg2.

    Clause_Count_Ratio -> float
    """    
    arg1_val = s_node_count(datum.Arg1_Trees)
    arg2_val = s_node_count(datum.Arg2_Trees)
    if arg2_val:
        r = float(arg1_val) / float(arg2_val)
    else:
        r = 0    
    return {'Clause_Count_Ratio': r}
    
def pronominal_subject_balance_feature(datum):
    """
    Banana Wugs: return true is the subject of an arg is a pronoun and
    false if it isn't.  The output would be { (T/T) , (T/F), (F/T),
    (F/F) }.

    CP interpretation: for each argument, determine whether it has a
    tree with an NP-SBJ daughter that is proniminal.

    {'pro-subj', 'no-pro-subj'} x {'pro-subj', 'no-pro-subj'} -> True
    """
    def has_pronominal_matrix_subject(trees):
        pro_re = re.compile(r"""
	                     \b(
                             I|Me|Myself|My|Mine|
                             We|Us|Ourselves|Our|Ours|
                             You|Yourself|Yourselves|Your|Yours|
                             He|Him|Himself|His|
                             She|Her|Herself|Hers|
                             It|Itself|Its|
                             They|Them|Themselves|Their|Theirs
                             )\b
                             """, 
                            re.VERBOSE | re.I)
        for tree in trees:
            for daught in tree:
                if isinstance(daught, nltk.tree.Tree) and daught.node == 'NP-SBJ':
                    leaf_str = " ".join(daught.leaves())
                    if pro_re.search(leaf_str):
                        return 'pro-subj'
        return 'no-pro-subj'
    arg1_val = has_pronominal_matrix_subject(datum.Arg1_Trees)
    arg2_val = has_pronominal_matrix_subject(datum.Arg2_Trees)
    return {(arg1_val, arg2_val): True}

def arg2_it_seems_feature(datum):
    """
    Banana Wugs: return false if first bigram of second linear argument is not ``it seems''

    CP interpretation: as above

    Arg2_Startswith_It_Seems -> False
    """
    feats = {}
    if not datum.Arg1_RawText.lower().strip().startswith('it seems'):
        feats['Arg2_Startswith_It_Seems'] = False
    return feats

def tense_matching_feature(datum):
    """
    Banana Wugs: How about tense mismatches between Arg1 and Arg2?
    E.g. Mismatch between simple past in one arg and past perfect in
    another arg.  (e.g. 'Before John was a lawyer, he had been a
    butcher.')  Generalization: If one argument is past perfect and
    another is simple past, it's probably temporal.  Return T if
    feature types match; return F if they don't.

    CP interpretation: Since Args are not necessarily single connected
    structures, but rather might be sets of trees, I take an
    approximate view, by comparing the V-like nodes in the structures
    for matches and mismatches. The single feature created is an
    integer-valued one in which disagreement contributes -1 and
    agreement contributes +1.

    Arg_Tense_Agreement -> int (negative or positive)
    """
    def get_v_tags(arg_pos):
        tags = [lem[1] for lem in arg_pos]
        return filter((lambda x : x.lower().startswith('v')), tags)
    arg1_tags = get_v_tags(datum.arg1_pos())
    arg2_tags = get_v_tags(datum.arg2_pos())
    feats = {}
    feats['Arg_Tense_Agreement'] = 0
    for t1, t2 in product(arg1_tags, arg2_tags):
        if t1 == t2:
            feats['Arg_Tense_Agreement'] += 1
        else:
            feats['Arg_Tense_Agreement'] -= 1
    return feats

def modal_balance_feature(datum):
    """
    Banana Wugs: Including modals, may be good indicator of
    contingency Return T T if both args have modal verbs, return T F
    if only arg 1 has a modal, return F T if only arg 2 has a modal,
    return F F if neither arg has a modal.

    CP interpretation: return the one-membered dictionary {(val1,
    val2) = True} where val1 is the return value of
    modalized(datum.Arg1_RawText) and val2 is the return value of
    modalized(datum.Arg2_RawText).

    {modal, non-modal} x {modal, non-modal} -> True
    """
    def modalized(arg_pos):
        for pos in arg_pos:
            if pos[1].lower() == 'md':
                return 'modal'
        return 'non-modal'
    arg1_val = modalized(datum.arg1_pos())
    arg2_val = modalized(datum.arg2_pos())
    return {(arg1_val, arg2_val): True}

def arg_length_ratio_feature(datum, log_transform=True, absolute_transform=True):
    """
    Banana Wugs: Ratio of length arg1 : arg2 (or log of the ratio - or
    both) Nonmonotonic? Could it be that when the two arguments are
    more balanced, they're more likely to be something? We could
    include absolute value of log of the ratio as well.

    CP interpretation: calculate the ratio of the lengths in words,
    and default to taking the absolute value of the log. The keyword
    arguments allow us to vary the behavior.

    Arg_Length_Ratio -> float
    """
    r = 0.0
    arg1_length = float(len(datum.arg1_words()))
    arg2_length = float(len(datum.arg2_words()))
    if arg2_length:    
        r = arg1_length / arg2_length
    if log_transform:
        if r < 0:
            r = -log(abs(r))
        else:
            r = log(r)
    if absolute_transform:
        r = abs(r)
    return {'Arg_Length_Ratio': r}

######################################################################

def banana_wugs_feature_function(datum):
    """
    Pool all of the above feature functions into a single
    feature-function dictionary.
    """
    feat_functions = [
        negation_pairs_feature,
        sentential_negation_feature,
        constituent_negation_feature,
        arg_sentiment_features,
        normalized_shared_token_count_feature,    
        embedded_clause_features,
        tree_height_features,
        tree_height_ratio_feature,
        clause_count_features,
        clause_count_ratio_feature,
        pronominal_subject_balance_feature,
        arg2_it_seems_feature,
        tense_matching_feature,
        modal_balance_feature,
        arg_length_ratio_feature
        ]
    feats = {}
    for feat_func in feat_functions:
        feats = dict(feats.items() + feat_func(datum).items())
    return feats

######################################################################

def model_to_csv():
    """
    Create a CSV file where the columns correspond to features.  We
    restrict attention to the feature functions with relatively few
    output values so that the number of columns stays manageable.    
    """
    output_filename = 'pdtb-competition-banana-wugs-model.csv'
    feat_functions = [
        negation_pairs_feature,
        sentential_negation_feature,
        constituent_negation_feature,
        arg_sentiment_features,
        normalized_shared_token_count_feature,    
        embedded_clause_features,
        tree_height_features,
        tree_height_ratio_feature,
        clause_count_features,
        clause_count_ratio_feature,
        pronominal_subject_balance_feature,
        arg2_it_seems_feature,
        tense_matching_feature,
        modal_balance_feature,
        arg_length_ratio_feature
        ]
    pdtb_competition_model_to_csv.model_to_csv(feat_functions, output_filename)
    
######################################################################
    
if __name__ == '__main__':
    """
    Cycle through the training data and print representations of the
    examples before and after banana_wugs_feature_function().
    """    
    implicit_train_picklename = 'pdtb-competition-implicit-train-indices.pickle'
    train_set_indices = pickle.load(file(implicit_train_picklename))
    corpus = pdtb.CorpusReader('pdtb2.csv')
    for i, datum in enumerate(corpus.iter_data(display_progress=False)):
        if datum.Relation == 'Implicit' and i in train_set_indices:
            print "======================================================================"
            print "Arg1:", datum.Arg1_RawText
            print "Arg2:", datum.Arg2_RawText
            print "Semantics:", datum.primary_semclass1()
            print "Features:"            
            feats = banana_wugs_feature_function(datum)
            for name, val in feats.items():
                print "%s: %s" % (str(name).rjust(35), str(val).ljust(18))

