#!/usr/bin/env python

"""
The Team Banana Slugs model for predicting Implicit coherence
relations in the Penn Discourse Treebank 2.

For background, see http://compprag.christopherpotts.net/.

This page implements the model for predicting Implicit coherence
relations that was developed by Team Banana Slugs in class on July 29
and then implemented by Chris.

FEATURE FUNCTIONS:

The feature function names all finish with _feature or _features. Each
one takes a datum as input and returns a dict mapping feature names
to values, which can be boolean, int, or float.

The feature function genre_feature() depends on having the set of
genre files derived from Bonnie Webber's genre lists for the Penn
Treebank: http://www.let.rug.nl/~bplank/metadata/genre_files_updated.html

The function banana_slugs_feature_function() pools all these
dictionaries into a single function on datum instances. This is the
``theory'' of PDTB data for predicting Implicit relations.

DEMO:

The bottom of the file defines a main method that displays simplified
text versions of the training set examples along with the features
determined by banana_wugs_feature_function(). Thus, running

python pdtb_competition_team_banana_slugs.py

will show you what the functions are saying about the training
examples.  For this to work, you need to have the training file
'pdtb-competition-implicit-train-indices.pickle' in the current
directory along with this file.
"""

######################################################################

__author__ = "Christopher Potts and the Banana Slugs team"
__copyright__ = "Copyright 2011, Christopher Potts"
__credits__ = []
__license__ = "Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License: http://creativecommons.org/licenses/by-nc-sa/3.0/"
__version__ = "1.0"
__maintainer__ = "Christopher Potts"
__email__ = "See the author's website"

######################################################################

import os
import re
import glob
import pickle
from itertools import product
import nltk.tree
from nltk.corpus import wordnet as wn
try:
    import pdtb
except:
    raise Exception('Get pdtb.py from the course website and put it in the current directory.')
import pdtb_competition_model_to_csv

######################################################################

NEGATIONS = set(['never', 'not', "n't", 'neither'])

# This should be the set of files containing the genre lists, with the
# genre name equivalent to the filename (minus .txt extension):
GENRE_FILENAMES = glob.glob('ptb-genres/*.txt')
if len(GENRE_FILENAMES) != 5:
    raise Exception('There should be give genre files in ptb-genres/. %s files found.' % len(GENRE_FILENAMES))

# Build the genre dictionary, mapping genre-names to lists of filenames:
GENRES = {}
for filename in GENRE_FILENAMES:
    genre_name = os.path.basename(filename).replace(".txt", "")
    GENRES[genre_name] = open(filename).read().splitlines()

######################################################################

def negated(words):
    """
    Return neg if words contains a member of NEGATIONS, else
    non-neg. This is not a feature function, but rather a helper
    function for arg_negation_features() and
    arg_negation_counts_features().
    """
    words = map(str.lower, words)
    if set(words) & NEGATIONS:
        return 'neg'
    else:
        return 'non-neg'

def arg_negation_features(datum):
    """
    Banana Slugs: ((negatives_regex.search(Arg1RawText) is not None), (negatives_regex.search(Arg2RawText) is not None))

    CP interpretation: create a boolean-valued feature for each
    argument representing whether it is negated or not:

    Arg1_Negated -> True
    Arg2_Negated -> True
    """
    feats = {}
    if negated(datum.arg1_words()) == 'neg':
        feats['Arg1_Negated'] = True
    if negated(datum.arg2_words()) == 'neg':
        feats['Arg2_Negated'] = True
    return feats

def arg_negation_counts_features(datum):
    """
    Banana Slugs: (len(negatives_regex.findall(Arg1RawText)), len(negatives_regex.findall(Arg2RawText)))

    CP interpretation: create int-valued features for each argument,
    giving the number of negations in that argument:

    Arg1_Negation_Count -> int
    Arg2_Negation_Count -> int
    """
    feats = {}
    feats['Arg1_Negation_Count'] = len(set(datum.arg1_words()) & NEGATIONS)
    feats['Arg2_Negation_Count'] = len(set(datum.arg2_words()) & NEGATIONS)
    return feats
    
def negation_pairs_feature(datum):
    """
    Banana Slugs: (negatives_regex.search(Arg1RawText) is not None) == (negatives_regex.search(Arg2RawText) is not None)
    
    CP interpretation: Returns the one-membered dictionary {(val1,
    val2) = True} where val1 is the return value of negated() applied
    to the tokenized words in Arg1 and val2 is the return value of
    negated() applied to the tokenized words in Arg2.

    {neg,non-neg} x {neg,non-neg} -> True
    """     
    arg1_val = negated(datum.arg1_words())
    arg2_val = negated(datum.arg2_words())
    return {(arg1_val, arg2_val): True}

def ngram_features(datum, n=1):
    """
    Banana Slugs: Create a presence vector for each argument, where
    the values are based on token N-grams in the two vectors combined
    such that if N=1 (i.e., unigrams), each single token that is
    present in an argument is assigned a value of one and zero
    otherwise. If N=2 (i.e., bigrams), each two consecutive tokens are
    assigned a value of one if present, but zero otherwise, etc. N can
    be anything we want it, but I have seen people usually use up to
    N=3 (or, less frequently, 4).

    CP interpretation: create a feature for each n-gram at the desired
    level, relativized to the Arg, keeping track only of presence, not
    counts.

    Arg1_Word_N -> True for ngrams W
    Arg2_Word_N -> True for ngrams W    
    """
    def get_ngram_features(words, argname, n=1):
        """Return the set of n-grams in words."""
        ngrams = {}
        max_index = len(words)-(n-1)        
        for i in xrange(max_index):
            ng = argname + " ".join(words[i: i+n])            
            ngrams[ng] = True
        return ngrams        
    arg1_feats = get_ngram_features(datum.arg1_words(lemmatize=True), 'Arg1_Word_', n=n)
    arg2_feats = get_ngram_features(datum.arg1_words(lemmatize=True), 'Arg2_Word_', n=n)
    return dict(arg1_feats.items() + arg2_feats.items())

def main_verb(trees):
    """Try to find the main verb for trees."""
    for tree in trees:
        for daught in tree:
            if isinstance(daught, nltk.tree.Tree) and daught.node.lower().startswith('v'):
                lems = daught.pos()
                lems = filter((lambda x : x[1].lower().startswith('v')), lems)
                if lems:
                    return lems[0][0].lower()
    return None

def main_verb_match_feature(datum):
    """
    Banana Slugs: Arg1Tree.MainVerb.penn_pos == Arg2Tree.MainVerb.penn_pos

    CP interpretation: as above, with the heuristic method in
    main_verb used to try to identify the main verb. Return a feature
    named Main_Verb_Match, but only if a main-verb could be found in
    both arguments:

    Main_Verb_Match -> {True, False}

    Might generate no feature.
    """
    feats = {}
    arg1_val = main_verb(datum.Arg1_Trees)
    if arg1_val != None:
        arg2_val = main_verb(datum.Arg2_Trees)
        if arg2_val != None:
            feats['Main_Verb_Match'] = arg1_val == arg2_val
    return feats

def main_verb_features(datum):
    """
    Banana Slugs: (Arg1Tree.MainVerb.penn_pos, Arg2Tree.MainVerb.penn_pos)

    CP interpretation: for each Arg, generate a word-level feature 'Arg_Main_Verb_V' where
    V is the verb itself (stemmed).

    Arg1_Main_Verb_V -> True for main verb V if one can be found
    Arg2_Main_Verb_V -> True for main verb V if one can be found

    Can generate just one or neither feature, depending on whether a main verb is found.
    """
    arg1_verb = main_verb(datum.Arg1_Trees)
    arg2_verb = main_verb(datum.Arg2_Trees)
    feats = {}
    if arg1_verb != None:
        feats['Arg1_Main_Verb_%s' % arg1_verb] = True
    if arg2_verb != None:
        feats['Arg2_Main_Verb_%s' % arg2_verb] = True
    return feats

def arg_length_ratio_feature(datum):
    """
    Banana Slugs: len(Arg1RawText.tokens) / len(Arg2RawText.tokens) # hopefully picks out expansions

    CP interpretation: calculate the ratio of the lengths in words.

    Arg_Length_Ratio -> float
    """
    arg1_length = float(len(datum.arg1_words()))
    arg2_length = float(len(datum.arg2_words()))
    r = 0.0
    if arg2_length:    
        r = arg1_length / arg2_length
    return {'Arg_Length_Ratio': r}

def antonym_feature(datum):
    """
    Banana Slugs: len([wn.synset(word_a).is_antonym(wn.synset(word_b)) for a in arg_1 for b in arg_2]) # hopefully picks out comparisons

    CP interpretation: as above, but adjusting for the fact that
    antonyms is defined only for lemmas and a given word can
    correspond to multiple different lemmas.

    Antonym_Count -> int
    """
    def get_lemmas(word):
        """Get the lemmas consistent with word == (string, tag) tuple."""
        if word[1] in ('a', 'n', 'v', 'r'):
            return wn.lemmas(word[0], word[1])
        else:
            return wn.lemmas(word[0])
    arg1_words = datum.arg1_pos(lemmatize=True)
    arg2_words = datum.arg2_pos(lemmatize=True)
    antonym_count = 0
    for w1 in arg1_words:        
        for lem1 in get_lemmas(w1):
            for w2 in arg2_words:
                for lem2 in get_lemmas(w2):
                    if lem2 in lem1.antonyms():
                        antonym_count += 1
    return {'Antonym_Count': antonym_count}

def pdtb_genre(datum):
    """
    Determine the genre for a datum object using Webber's lists. Returns the genre as string or None.
    """
    datum_filename = "wsj_%s%s" % (datum.Section, datum.FileNumber)
    for genre_name, filenames in GENRES.items():
        if datum_filename in filenames:
            return genre_name
    return None

def genre_feature(datum):
    """
    Banana Slugs: original genre as string from Penn Discourse Treebank original labels

    CP interpretation: as above.

    Genre-values -> True
    """
    genre = None
    datum_filename = "wsj_%s%s" % (datum.Section, datum.FileNumber)
    for genre_name, filenames in GENRES.items():
        if datum_filename in filenames:
            genre = genre_name
    return {'Genre_%s' % genre: True}

def modal_features(datum):
    """
    Banana Slugs:
    (len(filter(lambda tree: tree.contains_penn_pos('MD'), arg_1.Arg1_Trees)) > 0,
    len(filter(lambda tree: tree.contains_penn_pos('MD'), arg_1.Arg2_Trees)) > 0)

    CP interpretation: generate an int-valued feature for Arg where int is the
    number of modals in Arg (limiting to > 0).

    Arg1_Modal_Count -> int
    Arg2_Modal_Count -> int
    """
    def modal_count(arg_pos):
        c = 0
        for pos in arg_pos:
            if pos[1].lower() == 'md':
                c += 1
        return c
    feats = {}
    feats['Arg1_Modal_Count'] = modal_count(datum.arg1_pos())
    feats['Arg2_Modal_Count'] = modal_count(datum.arg2_pos())
    return feats

def wordnet_hypernym_pair_features(datum):
    """
    Banana Slugs:
    for a in arg_1.tokens:
      for b in arg_2.tokens:
          yield (wn.synset(a).hypernyms(), wn.synset(b).hypernyms())
          # sorry about the computational intensity on this one.

    CP interpretation: a cross-product feauture for hypernyms;
    generate a feature for every pair of hyperyms for every word in
    the two args.

    Synsets x Synsets -> True

    NOTE: Turns out to generate too many features; NLTK/Python ends up
    with memory allocation problems when trying to instantiate the
    feature encoding. See wordnet_hypernym_relations() for an
    approximate alternative.
    """
    def get_arg_hypernyms(words):
        hyp = []
        for word in words:
            synsets = []
            if word[1] in ('a', 'n', 'v', 'r'):
                synsets = wn.synsets(word[0], word[1])
            else:
                synsets = wn.synsets(word[0])            
            for synset in synsets:
                hyp += synset.hypernyms()
        return set(hyp)    
    arg1_hypernyms = get_arg_hypernyms(datum.arg1_pos(lemmatize=True))
    arg2_hypernyms = get_arg_hypernyms(datum.arg2_pos(lemmatize=True))    
    for hyp1, hyp2 in product(arg1_hypernyms, arg2_hypernyms):        
        feats[(hyp1.name, hyp2.name)] = True
    return feats

def wordnet_hypernym_count_features(datum):
    """
    Replacement for wordnet_hypernym_pair_features(): counts the
    hypernym relations going from Arg1 to Arg2 and from Arg2 to Arg1.
    
    Arg2_Hypernym_Of_Arg1_Count -> int
    Arg1_Hypernym_Of_Arg2_Count -> int
    """
    def get_arg_synsets(words):
        syns = []
        for word in words:
            if word[1] in ('a', 'n', 'v', 'r'):
                syns += wn.synsets(word[0], word[1])
            else:
                syns += wn.synsets(word[0])
        return syns
    arg1_synsets = get_arg_synsets(datum.arg1_pos(lemmatize=True))
    arg2_synsets = get_arg_synsets(datum.arg2_pos(lemmatize=True))
    feats = {'Arg2_Hypernym_Of_Arg1_Count': 0, 'Arg1_Hypernym_Of_Arg2_Count':0}
    for syn1, syn2 in product(arg1_synsets, arg2_synsets):
        if syn2 in syn1.hypernyms():
            feats['Arg2_Hypernym_Of_Arg1_Count'] += 1
        if syn1 in syn2.hypernyms():
            feats['Arg1_Hypernym_Of_Arg2_Count'] += 1
    return feats
    
def calendar_features(datum):
    """
    Banana Slugs:
    len(filter(lambda token: token in ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august',
    'september', 'october', 'november', 'december', 'next', 'previous', 'last',
    'week', 'month', 'future', 'late'], arg_1.lemmatized_tokens() + arg_2.lemmatized_tokens())) # hopefully picks out temporal

    CP interpretation: as above

    Calendar_Words_Count -> int
    """
    cal_words = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august',
                 'september', 'october', 'november', 'december', 'next', 'previous', 'last',
                 'week', 'month', 'future', 'late']
    cal_count = len(filter((lambda w : w in cal_words), datum.arg1_words() + datum.arg2_words()))
    return {'Calendar_Words_Count': cal_count}
    
######################################################################

def banana_slugs_feature_function(datum):
    """
    Pool all of the above feature functions into a single
    feature-function dictionary.
    """
    feat_functions = [
        arg_negation_features,
        arg_negation_counts_features,
        negation_pairs_feature,
        ngram_features,
        main_verb_match_feature,
        main_verb_features,
        arg_length_ratio_feature,
        antonym_feature,
        genre_feature,
        modal_features,
        #wordnet_hypernym_pair_features,
        wordnet_hypernym_count_features,
        calendar_features
        ]
    feats = {}
    for feat_func in feat_functions:
        feats = dict(feats.items() + feat_func(datum).items())
    return feats

######################################################################

def model_to_csv():
    """
    Create a CSV file where the columns correspond to features.  We
    restrict attention to the feature functions with relatively few
    output values so that the number of columns stays manageable.    
    """
    output_filename = 'pdtb-competition-banana-slugs-model.csv'
    feat_functions = [
        arg_negation_features,
        arg_negation_counts_features,
        negation_pairs_feature,
        #ngram_features,
        main_verb_match_feature,
        #main_verb_features,
        arg_length_ratio_feature,
        antonym_feature,
        genre_feature,
        modal_features,
        #wordnet_hypernym_pair_features,
        wordnet_hypernym_count_features,
        calendar_features
        ]
    pdtb_competition_model_to_csv.model_to_csv(feat_functions, output_filename)
    
######################################################################
    
if __name__ == '__main__':
    """
    Cycle through the training data and print representations of the
    examples before and after banana_slugs_feature_function().
    """    
    implicit_train_picklename = 'pdtb-competition-implicit-train-indices.pickle'
    train_set_indices = pickle.load(file(implicit_train_picklename))
    corpus = pdtb.CorpusReader('pdtb2.csv')
    for i, datum in enumerate(corpus.iter_data(display_progress=False)):
        if datum.Relation == 'Implicit' and i in train_set_indices:
            print "======================================================================"
            print "Arg1:", datum.Arg1_RawText
            print "Arg2:", datum.Arg2_RawText
            print "Semantics:", datum.primary_semclass1()
            print "Features:"            
            feats = banana_slugs_feature_function(datum)
            for name, val in feats.items():
                print "%s: %s" % (str(name).rjust(60), str(val).ljust(18))
