#!/usr/bin/python3
import os
import sys
import urllib
from urllib.parse import urljoin, quote
from urllib.request import Request, URLError, urlopen
import json
import time
import logging
import subprocess

logging.basicConfig()
logger = logging.getLogger("cockpit-auth-ovirt")

# if OVIRT_FQDN env variable is not set
DEFAULT_API_HOST = "localhost"
OVIRT_FQDN = os.environ.get("OVIRT_FQDN", DEFAULT_API_HOST)

# "/host" for container or empty otherwise
ROOT_DIR_PREFIX = os.environ.get("ROOT_DIR_PREFIX", "")

# pre-loaded content of https://[FQDN]/ovirt-engine/services/pki-resource?resource=ca-certificate&format=X509-PEM-CA
OVIRT_CA_FILE = "{0}/usr/share/ovirt-cockpit-sso/ca.pem".format(ROOT_DIR_PREFIX)
COCKPIT_SSH_COMMAND = "{0}/usr/libexec/cockpit-ssh".format(ROOT_DIR_PREFIX)

ENGINE_P12 = "{0}/etc/pki/ovirt-engine/keys/engine.p12".format(ROOT_DIR_PREFIX)
# TODO: read it from /etc/ovirt-engine/engine.conf.d/10-setup-pki.conf as ENGINE_PKI_ENGINE_STORE_PASSWORD
PKI_PASSWORD = "mypass"
KEY_GEN_SH = "{0}/usr/share/ovirt-cockpit-sso/keygen.sh".format(ROOT_DIR_PREFIX)

DEFAULT_USER = "root"


def usage():
    sys.stderr.write("usage {} [user@]host[:port]\n".format(sys.argv[0]))
    sys.exit(os.EX_USAGE)


def send_auth_command(challenge, response):
    cmd = {
        "command": "authorize",
    }

    if challenge:
        cmd["cookie"] = "session{}{}".format(os.getpid(), time.time())
        cmd["challenge"] = challenge
    if response:
        cmd["response"] = response

    text = json.dumps(cmd)
    os.write(1, str.encode("{}\n\n{}".format(len(text)+1, text)))


def send_problem_init(problem, message, auth_methods):
    cmd = {
        "command": "init",
        "problem": problem
    }

    if message:
        cmd["message"] = message

    if auth_methods:
        cmd["auth-method-results"] = auth_methods

    text = json.dumps(cmd)
    os.write(1, str.encode("{}\n\n{}".format(len(text)+1, text)))


def read_size(fd):
    size = 0
    seen = 0

    while True:
        t = bytes.decode(os.read(fd, 1))

        if not t:
            return 0

        if t == '\n':
            break

        size = (size * 10) + int(t)
        seen = seen + 1

        if seen > 7:
            raise ValueError("Invalid frame: size too long")

    return size


def read_frame(fd):
    size = read_size(fd)

    data = ""
    while size > 0:
        d = bytes.decode(os.read(fd, size))
        size = size - len(d)
        data = data + d

    return data


def read_auth_reply():
    data = read_frame(1)
    cmd = json.loads(data)
    response = cmd.get("response")
    if cmd.get("command") != "authorize" or \
       not cmd.get("cookie") or not response:
        raise ValueError("Did not receive a valid authorize command")

    return response


def read_key():
    # generate private key from .p12 file
    #
    # RSA header is supported only
    # unsupported fingerprints etc. are removed from the beginning
    private_key = subprocess.Popen(
        [KEY_GEN_SH, ENGINE_P12, PKI_PASSWORD],
        stdout=subprocess.PIPE
    )
    private_key = bytes.decode(private_key.communicate()[0])
    return private_key


def send_key():
    logger.debug("send_key() starts")
    data = read_key()
    send_auth_command(None, "private-key {}".format(data))
    logger.debug("send_auth_command() finished")


def get_host(host_id, auth_data):
    url = urljoin(
        "https://{0}".format(OVIRT_FQDN),
        "/ovirt-engine/api/hosts/{}".format(quote(host_id))
    )

    q = Request(url)
    q.add_header("Authorization", auth_data)
    q.add_header("Accept", "application/json")

    code = 0
    host = None
    try:
        logger.debug("oVirt: requesting host: %s", url)
        r = urlopen(q, cafile=OVIRT_CA_FILE)

        logger.debug("Data received")
        code = r.getcode()
        logger.debug("code: %s", code)
        obj = json.loads(r.read())
        logger.debug("object parsed: %s", json.dumps(obj))
        host = obj.get("address")
        ssh = obj.get("ssh", {})
        port = ssh.get("port")
        user = obj.get("username")
        if host:
            if port and port != "0":
                host = "{}:{}".format(host, port)
            if not user:
                user = DEFAULT_USER
            host = "{}@{}".format(user, host)
        else:
            sys.stderr.write("Got invalid data: no host_name field")
    except URLError as err:
        if hasattr(err, "code"):
            code = err.code

        if code != 404 and code != 401:
            sys.stderr.write("Got unexpected error {}".format(err))
    except ValueError:
        sys.stderr.write("Got invalid data: invalid json")

    return code, host


def main(args):
    if len(args) != 2:
        usage()

    logger.debug("cockpit-ovirt-auth starts")
    send_auth_command("*", None)
    logger.debug("auth command sent")
    try:
        resp = read_auth_reply()
    except ValueError as e:
        send_problem_init("internal-error", str(e), {})
        raise

    logger.debug("oVirt response received")
    code, host = get_host(args[1], resp)
    if not host:
        if code == 401:
            send_problem_init("authentication-failed", "Token was not valid",
                              {"password": "not-tried", "token": "denied"})
        if code == 404:
            send_problem_init("unknown-host", None, None)
        else:
            send_problem_init(
                "internal-error",
                "Error validating auth token",
                None
            )
        sys.exit(1)

    send_key()

    os.environ["COCKPIT_SSH_ALLOW_UNKNOWN"] = "1"
    os.execlpe(COCKPIT_SSH_COMMAND, COCKPIT_SSH_COMMAND, host, os.environ)


if __name__ == '__main__':
    main(sys.argv)
