#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 13 13:10:07 2018
Modified on Mon Apr 29 13:45:23 2019

@author: Omid Sadjadi <omid.sadjadi@nist.gov>
"""

import sys
from argparse import ArgumentParser
from operator import itemgetter


def split_line(line, delimiter='\t'):
    return line.strip().split(delimiter)


def is_float(astr):
    try:
        float(astr)
        return True
    except ValueError:
        return False


def validate_me(ref_file, sysout_file):
    fid1 = open(ref_file)
    fid2 = open(sysout_file)
    invalid = False
    err_str = ''
    line_no = 0
    ref_list = split_line(fid1.readline())
    sys_list = split_line(fid2.readline())
    while True:
        segmentid, side = ref_list[:2]
        if sys_list[0] != segmentid:
            err_str += ('Line {}: Incorrect segmentid. Expected "{}", got '
                        '"{}" instead. Segments should appear in the same '
                        'order as in the key.\n'.format(line_no, segmentid,
                                                        sys_list[0]))
            invalid = True
            break
        first_line = True
        prev_s_type = 'X'
        prev_sys_en = 0
        while sys_list[0] == segmentid:
            line_no += 1
            if len(sys_list) != len(ref_list) + 1:
                err_str += ('Line {}: Incorrect number of columns/fields. '
                            'Expected {}, got {} instead. TAB (\\t) delimiter '
                            'should be used.\n'.
                            format(line_no, len(ref_list)+1, len(sys_list)))
                invalid = True
            else:
                # checking if the fields match the reference
                n_cols = 5 if line_no == 1 else 2
                if sys_list[:n_cols] != ref_list[:n_cols]:
                    err_str += ('Line {}: Incorrect field(s). Expected "{}", '
                                'got "{}" instead.\n'
                                .format(line_no, '\t'.join(ref_list[:n_cols]),
                                        '\t'.join(sys_list[:n_cols])))
                    invalid = True
                if line_no == 1:
                    # checking if "confidence" is in the header
                    if sys_list[-1] != 'confidence':
                        err_str += ('Line {}: Expected "confidence" in the '
                                    'header, got "{}" instead.\n'
                                    .format(line_no, sys_list[-1]))
                        invalid = True
                else:
                    # checking if the time marks and scores are floats
                    float_cols = itemgetter(2, 3, 5)(sys_list)
                    if not all(map(is_float, float_cols)):
                        err_str += ('Line {}: Expected float in "start", "end"'
                                    ', and "confidence" columns, got "{}" '
                                    'instead.\n'.format(line_no, float_cols))
                        invalid = True
                    else:
                        sys_st, sys_en = float(sys_list[2]), float(sys_list[3])
                        # checking if the time mark of the first line starts at
                        # zero "0"
                        if first_line and sys_st != 0:
                            err_str += ('Line {}: Expected 0 in "start" '
                                        'column for the first line of "{}", '
                                        'got "{}" instead.\n'.
                                        format(line_no, segmentid, sys_st))
                            invalid = True
                        # checking if the segment time marks are contiguous
                        if sys_st != prev_sys_en:
                            err_str += ('Line {}: Expected {} in "start" '
                                        'column for got "{}" instead.\n'.
                                        format(line_no, prev_sys_en, sys_st))
                            invalid = True
                        prev_sys_en = sys_en
                    s_ns_type = sys_list[4]
                    # checking if the segment_type is in ['S', 'NS']
                    if s_ns_type not in ['S', 'NS']:
                        err_str += ('Line {}: Expected "S" or "NS" in the'
                                    '"segment_type" column, got "{}" '
                                    'instead.\n'.format(line_no, s_ns_type))
                        invalid = True
                    else:
                        # checking if the S/NS labels are alternating
                        if prev_s_type == s_ns_type:
                            err_str += ('Line {}: Expected a different '
                                        '"segment_stype" on each line got two '
                                        '"{}" in a row instead.\n'.
                                        format(line_no, s_ns_type))
                            invalid = True
                        prev_s_type = s_ns_type
            first_line = False
            sys_list = split_line(fid2.readline())
        while ref_list[0] == segmentid:
            if line_no > 1:
                ref_en = float(ref_list[3])
            ref_list = split_line(fid1.readline())
        # checking if the last time mark in the "end" column matches that of
        # the reference/trial definition
        if not invalid and line_no > 1 and sys_en != ref_en:
            err_str += ('Line {}: Expected {} in "end" column for the last'
                        ' line of "{}", got {} instead.\n'.
                        format(line_no, ref_en, segmentid,  sys_en))
            invalid = True
        if sys_list == [''] or ref_list == ['']:
            break
    ref_list = fid1.readline()
    sys_list = fid2.readline()
    # checking if the number of lines in two files match
    if sys_list and not ref_list:
        err_str += ('The system output has more segments than the key.\n')
        invalid = True
    fid1.close()
    fid2.close()
    if err_str and invalid:
        print("\n" + err_str)
    return invalid


def main():
    parser = ArgumentParser(description='OpenSAT SAD Submission Validator.')
    parser.add_argument("-o", "--output", help="path to system output file",
                        type=str, required=True)
    parser.add_argument("-r", "--trials", help="path to the SAD trial key",
                        type=str, required=True)
    parser.add_argument("-n", "--lines", help="Number of lines to print",
                        type=int, default=20)
    args = parser.parse_args()
    system_output = args.output
    trials_list = args.trials
    max_lines = args.lines
    if validate_me(trials_list, system_output):
        sys.exit(1)
    else:
        sys.exit(0)
    

if __name__ == '__main__':
    main()
#    validate_me('opensat_sad_key.tsv', 'system_output/opensat_sad_sysout.tsv')
