#!python3
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import matplotlib.colors as colors
from matplotlib import animation

np.seterr(divide='ignore', invalid='ignore')

blue_red_gradient = ["#FFFFFF", "#AF0000", "000000", "#0000FF", "#FFFFFF"]
mfield_cmap = colors.LinearSegmentedColormap.from_list('mycmap', blue_red_gradient)

# -- Use 345kV spacing = 20ft, lowest phase to ground = 60ft
Spacing = 20.0
Steps = 600
step = 120.0 / Steps

x_spacing = int(round(np.sqrt(3) * Spacing / 4, 0))                                                                                         # change x spacing to integer
Field_Base_x = -1.0 / step - (Spacing - 2.0 * step) / ((4 * x_spacing) ** 2 + (Spacing - 2 * step) ** 2) - 1.0 / (2.0 * (Spacing - step))   # get max field in x direction
Field_Base_y = x_spacing / ((2 * x_spacing) ** 2 + (Spacing / 2.0 - step) ** 2)                                                             # get max field in y direction
Field_Base = np.sqrt(Field_Base_x ** 2 + Field_Base_y ** 2)                                                                                 # calculate max field (for normalizing)

X = np.linspace(-60, 60, 601)                                                                       # define X matrix
Y = np.linspace(-70, 60, 651)                                                                       # define Y matrix
X, Y = np.meshgrid(X, Y)                                                                            # create the XY matrix

fig, ax = plt.subplots()                                                                            # initialize the figure and the grid
plt.gca().set_aspect('equal', adjustable='box')                                                     # make plot area square
# matplotlib_logo = plt.imread('matplotlib_logo_sm_black.png')                                        # read the matplotlib logo
fig.tight_layout(pad=1.5)                                                                           # give the figure a tight layout
# fig.figimage(matplotlib_logo, 456, 400, zorder=1)                                                   # place the matplotlib logo
plt.title('Single Pole  Magnetic Fields')                                                           # plot title


def getBdata(x, y, angle):                                                                          # get the B fields for all coordinates for this angle step
    angle_a = angle * np.pi / 180.0                                                                 # convert the angle step to radians (A angle is reference)
    angle_b = angle_a - 2.0 * np.pi / 3.0                                                           # phase B lags by 120
    angle_c = angle_a + 2.0 * np.pi / 3.0                                                           # phase C leads by 120

    a_denominator = (x + x_spacing) ** 2 + y ** 2                                                   # calculate phase A field denominator
    b_denominator = (x - x_spacing) ** 2 + (y - Spacing / 2) ** 2                                   # calculate phase B field denominator
    c_denominator = (x + x_spacing) ** 2 + (y - Spacing) ** 2                                       # calculate phase C field denominator

    Bax = -y * np.cos(angle_a) / a_denominator                                                      # calculate phase A fields in the x direction
    Bbx = -(y - Spacing / 2) * np.cos(angle_b) / b_denominator                                      # calculate phase B fields in the x direction
    Bcx = -(y - Spacing) * np.cos(angle_c) / c_denominator                                          # calculate phase C fields in the x direction

    Bay = (x + x_spacing) * np.cos(angle_a) / a_denominator                                         # calculate phase A fields in the y direction
    Bby = (x - x_spacing) * np.cos(angle_b) / b_denominator                                         # calculate phase B fields in the y direction
    Bcy = (x + x_spacing) * np.cos(angle_c) / c_denominator                                         # calculate phase C fields in the y direction

    mag_Ba = np.sqrt(Bax ** 2 + Bay ** 2)                                                           # calculate magnitude of phase A fields
    mag_Bb = np.sqrt(Bbx ** 2 + Bby ** 2)                                                           # calculate magnitude of phase B fields
    mag_Bc = np.sqrt(Bcx ** 2 + Bcy ** 2)                                                           # calculate magnitude of phase C fields

    da = mag_Ba * np.cos(angle_a) / abs(np.cos(angle_a))                                            # calculate phase A field contributions
    db = mag_Bb * np.cos(angle_b) / abs(np.cos(angle_b))                                            # calculate phase B field contribution
    dc = mag_Bc * np.cos(angle_c) / abs(np.cos(angle_c))                                            # calculate phase C field contribution
    polarity = (da + db + dc) / abs(da + db + dc)                                                   # get the sign of the total ABC contribution

    Bx = Bax + Bbx + Bcx                                                                            # calculate total ABC fields in x direction
    By = Bay + Bby + Bcy                                                                            # calculate total ABC fields in y direction
    mag_B = np.sqrt(Bx ** 2 + By ** 2)                                                              # calculate the total field magnitudes
    B = polarity * mag_B / 5.07627293341579

    B[350][255] = 1.0                                                                               # set the magnitude at center of conductor A to maximum
    B[400][345] = 1.0                                                                               # set the magnitude at center of conductor B to maximum
    B[450][255] = 1.0                                                                               # set the magnitude at center of conductor C to maximum
    return B                                                                                        # return the matrix of magnetic fields for all coordinates


def init_heatmap():                                                                                 # initialize heatmap animation
    global ax                                                                                       # declare global
    z_init = getBdata(X, Y, 0.0)                                                                    # get angle=0 fields
    ax = sb.heatmap(z_init, norm=colors.SymLogNorm(linthresh=0.005, linscale=1.0), vmin=-1.0,
                    vmax=1.0, xticklabels=False, yticklabels=False, cmap=mfield_cmap, cbar=False)   # generate this heatmap
    return


def update_heatmap(angle):                                                                          # get next frame number from animation object
    global ax                                                                                       # declare global
    angle = 6 * angle                                                                               # step angle by 6 degrees (360 would have 60 steps)
    print(angle, end=' ')                                                                           # print angle progress
    angle_a = angle                                                                                 # phase a is the reference angle
    angle_b = angle + 240                                                                           # phase b lags by 120
    angle_c = angle + 120                                                                           # phase c leads by 120
    if angle_a >= 360:                                                                              # output formatting ... limit angles to less than 360 ...
        angle_a -= 360                                                                              #
    if angle_b >= 360:                                                                              #
        angle_b -= 360                                                                              #
    if angle_c >= 360:                                                                              #
        angle_c -= 360                                                                              #
    z = getBdata(X, Y, angle)                                                                       # get field matrix
    z[0][0] = -1.0                                                                                  # set top-left corner cell to max negative
    z[0][-1] = 1.0                                                                                  # set top-right corner cell to max positive
    # fig.clear()                                                                                   # use this if showing the color bar
    ax = sb.heatmap(z, norm=colors.SymLogNorm(linthresh=0.02, linscale=1.0), vmin=-1.0, vmax=1.0,
                    xticklabels=False, yticklabels=False, cmap=mfield_cmap, cbar=False)             # generate heatmap for this loops angle
    ax.invert_yaxis()                                                                               # make Y increase from bottom to top
    plt.xlabel('B = \u03BC\u2080i/2\u03C0r')                                                        # this plot y axis title

    ax.plot([300, 300], [50, 500], color='dimgrey', linewidth=4.0, alpha=0.6)                       # vertical pole (use line to illustrate)
    ax.plot([297, 250], [370, 370], color='dimgrey', linewidth=1.5, alpha=0.6)                      # a phase support (use line to illustrate)
    ax.plot([303, 350], [420, 420], color='dimgrey', linewidth=1.5, alpha=0.6)                      # b phase support (use line to illustrate)
    ax.plot([297, 250], [470, 470], color='dimgrey', linewidth=1.5, alpha=0.6)                      # c phase support (use line to illustrate)

    plt.axhline(y=44, color='darkgreen', linewidth=6.0)                                             # grass area (use wide line to illustrate)
    plt.axhline(y=18, color='#52361B', linewidth=19.0)                                              # earth area (use very wide line to illustrate)
    ax.text(255, 22, '|', fontsize=6, horizontalalignment='center')                                 # put dimension text in the earth area
    ax.text(345, 22, '|', fontsize=6, horizontalalignment='center')                                 # put dimension text in the earth area
    ax.text(300, 22, '17.3 ft', fontsize=6, horizontalalignment='center')                           # put dimension text in the earth area
    ax.text(300, 6, 'Log scale used near conductors for visual effect', fontsize=5,
            color='#999999', horizontalalignment='center')                                          # put dimension text in the earth area
    plt.yticks(ticks=[350, 400, 450], labels=['{0:d}\u00B0 A'.format(angle_a),
                                              '{0:d}\u00B0 B'.format(angle_b),
                                              '{0:d}\u00B0 C'.format(angle_c)])                     # y axis dynamic update
    return


# =================================================================================================
# -- MAIN ---------------- MAIN ---------------- MAIN ---------------- MAIN -----------------------
# =================================================================================================
if __name__ == '__main__':
    show_progress = False                                                                           # True for developement, False for .mp4 or .gif file output
    mp4_writer = animation.FFMpegWriter(fps=6, metadata=dict(artist='3Phaseee.com'), bitrate=1800)
    # gif_writer = animation.ImageMagickWriter(fps=6, metadata=dict(artist='3Phaseee.com'), bitrate=1800)
    ani = animation.FuncAnimation(fig, update_heatmap,  init_func=init_heatmap, frames=61, interval=10, repeat=False)
    if show_progress:
        plt.show()
    else:
        ani.save('Mfield2.mp4', writer=mp4_writer)
        # ani.save('Mfield2.gif', writer=gif_writer)
