#!python3
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import matplotlib.colors as colors
from matplotlib import animation

np.seterr(divide='ignore', invalid='ignore')

green_yellow_gradient = ["#FFFFFF", "#70AD47", "000000", "#FFFF00", "#FFFFFF"]
efield_cmap = colors.LinearSegmentedColormap.from_list('mycmap', green_yellow_gradient)

# -- Use 345kV spacing = 20ft, lowest phase to ground = 60ft
Spacing = 20.0
Steps = 600
Height = 60
step = 120.0 / Steps
Field_Base = 1.0 / step + 1.0 / (2 * (Spacing - step)) + 1.0 / (2 * (2 * Spacing - step))           # get maximum field one step from A phase (normalize to this)

X = np.linspace(-60, 60, 601)                                                                       # define X matrix
Y = np.linspace(-70, 60, 651)                                                                       # define Y matrix
X, Y = np.meshgrid(X, Y)                                                                            # create the XY matrix

fig, ax = plt.subplots()                                                                            # initialize the figure and the grid
plt.gca().set_aspect('equal', adjustable='box')                                                     # make plot area square
# matplotlib_logo = plt.imread('matplotlib_logo_sm_black.png')                                        # read the matplotlib logo
fig.tight_layout(pad=1.5)                                                                           # give the figure a tight layout
# fig.figimage(matplotlib_logo, 456, 400, zorder=1)                                                   # place the matplotlib logo
plt.title('H-Frame  Electric Fields')                                                               # plot title


def getEdata(x, y, d, angle):                                                                       # get the B fields for all coordinates for this angle step
    angle_a = angle * np.pi / 180.0                                                                 # convert the angle step to radians (A angle is reference)
    angle_b = angle_a - 2.0 * np.pi / 3.0                                                           # phase B lags by 120
    angle_c = angle_a + 2.0 * np.pi / 3.0                                                           # phase C leads by 120

    Adenominator = (x + d) ** 2 + y ** 2                                                            # calculate phase A field denominator
    Bdenominator = x ** 2 + y ** 2                                                                  # calculate phase B field denominator
    Cdenominator = (x - d) ** 2 + y ** 2                                                            # calculate phase C field denominator

    _Adenominator = (x + d) ** 2 + (y + 2 * Height) ** 2                                            # calculate image phase _A field denominator
    _Bdenominator = x ** 2 + (y + 2 * Height) ** 2                                                  # calculate image phase _B field denominator
    _Cdenominator = (x - d) ** 2 + (y + 2 * Height) ** 2                                            # calculate image phase _C field denominator

    # -- GET X COMPONENTS -------------------------------------------------------------------------
    # -- CALCULATE OVERHEAD PHASES X COMPONENTS -------------------------------
    Eax_ = (x + d) * np.cos(angle_a) / Adenominator                                                 # calculate phase A fields in the x direction
    Ebx_ = x * np.cos(angle_b) / Bdenominator                                                       # calculate phase B fields in the x direction
    Ecx_ = (x - d) * np.cos(angle_c) / Cdenominator                                                 # calculate phase C fields in the x direction
    # -- CALCULATE IMAGE PHASES X COMPONENTS ----------------------------------
    _Eax = -(x + d) * np.cos(angle_a) / _Adenominator                                               # calculate image phase _A fields in the x direction
    _Ebx = -x * np.cos(angle_b) / _Bdenominator                                                     # calculate image phase _B fields in the x direction
    _Ecx = -(x - d) * np.cos(angle_c) / _Cdenominator                                               # calculate image phase _C fields in the x direction
    # -- TOTAL ABC PHASE X COMPONENTS -----------------------------------------
    Eax = Eax_ + _Eax                                                                               # calculate total A field (overhead + image) in the x direction
    Ebx = Ebx_ + _Ebx                                                                               # calculate total B field (overhead + image) in the x direction
    Ecx = Ecx_ + _Ecx                                                                               # calculate total C field (overhead + image) in the x direction

    # -- GET Y COMPONENTS -------------------------------------------------------------------------
    # -- CALCULATE OVERHEAD PHASES Y COMPONENTS -------------------------------
    Eay_ = y * np.cos(angle_a) / Adenominator                                                       # calculate phase A fields in the y direction
    Eby_ = y * np.cos(angle_b) / Bdenominator                                                       # calculate phase B fields in the y direction
    Ecy_ = y * np.cos(angle_c) / Cdenominator                                                       # calculate phase C fields in the y direction
    # -- CALCULATE IMAGE PHASES Y COMPONENTS ----------------------------------
    _Eay = -(y + 2 * Height) * np.cos(angle_a) / _Adenominator                                      # calculate image phase _A fields in the y direction
    _Eby = -(y + 2 * Height) * np.cos(angle_b) / _Bdenominator                                      # calculate image phase _B fields in the y direction
    _Ecy = -(y + 2 * Height) * np.cos(angle_c) / _Cdenominator                                      # calculate image phase _C fields in the y direction
    # -- TOTAL ABC PHASE Y COMPONENTS -----------------------------------------
    Eay = Eay_ + _Eay                                                                               # calculate total A field (overhead + image) in the y direction
    Eby = Eby_ + _Eby                                                                               # calculate total B field (overhead + image) in the y direction
    Ecy = Ecy_ + _Ecy                                                                               # calculate total C field (overhead + image) in the y direction

    # -- CALCULATE CONTRIBUTIONS TO POLARITY ------------------------------------------------------
    mag_Ea = np.sqrt(Eax ** 2 + Eay ** 2)                                                           # calculate magnitude of phase A fields
    mag_Eb = np.sqrt(Ebx ** 2 + Eby ** 2)                                                           # calculate magnitude of phase B fields
    mag_Ec = np.sqrt(Ecx ** 2 + Ecy ** 2)                                                           # calculate magnitude of phase C fields
    da = mag_Ea * np.cos(angle_a) / abs(np.cos(angle_a))                                            # calculate phase A field contributions
    db = mag_Eb * np.cos(angle_b) / abs(np.cos(angle_b))                                            # calculate phase B field contribution
    dc = mag_Ec * np.cos(angle_c) / abs(np.cos(angle_c))                                            # calculate phase C field contribution
    polarity = (da + db + dc) / abs(da + db + dc)                                                   # get the sign of the total ABC contribution

    # -- FINALLY GET TOTAL ELECTRIC FIELD ---------------------------------------------------------
    Ex = Eax + Ebx + Ecx
    Ey = Eay + Eby + Ecy
    mag_E = np.sqrt(Ex ** 2 + Ey ** 2)                                                              # calculate the total field magnitudes

    # zmax = 0.0
    # for i in mag_E:
    #     if max(i) > zmax:
    #         zmax = max(i)
    # print(zmax, Field_Base)

    E = polarity * mag_E / 5.035887206902947                                                        # Field_Base is slighty smaller than calculated due to images
    # E = polarity * mag_E / Field_Base                                                             # get the total ABC field polarity and normalize
    E[350][200] = 1.0                                                                               # set the magnitude at center of conductor A to maximum
    E[350][300] = 1.0                                                                               # set the magnitude at center of conductor B to maximum
    E[350][400] = 1.0                                                                               # set the magnitude at center of conductor C to maximum
    return E                                                                                        # return the matrix of magnetic fields for all coordinates


def init_heatmap():                                                                                 # initialize heatmap animation
    global ax                                                                                       # declare global
    z_init = getEdata(X, Y, Spacing, 0.0)                                                           # get angle=0 fields
    ax = sb.heatmap(z_init, norm=colors.SymLogNorm(linthresh=0.005, linscale=1.0), vmin=-1.0,
                    vmax=1.0, xticklabels=False, yticklabels=False, cmap=efield_cmap, cbar=False)   # generate this heatmap
    return


def update_heatmap(angle):                                                                          # get next frame number from animation object
    global ax                                                                                       # declare global
    angle = 6 * angle                                                                               # step angle by 6 degrees (360 would have 60 steps)
    print(angle, end=' ')                                                                           # print angle progress
    angle_a = angle                                                                                 # phase a is the reference angle
    angle_b = angle + 240                                                                           # phase b lags by 120
    angle_c = angle + 120                                                                           # phase c leads by 120
    if angle_a >= 360:                                                                              # output formatting ... limit angles to less than 360 ...
        angle_a -= 360                                                                              #
    if angle_b >= 360:                                                                              #
        angle_b -= 360                                                                              #
    if angle_c >= 360:                                                                              #
        angle_c -= 360                                                                              #
    z = getEdata(X, Y, Spacing, angle)                                                              # get field matrix
    z[0][0] = -1.0                                                                                  # set top-left corner cell to max negative
    z[0][-1] = 1.0                                                                                  # set top-right corner cell to max positive
    # fig.clear()                                                                                   # use this if showing the color bar
    ax = sb.heatmap(z, norm=colors.SymLogNorm(linthresh=0.02, linscale=1.0), vmin=-1.0, vmax=1.0,
                    xticklabels=False, yticklabels=False, cmap=efield_cmap, cbar=False)             # generate heatmap for this loops angle
    ax.invert_yaxis()                                                                               # make Y increase from bottom to top
    plt.ylabel('E = 60i/r')                                                                         # this plot y axis title
    ax.plot([250, 250], [50, 390], color='dimgrey', linewidth=2.5, alpha=0.6)                       # left vertical pole (use line to illustrate)
    ax.plot([350, 350], [50, 390], color='dimgrey', linewidth=2.5, alpha=0.6)                       # right vertical pole (use line to illustrate)
    ax.plot([190, 410], [365, 365], color='dimgrey', linewidth=2.0, alpha=0.6)                      # horizontal cross-arm (use line to illustrate)
    plt.axhline(y=44, color='darkgreen', linewidth=6.0)                                             # grass area (use wide line to illustrate)
    plt.axhline(y=18, color='#52361B', linewidth=19.0)                                              # earth area (use very wide line to illustrate)
    ax.text(250, 22, '|', fontsize=6, horizontalalignment='center')                                 # put dimension text in the earth area
    ax.text(350, 22, '|', fontsize=6, horizontalalignment='center')                                 # put dimension text in the earth area
    ax.text(300, 22, '20 ft', fontsize=6, horizontalalignment='center')                             # put dimension text in the earth area
    ax.text(300, 6, 'Log scale used near conductors for visual effect', fontsize=5,
            color='#999999', horizontalalignment='center')                                          # put dimension text in the earth area
    plt.xticks(ticks=[200, 300, 400], labels=['A\n{0:3d}\u00B0'.format(angle_a),
                                              'B\n{0:3}\u00B0'.format(angle_b),
                                              'C\n{0:3d}\u00B0'.format(angle_c)])                   # x axis dynamic update
    return


# =================================================================================================
# -- MAIN ---------------- MAIN ---------------- MAIN ---------------- MAIN -----------------------
# =================================================================================================
if __name__ == '__main__':
    show_progress = False                                                                           # True for developement, False for .mp4 or .gif file output
    mp4_writer = animation.FFMpegWriter(fps=6, metadata=dict(artist='3Phaseee.com'), bitrate=1800)
    # gif_writer = animation.ImageMagickWriter(fps=6, metadata=dict(artist='3Phaseee.com'), bitrate=1800)
    ani = animation.FuncAnimation(fig, update_heatmap,  init_func=init_heatmap, frames=61, interval=10, repeat=False)
    if show_progress:
        plt.show()
    else:
        ani.save('Efield1.mp4', writer=mp4_writer)
        # ani.save('Efield1.gif', writer=gif_writer)
