#!/bin/bash
# shellcheck disable=SC2029  # Client-side expansion in ssh commands is intentional
set -eo pipefail

########################
# User definitions: username|display_name|gidnumber|qos|ssh_keys
# qos: "unlimited" or "limited" (limited = max 8 GPUs, 1 job at a time)
# ssh_keys: semicolon-separated list of "key_name:pubkey" pairs
#   key is "ssh-rsa AAAAlongkey" - no comment is allowed
USERS=(
    "jsmithg1|John Smith|79109|limited|josmith1_key_1:ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIM8YTFl1qSxBUW3zk0R7QoQK64/47wnG6v32eXZXGYVB"
    "jsmithh1|John Smith|79110|unlimited|josmith2_key_1:ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIM8YTFl1qSxBUW3zk0R7QoQK64/47wnG6v32eXZXGYVB"
    #"josmith0|John Smith|79010|limited|josmith0_key_1:ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAAAAAAAAAAAAAAAAAA;jsmith0_key_0:ssh-ed25519 EXAMPLE_MULTI_KEYAAAAIBxxREPLACExxWITHxxREALxxKEYxxxxxxxxxx"
)

#########################

DEBUG=false
JUMPHOST=""

for arg in "$@"; do
    case "${arg}" in
    --debug) DEBUG=true ;;
    *) JUMPHOST="${arg}" ;;
    esac
done

if [ -z "${JUMPHOST}" ]; then
    echo "Usage: bash $0 [--debug] jumphost.public.ip.address"
    echo "Requirements: Run this from your laptop"
    exit 2
fi
if ! [[ "${JUMPHOST}" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
    echo "Error: '${JUMPHOST}' the first argument, is not a valid IP address"
    exit 2
fi

debug() {
    if [ "${DEBUG}" = true ]; then
        echo "[DEBUG] $*"
    fi
}

AUTH_HOST="root@auth.cluster.verda.internal"

SSH_OPTS=(-o LogLevel=ERROR -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no)

# ssh helper functions

ssh_auth() {
    ssh "${SSH_OPTS[@]}" -o "ProxyCommand=ssh ${SSH_OPTS[*]} -W %h:%p ubuntu@${JUMPHOST}" "${AUTH_HOST}" "$@"
}

ssh_jump() {
    ssh "${SSH_OPTS[@]}" "ubuntu@${JUMPHOST}" "$@"
}

## Validate USERS map before doing anything

declare -A seen_usernames seen_gids seen_keynames
for user_entry in "${USERS[@]}"; do
    IFS='|' read -r username _ gidnumber qos ssh_keys <<<"${user_entry}"

    # Catch old 4-field format (missing qos field)
    if [ -z "${ssh_keys}" ]; then
        echo "Error: USERS entry for '${username}' has too few fields (expected 5: username|display_name|gidnumber|qos|ssh_keys)"
        echo "Did you copy an old USERS map? Add a qos field (unlimited or limited) before ssh_keys."
        exit 2
    fi

    if [ "${qos}" != "unlimited" ] && [ "${qos}" != "limited" ]; then
        echo "Error: Invalid qos '${qos}' for ${username} (must be 'unlimited' or 'limited')"
        exit 2
    fi

    if [ -n "${seen_usernames[${username}]+x}" ]; then
        echo "Error: Duplicate username '${username}'"
        exit 2
    fi
    seen_usernames[${username}]=1

    if [ -n "${seen_gids[${gidnumber}]+x}" ]; then
        echo "Error: Duplicate gidnumber '${gidnumber}' (user ${username})"
        exit 2
    fi
    seen_gids[${gidnumber}]=1

    IFS=';' read -ra key_entries <<<"${ssh_keys}"
    for key_entry in "${key_entries[@]}"; do
        ssh_key_name="${key_entry%%:*}"
        ssh_pubkey="${key_entry#*:}"
        if [ -z "${ssh_key_name}" ] || [ -z "${ssh_pubkey}" ]; then
            echo "Error: Invalid ssh key entry for ${username}: '${key_entry}'"
            exit 2
        fi
        if [ -n "${seen_keynames[${ssh_key_name}]+x}" ]; then
            echo "Error: Duplicate ssh key name '${ssh_key_name}' (user ${username})"
            exit 2
        fi
        seen_keynames[${ssh_key_name}]=1
        key_type="${ssh_pubkey%% *}"
        key_data="${ssh_pubkey#* }"
        if [[ ! "${key_type}" =~ ^ssh-(ed25519|rsa|ecdsa)|^ecdsa-sha2- ]]; then
            echo "Error: Invalid ssh key type '${key_type}' for ${username} key ${ssh_key_name}"
            exit 2
        fi
        if ! echo "${key_data}" | base64 -d &>/dev/null; then
            echo "Error: Invalid base64 in ssh key for ${username} key ${ssh_key_name}"
            exit 2
        fi
    done
done

SUCCEEDED_USERS=()
FAILED_USERS=()
SKIPPED_USERS=()

## Kanidm setup

PASS=$(ssh_auth 'docker exec -i kanidm kanidmd recover-account idm_admin 2>/dev/null' | grep -i new_password 2>/dev/null | cut -d ":" -f3-6 | tr -d '"' | xargs)

ssh_auth "KANIDM_PASSWORD=${PASS} kanidm login --name idm_admin"

# Create group if it doesn't exist
if ! ssh_auth "kanidm group get cluster_users" &>/dev/null; then
    echo "Creating kanidm group cluster_users"
    ssh_auth "kanidm group create cluster_users"
    ssh_auth "kanidm group posix set cluster_users --gidnumber 70000"
fi

for user_entry in "${USERS[@]}"; do
    IFS='|' read -r username display_name gidnumber qos ssh_keys <<<"${user_entry}"

    echo "--- ${username} ---"

    if ! ssh_auth "kanidm person get ${username}" 2>/dev/null | grep -q "spn:"; then
        echo "Creating kanidm user ${username}"
        ssh_auth "kanidm person create ${username} \"${display_name}\""
    fi

    if ! ssh_auth "kanidm person posix show ${username}" 2>/dev/null | grep -q gidnumber; then
        echo "Setting posix attributes for ${username}"
        ssh_auth "kanidm person posix set ${username} --shell /bin/bash"
        ssh_auth "kanidm person posix set ${username} --gidnumber ${gidnumber}"
    else
        # Verify gidnumber matches what we expect
        posix_output=$(ssh_auth "kanidm person posix show ${username}" 2>/dev/null || true)
        remote_gid=$(echo "${posix_output}" | grep "^gidnumber:" | awk '{print $2}')
        if [ "${remote_gid}" != "${gidnumber}" ]; then
            debug "kanidm person posix show ${username}:"
            debug "${posix_output}"
            echo "ERROR: User ${username} has gidnumber ${remote_gid} in kanidm but ${gidnumber} in USERS map"
            echo "Fix the USERS map or manually update kanidm before re-running"
            exit 1
        fi
        debug "User ${username} gidnumber ${gidnumber} matches kanidm"
    fi

    ssh_auth "kanidm group add-members cluster_users ${username}" 2>/dev/null || true
    IFS=';' read -ra key_entries <<<"${ssh_keys}"
    for key_entry in "${key_entries[@]}"; do
        ssh_key_name="${key_entry%%:*}"
        ssh_pubkey="${key_entry#*:}"
        ssh_auth "kanidm person ssh add-publickey ${username} \"${ssh_key_name}\" \"${ssh_pubkey}\"" 2>/dev/null || true
    done

    if [ "${DEBUG}" = true ]; then
        debug "kanidm person posix show ${username}:"
        ssh_auth "kanidm person posix show ${username}" 2>/dev/null || true
    fi

    # Generate cluster ssh key if not present
    # Use sudo on the jumphost since the user may not have SSH key access yet
    if ssh_jump "sudo test -f /home/${username}/.ssh/id_rsa"; then
        debug "SSH key already exists for ${username}"
        SUCCEEDED_USERS+=("${username}")
    else
        if ssh_jump "sudo -u ${username} ssh-keygen -q -t ssh-rsa -N \"\" -C \"cluster key\" -f /home/${username}/.ssh/id_rsa && sudo -u ${username} bash -c 'umask 0133 && cat /home/${username}/.ssh/id_rsa.pub >> /home/${username}/.ssh/authorized_keys'"; then
            echo "Generated SSH key for ${username}"
            SUCCEEDED_USERS+=("${username}")
        else
            echo "ERROR: Failed to generate SSH key for ${username}"
            FAILED_USERS+=("${username}")
        fi
    fi
done

## SLURM sacctmgr
echo "Starting slurm sacctmgr changes & validation"

CLUSTER_NAME=$(ssh_jump 'grep ClusterName /etc/slurm/slurm.conf | cut -d "=" -f2' 2>/dev/null)

# Helper: check sacctmgr output without pipefail killing us
sacctmgr_has() {
    local output
    output=$(ssh_jump "sudo sacctmgr -n -P show $1" 2>/dev/null || true)
    echo "${output}" | grep -q "$2"
}

# Create SLURM account if it doesn't exist
if ! sacctmgr_has "account default_acct" "default_acct"; then
    echo "Setting up SLURM account default_acct"
    ssh_jump 'sudo sacctmgr add account default_acct Cluster='"${CLUSTER_NAME}"' Description="Default account" Organization="verda" <<<yes' || true
fi
if ! sacctmgr_has "account default_acct" "default_acct"; then
    echo "ERROR: SLURM account default_acct not found after creation attempt"
    exit 1
fi

# Create QOS if they don't exist
if ! sacctmgr_has "qos limited_qos" "limited_qos"; then
    echo "SLURM add qos limited_qos"
    ssh_jump 'sudo sacctmgr add qos limited_qos MaxTRESPerUser=gres/gpu=8 MaxJobsPerUser=1 <<<yes' || true
fi
if ! sacctmgr_has "qos limited_qos" "limited_qos"; then
    echo "ERROR: SLURM qos limited_qos not found after creation attempt"
    exit 1
fi

if ! sacctmgr_has "qos unlimited_qos" "unlimited_qos"; then
    echo "SLURM add qos unlimited_qos"
    ssh_jump 'sudo sacctmgr add qos unlimited_qos <<<yes' || true
fi
if ! sacctmgr_has "qos unlimited_qos" "unlimited_qos"; then
    echo "ERROR: SLURM qos unlimited_qos not found after creation attempt"
    exit 1
fi

# Make limited_qos the default for the account, allow both QoS
ssh_jump 'sudo sacctmgr modify account default_acct set DefaultQOS=limited_qos QOS=limited_qos,unlimited_qos <<<yes' >/dev/null || true

# Add users to SLURM and assign per-user QoS (only succeeded users)
# Build a map of username -> qos from USERS
declare -A user_qos_map
for user_entry in "${USERS[@]}"; do
    IFS='|' read -r u _ _ q _ <<<"${user_entry}"
    user_qos_map[${u}]="${q}_qos"
done

for username in "${SUCCEEDED_USERS[@]}"; do
    if ! sacctmgr_has "user ${username}" "${username}"; then
        echo "SLURM add user ${username}"
        ssh_jump "sudo sacctmgr add user ${username} Account=default_acct <<<yes" || true
    fi
    if ! sacctmgr_has "user ${username}" "${username}"; then
        echo "ERROR: SLURM user ${username} not found after creation attempt"
        FAILED_USERS+=("${username}")
        continue
    fi
    # Set per-user QoS
    target_qos="${user_qos_map[${username}]}"
    debug "SLURM set ${username} DefaultQOS=${target_qos} QOS=${target_qos}"
    ssh_jump "sudo sacctmgr modify user ${username} set DefaultQOS=${target_qos} QOS=${target_qos} <<<yes" >/dev/null || true
    # Verify QoS was applied
    actual_qos=$(ssh_jump "sudo sacctmgr show assoc user=${username} format=DefaultQOS -n -P" 2>/dev/null || true)
    if [ "${actual_qos}" != "${target_qos}" ]; then
        echo "ERROR: ${username} DefaultQOS is '${actual_qos}', expected '${target_qos}'"
        FAILED_USERS+=("${username}")
    else
        debug "${username} QoS verified: ${actual_qos}"
    fi
done

# Check AccountingStorageEnforce
ENFORCE=$(ssh_jump 'grep -E "^AccountingStorageEnforce" /etc/slurm/slurm.conf' 2>/dev/null || true)
if echo "${ENFORCE}" | grep -q "associations" && echo "${ENFORCE}" | grep -q "limits" && echo "${ENFORCE}" | grep -q "qos"; then
    ENFORCE_OK=true
else
    ENFORCE_OK=false
fi

## Summary
echo ""
echo "=== Summary ==="
echo "Cluster: ${CLUSTER_NAME}"
if [ ${#SUCCEEDED_USERS[@]} -gt 0 ]; then
    echo "OK: ${SUCCEEDED_USERS[*]}"
fi
if [ ${#FAILED_USERS[@]} -gt 0 ]; then
    echo "FAILED: ${FAILED_USERS[*]}"
fi
if [ ${#SKIPPED_USERS[@]} -gt 0 ]; then
    echo "SKIPPED (not processed): ${SKIPPED_USERS[*]}"
fi
if [ "${ENFORCE_OK}" = true ]; then
    echo "AccountingStorageEnforce: OK"
else
    echo "AccountingStorageEnforce: MISSING - add 'AccountingStorageEnforce=associations,limits,qos' to /etc/slurm/slurm.conf"
fi

if [ "${DEBUG}" = true ] && [ ${#FAILED_USERS[@]} -eq 0 ] && [ ${#SKIPPED_USERS[@]} -eq 0 ]; then
    cat <<'SLURM'
        _____
       /     \
      | SLURM |
      | ~~~~~ |
      | ~~~~~ |
      | ~~~~~ |
      |  _    |
       \| |__/
        |    |
        |    |
        |    |
        |____|
    It's highly addictive!
SLURM
fi
