"""
ACSNSQIPUtil.py
=============
Reusable utilities for the ACS NSQIP benign lung resection study.

Extracted from 03_benign_resection_nsqip.py.  Import this module in the
main analysis script (or any downstream notebook) rather than duplicating
constants and helpers.

Contents
--------
CONFIGURATION
    START_YEAR, END_YEAR, NSQIP_DIR, OUT_ROOT

CPT CODE DEFINITIONS
    RESECTION_META, ALL_CPT

ICD CODE SETS
    BENIGN_PFX, CANCER_PFX, STRUCTURAL_PFX

NSQIP FIELD CANDIDATE LISTS
    *_CANDIDATES constants for every NSQIP concept
    NSQIP_COMPLICATION_FIELDS, NSQIP_THORACIC_FIELDS

COLUMN / DATA HELPERS
    match_pfx()             prefix-match ICD codes
    detect_col()            find the first present candidate column
    safe_col()              safe column accessor with default
    year_from_path()        infer 4-digit year from a PUF filename
    read_nsqip_file()       read a single PUF (.txt/.csv/.sas7bdat)
    normalise_yesno()       NSQIP Yes/No → 0/1
    normalise_asa()         ASA string → numeric 1–5
    bmi_from_height_weight()  compute BMI from inches/pounds

ICD CLASSIFICATION
    classify_dx_primary()   classify a single normalised ICD string
    add_classification()    vectorised classification for a DataFrame

DEMOGRAPHICS / COHORT PREP
    add_age_group()
    add_sex_label()
    add_bmi()
    add_asa()
    add_comorbidities()
    add_svc_year()

30-DAY OUTCOMES
    add_outcomes()          attach all NSQIP outcome flags to cohort

ANALYSIS HELPERS
    rate_row()              single-group benign-rate dict
    rate_table()            stratified benign-rate DataFrame
    outcome_table()         stratified 30-day outcome DataFrame
    run_trend_test()        Cochran-Armitage trend test → file
    prepare_regression_df() encode dummy variables for logistic models
    impute_and_prepare()    impute missing covariates, drop incomplete rows
    run_logistic_regression()
    run_risk_adjusted_rates()
"""

from __future__ import annotations

import os
import re
import glob
import warnings
from typing import Optional

import numpy as np
import pandas as pd
from scipy import stats

warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ---------------------------------------------------------------------------
# Optional dependencies
# ---------------------------------------------------------------------------

try:
    import statsmodels.formula.api as smf
    import statsmodels.api as sm
    HAS_SM = True
except ImportError:
    HAS_SM = False
    print("[WARN] statsmodels not installed — logistic regression will be skipped.")
    print("       Install with: pip install statsmodels")

try:
    import scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False
    print("[WARN] scipy not installed — some statistical functions will be skipped.")
    print("       Install with: pip install scipy")

try:
    import pyreadstat
    HAS_PYREADSTAT = True
except ImportError:
    HAS_PYREADSTAT = False
    print("[INFO] pyreadstat not installed — SAS7BDAT files will not be readable.")
    print("       Install with: pip install pyreadstat")


# =============================================================================
# CONFIGURATION
# =============================================================================

NSQIP_DIR  = "./Extracted/"
OUT_ROOT   = "./study_output_nsqip"
START_YEAR = 2010
END_YEAR   = 2024


# =============================================================================
# CPT CODE DEFINITIONS
# =============================================================================

RESECTION_META: dict[str, tuple[str, str, str]] = {
    # code     : (extent_detail,         extent_group,    approach)
    "32096"    : ("wedge_open",          "Wedge",         "Open"),   # 2010–2013
    "32505"    : ("wedge_open",          "Wedge",         "Open"),   # 2014+
    "32484"    : ("segmentectomy_open",  "Segmentectomy", "Open"),
    "32480"    : ("lobectomy_open",      "Lobectomy",     "Open"),
    "32482"    : ("bilobectomy_open",    "Lobectomy",     "Open"),
    "32440"    : ("pneumonectomy_open",  "Pneumonectomy", "Open"),
    "32442"    : ("pneumonectomy_open",  "Pneumonectomy", "Open"),
    "32486"    : ("sleeve_lobectomy",    "Lobectomy",     "Open"),
    "32488"    : ("completion_pnx",      "Pneumonectomy", "Open"),
    "32666"    : ("wedge_vats",          "Wedge",         "VATS"),
    "32667"    : ("wedge_vats",          "Wedge",         "VATS"),
    "32669"    : ("segmentectomy_vats",  "Segmentectomy", "VATS"),
    "32663"    : ("lobectomy_vats",      "Lobectomy",     "VATS"),
    "32671"    : ("pneumonectomy_vats",  "Pneumonectomy", "VATS"),
}

ALL_CPT: list[str] = list(RESECTION_META.keys())


# =============================================================================
# ICD CODE PREFIX SETS
# =============================================================================

BENIGN_PFX: set[str] = {
    # Benign neoplasms of lung/bronchus
    "2123", "2124",
    "D1430", "D1431", "D1432",
    # Solitary pulmonary nodule / healed granuloma
    "79311", "R911",
    # Nonspecific lung abnormality
    "79319", "R918",
    # Sarcoidosis
    "135", "D860",
    # Granulomatosis with polyangiitis
    "4464", "M3130", "M3131",
    # Rheumatoid lung nodule
    "71481",
    "M0510", "M0511", "M0512", "M0513", "M0514",
    "M0515", "M0516", "M0517", "M0519",
    # Organising pneumonia (COP/BOOP)
    "5168", "J84116",
    # Hypersensitivity pneumonitis
    "495", "J67",
    # Pulmonary amyloidosis
    "27739", "E854",
    # Pulmonary fibrosis NOS
    "515", "J8410",
    # Idiopathic pulmonary fibrosis
    "51631", "J84112",
    # ILD unspecified
    "5169", "J849",
    # Bronchogenic cyst
    "7484", "Q330",
}

CANCER_PFX: set[str] = {
    # Primary lung cancer
    "162", "C34",
    # Carcinoma in situ
    "2312", "D022",
    # Typical carcinoid (ICD-9)
    "20920", "20921", "20922", "20923", "20924",
    "20925", "20926", "20927", "20928", "20929",
    "C7A09",
    # Pulmonary NET / atypical carcinoid
    "20960", "20961", "20962", "20963", "20964", "20965",
    "20970", "20971", "20972", "20973", "20974", "20975",
    "20976", "20979",
    "C7A",
    # Secondary NET
    "C7B",
    # Benign / other carcinoid where resection IS standard of care
    "20961", "20960", "20930",
    "D3A090", "D3A09", "D3A00",
    # Metastasis to lung
    "1970", "C7800", "C7801", "C7802",
    # Uncertain behaviour
    "2357", "D381",
    # Unspecified behaviour
    "2391", "D491",
}

STRUCTURAL_PFX: set[str] = {
    # Lung abscess
    "5130", "J851", "J852",
    # Tuberculosis
    "011", "A15",
    # Pulmonary aspergillosis
    "B44",
    # Bronchiectasis
    "494", "J47",
    # Empyema
    "J86",
    # CPAM / bronchogenic cyst / sequestration / congenital cystic lung
    "Q338", "Q332", "Q33",
    # Pulmonary AVM
    "Q2572",
    # Massive haemoptysis
    "R042",
    # Pneumothorax
    "J93",
    # Traumatic laceration
    "S273",
    # Pulmonary infarction
    "I2699",
    # Solitary nodule (structural fallback)
    "R911",
}


# =============================================================================
# NSQIP FIELD CANDIDATE LISTS
# =============================================================================

CASEID_CANDIDATES        = ["caseid", "case_id"]
OPERYR_CANDIDATES        = ["operyr", "oper_yr", "year"]
CPT_CANDIDATES           = ["cpt", "cpt_cd", "principalcpt"]
CONCPT_CANDIDATES        = [f"concpt{i}" for i in range(1, 11)] + \
                            [f"con_cpt{i}" for i in range(1, 11)]
PODIAG_CANDIDATES        = ["podiag", "po_diag", "prncptx_podiag"]
PODIAG10_CANDIDATES      = ["podiag10", "po_diag10", "prncptx_podiag10"]
AGE_CANDIDATES           = ["age", "age_yrs"]
SEX_CANDIDATES           = ["sex", "gender"]
RACE_CANDIDATES          = [
    "race_new", "race", "ethnicity_american_indian",
    "ethnicity_asian", "ethnicity_black",
    "ethnicity_hispanic", "ethnicity_white",
]
HEIGHT_CANDIDATES        = ["height", "heightin"]
WEIGHT_CANDIDATES        = ["weight", "weightlb"]
BMI_CANDIDATES           = ["bmi"]
ASA_CANDIDATES           = ["asaclas", "asa_class"]
HXCOPD_CANDIDATES        = ["hxcopd", "hx_copd", "copd"]
HXCHF_CANDIDATES         = ["hxchf", "hx_chf", "chf"]
DIABETES_CANDIDATES      = ["diabetes", "diabetes_mellitus"]
SMOKE_CANDIDATES         = ["smoke", "smoker", "smoking"]
HYPERMED_CANDIDATES      = ["hypermed", "hyper_med", "antihypertensive"]
MORT_CANDIDATES          = ["dopertod", "died30", "mortality", "death30"]

NSQIP_COMPLICATION_FIELDS: list[str] = [
    "supinfec", "wndinfd", "orgspcssi", "dehis",
    "oupneumo", "reintub", "pulembol", "failwean",
    "renainsf", "oprenafl", "urninfec", "cnscva",
    "cdarrest", "cdmi", "nnevent", "othbleed",
    "othdvt", "returnor", "readmission1",
]

NSQIP_THORACIC_FIELDS: list[str] = [
    "air_leak", "bronchopleural", "chylothorax",
    "atrial_fib", "pneumonia_pul", "empyema_pul", "atel",
]

# Covariates used in logistic regression (BMI available in NSQIP)
BASE_COVARIATES_NSQIP: list[str] = [
    "age_num", "female", "bmi", "asa_num",
    "wedge", "segmentectomy", "pneumonectomy", "vats",
    "cc_copd", "cc_chf", "cc_diabetes", "cc_smoking", "antihypertensive",
]
CONTINUOUS_COVARIATES_NSQIP: set[str] = {"age_num", "bmi"}
BINARY_COVARIATES_NSQIP: set[str] = (
    set(BASE_COVARIATES_NSQIP) - CONTINUOUS_COVARIATES_NSQIP
)


# =============================================================================
# COLUMN / DATA HELPERS
# =============================================================================

def match_pfx(series: pd.Series, prefixes: set) -> pd.Series:
    """Match ICD codes (dots stripped, uppercased) against a prefix set."""
    if not prefixes:
        return pd.Series(False, index=series.index)
    clean   = series.fillna("").str.upper().str.replace(".", "", regex=False)
    pattern = "^(" + "|".join(
        re.escape(p) for p in sorted(prefixes, key=len, reverse=True)
    ) + ")"
    return clean.str.match(pattern)


def detect_col(cols: list[str], candidates: list[str]) -> Optional[str]:
    """Return the first candidate column present in *cols* (case-insensitive)."""
    cols_lower = [c.lower() for c in cols]
    for cand in candidates:
        if cand.lower() in cols_lower:
            return cols[cols_lower.index(cand.lower())]
    return None


def safe_col(df: pd.DataFrame, col: Optional[str], default=0) -> pd.Series:
    """Return *df[col]* if it exists, otherwise a constant Series."""
    return df[col] if col and col in df.columns else pd.Series(default, index=df.index)


def year_from_path(path: str) -> Optional[int]:
    """Infer a 4-digit calendar year from a NSQIP PUF filename."""
    m = re.search(r"(\d{4})", os.path.basename(path))
    if m:
        return int(m.group(1))
    m = re.search(r"puf(\d{2})", os.path.basename(path).lower())
    if m:
        yy = int(m.group(1))
        return 2000 + yy if yy < 50 else 1900 + yy
    return None


def read_nsqip_file(path: str) -> pd.DataFrame:
    """
    Read a single NSQIP PUF file.

    Handles tab-delimited .txt, comma-delimited .csv, and .sas7bdat
    (requires pyreadstat).  Column names are lowercased and stripped.
    Returns an empty DataFrame on failure.
    """
    ext = os.path.splitext(path)[1].lower()
    print(f"  Reading: {os.path.basename(path)} ({ext})")
    try:
        if ext == ".sas7bdat":
            if not HAS_PYREADSTAT:
                raise RuntimeError("pyreadstat required for .sas7bdat files")
            df, _ = pyreadstat.read_sas7bdat(path)
        elif ext == ".csv":
            df = pd.read_csv(path, dtype=str, low_memory=False, encoding="latin1")
        else:
            df = pd.read_csv(path, sep="\t", dtype=str, low_memory=False, encoding="latin1")
            if df.shape[1] == 1:          # retry as CSV if single column
                df = pd.read_csv(path, dtype=str, low_memory=False, encoding="latin1")
        df.columns = [c.strip().lower() for c in df.columns]
        print(f"    {len(df):,} rows × {len(df.columns)} cols")
        return df
    except Exception as exc:
        print(f"  [ERROR] Could not read {path}: {exc}")
        return pd.DataFrame()


def normalise_yesno(series: pd.Series) -> pd.Series:
    """
    Normalise NSQIP Yes/No and 1/2-encoded comorbidity flags to 0/1 integers.

    NSQIP encodes as: 'Yes'/'No', 1/2, 'YES'/'NO', '1'/'2', 'Y'/'N'.
    """
    s = series.astype(str).str.strip().str.upper()
    return s.map({
        "YES": 1, "NO": 0, "Y": 1, "N": 0,
        "1": 1, "2": 0, "0": 0,
        "1.0": 1, "2.0": 0, "0.0": 0,
    }).fillna(0).astype(int)


def normalise_asa(series: pd.Series) -> pd.Series:
    """Convert ASA class string to numeric 1–5 (NaN if unrecognised)."""
    s = series.astype(str).str.strip().str.upper()
    mapping = {
        "1": 1, "1-NO DISTURB": 1, "ASA 1": 1,
        "2": 2, "2-MILD DISTURB": 2, "ASA 2": 2,
        "3": 3, "3-SEVERE DISTURB": 3, "ASA 3": 3,
        "4": 4, "4-LIFE THREAT": 4, "ASA 4": 4,
        "5": 5, "5-MORIBUND": 5, "ASA 5": 5,
    }
    return s.map(mapping)


def bmi_from_height_weight(
    height_in: pd.Series,
    weight_lb: pd.Series,
) -> pd.Series:
    """Compute BMI from height (inches) and weight (pounds). Clips to 10–80."""
    h = pd.to_numeric(height_in, errors="coerce")
    w = pd.to_numeric(weight_lb, errors="coerce")
    bmi = (w * 703) / (h ** 2)
    bmi[~bmi.between(10, 80)] = np.nan
    return bmi


# =============================================================================
# ICD CLASSIFICATION
# =============================================================================

_NULL_DX_VALUES = {"", ".", "nan", "none", "na", "n/a", "unknown"}

_CONFLICT_PRIORITY = {"malignant": 0, "structural": 1, "benign": 2, "unclassified": 3}


def _clean_dx(series: pd.Series) -> pd.Series:
    """Strip whitespace and replace sentinel nulls with empty string."""
    return series.fillna("").astype(str).str.strip().apply(
        lambda x: "" if x.lower() in _NULL_DX_VALUES else x
    )


def _looks_like_icd10(series: pd.Series) -> pd.Series:
    return series.str.match(r"^[A-Za-z]")


def _looks_like_icd9(series: pd.Series) -> pd.Series:
    return series.str.match(r"^[0-9EeVv]")


def classify_dx_primary(dx: str) -> str:
    """
    Classify a single, already-normalised ICD code string.

    Returns one of: 'malignant' | 'benign' | 'structural' | 'unclassified'.
    'structural' is returned for structural-only codes; the caller is
    responsible for mapping structural → malignant if desired.
    """
    if not dx:
        return "unclassified"
    s = pd.Series([dx])
    if match_pfx(s, CANCER_PFX).any():     return "malignant"
    if match_pfx(s, BENIGN_PFX).any():     return "benign"
    if match_pfx(s, STRUCTURAL_PFX).any(): return "structural"
    return "unclassified"


def add_classification(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add ICD-era-aware classification columns to *df* (in-place copy).

    Requires:
        df["dx_podiag"]   — cleaned ICD-9 diagnosis (may be empty string)
        df["dx_podiag10"] — cleaned ICD-10 diagnosis (may be empty string)
        df["operyr_num"]  — numeric operation year (may be NaN)

    Adds:
        dx_primary          — era-appropriate authoritative code
        dx_is_benign        — bool
        dx_is_cancer        — bool
        dx_is_structural    — bool
        classification      — 'benign' | 'malignant' | 'unclassified'
        is_benign           — int (0/1)
        is_malignant        — int (0/1)
    """
    df = df.copy()
    yr = df.get("operyr_num", pd.Series(np.nan, index=df.index))

    has_icd9  = (df["dx_podiag"]   != "") & _looks_like_icd9(df["dx_podiag"])
    has_icd10 = (df["dx_podiag10"] != "") & _looks_like_icd10(df["dx_podiag10"])

    pre2015    = yr.le(2014)
    yr2015     = yr.eq(2015)
    post2015   = yr.ge(2016)
    yr_unknown = yr.isna()

    df["dx_primary"] = np.select(
        [
            pre2015,
            yr2015 & has_icd10,
            yr2015 & ~has_icd10,
            post2015,
            yr_unknown & has_icd10,
            yr_unknown & ~has_icd10 & has_icd9,
        ],
        [
            df["dx_podiag"],
            df["dx_podiag10"],
            df["dx_podiag"],
            df["dx_podiag10"],
            df["dx_podiag10"],
            df["dx_podiag"],
        ],
        default="",
    )

    # Dual-field conflict resolution (conservative: malignant wins)
    dual_mask = has_icd9 & has_icd10
    df["_dual_resolved_cls"] = np.nan
    if dual_mask.any():
        dual = df[dual_mask]
        cls9  = dual["dx_podiag"].apply(classify_dx_primary)
        cls10 = dual["dx_podiag10"].apply(classify_dx_primary)
        resolved = pd.Series(
            [
                c9 if _CONFLICT_PRIORITY[c9] <= _CONFLICT_PRIORITY[c10] else c10
                for c9, c10 in zip(cls9, cls10)
            ],
            index=dual.index,
        )
        df.loc[dual_mask, "_dual_resolved_cls"] = resolved

    df["dx_is_benign"]     = match_pfx(df["dx_primary"], BENIGN_PFX)
    df["dx_is_cancer"]     = match_pfx(df["dx_primary"], CANCER_PFX)
    df["dx_is_structural"] = match_pfx(df["dx_primary"], STRUCTURAL_PFX)

    def _classify_row(row) -> str:
        dual_cls = row.get("_dual_resolved_cls")
        if isinstance(dual_cls, str) and dual_cls in ("malignant", "benign"):
            return dual_cls
        if isinstance(dual_cls, str) and dual_cls == "structural":
            return "malignant"
        if row["dx_is_cancer"]:      return "malignant"
        if row["dx_is_benign"]:      return "benign"
        if row["dx_is_structural"]:  return "malignant"
        return "unclassified"

    df["classification"] = df.apply(_classify_row, axis=1)
    df["is_benign"]      = (df["classification"] == "benign").astype(int)
    df["is_malignant"]   = (df["classification"] == "malignant").astype(int)
    return df


# =============================================================================
# DEMOGRAPHICS / COHORT PREP
# =============================================================================

def add_age_group(df: pd.DataFrame) -> pd.DataFrame:
    """Add *age_num* (float) and *age_group* (categorical) columns."""
    df = df.copy()
    age_col = detect_col(df.columns.tolist(), AGE_CANDIDATES)
    if age_col:
        df["age_num"] = pd.to_numeric(df[age_col], errors="coerce")
    else:
        df["age_num"] = np.nan
    df["age_group"] = pd.cut(
        df["age_num"],
        bins=[0, 49, 59, 69, 79, 120],
        labels=["<50", "50-59", "60-69", "70-79", "80+"],
    )
    return df


def add_sex_label(df: pd.DataFrame) -> pd.DataFrame:
    """Add *sex_label* ('Male' | 'Female' | 'Unknown')."""
    df = df.copy()
    sex_col = detect_col(df.columns.tolist(), SEX_CANDIDATES)
    if sex_col:
        s = df[sex_col].astype(str).str.strip().str.upper()
        df["sex_label"] = s.map({
            "MALE": "Male", "FEMALE": "Female",
            "M": "Male", "F": "Female",
            "1": "Male", "2": "Female",
        }).fillna("Unknown")
    else:
        df["sex_label"] = "Unknown"
    return df


def add_bmi(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add *bmi* (float) and *bmi_cat* (categorical) columns.

    Uses the direct BMI column if present; otherwise derives from
    height (inches) and weight (pounds).
    """
    df = df.copy()
    cols = df.columns.tolist()
    bmi_col    = detect_col(cols, BMI_CANDIDATES)
    height_col = detect_col(cols, HEIGHT_CANDIDATES)
    weight_col = detect_col(cols, WEIGHT_CANDIDATES)

    if bmi_col:
        df["bmi"] = pd.to_numeric(df[bmi_col], errors="coerce")
        df.loc[~df["bmi"].between(10, 80), "bmi"] = np.nan
    elif height_col and weight_col:
        df["bmi"] = bmi_from_height_weight(df[height_col], df[weight_col])
    else:
        df["bmi"] = np.nan

    df["bmi_cat"] = pd.cut(
        df["bmi"],
        bins=[0, 18.5, 25, 30, 40, 200],
        labels=[
            "Underweight (<18.5)", "Normal (18.5–24.9)",
            "Overweight (25–29.9)", "Obese (30–39.9)", "Severely Obese (≥40)",
        ],
    )
    return df


def add_asa(df: pd.DataFrame) -> pd.DataFrame:
    """Add *asa_num* (float, 1–5) column."""
    df = df.copy()
    asa_col = detect_col(df.columns.tolist(), ASA_CANDIDATES)
    df["asa_num"] = normalise_asa(df[asa_col]) if asa_col else np.nan
    return df


def add_comorbidities(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add binary (0/1) comorbidity columns:
        cc_copd, cc_chf, cc_diabetes, cc_smoking, antihypertensive
    Missing columns are filled with 0.
    """
    df = df.copy()
    cols = df.columns.tolist()
    mapping = {
        "cc_copd":          HXCOPD_CANDIDATES,
        "cc_chf":           HXCHF_CANDIDATES,
        "cc_diabetes":      DIABETES_CANDIDATES,
        "cc_smoking":       SMOKE_CANDIDATES,
        "antihypertensive": HYPERMED_CANDIDATES,
    }
    for field, candidates in mapping.items():
        src = detect_col(cols, candidates)
        df[field] = normalise_yesno(df[src]) if src else 0
    return df


def add_svc_year(df: pd.DataFrame) -> pd.DataFrame:
    """Add *svc_year* (nullable Int64) from operyr / operyr_num."""
    df = df.copy()
    if "operyr_num" in df.columns:
        df["svc_year"] = pd.to_numeric(df["operyr_num"], errors="coerce").astype("Int64")
    else:
        operyr_col = detect_col(df.columns.tolist(), OPERYR_CANDIDATES)
        if operyr_col:
            df["svc_year"] = pd.to_numeric(df[operyr_col], errors="coerce").astype("Int64")
        else:
            df["svc_year"] = pd.NA
    return df


# =============================================================================
# 30-DAY OUTCOMES
# =============================================================================

def add_outcomes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Attach all NSQIP 30-day outcome flags to *df*.

    Adds:
        died30           — 0/1
        <complication>   — 0/1 for each field in NSQIP_COMPLICATION_FIELDS
                           and NSQIP_THORACIC_FIELDS (0 if column absent)
        major_morbidity  — composite 0/1 (any major complication)
    """
    df = df.copy()
    mort_col = detect_col(df.columns.tolist(), MORT_CANDIDATES)
    df["died30"] = normalise_yesno(df[mort_col]) if mort_col else np.nan

    for comp in NSQIP_COMPLICATION_FIELDS + NSQIP_THORACIC_FIELDS:
        df[comp] = normalise_yesno(df[comp]) if comp in df.columns else 0

    major_fields = [
        c for c in [
            "oupneumo", "reintub", "pulembol", "failwean",
            "renainsf", "oprenafl", "urninfec", "cnscva",
            "cdarrest", "cdmi", "othbleed", "othdvt",
            "returnor", "readmission1",
            "air_leak", "bronchopleural", "chylothorax",
            "atrial_fib", "pneumonia_pul", "empyema_pul", "atel",
        ]
        if c in df.columns
    ]
    if major_fields:
        df["major_morbidity"] = df[major_fields].max(axis=1).astype(int)
    else:
        df["major_morbidity"] = np.nan

    return df


# =============================================================================
# ANALYSIS HELPERS
# =============================================================================

def rate_row(label: str, df: pd.DataFrame) -> dict:
    """Return a single-row dict summarising the benign resection rate."""
    n  = len(df)
    nb = int(df["is_benign"].sum()) if n > 0 else 0
    return {
        "group":        label,
        "N_classified": n,
        "N_benign":     nb,
        "N_malignant":  n - nb,
        "benign_rate":  round(nb / n, 6) if n > 0 else None,
        "benign_pct":   f"{nb/n:.2%}" if n > 0 else "N/A",
    }


def rate_table(df: pd.DataFrame, group_col: Optional[str] = None) -> pd.DataFrame:
    """Return a stratified benign-rate DataFrame (one row per group + TOTAL)."""
    if group_col and group_col in df.columns:
        rows = [rate_row(str(nm), sub)
                for nm, sub in df.groupby(group_col, observed=True)]
        rows.append(rate_row("TOTAL", df))
    else:
        rows = [rate_row("Overall", df)]
    return pd.DataFrame(rows)


def outcome_table(
    df: pd.DataFrame,
    group_col: Optional[str] = None,
) -> pd.DataFrame:
    """
    Return a 30-day outcomes table (mortality, major morbidity, return to OR,
    readmission) stratified by *group_col* if provided.
    """
    outcome_fields = {
        "died30":          "Mortality",
        "major_morbidity": "Major morbidity",
        "returnor":        "Return to OR",
        "readmission1":    "Readmission",
    }
    groups = (
        df.groupby(group_col, observed=True)
        if group_col and group_col in df.columns
        else [("Overall", df)]
    )
    rows = []
    for gname, gdf in groups:
        row = {"group": str(gname), "N": len(gdf)}
        for field, label in outcome_fields.items():
            if field in gdf.columns and gdf[field].notna().any():
                row[f"{label}_N"]   = int(gdf[field].sum())
                row[f"{label}_pct"] = f"{gdf[field].mean():.2%}"
            else:
                row[f"{label}_N"]   = "N/A"
                row[f"{label}_pct"] = "N/A"
        rows.append(row)
    return pd.DataFrame(rows)


def run_trend_test(df: pd.DataFrame, outdir: str, label: str) -> None:
    """
    Cochran-Armitage trend test on annual benign resection rates.

    Writes *outdir*/trend_test.txt.  Requires at least 3 years of data.
    """
    yr = (
        df.groupby("svc_year", observed=True)
        .agg(N=("is_benign", "count"), N_benign=("is_benign", "sum"))
        .reset_index()
    )
    yr["svc_year"] = pd.to_numeric(yr["svc_year"], errors="coerce")
    yr = yr.dropna().sort_values("svc_year")
    yr["rate"] = yr["N_benign"] / yr["N"]
    if len(yr) < 3:
        return

    yrs = yr["svc_year"].values
    rs  = yr["rate"].values
    ns  = yr["N"].values
    yc  = yrs - yrs.mean()

    wse_denom = np.sum(ns * yc ** 2)
    if wse_denom == 0:
        return

    ws  = np.sum(ns * yc * rs) / wse_denom
    p0  = rs.mean()
    wse = np.sqrt(p0 * (1 - p0) * wse_denom / (np.sum(ns) * wse_denom))
    wz  = ws / wse if wse > 0 else 0
    wp  = 2 * (1 - stats.norm.cdf(abs(wz)))

    txt = (
        f"Cochran-Armitage Trend Test — {label}\n{'='*60}\n"
        + yr[["svc_year", "N", "N_benign", "rate"]].to_string(index=False)
        + f"\n\nSlope : {ws:+.6f}/yr ({ws*100:+.4f} pct pts/yr)\n"
        f"Z     : {wz:.3f}\np     : {wp:.4f}\n"
        f"Result: {'SIGNIFICANT' if wp < 0.05 else 'Not significant'} at p < 0.05\n"
    )
    print(f"    Trend: slope={ws*100:+.4f} pct-pts/yr  Z={wz:.3f}  p={wp:.4f}")
    with open(f"{outdir}/trend_test.txt", "w") as fh:
        fh.write(txt)


def prepare_regression_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Encode dummy/indicator variables needed by the logistic models.

    Adds: female, wedge, segmentectomy, pneumonectomy, vats
    (Lobectomy / Open are implicit reference categories.)
    """
    reg = df.copy()
    reg["female"]        = (reg["sex_label"] == "Female").astype(int)
    reg["wedge"]         = (reg["extent_group"] == "Wedge").astype(int)
    reg["segmentectomy"] = (reg["extent_group"] == "Segmentectomy").astype(int)
    reg["pneumonectomy"] = (reg["extent_group"] == "Pneumonectomy").astype(int)
    reg["vats"]          = (reg["approach"] == "VATS").astype(int)
    return reg


def impute_and_prepare(
    reg: pd.DataFrame,
    extra_cols: Optional[list[str]] = None,
) -> tuple[pd.DataFrame, list[str]]:
    """
    Impute missing covariates and return an analysis-ready DataFrame.

    Imputation strategy:
        Continuous (age_num, bmi) → median imputation
        Binary (comorbidities, procedure flags) → mode imputation (typically 0)

    Parameters
    ----------
    reg : DataFrame
        Output of :func:`prepare_regression_df`.
    extra_cols : list[str], optional
        Additional columns (e.g. ['svc_year']) to include and require
        non-null in the returned DataFrame.

    Returns
    -------
    complete : DataFrame
        Rows with all required columns present after imputation.
    covars : list[str]
        Subset of BASE_COVARIATES_NSQIP with at least 2 unique values
        (singletons dropped to prevent rank deficiency).
    """
    cols = ["is_benign"] + BASE_COVARIATES_NSQIP
    if extra_cols:
        cols = cols + [c for c in extra_cols if c not in cols]

    for c in cols:
        if c in reg.columns:
            reg[c] = pd.to_numeric(reg[c], errors="coerce")
        else:
            reg[c] = np.nan

    for c in CONTINUOUS_COVARIATES_NSQIP:
        if c in reg.columns and reg[c].isna().any():
            reg[c] = reg[c].fillna(reg[c].median())

    for c in BINARY_COVARIATES_NSQIP:
        if c in reg.columns and reg[c].isna().any():
            mode_val = reg[c].mode().iloc[0] if not reg[c].mode().empty else 0
            reg[c] = reg[c].fillna(mode_val)

    must_have = ["is_benign"] + (extra_cols or [])
    complete  = reg[cols].dropna(subset=must_have)

    drop_cols = [c for c in BASE_COVARIATES_NSQIP if complete[c].nunique() < 2]
    covars    = [c for c in BASE_COVARIATES_NSQIP if c not in drop_cols]
    return complete, covars


def run_logistic_regression(
    df: pd.DataFrame,
    outdir: str,
    label: str,
) -> None:
    """
    Fit a logistic regression (is_benign ~ covariates) and write:
        outdir/odds_ratios.csv
        outdir/logistic_regression.txt

    Skipped if statsmodels is unavailable or sample size < 100.
    """
    if not HAS_SM:
        print("    [SKIP logistic] statsmodels not installed")
        return

    reg = prepare_regression_df(df)
    complete, covars = impute_and_prepare(reg)
    n_comp = len(complete)
    print(f"    Logistic N={n_comp:,} (after imputation, of {len(df):,})")

    if n_comp < 100 or complete["is_benign"].nunique() < 2 or not covars:
        print("    [SKIP logistic] insufficient data or no valid covariates")
        return

    formula = "is_benign ~ " + " + ".join(covars)
    print(f"    Formula: {formula}")

    try:
        model = smf.logit(formula, data=complete).fit(
            method="bfgs", maxiter=500, disp=False
        )
        ci   = model.conf_int()
        or_df = pd.DataFrame({
            "covariate": model.params.index,
            "OR":        np.exp(model.params.values).round(4),
            "CI_low":    np.exp(ci[0].values).round(4),
            "CI_high":   np.exp(ci[1].values).round(4),
            "p_value":   model.pvalues.values.round(4),
        })
        os.makedirs(outdir, exist_ok=True)
        or_df.to_csv(f"{outdir}/odds_ratios.csv", index=False)
        with open(f"{outdir}/logistic_regression.txt", "w") as fh:
            fh.write(
                f"Logistic Regression — {label}\n"
                f"NOTE: BMI is INCLUDED (available in NSQIP, unlike Medicare)\n"
                f"Formula: {formula}\nN = {n_comp:,}\n\n"
                + model.summary().as_text()
            )
        print(f"    OR table:\n{or_df.to_string(index=False)}")
    except Exception as exc:
        import traceback
        print(f"    [WARN] Regression failed: {exc}")
        traceback.print_exc()


def run_risk_adjusted_rates(
    df: pd.DataFrame,
    outdir: str,
    label: str,
) -> None:
    """
    Risk-adjusted benign resection rates by year via marginal standardisation.

    Method
    ------
    1. Fit logistic: is_benign ~ C(svc_year) + case-mix covariates.
    2. For each year k, clone the full study population with svc_year=k
       (all other covariates unchanged) and compute mean predicted P(benign).
    3. 95% CIs via the delta method.
    4. Linear trend on adjusted rates written to outdir/risk_adjusted_trend_test.txt.

    Writes
    ------
        outdir/risk_adjusted_rate_by_year.csv
        outdir/risk_adjusted_trend_test.txt
    """
    if not HAS_SM:
        print("    [SKIP risk-adjusted] statsmodels not installed")
        return

    from scipy.stats import linregress  # local import to keep top-level clean

    reg = prepare_regression_df(df)
    if "svc_year" not in reg.columns:
        print("    [SKIP risk-adjusted] svc_year column missing")
        return

    reg["svc_year"] = pd.to_numeric(reg["svc_year"], errors="coerce")
    complete, covars = impute_and_prepare(reg, extra_cols=["svc_year"])
    complete = complete.dropna(subset=["svc_year"])
    complete["svc_year"] = complete["svc_year"].astype(int)
    n_comp = len(complete)

    years = sorted(complete["svc_year"].unique())
    if len(years) < 3 or n_comp < 200 or complete["is_benign"].nunique() < 2 or not covars:
        print("    [SKIP risk-adjusted] insufficient data")
        return

    ref_year = years[0]
    year_dummies = pd.get_dummies(complete["svc_year"], prefix="yr", drop_first=False, dtype=float)
    ref_col = f"yr_{ref_year}"
    if ref_col in year_dummies.columns:
        year_dummies = year_dummies.drop(columns=[ref_col])
    yr_cols = list(year_dummies.columns)

    X = sm.add_constant(pd.concat([year_dummies, complete[covars]], axis=1))
    y = complete["is_benign"].values

    formula_str = (
        f"is_benign ~ C(svc_year, Treatment(reference={ref_year})) + "
        + " + ".join(covars)
    )
    print(f"    Risk-adjusted formula: {formula_str}")
    print(f"    N={n_comp:,}  years={years[0]}–{years[-1]}  ref_year={ref_year}")

    try:
        model = sm.Logit(y, X.astype(float)).fit(method="bfgs", maxiter=500, disp=False)
    except Exception as exc:
        import traceback
        print(f"    [WARN] Risk-adjusted model failed: {exc}")
        traceback.print_exc()
        return

    results = []
    for yr_val in years:
        X_cf = X.copy()
        for yc in yr_cols:
            X_cf[yc] = 0.0
        target_col = f"yr_{yr_val}"
        if target_col in X_cf.columns:
            X_cf[target_col] = 1.0

        pred     = model.predict(X_cf.astype(float))
        adj_rate = pred.mean()

        try:
            dpdb  = (pred * (1 - pred)).values[:, None] * X_cf.values
            g     = dpdb.mean(axis=0)
            var_adj = g @ model.cov_params().values @ g
            se_adj  = np.sqrt(max(var_adj, 0))
            ci_lo   = max(adj_rate - 1.96 * se_adj, 0)
            ci_hi   = min(adj_rate + 1.96 * se_adj, 1)
        except (ValueError, np.linalg.LinAlgError):
            se_adj = ci_lo = ci_hi = np.nan

        yr_mask = complete["svc_year"] == yr_val
        n_yr    = int(yr_mask.sum())
        n_ben   = int(complete.loc[yr_mask, "is_benign"].sum())
        crude   = n_ben / n_yr if n_yr > 0 else np.nan

        results.append({
            "year":               yr_val,
            "N":                  n_yr,
            "N_benign":           n_ben,
            "crude_rate":         round(crude, 6),
            "risk_adjusted_rate": round(adj_rate, 6),
            "ra_CI_low":          round(ci_lo, 6),
            "ra_CI_high":         round(ci_hi, 6),
            "ra_SE":              round(se_adj, 6),
        })

    ra_df = pd.DataFrame(results)
    os.makedirs(outdir, exist_ok=True)
    ra_df.to_csv(f"{outdir}/risk_adjusted_rate_by_year.csv", index=False)
    print(f"    Saved risk_adjusted_rate_by_year.csv ({len(ra_df)} rows)")
    for _, r in ra_df.iterrows():
        print(f"      {int(r['year'])}  N={int(r['N']):>6,}  crude={r['crude_rate']:.4f}  "
              f"adj={r['risk_adjusted_rate']:.4f} ({r['ra_CI_low']:.4f}–{r['ra_CI_high']:.4f})")

    if len(ra_df) >= 3:
        slope, intercept, r_val, p_val, _ = linregress(
            ra_df["year"].values.astype(float),
            ra_df["risk_adjusted_rate"].values,
        )
        trend_txt = (
            f"Risk-Adjusted Trend Test — {label}\n{'='*60}\n"
            f"Method: Marginal standardisation via logistic regression\n"
            f"  Model: is_benign ~ C(svc_year) + {' + '.join(covars)}\n"
            f"  Reference year: {ref_year}\n"
            f"  N (complete cases): {n_comp:,}\n"
            f"  Years: {years[0]}–{years[-1]}\n\n"
            + ra_df.to_string(index=False) + "\n\n"
            f"Linear trend on risk-adjusted rates:\n"
            f"  Slope     : {slope:+.6f}/yr ({slope*100:+.4f} pct pts/yr)\n"
            f"  Intercept : {intercept:.6f}\n"
            f"  R²        : {r_val**2:.4f}\n"
            f"  p-value   : {p_val:.4f}\n"
            f"  Result    : {'SIGNIFICANT' if p_val < 0.05 else 'Not significant'} at p < 0.05\n"
        )
        with open(f"{outdir}/risk_adjusted_trend_test.txt", "w") as fh:
            fh.write(trend_txt)
        print(f"    Risk-adj trend: slope={slope*100:+.4f} pct-pts/yr  "
              f"R²={r_val**2:.4f}  p={p_val:.4f}")
