from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from openpyxl import load_workbook

from _pipeline_utils import _generated_timestamp, m4_m6_abjad_digits

try:
    import sys

    if hasattr(sys, "set_int_max_str_digits"):
        sys.set_int_max_str_digits(0)
except Exception:
    pass

MODULUS = 19
METRIC_ORDER = ["M1", "M2", "M3", "M4", "M5", "M6", "M7"]
EXPECTED_MUSHAFS = [
    "Bazzi(Basra)",
    "Bazzi(damascus)",
    "Bazzi(himsi)",
    "Bazzi(Kufa)",
    "Bazzi(Mecca)",
    "Bazzi(Medina I)",
    "Bazzi(Medina II)",
    "Bazzi(VERSE 0)",
    "Doori(Basra)",
    "Doori(damascus)",
    "Doori(himsi)",
    "Doori(Kufa)",
    "Doori(Mecca)",
    "Doori(Medina I)",
    "Doori(Medina II)",
    "Doori(VERSE 0)",
    "Hafs(Basra)",
    "Hafs(damascus)",
    "Hafs(himsi)",
    "Hafs(Kufa)",
    "Hafs(Mecca)",
    "Hafs(Medina I)",
    "Hafs(Medina II)",
    "Hafs(VERSE 0)",
    "Qaloon(Basra)",
    "Qaloon(damascus)",
    "Qaloon(himsi)",
    "Qaloon(Kufa)",
    "Qaloon(Mecca)",
    "Qaloon(Medina I)",
    "Qaloon(Medina II)",
    "Qaloon(VERSE 0)",
    "Qumball(Basra)",
    "Qumball(damascus)",
    "Qumball(himsi)",
    "Qumball(Kufa)",
    "Qumball(Mecca)",
    "Qumball(Medina I)",
    "Qumball(Medina II)",
    "Qumball(VERSE 0)",
    "Shouba(Basra)",
    "Shouba(damascus)",
    "Shouba(himsi)",
    "Shouba(Kufa)",
    "Shouba(Mecca)",
    "Shouba(Medina I)",
    "Shouba(Medina II)",
    "Shouba(VERSE 0)",
    "Soosi(Basra)",
    "Soosi(damascus)",
    "Soosi(himsi)",
    "Soosi(Kufa)",
    "Soosi(Mecca)",
    "Soosi(Medina I)",
    "Soosi(Medina II)",
    "Soosi(VERSE 0)",
    "Submission",
    "The_Criterion",
    "Warsh(Basra)",
    "Warsh(damascus)",
    "Warsh(himsi)",
    "Warsh(Kufa)",
    "Warsh(Mecca)",
    "Warsh(Medina I)",
    "Warsh(Medina II)",
    "Warsh(VERSE 0)",
]


def _to_digits(value: object) -> Optional[str]:
    """Return a digit-only string for numeric-like cell values, else None."""
    if value is None or isinstance(value, bool):
        return None

    if isinstance(value, int):
        return str(value)

    if isinstance(value, float):
        if value.is_integer():
            return str(int(value))
        return None

    if isinstance(value, str):
        text = value.strip()
        if text.isdigit():
            return text

    return None


def _digits_mod(digits: str, modulus: int = MODULUS) -> int:
    """Compute digits % modulus without converting to a huge integer."""
    rem = 0
    for ch in digits:
        rem = (rem * 10 + (ord(ch) - ord("0"))) % modulus
    return rem


def _short_metric_name(metric_name: str) -> str:
    return metric_name.split("_", 1)[0]


def _mushaf_name_from_file(file_path: Path) -> str:
    return file_path.stem.replace("_7Metrics", "")


def _collect_excel_files(input_dir: Path) -> List[Path]:
    files = list(input_dir.glob("*_7Metrics.xlsx"))
    if not files:
        raise FileNotFoundError("No files matching '*_7Metrics.xlsx' were found.")

    by_mushaf = {_mushaf_name_from_file(path): path for path in files}
    missing = [name for name in EXPECTED_MUSHAFS if name not in by_mushaf]
    extras = sorted(name for name in by_mushaf if name not in EXPECTED_MUSHAFS)

    if missing or extras:
        raise ValueError(
            "Dataset set mismatch; "
            f"missing={missing or 'None'}, extras={extras or 'None'}"
        )

    return [by_mushaf[name] for name in EXPECTED_MUSHAFS]


def _discover_metrics(headers: List[str]) -> List[Tuple[int, str, str]]:
    non_metric = {"sura", "verse", "text"}
    return [
        (i, headers[i], _short_metric_name(headers[i]))
        for i in range(len(headers))
        if headers[i] and headers[i].lower() not in non_metric
    ]


def _resolve_worksheet(wb, file_path: Path):
    candidates = []
    for sheet_name in wb.sheetnames:
        ws = wb[sheet_name]
        try:
            header_cells = next(ws.iter_rows(min_row=1, max_row=1, values_only=True))
        except StopIteration:
            continue

        headers = [str(h).strip() if h is not None else "" for h in header_cells]
        header_to_idx = {h.lower(): i for i, h in enumerate(headers)}

        if "sura" not in header_to_idx or "verse" not in header_to_idx:
            continue

        discovered = _discover_metrics(headers)
        if len(discovered) != 7:
            continue

        short_names = {_short_metric_name(name) for _, name, _ in discovered}
        if short_names != set(METRIC_ORDER):
            continue

        candidates.append(ws)

    if len(candidates) != 1:
        raise ValueError(
            f"Expected exactly one valid worksheet in {file_path.name}; found {len(candidates)}"
        )

    return candidates[0]


def _pass_fail(mod_value: int) -> str:
    return "PASS" if mod_value == 0 else "FAIL"


def analyze_dataset(file_path: Path) -> Dict[str, object]:
    wb = load_workbook(file_path, read_only=True, data_only=True)
    row_tokens: List[Tuple[int, int, List[str]]] = []
    verse_sums: Dict[str, int] = {}
    sura_totals: Dict[str, Dict[int, int]] = {}
    metric_keys: List[str] = []
    try:
        ws = _resolve_worksheet(wb, file_path)

        header_cells = next(ws.iter_rows(min_row=1, max_row=1, values_only=True))
        headers = [str(h).strip() if h is not None else "" for h in header_cells]
        header_to_idx = {h.lower(): i for i, h in enumerate(headers)}

        if "sura" not in header_to_idx or "verse" not in header_to_idx:
            raise ValueError(f"{file_path.name} missing required columns: sura, verse")

        sura_idx = header_to_idx["sura"]
        verse_idx = header_to_idx["verse"]
        if "text" not in header_to_idx:
            raise ValueError(f"{file_path.name} missing required column: text")
        text_idx = header_to_idx["text"]

        discovered_metrics: List[Tuple[int, str, str]] = _discover_metrics(headers)

        if len(discovered_metrics) != 7:
            raise ValueError(f"Expected 7 metrics, found {len(discovered_metrics)} in {file_path.name}")

        metric_by_short = {short: (idx, full, short) for idx, full, short in discovered_metrics}
        missing = [m for m in METRIC_ORDER if m not in metric_by_short]
        extras = [m for m in metric_by_short if m not in METRIC_ORDER]
        if missing or extras:
            raise ValueError(
                f"Metric headers mismatch in {file_path.name}; missing={missing or 'None'}, extras={extras or 'None'}"
            )

        # Enforce canonical metric order M1..M7 regardless of physical header order.
        metric_columns: List[Tuple[int, str, str]] = [metric_by_short[m] for m in METRIC_ORDER]

        metric_keys = [short_name for _, _, short_name in metric_columns]

        # Verse sums for measurement A.
        verse_sums = {k: 0 for k in metric_keys}

        # Sura totals for measurements D/E/F; include verse 0 by including all rows as-is.
        sura_totals = {k: {s: 0 for s in range(1, 115)} for k in metric_keys}

        for row_number, row in enumerate(ws.iter_rows(min_row=2, values_only=True), start=2):
            if row is None:
                continue

            sura_digits = _to_digits(row[sura_idx] if sura_idx < len(row) else None)
            verse_digits = _to_digits(row[verse_idx] if verse_idx < len(row) else None)
            if sura_digits is None or verse_digits is None:
                raise ValueError(
                    f"Invalid sura/verse at {file_path.name} row {row_number}; both must be numeric"
                )

            sura = int(sura_digits)
            verse = int(verse_digits)
            if not 1 <= sura <= 114:
                raise ValueError(f"Invalid sura at {file_path.name} row {row_number}; got {sura}")

            row_metric_digits: List[str] = [""] * len(metric_columns)
            row_text = row[text_idx] if text_idx < len(row) else None
            m4_digits, m6_digits = m4_m6_abjad_digits(row_text)

            for metric_pos, (col_idx, _full_name, metric_key) in enumerate(metric_columns):
                if metric_key == "M4":
                    digits = m4_digits
                elif metric_key == "M6":
                    digits = m6_digits
                else:
                    if col_idx >= len(row):
                        raise ValueError(
                            f"Missing metric cell for {metric_key} at {file_path.name} row {row_number}"
                        )

                    digits = _to_digits(row[col_idx])
                    if digits is None:
                        raise ValueError(
                            f"Invalid metric value for {metric_key} at {file_path.name} row {row_number}; must be numeric"
                        )

                value = int(digits)
                row_metric_digits[metric_pos] = digits
                verse_sums[metric_key] += value

                sura_totals[metric_key][sura] += value

            row_tokens.append((sura, verse, row_metric_digits))
    finally:
        wb.close()

    row_tokens.sort(key=lambda x: (x[0], x[1]))

    metric_results: Dict[str, Dict[str, object]] = {}

    for metric_pos, metric_key in enumerate(metric_keys):
        # B and C: verse concatenations
        verse_forward_values = [token_row[metric_pos] for _, _, token_row in row_tokens]
        verse_reverse_values = list(reversed(verse_forward_values))

        verse_forward_number = "".join(verse_forward_values)
        verse_reverse_number = "".join(verse_reverse_values)
        verse_forward_len = len(verse_forward_number)
        verse_reverse_len = len(verse_reverse_number)
        verse_forward_mod = _digits_mod(verse_forward_number)
        verse_reverse_mod = _digits_mod(verse_reverse_number)

        # D/E/F: sura-derived measurements across 114 suras
        ordered_sura_totals = [sura_totals[metric_key][s] for s in range(1, 115)]

        global_sura_sum_value = sum(ordered_sura_totals)
        global_sura_sum_mod = global_sura_sum_value % MODULUS

        sura_forward_parts = [str(v) for v in ordered_sura_totals]
        sura_reverse_parts = list(reversed(sura_forward_parts))
        sura_forward_number = "".join(sura_forward_parts)
        sura_reverse_number = "".join(sura_reverse_parts)
        sura_forward_len = len(sura_forward_number)
        sura_reverse_len = len(sura_reverse_number)
        sura_forward_mod = _digits_mod(sura_forward_number)
        sura_reverse_mod = _digits_mod(sura_reverse_number)

        verse_sum_value = verse_sums[metric_key]
        verse_sum_mod = verse_sum_value % MODULUS
        if verse_sum_value != global_sura_sum_value:
            raise ValueError(
                f"Invariant failed in {file_path.name} for {metric_key}: verse sum != sura sum"
            )

        metric_results[metric_key] = {
            "A_verse_sum_value": verse_sum_value,
            "A_verse_sum_mod": verse_sum_mod,
            "A_verse_sum_digit_length": len(str(abs(verse_sum_value))),
            "B_verse_forward_number": verse_forward_number,
            "B_verse_forward_len": verse_forward_len,
            "B_verse_forward_mod": verse_forward_mod,
            "C_verse_reverse_number": verse_reverse_number,
            "C_verse_reverse_len": verse_reverse_len,
            "C_verse_reverse_mod": verse_reverse_mod,
            "D_sura_sum_value": global_sura_sum_value,
            "D_sura_sum_mod": global_sura_sum_mod,
            "D_sura_sum_digit_length": len(str(abs(global_sura_sum_value))),
            "E_sura_forward_number": sura_forward_number,
            "E_sura_forward_len": sura_forward_len,
            "E_sura_forward_mod": sura_forward_mod,
            "F_sura_reverse_number": sura_reverse_number,
            "F_sura_reverse_len": sura_reverse_len,
            "F_sura_reverse_mod": sura_reverse_mod,
        }

    # G: final composite chain of seven global sura sums in M1..M7 order
    composite_sum_strings = [str(metric_results[k]["D_sura_sum_value"]) for k in metric_keys]
    composite_number = "".join(composite_sum_strings)
    composite_len = len(composite_number)
    composite_mod = _digits_mod(composite_number)

    return {
        "mushaf": _mushaf_name_from_file(file_path),
        "metric_keys": metric_keys,
        "metric_results": metric_results,
        "G_composite_number": composite_number,
        "G_composite_len": composite_len,
        "G_composite_mod": composite_mod,
    }


def build_summary(input_dir: Path, output_file: Path) -> None:
    excel_files = _collect_excel_files(input_dir)

    lines: List[str] = []
    bool_lines: List[str] = []
    generated_at = _generated_timestamp()
    lines.append("GLOBAL LOCK SUMMARY")
    lines.append(f"Generated: {generated_at}")
    lines.append("Specification: Phase 3 Global Lock Engine")
    lines.append("")
    bool_lines.append("GLOBAL LOCK SUMMARY (BOOLEAN)")
    bool_lines.append(f"Generated: {generated_at}")
    bool_lines.append("Specification: Phase 3 Global Lock Engine")
    bool_lines.append("")

    json_payload: Dict[str, object] = {
        "schema_version": "1.2",
        "summary_type": "global_lock",
        "generated": generated_at,
        "modulus": MODULUS,
        "metric_order": METRIC_ORDER,
        "total_measurements_per_dataset": 43,
        "datasets": [],
    }

    for dataset_index, file_path in enumerate(excel_files, start=1):
        result = analyze_dataset(file_path)
        mushaf = result["mushaf"]
        metric_keys = result["metric_keys"]
        metric_results = result["metric_results"]
        g_number = result["G_composite_number"]
        g_len = result["G_composite_len"]
        g_mod = result["G_composite_mod"]

        dataset_json_metrics: Dict[str, object] = {}

        # Count passes across all 42 A-F measurements + G
        pass_count = sum(
            1
            for mk in metric_keys
            for mod_key in [
                "A_verse_sum_mod", "B_verse_forward_mod", "C_verse_reverse_mod",
                "D_sura_sum_mod", "E_sura_forward_mod", "F_sura_reverse_mod",
            ]
            if metric_results[mk][mod_key] == 0
        ) + (1 if g_mod == 0 else 0)
        fail_count = 43 - pass_count

        lines.append(f"DATASET {dataset_index}: {mushaf}")
        lines.append("Total global measurements: 43")
        lines.append("")
        lines.append("Global Verse + Global Sura Measurements (A-F per metric):")
        bool_lines.append(f"DATASET {dataset_index}: {mushaf}")
        bool_lines.append("Total global measurements: 43")
        bool_lines.append(f"Total passed: {pass_count}/43")
        bool_lines.append(f"All passed: {pass_count == 43}")
        bool_lines.append("")
        bool_lines.append("Global Verse + Global Sura Measurements (A-F per metric):")

        for metric_key in metric_keys:
            m = metric_results[metric_key]
            lines.append(f"{metric_key}:")
            lines.append(
                f"  A Verse Sum -> value={m['A_verse_sum_value']}, mod19={m['A_verse_sum_mod']}, {_pass_fail(m['A_verse_sum_mod'])}"
            )
            lines.append(
                f"  B Verse Forward Concat -> digits={m['B_verse_forward_len']}, mod19={m['B_verse_forward_mod']}, {_pass_fail(m['B_verse_forward_mod'])}"
            )
            lines.append(
                f"  C Verse Reverse Concat -> digits={m['C_verse_reverse_len']}, mod19={m['C_verse_reverse_mod']}, {_pass_fail(m['C_verse_reverse_mod'])}"
            )
            lines.append(
                f"  D Sura Sum -> value={m['D_sura_sum_value']}, mod19={m['D_sura_sum_mod']}, {_pass_fail(m['D_sura_sum_mod'])}"
            )
            lines.append(
                f"  E Sura Forward Concat -> digits={m['E_sura_forward_len']}, mod19={m['E_sura_forward_mod']}, {_pass_fail(m['E_sura_forward_mod'])}"
            )
            lines.append(
                f"  F Sura Reverse Concat -> digits={m['F_sura_reverse_len']}, mod19={m['F_sura_reverse_mod']}, {_pass_fail(m['F_sura_reverse_mod'])}"
            )
            bool_lines.append(f"{metric_key}:")
            bool_lines.append(f"  A Verse Sum -> {m['A_verse_sum_mod'] == 0}")
            bool_lines.append(f"  B Verse Forward Concat -> {m['B_verse_forward_mod'] == 0}")
            bool_lines.append(f"  C Verse Reverse Concat -> {m['C_verse_reverse_mod'] == 0}")
            bool_lines.append(f"  D Sura Sum -> {m['D_sura_sum_mod'] == 0}")
            bool_lines.append(f"  E Sura Forward Concat -> {m['E_sura_forward_mod'] == 0}")
            bool_lines.append(f"  F Sura Reverse Concat -> {m['F_sura_reverse_mod'] == 0}")

            dataset_json_metrics[metric_key] = {
                "A": {
                    "number": str(m["A_verse_sum_value"]),
                    "digit_length": m["A_verse_sum_digit_length"],
                    "mod19": m["A_verse_sum_mod"],
                    "is_multiple_of_19": m["A_verse_sum_mod"] == 0,
                },
                "B": {
                    "number": m["B_verse_forward_number"],
                    "digit_length": m["B_verse_forward_len"],
                    "mod19": m["B_verse_forward_mod"],
                    "is_multiple_of_19": m["B_verse_forward_mod"] == 0,
                },
                "C": {
                    "number": m["C_verse_reverse_number"],
                    "digit_length": m["C_verse_reverse_len"],
                    "mod19": m["C_verse_reverse_mod"],
                    "is_multiple_of_19": m["C_verse_reverse_mod"] == 0,
                },
                "D": {
                    "number": str(m["D_sura_sum_value"]),
                    "digit_length": m["D_sura_sum_digit_length"],
                    "mod19": m["D_sura_sum_mod"],
                    "is_multiple_of_19": m["D_sura_sum_mod"] == 0,
                },
                "E": {
                    "number": m["E_sura_forward_number"],
                    "digit_length": m["E_sura_forward_len"],
                    "mod19": m["E_sura_forward_mod"],
                    "is_multiple_of_19": m["E_sura_forward_mod"] == 0,
                },
                "F": {
                    "number": m["F_sura_reverse_number"],
                    "digit_length": m["F_sura_reverse_len"],
                    "mod19": m["F_sura_reverse_mod"],
                    "is_multiple_of_19": m["F_sura_reverse_mod"] == 0,
                },
            }

        lines.append("Final Composite Global Lock (G):")
        lines.append(f"  G Composite Sura-Sum Chain -> digits={g_len}, mod19={g_mod}, {_pass_fail(g_mod)}")
        bool_lines.append("Final Composite Global Lock (G):")
        bool_lines.append(f"  G Composite Sura-Sum Chain -> {g_mod == 0}")

        json_payload["datasets"].append(
            {
                "dataset_index": dataset_index,
                "mushaf": mushaf,
                "total_global_measurements": 43,
                "pass_count": pass_count,
                "fail_count": fail_count,
                "all_measurements_passed": pass_count == 43,
                "metrics": dataset_json_metrics,
                "G": {
                    "number": g_number,
                    "digit_length": g_len,
                    "mod19": g_mod,
                    "is_multiple_of_19": g_mod == 0,
                },
            }
        )

        if dataset_index != len(excel_files):
            lines.append("")
            bool_lines.append("")

    output_file.parent.mkdir(parents=True, exist_ok=True)
    output_file.write_text("\n".join(lines) + "\n", encoding="utf-8")
    bool_output = output_file.parent / (output_file.stem + "_bool.txt")
    bool_output.write_text("\n".join(bool_lines) + "\n", encoding="utf-8")
    json_dir = output_file.parent.parent / "json"
    json_dir.mkdir(parents=True, exist_ok=True)
    with (json_dir / (output_file.stem + ".json")).open("w", encoding="utf-8") as handle:
        json.dump(json_payload, handle, ensure_ascii=True, indent=2)


def main() -> None:
    root = Path(__file__).resolve().parent
    output_path = root / "summary" / "txt" / "global_lock_summary.txt"
    build_summary(root, output_path)
    print(f"Summary written to: {output_path}")


if __name__ == "__main__":
    main()
