import mpmath as mp
import matplotlib.pyplot as plt
import time

mp.mp.dps = 60  # High precision for massive exponents

# ==========================================
# CONFIGURATION SWEEP SETTINGS
# ==========================================

"""
"""
#coarse
START_A = mp.mpf("137.035990000")
END_A   = mp.mpf("137.060000000")

START_OMEGA = mp.mpf("2.007070")
END_OMEGA   = mp.mpf("2.007170")

"""
#fine
START_A = mp.mpf("137.0359921000")
END_A   = mp.mpf("137.0359930000")

START_OMEGA = mp.mpf("2.0071349535")
END_OMEGA   = mp.mpf("2.0071349565")
"""

STEP_A  = mp.mpf("1e-10")
RES_OMEGA   = 149             # Number of Omega steps
QUICK_MODE = True

# Reference Omega for visual comparison
REF_OMEGA = mp.mpf("2.00713495433")    #calculated



STEP_OMEGA  = (END_OMEGA - START_OMEGA) / RES_OMEGA
# Quick mode: reduce resolution for testing
if QUICK_MODE:
    STEP_A = mp.mpf("1e-7")
    RES_OMEGA = 99
    STEP_OMEGA = (END_OMEGA - START_OMEGA) / RES_OMEGA
    print("⚡ QUICK MODE: Reduced resolution for faster testing")

# ==========================================
# LOAD CONSTANTS BY CODATA VERSION
# ==========================================
def load_constants(codata_2014):
    c = mp.mpf("299792458")
    two = mp.mpf(2)
    three = mp.mpf(3)
    alpha_denom = mp.mpf("137.0359991390")

    if codata_2014:
        return {
            'e': mp.mpf("1.6021766208e-19"),
            'h': mp.mpf("6.626070040e-34"),
            'lambda_e': mp.mpf("2.4263102367e-12"),
            'R': mp.mpf("10973731.568508"),
            'ge': mp.mpf("2.00231930436182"),
            'ye': mp.mpf("28024951640"),
            'me': mp.mpf("9.10938356e-31"),
            'mu0': 4 * mp.pi / mp.mpf("1e7"),
            'c': c, 'two': two, 'three': three,
            'alpha_denom': alpha_denom, 'tag': '2014'
        }
    else:
        return {
            'e': mp.mpf("1.602176634e-19"),
            'h': mp.mpf("6.62607015e-34"),
            'lambda_e': mp.mpf("2.42631023538e-12"),
            'R': mp.mpf("10973731.568157"),
            'ge': mp.mpf("2.00231930436092"),
            'ye': mp.mpf("28024951386.1"),
            'me': mp.mpf("9.1093837139e-31"),
            'mu0': mp.mpf("1.25663706127e-6"),
            'c': c, 'two': two, 'three': three,
            'alpha_denom': alpha_denom, 'tag': '2022'
        }

# ==========================================
# CORE FITTING FUNCTION WITH PRECOMPUTATION
# ==========================================
def run_fit_2d(codata_2014, dimensioned):
    C = load_constants(codata_2014)
    
    # --- PRECOMPUTE CONSTANT COEFFICIENTS (DONE ONCE PER DATASET) ---
    if dimensioned:
        K1 = (C['R']**7) * (C['two']**260) * (mp.pi**122) * (C['mu0']**9) * (C['three']**21)
        K2 = (C['e']**7) / ((C['two']**82) * (mp.pi**43) * (C['mu0']**3))
        K3 = (C['h']**7) / ((C['two']**164) * (mp.pi**93) * (C['mu0']**13))
        K4 = (C['lambda_e']**7) / ((C['two']**253) * (mp.pi**122) * (C['mu0']**9) * (C['three']**21))
        K5 = (C['me']**7) * (C['two']**96) * (mp.pi**36) * (C['three']**21) / (C['mu0']**4)
        K6 = ((C['ye']/C['ge'])**7) * C['mu0'] / ((C['two']**164) * (mp.pi**72) * (C['three']**21))
    else:
        K1 = (C['R']**7) * (C['mu0']**9) / (C['c']**35) * (C['two']**295) * (mp.pi**157) * (C['three']**21)
        K2 = (C['e']**7) * (C['c']**21) / (C['mu0']**3) / ((C['two']**103) * (mp.pi**64))
        K3 = (C['h']**7) * (C['c']**35) / (C['mu0']**13) / ((C['two']**199) * (mp.pi**128))
        K4 = (C['lambda_e']**7) * (C['c']**35) / (C['mu0']**9) / ((C['two']**288) * (mp.pi**157) * (C['three']**21))
        K5 = (C['me']**7) * (C['c']**7) / (C['mu0']**4) * (C['two']**89) * (mp.pi**29) * (C['three']**21)
        K6 = ((C['ye']/C['ge'])**7) * C['mu0'] * (C['c']**14) / ((C['two']**178) * (mp.pi**86) * (C['three']**21))

    # Global tracking
    global_best_res = mp.inf
    global_best_a = None
    global_best_omega = None
    
    # Storage for plotting
    omega_vals = []
    min_res_per_omega = []
    best_a_per_omega = []
    
    print(f"    ▶ Scanning {RES_OMEGA+1} Omega values (each ~{(END_A-START_A)/STEP_A:.0f} alpha steps)...")
    
    omega = START_OMEGA
    slice_count = 0
    start_time = time.time()
    
    while omega <= END_OMEGA + STEP_OMEGA/2:
        slice_count += 1
        
        # Precompute Omega powers for this slice
        o33, o75, o80, o89, o122, o150, o155, o225 = (
            omega**33, omega**75, omega**80, omega**89,
            omega**122, omega**150, omega**155, omega**225
        )
        
        if dimensioned:
            v = C['c'] / (2 * mp.pi * omega**2)
            v7, v14, v21, v35 = v**7, v**14, v**21, v**35
        
        # Reset per-Omega tracking
        best_res_slice = mp.inf
        best_a_slice = None
        
        a = START_A
        while a <= END_A + STEP_A/2:
            # Precompute a powers
            a10, a12, a13, a15, a25, a26 = a**10, a**12, a**13, a**15, a**25, a**26

            if dimensioned:
                a1 = K1 * a26 * o155 / v35
                a2 = K2 * a10 * v21 / o33
                a3 = K3 * a13 * v35 / o80
                a4 = K4 * v35 / (o155 * a12)
                a5 = K5 * o89 * a25 * v7
                a6 = K6 * v14 / (o122 * a15)
            else:
                a1 = K1 * a26 * o225
                a2 = K2 * a10 / o75
                a3 = K3 * a13 / o150
                a4 = K4 / (a12 * o225)
                a5 = K5 * a25 * o75
                a6 = K6 / (a15 * o150)

            # a0 fixed to 1.0 → abs(a0-1) = 0, omitted for speed
            res = mp.mpf("1e7") * (abs(a1-1) + abs(a2-1) + abs(a3-1) + abs(a4-1) + abs(a5-1) + abs(a6-1))
            
            if res < best_res_slice:
                best_res_slice = res
                best_a_slice = a
            
            a += STEP_A
        
        # Store slice results
        omega_vals.append(float(omega))
        min_res_per_omega.append(float(best_res_slice))
        best_a_per_omega.append(float(best_a_slice))
        
        # Update global optimum
        if best_res_slice < global_best_res:
            global_best_res = best_res_slice
            global_best_a = best_a_slice
            global_best_omega = omega
        
        # Progress update every 20 slices
        if slice_count % 20 == 0:
            elapsed = time.time() - start_time
            eta = elapsed * (RES_OMEGA+1) / slice_count - elapsed
            print(f"      ✓ {slice_count}/{RES_OMEGA+1} Omega slices | ETA: {eta/60:.1f} min")
        
        omega += STEP_OMEGA
    
    # Adaptive log scaling for plotting
    nonzero = [r for r in min_res_per_omega if r > 0]
    eps = min(nonzero) * mp.mpf("1e-10") if nonzero else mp.mpf("1e-300")
    log_res = [float(mp.log10(r if r > 0 else eps)) for r in min_res_per_omega]
    
    return {
        'omega': omega_vals,
        'log_res': log_res,
        'opt_a': best_a_per_omega,
        'best_a': global_best_a,
        'best_omega': global_best_omega,
        'best_res': global_best_res,
        'tag': C['tag']
    }

# ==========================================
# PLOTTING FUNCTION (2-panel layout)
# ==========================================
def plot_comparison(results, mode_name):
    """Generate 2-panel comparison plot for a given mode (dimensioned/dimensionless)"""
    plt.figure(figsize=(14, 6))
    
    # --- Left panel: Omega vs Minimum Residual ---
    ax1 = plt.subplot(1, 2, 1)
    r1, r2 = results[f"{mode_name}_2014"], results[f"{mode_name}_2022"]
    
    ax1.plot(r1['omega'], r1['log_res'], label=f"CODATA {r1['tag']}", linewidth=1.5, color='navy')
    ax1.plot(r2['omega'], r2['log_res'], label=f"CODATA {r2['tag']}", linewidth=1.5, color='crimson', linestyle='--')
    
    # Bold full-height reference line
    ax1.axvline(float(REF_OMEGA), ymin=0, ymax=1, color='black', linestyle='-', 
                linewidth=3, alpha=0.8, label=f'Reference Ω = {float(REF_OMEGA):.11f}', zorder=5)
    
 
    """
    # Annotate best-fit values
    ax1.annotate(f"Best α⁻¹: {r1['best_a']:.10f}\nBest Ω: {r1['best_omega']:.12f}",
                 xy=(float(r1['best_omega']), min(r1['log_res'])),
                 xytext=(10, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='navy', alpha=0.1),
                 color='navy', fontsize=8, ha='left')
    
    ax1.annotate(f"Best α⁻¹: {r2['best_a']:.10f}\nBest Ω: {r2['best_omega']:.12f}",
                 xy=(float(r2['best_omega']), min(r2['log_res'])),
                 xytext=(10, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='crimson', alpha=0.1),
                 color='crimson', fontsize=8, ha='left')
    
    # Subtle markers at best Omega
    ax1.axvline(float(r1['best_omega']), color='navy', linestyle=':', alpha=0.3)
    ax1.axvline(float(r2['best_omega']), color='crimson', linestyle=':', alpha=0.3)
    """

    ax1.set_xlabel('Omega', fontsize=10)
    ax1.set_ylabel('Minimum log₁₀(Residual)', fontsize=10)
    ax1.set_title('Omega vs Best-Fit Residual', fontsize=11, fontweight='semibold')
    ax1.legend(fontsize=9)
    ax1.grid(True, alpha=0.3, linestyle=':')
    
    # --- Right panel: Omega vs Optimal α⁻¹ (uniqueness check) ---
    ax2 = plt.subplot(1, 2, 2)
    ax2.plot(r1['omega'], r1['opt_a'], label=f"CODATA {r1['tag']}", linewidth=1.5, color='navy', marker='o', markersize=2)
    ax2.plot(r2['omega'], r2['opt_a'], label=f"CODATA {r2['tag']}", linewidth=1.5, color='crimson', linestyle='--', marker='s', markersize=2)
    
    ax2.axvline(float(REF_OMEGA), ymin=0, ymax=1, color='black', linestyle='-', 
                linewidth=2, alpha=0.6, zorder=5)
    
    ax2.axhline(float(r1['best_a']), color='navy', linestyle=':', alpha=0.4)
    ax2.axhline(float(r2['best_a']), color='crimson', linestyle=':', alpha=0.4)
    
    ax2.set_xlabel('Omega', fontsize=10)
    ax2.set_ylabel('Optimal α⁻¹', fontsize=10)
    ax2.set_title('Omega vs Optimal α⁻¹ (uniqueness check)', fontsize=11, fontweight='semibold')
    ax2.legend(fontsize=9)
    ax2.grid(True, alpha=0.3, linestyle=':')
    
    # Deviation annotation
    dev_2014 = float(r1['best_omega'] - REF_OMEGA)
    dev_2022 = float(r2['best_omega'] - REF_OMEGA)
    ax2.text(0.02, 0.98, f"ΔΩ₂₀₁₄ = {dev_2014:+.2e}\nΔΩ₂₀₂₂ = {dev_2022:+.2e}",
             transform=ax2.transAxes, fontsize=8,
             bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5),
             verticalalignment='top')
    
    plt.suptitle(f'{mode_name.capitalize()} Formulas: CODATA 2014 vs 2022', 
                 fontsize=13, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print(f"\n📊 {mode_name.upper()} BEST-FIT SUMMARY:")
    print(f"{'CODATA':<8} {'α⁻¹ (best)':<20} {'Ω (best)':<20} {'log₁₀(res)':<12}")
    print("-"*60)
    for tag in ['2014', '2022']:
        r = results[f"{mode_name}_{tag}"]
        print(f"{r['tag']:<8} {r['best_a']:<20.12f} {r['best_omega']:<20.12f} {r['best_res']:<12.3e}")
    print()

# ==========================================
# MAIN EXECUTION: SEQUENTIAL MODES
# ==========================================
print("🚀 Starting sequential alpha/Omega optimization")
print(f"   Alpha range: {START_A} → {END_A} (step={STEP_A})")
print(f"   Omega range: {START_OMEGA} → {END_OMEGA} ({RES_OMEGA+1} steps)")
print(f"   Reference Ω: {REF_OMEGA}")
print(f"   QUICK_MODE:  {QUICK_MODE}")
print()

# --- PHASE 1: DIMENSIONLESS FORMULAS (~12 hours) ---
print("🔹 PHASE 1/2: Running DIMENSIONLESS formulas...")
results = {}
for codata_flag in [True, False]:
    key = f"nondim_{'2014' if codata_flag else '2022'}"
    print(f"  • Processing {key}...")
    start = time.time()
    r = run_fit_2d(codata_flag, dimensioned=False)
    results[key] = r
    elapsed = time.time() - start
    print(f"    ✅ {key} complete | Time: {elapsed/60:.1f} min | Best res: {r['best_res']:.3e}")

# Plot dimensionless results immediately
plot_comparison(results, mode_name='nondim')
print("💡 Dimensionless results ready! Cross-reference with dimensioned run next.\n")

# --- PHASE 2: DIMENSIONED FORMULAS (~12 hours) ---
print("🔹 PHASE 2/2: Running DIMENSIONED formulas...")
for codata_flag in [True, False]:
    key = f"dim_{'2014' if codata_flag else '2022'}"
    print(f"  • Processing {key}...")
    start = time.time()
    r = run_fit_2d(codata_flag, dimensioned=True)
    results[key] = r
    elapsed = time.time() - start
    print(f"    ✅ {key} complete | Time: {elapsed/60:.1f} min | Best res: {r['best_res']:.3e}")

# Plot dimensioned results for cross-reference
plot_comparison(results, mode_name='dim')

# --- FINAL COMPARISON SUMMARY ---
print("="*70)
print("🎯 FINAL CROSS-REFERENCE SUMMARY")
print("="*70)
print(f"{'Mode':<12} {'CODATA':<8} {'α⁻¹ (best)':<20} {'Ω (best)':<20} {'log₁₀(res)':<12}")
print("-"*70)
for mode in ['nondim', 'dim']:
    for tag in ['2014', '2022']:
        key = f"{mode}_{tag}"
        r = results[key]
        mode_label = "Dimensionless" if mode == 'nondim' else "Dimensioned"
        print(f"{mode_label:<12} {r['tag']:<8} {r['best_a']:<20.12f} {r['best_omega']:<20.12f} {r['best_res']:<12.3e}")
print("="*70)

# Refinement guidance
print(f"\n💡 To refine further:")
print(f"   1. Narrow Omega range around best Ω found above")
print(f"   2. Increase RES_OMEGA for finer sampling")
print(f"   3. Reduce STEP_A to 1e-13 for ultra-fine alpha tuning")
print(f"   4. Re-run to zoom into the global minimum")
