"""video.py

General purpose image and video generation functions for the Pausch Bridge.
"""
#================================================================
# Dependencies:
#
# This module assumes the availability of the OpenCV and numpy libraries. A
# recommended method for installing these in Python 3 follows:
#
#   pip3 install opencv-contrib-python
#
# General OpenCV information:   https://opencv.org/
# General NumPy information:    https://numpy.org/

#================================================================
# Import standard Python modules.
import logging

# Import the numpy and OpenCV modules.
import numpy as np
import cv2 as cv

#================================================================
# Define the video properties using the canonical video format for the Pausch
# Bridge lighting system.
frame_rate   = 30
frame_width  = 228
frame_height = 8

# Specify a format code and file format.  The exact combinations of codec and
# file formats available are different for each platform.
codec_code = cv.VideoWriter.fourcc(*'png ') # PNG images, lossless, clean block edges
video_file_extension = 'avi'

#================================================================
# Define a set of identification colors as (B,G,R) triples of unsigned 8-bit integers.
digit_colors = ((  0,  0,  0),  # digit 0       black
                (  0, 75,150),  # digit 1       brown
                (  0,  0,255),  # digit 2       red
                (  0,165,255),  # digit 3       orange
                (  0,255,255),  # digit 4       yellow
                (  0,255,  0),  # digit 5       green
                (255,  0,  0),  # digit 6       blue
                (211, 50,147),  # digit 7       violet
                (160,160,160),  # digit 8       grey
                (255,255,255))  # digit 9       white

digit_names = ('Blk', 'Brn', 'Red', 'Org', 'Yel', 'Grn', 'Blu', 'Vio', 'Gry', 'Wht')

margin_color =  ( 80, 80, 80)  # dark gray to use for the margins

#================================================================
def color_code_image(value):
    """Generate a single video frame for the Pausch Bridge representing the given
    integer code.  Each digit is encoded as a color block spanning two fixture
    groups, separated by a single group in the margin color.  The code is
    preceded and trailed by margin color blocks to distinguish black code blocks
    from the margins.  There are 35 contiguous CG4 fixture in the main span.
    With allowances for margins, this may encode as many as 11 digits.  If
    zero-padding is desired, pass a string value including leading zeros.
    Returns a tuple (frame, code_name).
    """
    # cast input value to string as needed
    num_str = str(value)
    if not num_str.isdecimal():
        logging.warning("Slate code generator passed an illegal value '%s', some characters will be omitted.", value)

    code_colors = list()
    code_names = list()
    for c in num_str:
        if c.isdigit():
            i = int(c)
            code_colors.append(digit_colors[i])
            code_names.append(digit_names[i])
            
    if len(code_colors) > 11:
        logging.warning("Slate code generator passed an excessively long value '%s', some characters will be omitted.", value)
        code_colors = code_colors[0:11]
        code_names  = code_names[0:11]
        
    # create a black image
    frame = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)

    # fill the main span section with the margin color
    frame[:,44:192,:] = margin_color

    # center the code within the main span
    num_digits = len(code_colors)
    first_column = 116 - (12 * (num_digits // 2))

    # draw code values as two-group blocks with a single-group margin between each digit
    for i, color in enumerate(code_colors):
        column = first_column + 12*i
        frame[:,column:column+8,:] = color

    # assemble a string designating the color sequence
    code_name = "".join(code_names)
    
    return (frame, code_name)

#================================================================
def image_keyframes(frame, beats=5):
    """Generator function to produce a sequence of keyframes for presenting a static
    image including a fade-in and fade-out.  The keyframes are assumed to occur
    at constant rate, so the first keyframe is black, then the image is repeated
    for a specified number of beats, then the final two keyframes are black.
    """
    
    black = np.zeros(frame.shape, dtype=np.uint8)
    yield black         # first black keyframe from which to fade in
    for i in range(beats):
        yield frame     # image keyframes to show and hold the image
    yield black         # black keyframe to which to fade
    yield black         # black keyframe on which to hold
    return              # no more keyframes

#================================================================
def keyframe_interpolator(keyframe_generator, tempo=60):
    """Generator function to produce successive frames of a video sequence by linear
    interpolation between keyframes at a constant tempo.  Yields a video image
    frame, or None when the sequence is complete.

    :param keyframe_generator: generator function which returns a video frame or None
    :param tempo: keyframe rate in beats per minute

    """
    
    keyframe_phase = 0.0  # unit phase for the cross-fade, cycles over 0 to 1

    frame_interval = 1.0 / frame_rate                       # seconds between video frames
    keyframe_interval = 60.0 / tempo                        # seconds between key frames
    keyframe_rate = 1.0 / (frame_rate * keyframe_interval)  # phase / frame

    # Generate the first keyframe or end sequence.
    try:
        frame0 = next(keyframe_generator)
    except StopIteration:
        return
    
    # Generate the second keyframe.  If null, repeat the first keyframe for a single beat.
    try:
        frame1 = next(keyframe_generator)
    except StopIteration:
        frame1 = frame0
    
    while True:
        # Cross-fade between successive key frames at the given tempo.  This will
        # return a new frame of integer pixels.
        frame = cv.addWeighted(frame0, (1.0 - keyframe_phase), frame1, keyframe_phase, 0.0)

        # Return the frame and advance the generator state.
        yield frame
        
        # Advance the cross-fade phase.
        keyframe_phase += keyframe_rate
        
        # Once the second keyframe is reached, reset the fade and generate the successor.
        if keyframe_phase > 1.0:
            keyframe_phase -= 1.0
            frame0 = frame1
            # generate the next keyframe or end sequence if done
            try:
                frame1 = next(keyframe_generator)
            except StopIteration:
                return

#================================================================
def write_video_file(filepath, frame_generator):

    """Write a video file using frames returned from a generator function.  The
    function yields either an image frame or None once the sequence is complete.
    The video file format is determined from the extension in the path.
    """

    # Collect the first video frame to determine the output size.
    try:
        frame = next(frame_generator)
    except StopIteration:
        return # nothing to do
    
    if frame is not None:
        height = frame.shape[0]
        width  = frame.shape[1]
        
        # Open the writer with a path, format, frame rate, and size.
        out = cv.VideoWriter(filepath, codec_code, frame_rate, (width, height))

        # Write the first frame
        out.write(frame)

        # Write all remaining frames to the stream.
        for frame in frame_generator:
            out.write(frame)

        # Release everything when done.
        out.release()
            
#================================================================
def read_image_file(path):
    """Read an image file and preprocess into a 228-pixel wide BGR image in
    which each row represents a keyframe.  This subsamples vertically every
    eight pixels and deletes any alpha channel.  Returns None or an image
    array.
    """

    # the default behavior for imread appears to be to ignore the alpha channel
    source = cv.imread(path)
    if source is None:
        logging.warning("Failed to read %s", path)
        return None
    
    rows, cols, planes = source.shape
    if planes != 3:
        logging.warning("Image %s read with %d planes, unsupported.", path, planes)
        return None
    
    keyframes = rows // 8
    output = cv.resize(source, dst=None, dsize=(frame_width,keyframes), interpolation=cv.INTER_NEAREST)
    return output

#================================================================
def image_row_keyframes(source):
    """Generator function to produce a sequence of keyframes by expanding successive
    rows of a source image into full frames.

    :param source: a NumPy array representing a 228-pixel-wide image with each row a keyframe
    """
    source_rows = source.shape[0]
    for r in range(source_rows):
        row = source[r:r+1,:,:]
        yield cv.resize(row, dst=None, dsize=(frame_width, frame_height), interpolation=cv.INTER_NEAREST)

#================================================================
def validate_animation_timing(source, tempo, max_duration = 60, min_duration = 15):
    
    """Check the size and rate of a source animation image, adjusting values to stay
    within policy limits.  This will apply a rate policy, maximum duration
    policy, and minimum duration policy.  Overly long animations may be
    truncated, overly short animations will be looped.

    :param source: source image for which each row is a keyframe
    :param tempo: keyframe rate in beats per minute
    :return: tuple (image, tempo)
    """

    # maximum tempo in BPM is the number of frames in a minute
    max_tempo = int(60 * frame_rate)
    
    if tempo > max_tempo:
        logging.warning("Limiting tempo to %f BPM (was %f)", max_tempo, tempo)
        tempo = max_tempo
    elif tempo < 1:
        logging.warning("Limiting tempo to 1 BPM (was %f)", 1, tempo)
        tempo = 1
        
    # Validate the source size.
    keyframe_interval = 60.0 / tempo       # seconds between keyframes
    num_keyframes = source.shape[0]
    total_duration = num_keyframes * keyframe_interval

    if total_duration > max_duration:
        logging.warning("Truncating image to clip to max duration.")
        max_keyframes = int(max_duration / keyframe_interval)
        return (source[0:max_keyframes], tempo)

    elif total_duration < min_duration:
        iterations = int(2*min_duration / total_duration)
        logging.info("Looping short clip with %f duration %d time(s).", total_duration, iterations)
        # repeat all rows in the source matrix to loop it
        return(np.tile(source, (iterations,1,1)), tempo)
    
    else:
        return (source, tempo)

