import numpy as np
import autograd.numpy as anp                 # autograd-traced numpy, used inside the objective

import tidy3d as td
import tidy3d.web as web
from tidy3d.components.mode_spec import ModeSortSpec
from tidy3d.plugins.autograd import rescale, value_and_grad, adam, apply_updates
from tidy3d.plugins.autograd.invdes import (
    make_filter_and_project,        # density -> conic filter -> tanh projection
    make_erosion_dilation_penalty,  # penalizes features that vanish under erode/dilate
)


# ============================================================================
# 1. MATERIALS  (silicon-on-insulator at ~1.55 um)
# ============================================================================
n_si, n_sio2 = 3.48, 1.444
eps_si, eps_sio2 = n_si**2, n_sio2**2
mat_si = td.Medium(permittivity=eps_si)
mat_sio2 = td.Medium(permittivity=eps_sio2)


# ============================================================================
# 2. WAVELENGTH BAND  (100 nm wide, sampled at 11 points for the broadband FOM)
# ============================================================================
wl0 = 1.55
freq0 = td.C_0 / wl0
fwidth = td.C_0 / 1.50 - td.C_0 / 1.60       # source bandwidth covering the whole band
freqs = td.C_0 / np.linspace(1.50, 1.60, 11)  # the 11 frequencies the objective averages over


# ============================================================================
# 3. GEOMETRY  (propagation along +x; chip plane is x-y; SOI thickness along z)
# ============================================================================
h_si = 0.22                 # SOI device-layer thickness (core centered at z = 0)
w_in = 0.45                 # single-mode input waveguide width (guides TE0 + TM0 only)
w_bus = 1.6                 # multimode output bus width -> supports TE0, TE1, TM0, TM1
Lx_des = Ly_des = 6.0       # design-region size

n_ports = 4
pitch = 1.4                                  # input-port spacing in y
port_y = np.linspace((n_ports - 1) / 2 * pitch, -(n_ports - 1) / 2 * pitch, n_ports)
#         -> [+2.1, +0.7, -0.7, -2.1]   (top to bottom)

# Target output-bus mode for each input port, top -> bottom:
#   I1 -> TM1,  I2 -> TM0,  I3 -> TE1,  I4 -> TE0
TARGETS = [("TM", 1), ("TM", 0), ("TE", 1), ("TE", 0)]

# Domain = design region + straight-waveguide leads + SiO2 buffer for the PML.
Lx = Lx_des + 2 * 2.0       # 10.0  (2.0 um lead on each side)
Ly = Ly_des + 2 * 1.5       # 9.0   (1.5 um cladding buffer)
Lz = h_si + 2 * 0.89        # ~2.0  (0.89 um cladding buffer)
x_des_edge = Lx_des / 2     # 3.0
x_src = -x_des_edge - 0.7   # input mode-source plane
x_out = x_des_edge + 0.7    # output bus mode-monitor plane

# Design resolution: dl_design sets both the parameter grid and the in-plane FDTD mesh.
dl_design = 0.02                              # 20 nm
nx = ny = int(round(Lx_des / dl_design))      # 300 x 300 = 90,000 pixels
filter_radius = 0.08                          # conic-filter radius -> ~80 nm min feature

design_box = td.Box(center=(0, 0, 0), size=(Lx_des, Ly_des, h_si))


# ============================================================================
# 4. STATIC (NON-DESIGN) STRUCTURES  (input waveguides + output bus)
# ============================================================================
static_structures = [
    td.Structure(                                          # four input leads (through -x PML)
        geometry=td.Box.from_bounds(rmin=(-Lx / 2 - 1.0, y - w_in / 2, -h_si / 2),
                                    rmax=(-x_des_edge, y + w_in / 2, +h_si / 2)),
        medium=mat_si,
    )
    for y in port_y
] + [
    td.Structure(                                          # multimode output bus (through +x PML)
        geometry=td.Box.from_bounds(rmin=(x_des_edge, -w_bus / 2, -h_si / 2),
                                    rmax=(Lx / 2 + 1.0, +w_bus / 2, +h_si / 2)),
        medium=mat_si,
    )
]


# ============================================================================
# 5. DESIGN PARAMETERIZATION  (params -> density -> permittivity Structure)
# ============================================================================
# Differentiable map: raw params in [0,1] --conic filter--> --tanh projection(beta,eta)--> density.
#   beta : projection sharpness. beta=1 ~ continuous (grayscale); large beta -> binary {0,1}.
#   eta  : projection threshold. eta=0.5 nominal; shifting eta erodes/dilates every edge
#          uniformly, which is how we emulate an etch bias (see ETAS below).
_filter_project = make_filter_and_project(filter_radius, dl_design)


def get_density(params, beta, eta=0.5):
    """Material density in [0, 1] (1 = silicon, 0 = oxide).
    eta < 0.5 dilates (under-etch, more Si); eta > 0.5 erodes (over-etch, less Si)."""
    return _filter_project(params, beta=beta, eta=eta)


def make_design_structure(params, beta, eta=0.5):
    """Turn the density field into a `CustomMedium` Structure filling the design box."""
    eps2d = rescale(get_density(params, beta, eta), eps_sio2, eps_si)   # 0->SiO2, 1->Si
    return td.Structure.from_permittivity_array(
        geometry=design_box, eps_data=eps2d.reshape((nx, ny, 1))
    )


# ============================================================================
# 6. GRID, MONITORS, SOURCES
# ============================================================================
# Enforce the in-plane mesh to match the pixel size; z is left to the auto mesher.
grid_spec = td.GridSpec.auto(
    min_steps_per_wvl=20, wavelength=wl0,
    override_structures=[td.MeshOverrideStructure(
        geometry=design_box, dl=(dl_design, dl_design, None), enforce=True)],
)

# Two ModeMonitors on the bus output plane, each polarization-sorted by ModeSortSpec so
# that mode_index 0,1 reliably mean {TE0, TE1} (or {TM0, TM1}) regardless of solve order.
_mode_plane_out = (0, w_bus + 2.0, h_si + 1.2)
_mode_spec_out = {
    pol: td.ModeSpec(num_modes=8, target_neff=n_si,        # solve 8 raw modes, then filter
                     sort_spec=ModeSortSpec(filter_key=f"{pol}_fraction", filter_reference=0.5,
                                            filter_order="over", keep_modes=2,
                                            sort_key="n_eff", sort_order="descending"))
    for pol in ("TE", "TM")
}


def output_monitors():
    """TE-sorted and TM-sorted mode monitors on the bus output plane."""
    return [td.ModeMonitor(center=(x_out, 0, 0), size=_mode_plane_out, freqs=list(freqs),
                           mode_spec=_mode_spec_out[pol], name=f"out_{pol.lower()}")
            for pol in ("TE", "TM")]


_mode_plane_in = (0, min(1.3, pitch - 0.1), h_si + 1.2)


def input_source(port_index):
    """ModeSource launching the fundamental mode for the port. In the 0.45 um lead,
    TE0 is mode_index 0 and TM0 is mode_index 1 (TE0 has the higher n_eff)."""
    pol, _ = TARGETS[port_index]
    mode_spec = td.ModeSpec(num_modes=1 if pol == "TE" else 2, target_neff=n_si)
    return td.ModeSource(
        center=(x_src, port_y[port_index], 0), size=_mode_plane_in,
        source_time=td.GaussianPulse(freq0=freq0, fwidth=fwidth),
        direction="+", mode_spec=mode_spec, mode_index=0 if pol == "TE" else 1,
        name=f"src_{port_index}",
    )


def get_sim(params, beta, port_index, eta=0.5):
    """Full Simulation exciting a single input port (one forward run per port)."""
    return td.Simulation(
        size=(Lx, Ly, Lz), center=(0, 0, 0),
        medium=mat_sio2,                                   # background = oxide cladding
        structures=list(static_structures) + [make_design_structure(params, beta, eta)],
        sources=[input_source(port_index)],
        monitors=output_monitors(),
        grid_spec=grid_spec,
        boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()),
        run_time=td.RunTimeSpec(quality_factor=3.0),
    )


# ============================================================================
# 7. OBJECTIVE  (multi-port + multi-mode + multi-freq, softmin, fab penalty)
# ============================================================================
# Three etch states per channel. eta-threshold shifts emulate a ~+-8 nm edge bias.
ETAS = {"under": 0.4, "nominal": 0.5, "over": 0.6}
TAU = 0.05                                    # softmin temperature (small -> sharper worst-case)
_penalty = make_erosion_dilation_penalty(filter_radius, dl_design)
_MON = {"TE": "out_te", "TM": "out_tm"}


def objective(params, beta, penalty_weight, etas):
    """Scalar FOM to MAXIMIZE: softmin of conversion efficiencies minus fab penalty.

    Builds one simulation per (etch variant, input port), runs them as one parallel batch,
    extracts each target-mode efficiency, and combines them with a worst-case softmin.
    """
    # MULTI-PORT x FAB-VARIANTS: n_ports * len(etas) simulations run together.
    sims = {f"{v}_p{i}": get_sim(params, beta, i, eta=eta)
            for v, eta in etas.items() for i in range(n_ports)}
    batch = web.run(sims, verbose=False)

    # One efficiency per (etch variant, channel); |amp|^2 averaged over the 11 wavelengths
    # is the power coupled into the target bus mode (MULTI-MODE + MULTI-FREQ live here).
    effs = []
    for v in etas:
        for i, (pol, k) in enumerate(TARGETS):
            amp = batch[f"{v}_p{i}"][_MON[pol]].amps.sel(direction="+", mode_index=k).values
            effs.append(anp.mean(anp.abs(amp) ** 2))
    effs = anp.array(effs)

    # SOFTMIN worst-case: softmax(-eff/TAU) puts the weight on the LOWEST efficiencies,
    # so the gradient is dominated by whichever channel/etch combo is currently worst.
    w = anp.exp(-effs / TAU)
    fom = anp.sum((w / anp.sum(w)) * effs)

    # FAB PENALTY on the nominal design (discourages fragile, non-manufacturable features).
    return fom - penalty_weight * _penalty(get_density(params, beta, eta=0.5))


# value_and_grad returns BOTH the FOM and its gradient w.r.t. params in one call;
# the adjoint simulations are dispatched automatically.
val_grad = value_and_grad(objective)


# ============================================================================
# 8. OPTIMIZATION LOOP  (Adam + two-phase beta ramp + phased-in robustness)
# ============================================================================
def run_optimization(num_steps=150, learning_rate=0.1, beta_max=100.0):
    """Run the Adam topology-optimization loop and return the final design. REMOTE / PAID.

    Schedule:
      * Continuous phase (first 40%): beta = 1, nominal etch only (4 sims/step). The
        optimizer finds good routing/conversion on a near-grayscale design.
      * Binarization phase (from 40% on): beta ramps 1 -> 100 to force a binary layout,
        and the under/nominal/over etch variants switch on (12 sims/step). Deferring the
        variants is deliberate: at low beta the three etch states are nearly identical.
    """
    optimizer = adam(learning_rate=learning_rate)
    params = 0.5 * np.ones((nx, ny))                       # uniform gray initialization
    opt_state = optimizer.init(params)
    history = []

    for i in range(num_steps):
        frac = i / (num_steps - 1)                         # progress in [0, 1]

        if frac < 0.4:                                     # continuous phase
            beta, etas = 1.0, {"nominal": 0.5}
        else:                                              # binarization phase
            beta = 1.0 + (beta_max - 1.0) * (frac - 0.4) / 0.6
            etas = ETAS
        penalty_weight = min(1.0, beta / 25.0)             # penalty ramps in with beta

        value, grad = val_grad(params, beta, penalty_weight, etas)

        # Adam minimizes, so pass -grad to MAXIMIZE the FOM; clip keeps params in [0, 1].
        updates, opt_state = optimizer.update(-grad, opt_state, params)
        params = np.clip(apply_updates(params, updates), 0.0, 1.0)

        history.append(float(value))
        binz = float(((np.asarray(get_density(params, beta)) > 0.95) |
                      (np.asarray(get_density(params, beta)) < 0.05)).mean())
        print(f"step {i+1:3d}/{num_steps}  FOM={value:+.4f}  beta={beta:6.1f}  "
              f"|grad|={np.linalg.norm(grad):.2e}  binarized={binz:.0%}  "
              f"sims={len(etas) * n_ports}")

    return params, history


# ============================================================================
# 9. RUN  (optimize, then save a FOM plot and export the design to GDS)
# ============================================================================
if __name__ == "__main__":
    params, history = run_optimization()

    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6, 4))
    plt.plot(history)
    plt.xlabel("iteration"); plt.ylabel("FOM (softmin worst-case efficiency)")
    plt.grid(True); plt.tight_layout(); plt.savefig("optimization_progress.png", dpi=130)
    print("wrote optimization_progress.png")

    # Threshold the final binarized permittivity at the Si/oxide midpoint and export to GDS.
    get_sim(params, beta=100.0, port_index=0).to_gds_file(
        fname="mux4_design.gds", z=0.0,
        permittivity_threshold=(eps_sio2 + eps_si) / 2, frequency=freq0)
    print("wrote mux4_design.gds")
