#!/usr/bin/env python

"""
Functions for exploring WordNet and helping with various kinds of
data extraction.
"""

__author__ = "Christopher Potts"
__copyright__ = "Copyright 2011, Christopher Potts"
__credits__ = []
__license__ = "Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License: http://creativecommons.org/licenses/by-nc-sa/3.0/"
__version__ = "1.0"
__maintainer__ = "Christopher Potts"
__email__ = "See the author's website"
		
######################################################################

import sys
import csv
import re
from collections import defaultdict
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer
import review_functions

######################################################################

def wn_pos_dist():
    """Count the Synsets in each WordNet POS category."""
    # One-dimensional count dict with 0 as the default value:
    cats = defaultdict(int)
    for synset in wn.all_synsets():
        cats[synset.pos] += 1
    # Print the results to the screen:
    for tag, count in cats.items():
         print tag, count
    # Total number (sum of the above):
    print 'Total', sum(cats.values())

######################################################################

def synset_method_values(synset):
    """
    For a given synset, get all the (method_name, value) pairs
    for that synset. Returns the list of such pairs.
    """
    name_value_pairs = []
    # All the available synset methods:
    method_names = ['hypernyms', 'instance_hypernyms', 'hyponyms', 'instance_hyponyms', 
                    'member_holonyms', 'substance_holonyms', 'part_holonyms', 
                    'member_meronyms', 'substance_meronyms', 'part_meronyms', 
                    'attributes', 'entailments', 'causes', 'also_sees', 'verb_groups',
                    'similar_tos']
    for method_name in method_names:
        # Get the method's value for this synset based on its string name.
        method = getattr(synset, method_name)
        vals = method()
        name_value_pairs.append((method_name, vals))
    return name_value_pairs

def synset_methods():
    """
    Iterates through all of the synsets in WordNet.  For each,
    iterate through all the Synset methods, creating a mapping
 
    method_name --> pos --> count
 
    where pos is a WordNet pos and count is the number of Synsets that 
    have non-empty values for method_name.
    """
    # Two-dimensional count dict with 0 as the default value final value:
    d = defaultdict(lambda : defaultdict(int))
    # Iterate through all the synsets using wn.all_synsets():
    for synset in wn.all_synsets():
        for method_name, vals in synset_method_values(synset):
            if vals: # If vals is nonempty:
                d[method_name][synset.pos] += 1
    return d

######################################################################

def lemma_method_values(lemma):
    """
    For a given lemma, get all the (method_name, value) pairs
    for that lemma. Returns the list of such pairs.
    """
    name_value_pairs = []
    # All the available synset methods:
    method_names = [# These are sometimes non-empty for Lemmas:
                    'antonyms', 'derivationally_related_forms', 
                    'also_sees', 'verb_groups', 'pertainyms',
                    # These were undefined for Lemmas in earlier versions of NLTK but are now defined:
                    'topic_domains', 'region_domains', 'usage_domains',
		    # These are always empty for Lemmas:
                    'hypernyms', 'instance_hypernyms', 
		    'hyponyms', 'instance_hyponyms',
                    'member_holonyms', 'substance_holonyms', 
                    'part_holonyms', 'member_meronyms',
                    'substance_meronyms', 'part_meronyms',                     
                    'attributes', 'derivationally_related_forms',
                    'entailments', 'causes', 'similar_tos', 'pertainyms']
    for method_name in method_names:
        # Check to make sure the method is defined:
        if hasattr(lemma, method_name):
            method = getattr(lemma, method_name)
            # Get the values from running that method:
            vals = method()
            name_value_pairs.append((method_name, vals))
    return name_value_pairs

def lemma_methods():
    """
    Iterates through all of the lemmas in WordNet.  For each, it
    iterates through all the Lemma methods, creating a mapping
    method_name --> pos --> count
    where pos is a WordNet pos and count is the number of Lemmas that
    have non-empty values for method_name.
    """    
    # Two-dimensional count dict with 0 as the default final value:
    d = defaultdict(lambda : defaultdict(int))
    for synset in wn.all_synsets():
        for lemma in synset.lemmas:
            for method_name, vals in lemma_method_values(lemma):
                if vals: # If vals is nonempty:
                    d[method_name][synset.pos] += 1
    return d

######################################################################

def wordnet_relations(word1, word2):
    """
    Uses the lemmas and synsets associated with word1 and word2 to
    gather all relationships between these two words. There is
    imprecision in this, since we range over all the lemmas and
    synsets consistent with each (string, pos) pair, but it seems
    to work well in practice.
    
    Arguments:
    word1, word2 (str, str) -- (string, pos) pairs
    
    Value:
    rels (set of str) -- the set of all WordNet relations that hold between word1 and word2
    """
    # This function ensures that we have a well-formed WordNet pos (or None for that value):
    s1, t1 = wordnet_sanitize(word1)
    s2, t2 = wordnet_sanitize(word2)
    # Output set of strings:
    rels = set([])       
    for lem1 in wn.lemmas(s1, t1):
        lemma_methodname_value_pairs = lemma_method_values(lem1)
        synset_methodname_value_pairs = synset_method_values(lem1.synset)
        for lem2 in wn.lemmas(s2, t2):
            # Lemma relations:
            for rel, rel_lemmas in lemma_methodname_value_pairs:
                if lem2 in rel_lemmas:
                    rels.add(rel)
            # Synset relations:
            for rel, rel_synsets in synset_methodname_value_pairs:
                if lem2.synset in rel_synsets:
                    rels.add(rel)
    return rels

def sentiwordnet_and_reviews_to_csv():
    """
    Prints the SentiWordNet and review data side-by-side in a CSV file
    for comparison. The output file created is called
    'sentiwordnet_and_reviews.csv'.
    """
    from sentiwordnet import SentiWordNetCorpusReader
    csvwriter = csv.writer(file('sentiwordnet_and_reviews.csv', 'w'))
    csvwriter.writerow(['Word', 'Tag', 'Coef', 'SentiPos', 'SentiNeg'])
    swn = SentiWordNetCorpusReader('SentiWordNet_3.0.0_20100705.txt')
    reviews = review_functions.get_all_imdb_scores("imdb-lemmas-assess.csv")
    items = {}
    for word, tag in reviews.iterkeys():
        vals = reviews[(word, tag)]
        p = vals['P']
        coef = vals['Coef']
        if p <= 0.01:
            for senti_synset in swn.senti_synsets(word, tag):
                items[(word, tag)] = [word, tag, coef, senti_synset.pos_score, senti_synset.neg_score]
    csvwriter.writerows(sorted(items.values()))
    
#sentiwordnet_and_reviews_to_csv()

######################################################################

def wn_string_lemmatizer(s):
    """
    WordNet lemmatizer for strings.
    
    Argument:    
    s -- a string of word/tag pairs, separated by spaces. If a word is
    missing a tag, or if its tag is not one of the WordNet pos values
    (a, n, r, v), then its tag is ignored. (It seems that the
    lemmatizer does much less in such cases.)
    
    Output:    
    lemmatized (list) -- the lemmatized strings (no tags)
    """
    # Instantiate the lemmatizer:
    wnl = WordNetLemmatizer()
    # Split on whitespace to create a list of word/tag strings:
    lemma_strs = re.split(r'\s+', s)
    # The output list:
    lemmatized = []
    # Now loop through the string_tag string pairs trying to lemmatize them:
    for sl in lemma_strs:
        word = ''
        tag = None        
        try: # If there is no slash divider,
            word, tag = re.split(r'/', sl)
            tag = tag.lower()
        except: # treat the whole unit as a word.
            word = sl
        # Make sure the tag is a WordNet-kosher:
        if tag  in ('a', 'n', 's', 'r', 'v'):
            lemmatized.append(wnl.lemmatize(word, tag))
        else:
            lemmatized.append(wnl.lemmatize(word))
    return lemmatized

def wordnet_sanitize(word):
    """
    Ensure that word is a (string, pos) pair that WordNet can
    understand.

    Argument: word (str, str) -- a (string, pos) pair

    Value: A possibly modified (string, pos) pair, where pos=None if
    the input pos is outside of WordNet.
    """
    string, tag = word
    string = string.lower()
    tag = tag.lower()
    if tag.startswith('v'):    tag = 'v'
    elif tag.startswith('n'):  tag = 'n'
    elif tag.startswith('j'):  tag = 'a'
    elif tag.startswith('rb'): tag = 'r'
    if tag in ('a', 'n', 'r', 'v'):
        return (string, tag)
    else:
        return (string, None)
