#!/usr/bin/env python3
"""
Get a client certificate for Polytechnique Montréal's eduroam wifi

Assuming you can login to Polytechnique Montréal's CAS server
(Central Authentication Service), this custom script will create a
new private key and get it signed by SecureW2's PKI infrastructure
so that you can connect to the eduroam wifi network using EAP-TLS.

The generated key and certificate will be saved in PKCS#12 format
in a file named "{device}-{fingerprint}.p12".

It will also save the appropriate CA certificate in PEM format to
the file "entrust-root-ca.crt", if it doesn't already exist.

Both files are written to the current working directory.

You should configure your wifi connection to use the identity:

  anonymous854623@polymtl.ca

This script requires openssl >=3.0.0 and a working internet connection.
"""

from base64 import b64encode, b64decode
from hashlib import sha1
from http.server import HTTPServer, BaseHTTPRequestHandler
import json
import os
import subprocess
from typing import Optional
from urllib.parse import urlencode, parse_qs
from urllib.request import Request, urlopen
from uuid import uuid4


def main(device: Optional[str]):
    if device is None:
        print("Generating a new device id...")
        device = sha1(os.urandom(20)).hexdigest()
        print(f"\n  {device = }\n")

    print("Generating a new private key...")
    pkey = openssl(
        command="genpkey",
        args=["-algorithm", "rsa"]
        + ["-pkeyopt", "rsa_keygen_bits:2048"]
        + ["-outform", "pem"],
    )

    print("Generating a certificate signing request...")
    csr = openssl(
        command="req",
        args=["-batch", "-new", "-sha256"]
        + ["-key", "/dev/stdin"]
        + ["-subj", "/CN=anonymous"]
        + ["-addext", "keyUsage=digitalSignature,nonRepudiation"]
        + ["-addext", "extendedKeyUsage=clientAuth"]
        + ["-outform", "der"],
        stdin=pkey,
    )

    print("Getting an auth code...")
    auth = get_password_auth()

    print("Getting a transaction id...")
    transaction_id = pki_challenge_request()

    print("Getting a signed certificate...")
    cert, revoked = pki_enroll(auth, csr, device, transaction_id)
    fingerprint = sha1(cert).hexdigest()
    if revoked:
        print("\n  Revoked certificate fingerprints:")
        for r in revoked:
            print(f"    {r}")
        print()

    print("Saving the client certificate bundle...")
    filename = f"{device}-{fingerprint}.p12"
    password = str(uuid4())
    print(f"\n  {filename = }\n  {password = }\n")
    # We specify PBS-SHA1-3DES and SHA1 for compatibility with Android
    openssl(
        command="pkcs12",
        args=["-export"]
        + ["-out", filename]
        + ["-name", f"Polyroam {fingerprint[:8]}..."]
        + ["-keypbe", "PBE-SHA1-3DES"]
        + ["-certpbe", "PBE-SHA1-3DES"]
        + ["-macalg", "SHA1"],
        passwords={"-passout": password},
        stdin=pkey + pem_from_der("CERTIFICATE", cert).encode("utf-8"),
    )

    print("Saving the CA certificate...")
    filename = "entrust-root-ca.crt"
    if os.path.exists(filename):
        print("\n  Already exists, skipping.\n")
    else:
        with open(filename, "w") as file:
            file.write(
                """\
-----BEGIN CERTIFICATE-----
MIIEPjCCAyagAwIBAgIESlOMKDANBgkqhkiG9w0BAQsFADCBvjELMAkGA1UEBhMC
VVMxFjAUBgNVBAoTDUVudHJ1c3QsIEluYy4xKDAmBgNVBAsTH1NlZSB3d3cuZW50
cnVzdC5uZXQvbGVnYWwtdGVybXMxOTA3BgNVBAsTMChjKSAyMDA5IEVudHJ1c3Qs
IEluYy4gLSBmb3IgYXV0aG9yaXplZCB1c2Ugb25seTEyMDAGA1UEAxMpRW50cnVz
dCBSb290IENlcnRpZmljYXRpb24gQXV0aG9yaXR5IC0gRzIwHhcNMDkwNzA3MTcy
NTU0WhcNMzAxMjA3MTc1NTU0WjCBvjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUVu
dHJ1c3QsIEluYy4xKDAmBgNVBAsTH1NlZSB3d3cuZW50cnVzdC5uZXQvbGVnYWwt
dGVybXMxOTA3BgNVBAsTMChjKSAyMDA5IEVudHJ1c3QsIEluYy4gLSBmb3IgYXV0
aG9yaXplZCB1c2Ugb25seTEyMDAGA1UEAxMpRW50cnVzdCBSb290IENlcnRpZmlj
YXRpb24gQXV0aG9yaXR5IC0gRzIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
AoIBAQC6hLZy254Ma+KZ6TABp3bqMriVQRrJ2mFOWHLP/vaCeb9zYQYKpSfYs1/T
RU4cctZOMvJyig/3gxnQaoCAAEUesMfnmr8SVycco2gvCoe9amsOXmXzHHfV1IWN
cCG0szLni6LVhjkCsbjSR87kyUnEO6fe+1R9V77w6G7CebI6C1XiUJgWMhNcL3hW
wcKUs/Ja5CeanyTXxuzQmyWC48zCxEXFjJd6BmsqEZ+pCm5IO2/b1BEZQvePB7/1
U1+cPvQXLOZprE4yTGJ36rfo5bs0vBmLrpxR57d+tVOxMyLlbc9wPBr64ptntoP0
jaWvYkxN4FisZDQSA/i2jZRjJKRxAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAP
BgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBRqciZ60B7vfec7aVHUbI2fkBJmqzAN
BgkqhkiG9w0BAQsFAAOCAQEAeZ8dlsa2eT8ijYfThwMEYGprmi5ZiXMRrEPR9RP/
jTkrwPK9T3CMqS/qF8QLVJ7UG5aYMzyorWKiAHarWWluBh1+xLlEjZivEtRh2woZ
Rkfz6/djwUAFQKXSt/S1mja/qYh2iARVBCuch38aNzx+LaUa2NSJXsq9rD1s2G2v
1fN2D807iDginWyTmsQ9v4IbZT+mD12q/OWyFcq1rca8PdCE6OoGcrBNOTJ4vz4R
nAuknZoh8/CbCzB428Hch0P+vGOaysXCHMnHjf87ElgI5rY97HosTvuDls4MPGmH
VHOkc8KT/1EQrBVUAdj8BbGJoX90g5pJ19xOe4pIb4tF9g==
-----END CERTIFICATE-----
"""
            )
        print(f"\n  {filename = }\n")

    print(
        """\
Done!

  You should configure your wifi connection to use the identity:

  anonymous854623@polymtl.ca
"""
    )


def openssl(
    command: str,
    args: list[str],
    passwords: dict[str, str] = {},
    stdin: bytes = b"",
) -> bytes:
    """Run an openssl command, sending all passwords through pipes"""
    password_args = []
    pass_fds = []  # File descriptors for the read side of the password pipes
    try:
        for option, password in passwords.items():
            r, w = os.pipe()
            pass_fds.append(r)
            with os.fdopen(w, "wb") as w:
                w.write(password.encode("utf-8"))
            # Tell openssl which file descriptor contains this password
            password_args.extend([option, f"fd:{r}"])
        # The command argument has to be first, but the other args could
        # end with "--" and filename-only arguments, so we insert the
        # password arguments in between
        process = subprocess.run(
            args=["openssl", command] + password_args + args,
            input=stdin,
            capture_output=True,
            pass_fds=pass_fds,
        )
        process.check_returncode()
    except subprocess.CalledProcessError:
        print("=== args ===")
        print(process.args)
        print("=== stdout ===")
        print(process.stdout.decode("utf-8", errors="backslashreplace"))
        print("=== stderr ===")
        print(process.stderr.decode("utf-8", errors="backslashreplace"))
        raise
    finally:
        # Make sure all the file descriptors get closed
        exceptions = []
        for fd in pass_fds:
            try:
                os.close(fd)
            except Exception as e:
                exceptions.append(e)
        if exceptions:
            raise ExceptionGroup("error(s) while closing pipes", exceptions)
    return process.stdout


def get_password_auth() -> dict:
    """Get an auth code from SecureW2, using Polytechnique's SSO"""
    with HTTPServer(("localhost", 0), AuthCodeListener) as server:
        # Generate a random token to identify the auth code request,
        # that we can look for in the eventual reply from SecureW2
        server.token = str(uuid4())
        url = (
            "https://polytechniquemontreal-auth.securew2.com/auth/"
            "3d30841e-8311-4a11-ad7b-a3fd2b1b1b52/"
            "AFACEB48-1D41-4E18-A90D-1ED8CC17A0B1?"
            + urlencode(
                {
                    "response_type": "code",
                    "state": server.token,
                    "redirect_uri": f"http://localhost:{server.server_port}/",
                }
            )
        )
        print(
            f"""
  Please visit the following URL in your browser
  and login to the Polytechnique Montréal CAS server
  to get the auth code from SecureW2:

  {url}
"""
        )
        server.auth_code = None
        while server.auth_code is None:
            server.handle_request()
    return {"type": 0, "value": server.auth_code}


class AuthCodeListener(BaseHTTPRequestHandler):
    """Localhost listener for the SecureW2 auth code"""

    def log_message(self, format, *args):
        """Silently discard all log messages"""
        pass

    def send_head(self) -> Optional[str]:
        """Common code for GET and HEAD"""
        # Check that the path is "/", with a query string
        if not self.path.startswith("/?"):
            self.send_error(404)  # Not found
            return None

        # Check that we have the right "state" token, otherwise this
        # could be an auth code from a different process
        query = parse_qs(self.path[2:])
        if query.get("state", []) != [self.server.token]:
            self.send_error(404)  # Not found
            return None

        # Check that we received a single auth code
        auth_codes = query.get("code", [])
        if len(auth_codes) != 1:
            self.send_error(400)  # Bad request
            return None

        # This is the happy case: we got the auth code
        self.send_response(200)  # Ok
        self.send_header("Connection", "close")
        self.send_header("Content-Type", "text/html;charset=utf-8")
        self.send_header("Content-Length", str(len(SUCCESS_HTML)))
        self.end_headers()
        return auth_codes[0]

    def do_HEAD(self):
        """Handle HEAD requests"""
        self.send_head()

    def do_GET(self):
        """Handle GET requests"""
        auth_code = self.send_head()
        if auth_code is not None:
            # Once the server has an auth_code, it will shut down
            self.server.auth_code = auth_code
            self.wfile.write(SUCCESS_HTML)


SUCCESS_HTML = """\
<!DOCTYPE html>
<html lang="en">
<meta charset="utf-8">
<title>Success</title>
<h1>Success</h1>
<p>The script has successfully gotten an auth code from SecureW2.
<p>You may close this browser tab.
""".encode(
    "utf-8"
)


def pki_request_response(req_data: dict) -> dict:
    """Send a request to the SecureW2 PKI server and get a response"""
    req_version = req_data.get("version", None)
    req_type = req_data.get("type", None)
    with urlopen(
        Request(
            method="POST",
            url=(
                "https://pki-services.securew2.com/enroll/"
                "3d30841e-8311-4a11-ad7b-a3fd2b1b1b52"
            ),
            headers={"Content-Type": "application/json"},
            data=json.dumps(req_data).encode("utf-8"),
        )
    ) as response:
        body = response.read()
    resp_data = json.loads(body.decode("utf-8"))

    try:
        error = resp_data.get("error", None)
        if error != 0:
            detailedError = resp_data.get("detailedError", None)
            raise ValueError(f"PKI error, {error=}, {detailedError=}")

        response_version = resp_data.get("version", None)
        if response_version != req_version and req_type != "getVersion":
            raise ValueError(f"unexpected {response_version=}")

        response_type = resp_data.get("type", None)
        if response_type != req_type:
            raise ValueError(f"unexpected {response_type=}")
    except Exception:
        print("=== request ===")
        print(req_data)
        print("=== response ===")
        print(resp_data)
        raise

    return resp_data


def pki_challenge_request() -> str:
    """Get a fresh transaction-id from the PKI server"""
    req_data = {
        "version": "1.4",
        "type": "challengeRequest",
        "requests": [],
    }
    resp_data = pki_request_response(req_data)
    transaction_id = resp_data.get("transaction-id", None)
    if not isinstance(transaction_id, str):
        print("=== request ===")
        print(req_data)
        print("=== response ===")
        print(resp_data)
        raise ValueError(f"unexpected {transaction_id=}")
    return transaction_id


def pki_enroll(
    auth: dict,
    csr: bytes,
    device: str,
    transaction_id: str,
) -> tuple[bytes, list[str]]:
    """Get a client certificate from the PKI server"""
    auth = b64encode(json.dumps([auth]).encode("utf-8")).decode("utf-8")
    csr = b64encode(csr).decode("utf-8")
    req_data = {
        "version": "1.4",
        "type": "enroll",
        "challenge": auth,
        "certificateRequests": [csr],
        "clientCertificate": [],
        "configInfo": {
            "profileId": "AFACEB48-1D41-4E18-A90D-1ED8CC17A0B1",
            "name": "eduroam_eaptls",
            "UID": "99021",
        },
        "deviceAttributes": {"clientId": device},
        "identity": "",
        "transaction-id": transaction_id,
    }
    resp_data = pki_request_response(req_data)
    try:
        # We should be getting exactly one signed certificate
        signed_certificates = resp_data.get("signedCertificates", None)
        if not (
            isinstance(signed_certificates, list)
            and len(signed_certificates) == 1
            and isinstance(signed_certificates[0], str)
        ):
            raise ValueError(f"unexpected {signed_certificates=}")
        cert = b64decode(signed_certificates[0])

        # If this device already had a signed certificate,
        # it might get revoked as a side-effect
        enroll_attributes = resp_data.get("enrollAttributes", None)
        if enroll_attributes is None:
            enroll_attributes = {}
        if not isinstance(enroll_attributes, dict):
            raise ValueError(f"unexpected {enroll_attributes=}")
        revoked = enroll_attributes.get("Revoked-Certificates", None)
        if revoked is None:
            revoked = []
        if not (
            isinstance(revoked, list)
            and all(isinstance(fingerprint, str) for fingerprint in revoked)
        ):
            raise ValueError(f"unexpected {revoked=}")
        revoked = [b64decode(r).hex() for r in revoked]
    except Exception:
        print("=== request ===")
        print(req_data)
        print("=== response ===")
        print(resp_data)
        raise
    return cert, revoked


def pem_from_der(object_type: str, data: bytes) -> str:
    """Return a PEM-formatted version of some binary data"""
    data = b64encode(data).decode("utf-8")
    lines = [f"-----BEGIN {object_type}-----"]
    for i in range(0, len(data), 64):
        lines.append(data[i : i + 64])
    lines.append(f"-----END {object_type}-----\n")
    return "\n".join(lines)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        epilog=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--device",
        help=(
            "the device ID to use. Should be a sha1 hash, in other words, a "
            "40-digit hexadecimal string. If not given, a random one will be "
            "generated"
        ),
    )
    args = parser.parse_args()
    main(args.device)