#!/usr/bin/env python

"""
Balanced train/test spits for the PDTB2 competition described here:
http://compprag.christopherpotts.net/pdtb-competition.html
"""

__author__ = "Christopher Potts and the Banana Wugs team"
__copyright__ = "Copyright 2011, Christopher Potts"
__credits__ = []
__license__ = "Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License: http://creativecommons.org/licenses/by-nc-sa/3.0/"
__version__ = "1.0"
__maintainer__ = "Christopher Potts"
__email__ = "See the author's website"

######################################################################

import re
import pickle
import random
from collections import defaultdict
import pdtb

######################################################################

def balanced_train_test_split(train_picklename, test_picklename, train_category_size=600, test_category_size=200, relations=('Implicit',)):
    """
    Creates a train/test split of Implicit examples that is balanced
    wrt to the four primary semantic class values.  There are only 826
    Temporal examples in the Implicit subset, which is why the
    defaults are set at 600 train and 200 test. The keyword argument
    relations can be used to go beyond the smallish set of Implicit
    examples (most sensibly, to include Explicit too).
    """
    train = []
    test = []
    d = defaultdict(list)
    for i, datum in enumerate(pdtb.CorpusReader('pdtb2.csv').iter_data()):
        if datum.Relation in relations:
            d[datum.primary_semclass1()].append(i)
    # Create split:
    for data in d.itervalues():
        random.shuffle(data)
        train += data[:train_category_size]
        test += data[train_category_size: train_category_size+test_category_size]
    # Store the results as sets (which have slightly faster look-up times that lists):
    pickle.dump(set(train), file(train_picklename, 'w'))
    pickle.dump(set(test), file(test_picklename, 'w'))

######################################################################

def implicit_split():
    """Creates the purely Implicit set."""
    implicit_train_picklename = 'pdtb-competition-implicit-train-indices.pickle'
    implicit_test_picklename = 'pdtb-competition-implicit-test-indices.pickle'
    balanced_train_test_split(implicit_train_picklename, implicit_test_picklename)

def mixed_split():
    """Creates the mixed Implicit/Explicit set."""
    mixed_train_picklename = 'pdtb-competition-mixed-train-indices.pickle'
    mixed_test_picklename = 'pdtb-competition-mixed-test-indices.pickle'
    balanced_train_test_split(mixed_train_picklename, mixed_test_picklename)

