#!/usr/local/bin/python3
import argparse
import os
import signal
import sys
import time
import threading
import queue
from dataclasses import dataclass
from typing import Optional, Dict, Tuple

from ldap3 import Server, Connection, SUBTREE, core
from ldap3.utils.conv import escape_filter_chars


def eprint(*a):
    print(*a, file=sys.stderr, flush=True)


def smtp_reply(sessionid: str, token: str, decision: str, msg: str = ""):
    try:
        if msg:
            sys.stdout.write(f"filter-result|{sessionid}|{token}|{decision}|{msg}\n")
        else:
            sys.stdout.write(f"filter-result|{sessionid}|{token}|{decision}\n")
        sys.stdout.flush()
    except BrokenPipeError:
        os._exit(0)


def norm_addr(raw: str, strip_subaddr: bool) -> str:
    s = raw.strip()
    if s.startswith("<") and s.endswith(">") and len(s) > 2:
        s = s[1:-1].strip()
    s = s.lower()
    if strip_subaddr and "@" in s:
        local, dom = s.split("@", 1)
        if "+" in local:
            local = local.split("+", 1)[0]
        s = f"{local}@{dom}"
    return s


class TTLCache:
    def __init__(self, ttl_pos: int = 300, ttl_neg: int = 30, max_items: int = 2000):
        self.ttl_pos = max(0, int(ttl_pos))
        self.ttl_neg = max(0, int(ttl_neg))
        self.max_items = max(100, int(max_items))
        self._lock = threading.Lock()
        self._data: Dict[str, Tuple[str, float]] = {}  # addr -> (result, expires_at)

    def get(self, addr: str) -> Optional[str]:
        now = time.time()
        with self._lock:
            v = self._data.get(addr)
            if not v:
                return None
            res, exp = v
            if exp < now:
                self._data.pop(addr, None)
                return None
            return res

    def set(self, addr: str, res: str):
        ttl = self.ttl_pos if res == "found" else self.ttl_neg
        if ttl <= 0:
            return
        exp = time.time() + ttl
        with self._lock:
            if len(self._data) >= self.max_items:
                self._prune_locked(time.time(), aggressive=True)
            self._data[addr] = (res, exp)

    def _prune_locked(self, now: float, aggressive: bool = False):
        expired = [k for k, (_r, exp) in self._data.items() if exp < now]
        for k in expired:
            self._data.pop(k, None)
        if aggressive and len(self._data) >= self.max_items:
            drop = max(1, self.max_items // 10)
            for k in list(self._data.keys())[:drop]:
                self._data.pop(k, None)


class LDAPBackoff:
    def __init__(self, initial: int = 10, maximum: int = 120, factor: float = 2.0):
        self.initial = max(1, int(initial))
        self.maximum = max(self.initial, int(maximum))
        self.factor = max(1.0, float(factor))
        self._lock = threading.Lock()
        self._dead_until = 0.0
        self._current = float(self.initial)

    def should_try(self) -> bool:
        now = time.monotonic()
        with self._lock:
            return now >= self._dead_until

    def report_failure(self):
        now = time.monotonic()
        with self._lock:
            self._dead_until = max(self._dead_until, now + self._current)
            self._current = min(float(self.maximum), self._current * self.factor)

    def report_success(self):
        with self._lock:
            self._dead_until = 0.0
            self._current = float(self.initial)


class LDAPChecker:
    def __init__(self, host: str, port: int, binddn: str, password: str, base: str, timeout: int, close_each: bool = False):
        self.server = Server(host, port=port, connect_timeout=int(timeout), get_info=None)
        self.binddn = binddn
        self.password = password
        self.base = base
        self.timeout = int(timeout)
        self.close_each = bool(close_each)
        self.conn: Optional[Connection] = None

    def _hard_close(self):
        """
        Forcefully tear down current connection.

        We do BOTH:
          - conn.unbind() (graceful)
          - close underlying socket if present (prevents lingering CLOSE_WAIT)
        """
        c = self.conn
        self.conn = None
        if not c:
            return
        try:
            c.unbind()
        except Exception:
            pass
        try:
            s = getattr(c, "socket", None)
            if s:
                s.close()
        except Exception:
            pass

    def _connect(self) -> bool:
        try:
            self.conn = Connection(
                self.server,
                user=self.binddn,
                password=self.password,
                auto_bind=True,
                raise_exceptions=True,
                receive_timeout=self.timeout,
            )
            return True
        except Exception as ex:
            self.conn = None
            eprint(f"rcpt_ad: connect/bind failed: {ex!r}")
            return False

    def _ensure_bound(self) -> bool:
        if self.conn is None:
            return self._connect()
        try:
            if not self.conn.bound:
                self._hard_close()
                return self._connect()
            return True
        except Exception as ex:
            eprint(f"rcpt_ad: ensure_bound failed: {ex!r}")
            self._hard_close()
            return False

    def exists_rcpt(self, addr: str) -> str:
        """
        "found" / "notfound" / "tempfail"
        """
        if not self._ensure_bound():
            return "tempfail"

        a = escape_filter_chars(addr)
        filt = (
            "(&(|(objectClass=user)(objectClass=group))"
            f"(|(mail={a})(proxyAddresses=SMTP:{a})(proxyAddresses=smtp:{a})))"
        )

        try:
            ok = self.conn.search(
                search_base=self.base,
                search_filter=filt,
                search_scope=SUBTREE,
                attributes=["dn"],
                size_limit=1,
                time_limit=self.timeout,
            )

            if not ok:
                res = getattr(self.conn, "result", {}) or {}
                if res.get("result", 1) != 0:
                    eprint(f"rcpt_ad: search error: {res}")
                    return "tempfail"
                result = "notfound"
            else:
                result = "found" if self.conn.entries else "notfound"

            if self.close_each:
                self._hard_close()

            return result

        except core.exceptions.LDAPExceptionError as ex:
            eprint(f"rcpt_ad: LDAPException: {ex!r}")
            self._hard_close()
            return "tempfail"
        except Exception as ex:
            eprint(f"rcpt_ad: unexpected error: {ex!r}")
            self._hard_close()
            return "tempfail"


@dataclass(frozen=True)
class WorkItem:
    sessionid: str
    token: str
    addr: str


@dataclass(frozen=True)
class OutItem:
    sessionid: str
    token: str
    decision: str
    msg: str


def writer_main(out_q: "queue.Queue[OutItem]"):
    while True:
        item = out_q.get()
        if item is None:  # sentinel (unused)
            return
        smtp_reply(item.sessionid, item.token, item.decision, item.msg)


def worker_main(
    work_q: "queue.Queue[WorkItem]",
    out_q: "queue.Queue[OutItem]",
    cache: TTLCache,
    breaker: LDAPBackoff,
    ldap_host: str,
    ldap_port: int,
    binddn: str,
    password: str,
    base: str,
    timeout: int,
    close_each: bool,
):
    checker = LDAPChecker(ldap_host, ldap_port, binddn, password, base, timeout, close_each=close_each)

    while True:
        item = work_q.get()
        if item is None:
            return

        # circuit breaker: recent failures => instant tempfail
        if not breaker.should_try():
            out_q.put(OutItem(item.sessionid, item.token, "reject", "451 4.3.0 temporary lookup failure"))
            continue

        res = checker.exists_rcpt(item.addr)

        if res == "tempfail":
            breaker.report_failure()
            out_q.put(OutItem(item.sessionid, item.token, "reject", "451 4.3.0 temporary lookup failure"))
            continue

        breaker.report_success()
        cache.set(item.addr, res)

        if res == "found":
            out_q.put(OutItem(item.sessionid, item.token, "proceed", ""))
        else:
            out_q.put(OutItem(item.sessionid, item.token, "reject", "550 5.1.1 invalid recipient"))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ldap-host", default="127.0.0.1")
    ap.add_argument("--ldap-port", type=int, default=1389)
    ap.add_argument("--binddn", required=True)
    ap.add_argument("--password", required=True)
    ap.add_argument("--base", required=True)
    ap.add_argument("--timeout", type=int, default=3)

    ap.add_argument("--strip-subaddr", action="store_true")

    # NEW: maximum stability: do not reuse LDAP connection
    ap.add_argument("--ldap-close-each", action="store_true",
                    help="Unbind/close LDAP connection after each lookup (max stability, more overhead).")

    # tuned defaults for 1 vCPU / 1 GB
    ap.add_argument("--workers", type=int, default=2)
    ap.add_argument("--queue-size", type=int, default=200)

    ap.add_argument("--cache-ttl", type=int, default=300)
    ap.add_argument("--neg-cache-ttl", type=int, default=30)
    ap.add_argument("--cache-max", type=int, default=2000)

    ap.add_argument("--ldap-backoff-initial", type=int, default=10)
    ap.add_argument("--ldap-backoff-max", type=int, default=120)
    ap.add_argument("--ldap-backoff-factor", type=float, default=2.0)

    args = ap.parse_args()

    try:
        signal.signal(signal.SIGPIPE, signal.SIG_IGN)
    except Exception:
        pass

    cache = TTLCache(args.cache_ttl, args.neg_cache_ttl, args.cache_max)
    breaker = LDAPBackoff(args.ldap_backoff_initial, args.ldap_backoff_max, args.ldap_backoff_factor)

    work_q: "queue.Queue[WorkItem]" = queue.Queue(maxsize=max(10, int(args.queue_size)))
    out_q: "queue.Queue[OutItem]" = queue.Queue()

    threading.Thread(target=writer_main, args=(out_q,), name="writer", daemon=True).start()

    workers = max(1, int(args.workers))
    for _ in range(workers):
        threading.Thread(
            target=worker_main,
            args=(
                work_q, out_q, cache, breaker,
                args.ldap_host, args.ldap_port, args.binddn, args.password, args.base, args.timeout,
                args.ldap_close_each,
            ),
            daemon=True,
        ).start()

    # handshake
    while True:
        line = sys.stdin.readline()
        if not line:
            time.sleep(1)
            continue
        if line.rstrip("\n") == "config|ready":
            sys.stdout.write("register|filter|smtp-in|rcpt-to\n")
            sys.stdout.write("register|ready\n")
            sys.stdout.flush()
            break

    # main loop: never block on LDAP
    while True:
        line = sys.stdin.readline()
        if not line:
            time.sleep(1)
            continue
        line = line.rstrip("\n")

        if not line.startswith("filter|"):
            continue

        parts = line.split("|", 7)
        if len(parts) != 8:
            continue

        _tag, _ver, _ts, subsystem, phase, sessionid, token, address = parts

        if subsystem != "smtp-in" or phase != "rcpt-to":
            out_q.put(OutItem(sessionid, token, "proceed", ""))
            continue

        addr = norm_addr(address, args.strip_subaddr)

        cached = cache.get(addr)
        if cached == "found":
            out_q.put(OutItem(sessionid, token, "proceed", ""))
            continue
        if cached == "notfound":
            out_q.put(OutItem(sessionid, token, "reject", "550 5.1.1 invalid recipient"))
            continue

        try:
            work_q.put_nowait(WorkItem(sessionid, token, addr))
        except queue.Full:
            out_q.put(OutItem(sessionid, token, "reject", "451 4.3.0 temporary lookup failure"))


if __name__ == "__main__":
    try:
        main()
    except Exception as ex:
        eprint(f"rcpt_ad: fatal top-level exception: {ex!r}")
        sys.exit(1)
