"""
Fractal Design System — Matplotlib Color Utilities

Provides color palettes, sequential colormaps, and helpers that match
the Fractal data-visualization guide.

Usage:
    from fractal_colors import apply_fractal_style, PALETTES, register_colormaps

    # Full setup: load .mplstyle + register colormaps
    apply_fractal_style()

    # Or use palettes directly
    from fractal_colors import PINK, BLUE, ORANGE
    ax.plot(x, y, color=PINK[500])

    # Sequential colormaps (after register_colormaps())
    ax.imshow(data, cmap='fractal_blue')
"""

from pathlib import Path
from collections import OrderedDict

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# =========================================
#  Color Scales (50–900 stops)
# =========================================
# Each scale is an OrderedDict keyed by stop number.

GREY = OrderedDict([
    (0,   "#FEFEFE"),
    (50,  "#FAFAFA"),
    (100, "#F2F2F3"),
    (200, "#E6E6E7"),
    (300, "#D4D4D6"),
    (400, "#929295"),
    (500, "#6F6F72"),
    (600, "#565659"),
    (700, "#404042"),
    (800, "#2D2D2F"),
    (900, "#1C1C1E"),
])

BLUE = OrderedDict([
    (50,  "#F1FEFF"),
    (100, "#E0F1F4"),
    (200, "#BFDFE6"),
    (300, "#9FCEDB"),
    (400, "#6BAFC4"),
    (500, "#047897"),
    (600, "#025F78"),
    (700, "#054F6A"),
    (800, "#044055"),
    (900, "#03303F"),
])

PINK = OrderedDict([
    (50,  "#FFF9FF"),
    (100, "#FAD0F5"),
    (200, "#EAC3E4"),
    (300, "#DC9ED3"),
    (400, "#CB62BB"),
    (500, "#A81696"),
    (600, "#8A1079"),
    (700, "#6C0B5E"),
    (800, "#55094B"),
    (900, "#3E0638"),
])

ORANGE = OrderedDict([
    (50,  "#FFF9F5"),
    (100, "#FAEADE"),
    (200, "#EED4BF"),
    (300, "#DBB397"),
    (400, "#CC8F66"),
    (500, "#9A5E2A"),
    (600, "#74431A"),
    (700, "#563010"),
    (800, "#3D2009"),
    (900, "#281204"),
])

GREEN = OrderedDict([
    (50,  "#F2FEF8"),
    (100, "#E0F4EA"),
    (200, "#C0E5D2"),
    (300, "#98D4B6"),
    (400, "#3EA676"),
    (500, "#0A7B54"),
    (600, "#066141"),
    (700, "#044F34"),
    (800, "#033F29"),
    (900, "#022F1E"),
])

RED = OrderedDict([
    (50,  "#FFF5F5"),
    (100, "#FBE0E0"),
    (200, "#F0C4C4"),
    (300, "#E09C9C"),
    (400, "#CC6464"),
    (500, "#A11B1B"),
    (600, "#811212"),
    (700, "#660D0D"),
    (800, "#500A0A"),
    (900, "#3B0707"),
])

YELLOW = OrderedDict([
    (50,  "#FFFDF5"),
    (100, "#FBF2DA"),
    (200, "#F0E2B8"),
    (300, "#DBCA8A"),
    (400, "#C0A64E"),
    (500, "#866B0E"),
    (600, "#6A5308"),
    (700, "#534005"),
    (800, "#403103"),
    (900, "#2E2302"),
])

PURPLE = OrderedDict([
    (50,  "#FAF7FF"),
    (100, "#F0E8FE"),
    (200, "#DFD0FC"),
    (300, "#C9B0F8"),
    (400, "#A882F2"),
    (500, "#8B5CF6"),
    (600, "#7240D9"),
    (700, "#5B2FB5"),
    (800, "#462493"),
    (900, "#331A6E"),
])

# All scales by name
PALETTES = {
    "grey":   GREY,
    "blue":   BLUE,
    "pink":   PINK,
    "orange": ORANGE,
    "green":  GREEN,
    "red":    RED,
    "yellow": YELLOW,
    "purple": PURPLE,
}

# =========================================
#  Semantic Tokens
# =========================================

SEMANTIC = {
    "text_primary":    GREY[900],   # #1C1C1E
    "text_secondary":  GREY[500],   # #6F6F72
    "text_tertiary":   GREY[400],   # #929295
    "border_strong":   GREY[400],   # #929295
    "border_default":  GREY[300],   # #D4D4D6
    "border_subtle":   GREY[200],   # #E6E6E7
    "bg_canvas":       GREY[0],     # #FEFEFE
    "bg_surface_1":    GREY[50],    # #FAFAFA
    "bg_surface_2":    GREY[200],   # #E6E6E7
}


# =========================================
#  Pre-built Color Cycles
# =========================================

def _stops(stop: int):
    """Return a list of hex colors at a given stop across the priority-ordered scales."""
    order = [PINK, BLUE, ORANGE, GREEN, PURPLE, RED, YELLOW]
    return [scale[stop] for scale in order]


# Default cycle — 500 stops (anchor tones, good for lines and general use)
CYCLE_500 = _stops(500)

# Line cycle — 300 stops (per data-vis guide: line stroke = 300)
CYCLE_300 = _stops(300)

# Fill cycle — 100 stops (per data-vis guide: shape fill = 100)
CYCLE_100 = _stops(100)

# Stroke cycle — 300 stops (same as lines; used for shape borders)
CYCLE_STROKE = CYCLE_300


# =========================================
#  Fill + Stroke Pairs
# =========================================

def fill_stroke_pair(palette_name: str):
    """Return (fill_hex, stroke_hex) for a named palette.

    Per the design guide: fill = 100 stop, stroke = 300 stop.
    """
    scale = PALETTES[palette_name]
    return scale[100], scale[300]


def fill_stroke_pairs(n: int = 3):
    """Return n (fill, stroke) tuples in priority order.

    >>> fills, strokes = zip(*fill_stroke_pairs(3))
    """
    order = ["pink", "blue", "orange", "green", "purple", "red", "yellow"]
    return [fill_stroke_pair(order[i % len(order)]) for i in range(n)]


# =========================================
#  Sequential Colormaps
# =========================================

def _make_sequential_cmap(name: str, scale: OrderedDict):
    """Build a LinearSegmentedColormap from a 50–900 scale."""
    colors = [scale[stop] for stop in sorted(scale.keys())]
    return mcolors.LinearSegmentedColormap.from_list(name, colors, N=256)


def register_colormaps():
    """Register fractal_<name> sequential colormaps with matplotlib."""
    for name, scale in PALETTES.items():
        cmap_name = f"fractal_{name}"
        cmap = _make_sequential_cmap(cmap_name, scale)
        try:
            plt.colormaps.register(cmap, name=cmap_name)
            plt.colormaps.register(cmap.reversed(), name=f"{cmap_name}_r")
        except ValueError:
            pass  # already registered


# =========================================
#  Diverging Colormaps
# =========================================

def make_diverging_cmap(
    name: str = "fractal_div",
    low_scale: OrderedDict = BLUE,
    high_scale: OrderedDict = PINK,
):
    """Build a diverging colormap: low_scale (reversed) → white → high_scale."""
    low_colors = [low_scale[s] for s in sorted(low_scale.keys(), reverse=True)]
    high_colors = [high_scale[s] for s in sorted(high_scale.keys())]
    all_colors = low_colors + ["#FEFEFE"] + high_colors
    cmap = mcolors.LinearSegmentedColormap.from_list(name, all_colors, N=256)
    try:
        plt.colormaps.register(cmap, name=name)
    except ValueError:
        pass
    return cmap


# =========================================
#  Qualitative Colormap (categorical)
# =========================================

def make_qualitative_cmap(name: str = "fractal_qual", stop: int = 500):
    """Register a ListedColormap from the 7-color priority cycle."""
    colors = _stops(stop)
    cmap = mcolors.ListedColormap(colors, name=name)
    try:
        plt.colormaps.register(cmap, name=name)
    except ValueError:
        pass
    return cmap


# =========================================
#  Style Application
# =========================================

def apply_fractal_style(style_path: str = None):
    """Apply the fractal .mplstyle and register all colormaps.

    Parameters
    ----------
    style_path : str, optional
        Path to fractal.mplstyle. If None, looks for it next to this file.
    """
    if style_path is None:
        style_path = str(Path(__file__).parent / "fractal.mplstyle")

    plt.style.use(style_path)
    register_colormaps()
    make_diverging_cmap()
    make_qualitative_cmap()


# =========================================
#  Convenience: bar chart with fill/stroke
# =========================================

def styled_bars(ax, x, heights, palette_names=None, **bar_kwargs):
    """Draw bars using the design system's fill/stroke convention.

    Parameters
    ----------
    ax : matplotlib Axes
    x : array-like
        Bar positions.
    heights : list of array-like
        One array per series.
    palette_names : list of str, optional
        Palette names in order. Defaults to priority order.
    **bar_kwargs
        Extra kwargs forwarded to ax.bar().

    Returns
    -------
    list of BarContainer
    """
    if palette_names is None:
        palette_names = ["pink", "blue", "orange", "green", "purple", "red", "yellow"]

    containers = []
    x = np.asarray(x, dtype=float)
    n = len(heights)
    width = bar_kwargs.pop("width", 0.8 / n)

    for i, h in enumerate(heights):
        name = palette_names[i % len(palette_names)]
        fill, stroke = fill_stroke_pair(name)
        offset = (i - (n - 1) / 2) * width
        c = ax.bar(x + offset, h, width=width, color=fill,
                   edgecolor=stroke, linewidth=1, **bar_kwargs)
        containers.append(c)
    return containers
