import scipy, numpy as np

def ellipse_k_m(a, b):
    return b, 1 - (a/b)**2

def ellipse_arclength(k, m, theta1, theta2):
    # there is a lot more information on this function at wolfram.com:
    #  https://functions.wolfram.com/EllipticIntegrals/EllipticE2/introductions/IncompleteEllipticIntegrals/ShowAll.html
    # for exampe, it can be defined in terms of other functions. haven't comprehended most of it.
    # assuming that the ellipeinc implementation will be faster at checking if theta1 == 0 than a check here would be
    return k * (scipy.special.ellipeinc(theta2, m) - scipy.special.ellipeinc(theta1, m))

def ellipse_arclength_per_theta(k, m, theta):
    # d_ellipeinc_dtheta(theta, m) = 1 / np.sqrt(1 - m * np.sin(theta)**2)
    return k / np.sqrt(1 - m * np.sin(theta)**2)

def linspace_interval(min, max, interval, endpoint=True):
    return np.linspace(min, max, np.ceil((max - min) / interval).astype(np.int_))[:None if endpoint else -1]

def solve_eqn(variable, min, max, equation, **variables):
    expr1, expr2 = equation.split('=')
    def eval_(val):
        variables[variable] = val
        return eval(expr1, globals(), variables) - eval(expr2, globals(), variables)
    min_v = eval_(min)
    max_v = eval_(max)
    if min_v > max_v:
        min, max = max, min
        min_v, max_v = max_v, min_v
    inp = (min+max)/2
    while max_v - min_v > 1e-10:
        inp = min + (max - min) * (0 - min_v) / (max_v - min_v)
        v = eval_(inp)
        assert v >= min_v
        assert v <= max_v
        if v > 0:
            max = inp
            max_v = v
        else:
            min = inp
            min_v = v
    return inp


def sewing_points(spheroid_rx, spheroid_ry, panel_count, offset, resolution):
    panel_ellipse_width = spheroid_rx * 2 * np.pi / panel_count
    panel_ellipse_rx = panel_ellipse_width / 2
    #panel_intersection_ellipse_angle = np.arccos(panel_ellipse_rx / (panel_ellipse_rx + offset))
    #panel_intersection_y = np.sin(panel_intersection_ellipse_angle) * spheroid_ry
    #panel_intersection_sphere_theta = panel_intersection_ellipse_angle
    #panel_intersection_sphere_theta = spheroid_panel_offset_intersection_angle(spheroid_rx, spheroid_ry, offset, panel_count)
    spheroid_k, spheroid_m = ellipse_k_m(spheroid_rx, spheroid_ry)
    panel_intersection_sphere_theta = solve_eqn(
        'theta', 0, 3.14,
        '''offset*spheroid_k/np.sqrt(
            spheroid_k**2 +
            (np.pi*spheroid_rx*np.sin(theta)/panel_count)**2 * (1 - spheroid_m * np.sin(theta)**2)
        ) = spheroid_rx*np.pi*(1 - np.cos(theta)) / panel_count''',
        **locals()
    )
    Ys = []
    Xs = []
    panel_archeight = ellipse_arclength(spheroid_k, spheroid_m, 0, np.pi/2)
    # rather than inverting ellipeinc, angular resolution is calculated by assuming approximate constant curvature
    d_sphere_theta = resolution * 2 / (spheroid_rx + spheroid_ry) # s = theta r; theta = s / r
    sphere_thetas_edge = linspace_interval(0, panel_intersection_sphere_theta, d_sphere_theta, endpoint=False)
    sphere_thetas_main = linspace_interval(panel_intersection_sphere_theta, np.pi/2, d_sphere_theta)
    sphere_thetas = np.concatenate([sphere_thetas_edge, sphere_thetas_main])
    panel_intersection_idx = len(sphere_thetas_edge)
    # convert these to arclengths for Y
    Y = ellipse_arclength(spheroid_k, spheroid_m, 0, sphere_thetas)# + offset * np.sqrt(2) / 2
    dY = ellipse_arclength_per_theta(spheroid_k, spheroid_m, sphere_thetas)
    # calculate X based on the sphere circumference
        # we're looking for R based on theta; theta is already normalised as if on a circle
        # so we can calculate R for a circle and scale it
    X = np.cos(sphere_thetas) * spheroid_rx * np.pi / panel_count# + offset * np.sqrt(2) / 2
    dX = -np.sin(sphere_thetas) * spheroid_rx * np.pi / panel_count
    d = np.sqrt(dY**2+dX**2)
    Y += -dX * offset / d
    X += dY * offset / d
    for panel_idx in range(panel_count):
        min_idx_left = 0 if panel_idx == 0 else panel_intersection_idx
        min_idx_right = None if panel_idx + 1 == panel_count or panel_intersection_idx == 0 else panel_intersection_idx - 1
        Ys.extend([Y[min_idx_left:], Y[:min_idx_right:-1]])
        Xs.extend([panel_ellipse_width * panel_idx - X[min_idx_left:], panel_ellipse_width * panel_idx + X[:min_idx_right:-1]])
    Xs.extend([Xs[idx][::-1] for idx in range(len(Xs)-1,-1,-1)])
    Ys.extend([-Ys[idx][::-1] for idx in range(len(Ys)-1,-1,-1)])
    return np.concatenate(Xs), np.concatenate(Ys)

#import matplotlib.pyplot as plt
#plt.gca().set_aspect('equal', adjustable='box')
#sewing_points(1,1,3,0,0.1)
#sewing_points(1,1,3,0.5,0.1)
