from __future__ import annotations

import argparse
import csv
import json
from dataclasses import dataclass
from pathlib import Path

from openpyxl import Workbook, load_workbook
from openpyxl.styles import PatternFill


# Traditional letter-value mapping used by M3, M4, M5, and M6.
ABJAD_VALUE = {
    "ا": 1,
    "ب": 2,
    "ج": 3,
    "د": 4,
    "ه": 5,
    "و": 6,
    "ز": 7,
    "ح": 8,
    "ط": 9,
    "ي": 10,
    "ك": 20,
    "ل": 30,
    "م": 40,
    "ن": 50,
    "س": 60,
    "ع": 70,
    "ف": 80,
    "ص": 90,
    "ق": 100,
    "ر": 200,
    "ش": 300,
    "ت": 400,
    "ث": 500,
    "خ": 600,
    "ذ": 700,
    "ض": 800,
    "ظ": 900,
    "غ": 1000,
}


# Normalize common Arabic forms to the canonical letters expected by the metrics.
NORMALIZE_MAP = str.maketrans(
    {
        "أ": "ا",
        "إ": "ا",
        "آ": "ا",
        "ٱ": "ا",
        "ؤ": "و",
        "ئ": "ي",
        "ى": "ي",
        "ة": "ه",
    }
)


@dataclass
class MetricRow:
    sura: str
    aya: str
    text: str
    m1_letter_count: str
    m2_word_count: str
    m3_abjad_sum: str
    m4_concat_letters: str
    m5_concat_words: str
    m6_reverse_concat_letters: str
    m7_reverse_concat_words: str


def normalize_text(text: str) -> str:
    return text.translate(NORMALIZE_MAP)


def letters_only(text: str) -> list[str]:
    return [ch for ch in normalize_text(text) if ch in ABJAD_VALUE]


def concat_letter_values(letters: list[str]) -> str:
    return "".join(str(ABJAD_VALUE[ch]) for ch in letters)


def is_multiple_of_19_numeric_string(value: str) -> bool:
    remainder = 0
    for ch in value:
        remainder = (remainder * 10 + (ord(ch) - ord("0"))) % 19
    return remainder == 0


def metric_multiple_counts(rows: list[MetricRow]) -> dict[str, int]:
    counts = {
        "M1": 0,
        "M2": 0,
        "M3": 0,
        "M4": 0,
        "M5": 0,
        "M6": 0,
        "M7": 0,
    }

    for row in rows:
        values = {
            "M1": row.m1_letter_count,
            "M2": row.m2_word_count,
            "M3": row.m3_abjad_sum,
            "M4": row.m4_concat_letters,
            "M5": row.m5_concat_words,
            "M6": row.m6_reverse_concat_letters,
            "M7": row.m7_reverse_concat_words,
        }
        for metric, value in values.items():
            if value.isdigit() and is_multiple_of_19_numeric_string(value):
                counts[metric] += 1

    return counts


def write_metrics_summary_txt(
    build_summaries: list[dict[str, object]],
    output_txt: Path,
) -> None:
    lines = [
        "7 Metrics Summary (M1-M7)",
        "Generated from *_normalized.xlsx files in workspace root.",
        "",
    ]

    for summary in build_summaries:
        lines.extend(
            [
                f"File: {summary['file_name']}",
                f"Verse rows: {summary['verse_rows']}",
                f"M1 multiples-of-19 count: {summary['multiple_counts']['M1']}",
                f"M2 multiples-of-19 count: {summary['multiple_counts']['M2']}",
                f"M3 multiples-of-19 count: {summary['multiple_counts']['M3']}",
                f"M4 multiples-of-19 count: {summary['multiple_counts']['M4']}",
                f"M5 multiples-of-19 count: {summary['multiple_counts']['M5']}",
                f"M6 multiples-of-19 count: {summary['multiple_counts']['M6']}",
                f"M7 multiples-of-19 count: {summary['multiple_counts']['M7']}",
                "",
            ]
        )

    output_txt.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")


def compute_metrics(sura: str, aya: str, text: str) -> MetricRow:
    words = [w for w in normalize_text(text).split() if w]
    letters = letters_only(text)

    m1 = len(letters)
    m2 = len(words)
    m3 = sum(ABJAD_VALUE[ch] for ch in letters)
    m4 = concat_letter_values(letters)

    word_sums = []
    for word in words:
        word_letters = [ch for ch in word if ch in ABJAD_VALUE]
        word_sums.append(sum(ABJAD_VALUE[ch] for ch in word_letters))
    m5 = "".join(str(v) for v in word_sums)

    m6 = concat_letter_values(list(reversed(letters)))
    m7 = m5[::-1]

    return MetricRow(
        sura=str(sura),
        aya=str(aya),
        text=text,
        m1_letter_count=str(m1),
        m2_word_count=str(m2),
        m3_abjad_sum=str(m3),
        m4_concat_letters=m4,
        m5_concat_words=m5,
        m6_reverse_concat_letters=m6,
        m7_reverse_concat_words=m7,
    )


def parse_input(input_txt: Path) -> list[tuple[str, str, str]]:
    if input_txt.suffix.lower() == ".xlsx":
        return parse_input_xlsx(input_txt)

    rows: list[tuple[str, str, str]] = []
    for raw_line in input_txt.read_text(encoding="utf-8").splitlines():
        line = raw_line.strip("\n\r")
        if not line.strip():
            continue
        if "\t" not in line:
            raise ValueError(f"Expected tab-separated line, got: {line}")
        left, text = line.split("\t", 1)
        if ":" not in left:
            raise ValueError(f"Expected Sura:Aya prefix, got: {left}")
        sura, aya = left.split(":", 1)
        rows.append((sura.strip(), aya.strip(), text))
    return rows


def parse_input_xlsx(input_xlsx: Path) -> list[tuple[str, str, str]]:
    wb = load_workbook(input_xlsx, read_only=True, data_only=True)
    try:
        ws = wb.active
        header_row = next(ws.iter_rows(min_row=1, max_row=1, values_only=True), None)
        if not header_row:
            raise ValueError(f"Missing header row in: {input_xlsx}")

        headers = [str(cell).strip().lower() if cell is not None else "" for cell in header_row]
        try:
            sura_idx = headers.index("sura")
            aya_idx = headers.index("verse")
            text_idx = headers.index("text")
        except ValueError as exc:
            raise ValueError(
                f"Expected xlsx headers including sura, verse, text in: {input_xlsx}"
            ) from exc

        rows: list[tuple[str, str, str]] = []
        for row in ws.iter_rows(min_row=2, values_only=True):
            if row is None:
                continue

            sura_val = row[sura_idx] if sura_idx < len(row) else None
            aya_val = row[aya_idx] if aya_idx < len(row) else None
            text_val = row[text_idx] if text_idx < len(row) else None

            if sura_val is None and aya_val is None and text_val is None:
                continue

            sura = str(sura_val).strip()
            aya = str(aya_val).strip()
            text = "" if text_val is None else str(text_val)

            if not sura or not aya:
                continue

            rows.append((sura, aya, text))
        return rows
    finally:
        wb.close()


def build_summary(rows: list[MetricRow]) -> list[dict[str, object]]:
    by_sura: dict[int, list[MetricRow]] = {}
    for row in rows:
        key = int(row.sura)
        by_sura.setdefault(key, []).append(row)

    summary: list[dict[str, object]] = []
    summary.append(
        {
            "chapter": 0,
            "Total_M1_Letters": 0,
            "Total_M2_Words": 0,
            "Total_M3_Abjad": 0,
            "Total_M4_ConcatLetters": "",
            "Total_M5_ConcatWords": "00",
            "Total_M6_ReverseConcatLetters": "",
            "Total_M7_ReverseConcatWords": "00",
            "M1_SuraIsMultipleOf19": False,
            "M2_SuraIsMultipleOf19": False,
            "M3_SuraIsMultipleOf19": False,
        }
    )

    for sura in sorted(by_sura):
        sura_rows = by_sura[sura]
        total_m1 = sum(int(r.m1_letter_count) for r in sura_rows)
        total_m2 = sum(int(r.m2_word_count) for r in sura_rows)
        total_m3 = sum(int(r.m3_abjad_sum) for r in sura_rows)

        chapter_m4 = "".join(r.m4_concat_letters for r in sura_rows)
        chapter_m5 = "".join(r.m5_concat_words for r in sura_rows)
        chapter_m6 = "".join(r.m6_reverse_concat_letters for r in reversed(sura_rows))

        summary.append(
            {
                "chapter": sura,
                "Total_M1_Letters": total_m1,
                "Total_M2_Words": total_m2,
                "Total_M3_Abjad": total_m3,
                "Total_M4_ConcatLetters": chapter_m4,
                "Total_M5_ConcatWords": chapter_m5,
                "Total_M6_ReverseConcatLetters": chapter_m6,
                "Total_M7_ReverseConcatWords": chapter_m5[::-1],
                "M1_SuraIsMultipleOf19": total_m1 % 19 == 0,
                "M2_SuraIsMultipleOf19": total_m2 % 19 == 0,
                "M3_SuraIsMultipleOf19": total_m3 % 19 == 0,
            }
        )

    return summary


def write_json(rows: list[MetricRow], output_json: Path) -> None:
    payload = [
        {
            "Sura": r.sura,
            "Aya": r.aya,
            "Text": r.text,
            "M1_LetterCount": r.m1_letter_count,
            "M2_WordCount": r.m2_word_count,
            "M3_AbjadSum": r.m3_abjad_sum,
            "M4_ConcatLetters": r.m4_concat_letters,
            "M5_ConcatWords": r.m5_concat_words,
            "M6_ReverseConcatLetters": r.m6_reverse_concat_letters,
            "M7_ReverseConcatWords": r.m7_reverse_concat_words,
        }
        for r in rows
    ]
    output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")


def write_txt(rows: list[MetricRow], output_txt: Path) -> None:
    lines = [
        "\t".join(
            [
                "Sura",
                "Aya",
                "Text",
                "M1_LetterCount",
                "M2_WordCount",
                "M3_AbjadSum",
                "M4_ConcatLetters",
                "M5_ConcatWords",
                "M6_ReverseConcatLetters",
                "M7_ReverseConcatWords",
            ]
        )
    ]

    for r in rows:
        lines.append(
            "\t".join(
                [
                    r.sura,
                    r.aya,
                    r.text,
                    r.m1_letter_count,
                    r.m2_word_count,
                    r.m3_abjad_sum,
                    r.m4_concat_letters,
                    r.m5_concat_words,
                    r.m6_reverse_concat_letters,
                    r.m7_reverse_concat_words,
                ]
            )
        )

    output_txt.write_text("\n".join(lines) + "\n", encoding="utf-8")


def write_csv(rows: list[MetricRow], output_csv: Path) -> None:
    with output_csv.open("w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(
            [
                "Sura",
                "Aya",
                "Text",
                "M1_LetterCount",
                "M2_WordCount",
                "M3_AbjadSum",
                "M4_ConcatLetters",
                "M5_ConcatWords",
                "M6_ReverseConcatLetters",
                "M7_ReverseConcatWords",
            ]
        )
        for r in rows:
            writer.writerow(
                [
                    r.sura,
                    r.aya,
                    r.text,
                    r.m1_letter_count,
                    r.m2_word_count,
                    r.m3_abjad_sum,
                    r.m4_concat_letters,
                    r.m5_concat_words,
                    r.m6_reverse_concat_letters,
                    r.m7_reverse_concat_words,
                ]
            )


def discover_inputs(root_dir: Path) -> list[Path]:
    files = sorted(root_dir.glob("*_normalized.xlsx"))
    if files:
        return files
    return sorted(root_dir.glob("*Normalized.txt"))


def output_base_name(input_path: Path) -> str:
    stem = input_path.stem
    for suffix in ("_normalized", "_Normalized", " Normalized"):
        if stem.endswith(suffix):
            stem = stem[: -len(suffix)]
            break
    return stem


def write_xlsx(rows: list[MetricRow], summary_rows: list[dict[str, object]], output_xlsx: Path) -> None:
    highlight_fill = PatternFill(fill_type="solid", fgColor="FFFFFF00")

    def highlight_multiples_of_19(worksheet) -> None:
        for row in worksheet.iter_rows(min_row=2, min_col=4, max_col=10):
            for cell in row:
                value = cell.value
                if value is None:
                    continue

                # Treat numeric-looking strings and numeric cells consistently.
                if isinstance(value, str):
                    candidate = value.strip()
                    if not candidate or not candidate.isdigit():
                        continue
                    is_multiple = is_multiple_of_19_numeric_string(candidate)
                elif isinstance(value, int):
                    is_multiple = value % 19 == 0
                else:
                    continue

                if is_multiple:
                    cell.fill = highlight_fill

    wb = Workbook()
    ws_metrics = wb.active
    ws_metrics.title = "Metrics"
    ws_metrics.append(
        [
            "sura",
            "verse",
            "text",
            "M1_LetterCount",
            "M2_WordCount",
            "M3_AbjadSum",
            "M4_ConcatLetters",
            "M5_ConcatWords",
            "M6_ReverseConcatLetters",
            "M7_ReverseConcatWords",
        ]
    )
    for r in rows:
        ws_metrics.append(
            [
                int(r.sura),
                int(r.aya),
                r.text,
                int(r.m1_letter_count),
                int(r.m2_word_count),
                int(r.m3_abjad_sum),
                r.m4_concat_letters,
                r.m5_concat_words,
                r.m6_reverse_concat_letters,
                r.m7_reverse_concat_words,
            ]
        )

    highlight_multiples_of_19(ws_metrics)

    ws_summary = wb.create_sheet("Summary")
    ws_summary.append(
        [
            "chapter",
            "Total_M1_Letters",
            "Total_M2_Words",
            "Total_M3_Abjad",
            "Total_M4_ConcatLetters",
            "Total_M5_ConcatWords",
            "Total_M6_ReverseConcatLetters",
            "Total_M7_ReverseConcatWords",
            "M1_SuraIsMultipleOf19",
            "M2_SuraIsMultipleOf19",
            "M3_SuraIsMultipleOf19",
        ]
    )
    for row in summary_rows:
        ws_summary.append(
            [
                row["chapter"],
                row["Total_M1_Letters"],
                row["Total_M2_Words"],
                row["Total_M3_Abjad"],
                row["Total_M4_ConcatLetters"],
                row["Total_M5_ConcatWords"],
                row["Total_M6_ReverseConcatLetters"],
                row["Total_M7_ReverseConcatWords"],
                row["M1_SuraIsMultipleOf19"],
                row["M2_SuraIsMultipleOf19"],
                row["M3_SuraIsMultipleOf19"],
            ]
        )

    highlight_multiples_of_19(ws_summary)

    wb.save(output_xlsx)


def main() -> None:
    parser = argparse.ArgumentParser(description="Build 7 Quranic metrics from normalized source files.")
    parser.add_argument(
        "--input",
        default=None,
        help="Optional path to one source file (.xlsx or .txt). If omitted, all normalized files in root are processed.",
    )
    parser.add_argument(
        "--output-dir",
        default="organized_outputs",
        help="Output root directory containing json/txt/xlsx folders.",
    )
    args = parser.parse_args()

    root_dir = Path(__file__).resolve().parent
    output_dir = Path(args.output_dir)

    if args.input:
        input_files = [Path(args.input)]
    else:
        input_files = discover_inputs(root_dir)

    if not input_files:
        raise ValueError("No normalized input files were found in the project root.")

    json_dir = output_dir / "json"
    txt_dir = output_dir / "txt"
    csv_dir = output_dir / "csv"
    xlsx_dir = output_dir / "xlsx"
    json_dir.mkdir(parents=True, exist_ok=True)
    txt_dir.mkdir(parents=True, exist_ok=True)
    csv_dir.mkdir(parents=True, exist_ok=True)
    xlsx_dir.mkdir(parents=True, exist_ok=True)

    build_summaries: list[dict[str, object]] = []

    for input_file in input_files:
        raw_rows = parse_input(input_file)
        metric_rows = [compute_metrics(sura, aya, text) for sura, aya, text in raw_rows]
        summary_rows = build_summary(metric_rows)

        base_name = output_base_name(input_file)
        json_out = json_dir / f"{base_name}_7Metrics.json"
        txt_out = txt_dir / f"{base_name}_7Metrics.txt"
        csv_out = csv_dir / f"{base_name}_7Metrics.csv"
        xlsx_out = xlsx_dir / f"{base_name}_7Metrics.xlsx"

        write_json(metric_rows, json_out)
        write_txt(metric_rows, txt_out)
        write_csv(metric_rows, csv_out)
        write_xlsx(metric_rows, summary_rows, xlsx_out)

        build_summaries.append(
            {
                "file_name": f"{base_name}_7Metrics",
                "verse_rows": len(metric_rows),
                "multiple_counts": metric_multiple_counts(metric_rows),
            }
        )

        print(f"Wrote {len(metric_rows)} rows from {input_file.name} to {json_out}")
        print(f"Wrote text table to {txt_out}")
        print(f"Wrote csv table to {csv_out}")
        print(f"Wrote summary + metrics workbook to {xlsx_out}")

    summary_txt_out = txt_dir / "7Metrics_summary.txt"
    write_metrics_summary_txt(build_summaries, summary_txt_out)
    print(f"Wrote aggregate summary to {summary_txt_out}")


if __name__ == "__main__":
    main()
