#!/usr/bin/python3

import argparse
import logging
import os
import sys

from platypush import RedisBus
from platypush.message.event.custom import CustomEvent

from micmon.audio import AudioDevice
from micmon.model import Model

logger = logging.getLogger('micmon')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('model_path', help='Path to the file/directory containing the saved Tensorflow model')
    parser.add_argument('-i', help='Input sound device (e.g. hw:0,1 or default)', required=True, dest='sound_device')
    parser.add_argument('-e', help='Name of the event that should be raised when a positive event occurs', required=True, dest='event_type')
    parser.add_argument('-s', '--sound-server', help='Sound server to be used (available: alsa, pulse)', required=False, default='alsa', dest='sound_server')
    parser.add_argument('-P', '--positive-label', help='Model output label name/index to indicate a positive sample (default: positive)', required=False, default='positive', dest='positive_label')
    parser.add_argument('-N', '--negative-label', help='Model output label name/index to indicate a negative sample (default: negative)', required=False, default='negative', dest='negative_label')
    parser.add_argument('-l', '--sample-duration', help='Length of the FFT audio samples (default: 2 seconds)', required=False, type=float, default=2., dest='sample_duration')
    parser.add_argument('-r', '--sample-rate', help='Sample rate (default: 44100 Hz)', required=False, type=int, default=44100, dest='sample_rate')
    parser.add_argument('-c', '--channels', help='Number of audio recording channels (default: 1)', required=False, type=int, default=1, dest='channels')
    parser.add_argument('-f', '--ffmpeg-bin', help='FFmpeg executable path (default: ffmpeg)', required=False, default='ffmpeg', dest='ffmpeg_bin')
    parser.add_argument('-v', '--verbose', help='Verbose/debug mode', required=False, action='store_true', dest='debug')
    parser.add_argument('-w', '--window-duration', help='Duration of the look-back window (default: 10 seconds)', required=False, type=float, default=10., dest='window_length')
    parser.add_argument('-n', '--positive-samples', help='Number of positive samples detected over the window duration to trigger the event (default: 1)', required=False, type=int, default=1, dest='positive_samples')

    opts, args = parser.parse_known_args(sys.argv[1:])
    return opts


def main():
    args = get_args()
    if args.debug:
        logger.setLevel(logging.DEBUG)

    model_dir = os.path.abspath(os.path.expanduser(args.model_path))
    model = Model.load(model_dir)
    window = []
    cur_prediction = args.negative_label
    bus = RedisBus()

    with AudioDevice(system=args.sound_server,
                     device=args.sound_device,
                     sample_duration=args.sample_duration,
                     sample_rate=args.sample_rate,
                     channels=args.channels,
                     ffmpeg_bin=args.ffmpeg_bin,
                     debug=args.debug) as source:
        for sample in source:
            source.pause()  # Pause recording while we process the frame
            prediction = model.predict(sample)
            logger.debug(f'Sample prediction: {prediction}')
            has_change = False

            if len(window) < args.window_length:
                window += [prediction]
            else:
                window = window[1:] + [prediction]

            positive_samples = len([pred for pred in window if pred == args.positive_label])
            if args.positive_samples <= positive_samples and \
                    prediction == args.positive_label and \
                    cur_prediction != args.positive_label:
                cur_prediction = args.positive_label
                has_change = True
                logging.info(f'Positive sample threshold detected ({positive_samples}/{len(window)})')
            elif args.positive_samples > positive_samples and \
                    prediction == args.negative_label and \
                    cur_prediction != args.negative_label:
                cur_prediction = args.negative_label
                has_change = True
                logging.info(f'Negative sample threshold detected ({len(window)-positive_samples}/{len(window)})')

            if has_change:
                evt = CustomEvent(subtype=args.event_type, state=prediction)
                bus.post(evt)

            source.resume() # Resume recording


if __name__ == '__main__':
    main()