#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 29 10:24:58 2018

@author: Omid Sadjadi <omid.sadjadi@nist.gov>
"""

import sys
import argparse
import numpy as np
from opensat_sad_validator import validate_me


def read_tsv_file(tsv_file, delimiter='\t', encoding='ascii'):
    return np.genfromtxt(tsv_file, names=True, delimiter=delimiter,
                         dtype=None, encoding=encoding)


def get_binary_labels(start_times, end_times, segment_types, frame_rate=100.,
                      collar=0):
    start_times = (start_times * frame_rate).astype(int)
    end_times = (end_times * frame_rate).astype(int)
    collar = int(collar * frame_rate)
    sad_len = int(end_times[-1])
    sad = np.zeros(sad_len)
    if collar > 0:
        for st, en, seg in zip(start_times, end_times, segment_types):
            if seg == 'S':
                sad[max(st-collar, 0):st] = -1
                sad[en:min(en+collar, sad_len)] = -1
    for st, en, seg in zip(start_times, end_times, segment_types):
        if seg == 'S':
            sad[st:en] = 1
    return sad


def score_me(sys_out_file, ref_file, collar, prior):
    sys_out = read_tsv_file(sys_out_file)
    ref = read_tsv_file(ref_file)
    segmentids = ref['segmentid']
    idx = np.unique(segmentids, return_index=True)[1]
    segmentids = segmentids[np.sort(idx)]
    s_time_tot, ns_time_tot = 0., 0.
    fn_time_tot, fp_time_tot = 0., 0.
    print('\nSegmentid\tP_miss\tP_fa\tDCF')
    cnt = 0
    for segid in segmentids:
        grep = sys_out['segmentid'] == segid
        sys_out_i = sys_out[grep]
        grep = ref['segmentid'] == segid
        ref_i = ref[grep]
        sad_out = get_binary_labels(sys_out_i['segment_begin'],
                                    sys_out_i['segment_end'],
                                    sys_out_i['segment_type'])
        sad_ref = get_binary_labels(ref_i['segment_begin'],
                                    ref_i['segment_end'],
                                    ref_i['segment_type'], collar=collar)
        s_time_i = np.sum(sad_ref == 1)
        ns_time_i = np.sum(sad_ref == 0)
        fn_time_i = np.sum(1 - sad_out[sad_ref == 1])
        fp_time_i = np.sum(sad_out[sad_ref == 0])
        s_time_tot += s_time_i
        ns_time_tot += ns_time_i
        fn_time_tot += fn_time_i
        fp_time_tot += fp_time_i
        p_miss = fn_time_i/s_time_i if s_time_i != 0 else 0
        p_fa = fp_time_i/ns_time_i if ns_time_i != 0 else 0
        dcf_i = prior * p_miss + (1-prior) * p_fa
        print('{}\t{:.4f}\t{:.4f}\t{:.4f}'.format(segid, p_miss, p_fa, dcf_i))
        cnt += 1
    if cnt > 1:
        P_miss = fn_time_tot/s_time_tot if s_time_tot != 0 else 0
        P_fa = fp_time_tot/ns_time_tot if ns_time_tot != 0 else 0
        dcf = prior * P_miss + (1-prior) * P_fa
        print('All\t{:.4f}\t{:.4f}\t{:.4f}\n'.format(P_miss, P_fa, dcf))


def main():
    parser = argparse.ArgumentParser(description='OpenSAT SAD Scorer.')
    parser.add_argument("-o", "--output", help="path to system output file",
                        type=str, required=True)
    parser.add_argument("-r", "--key", help="path to the SAD trial key",
                        type=str, required=True)
    parser.add_argument("-p", "--ptarget", help="prior probability of speech",
                        type=float, default=0.75)
    parser.add_argument("-c", "--collar", help="Length of no-score collar in "
                        "seconds", type=float, default=0.50)
    args = parser.parse_args()
    system_output_file = args.output
    sad_key_file = args.key
    collar_length = args.collar
    prior_prob = args.ptarget

    if validate_me(sad_key_file, system_output_file):
        print("System output failed the validation step. Exiting...\n")
        sys.exit(1)

    score_me(system_output_file, sad_key_file, collar_length, prior_prob)


if __name__ == '__main__':
    main()
