#!/bin/sh
# omr-weight-sync: Sync UCI multipath_weight → BPF endpoint_weights map
# and uci network.<iface>.weight (ip route semantics).
# When enable_decision_weights=1 (default), polls GET /metrics/decision on
# each configured VPS first and writes model-assigned weights into UCI before
# the BPF sync.
# Polls every MPTCP_WEIGHT_INTERVAL seconds (default: 5).

POLL_INTERVAL="${MPTCP_WEIGHT_INTERVAL:-5}"
WEIGHT_SCRIPT="/usr/sbin/mptcp-scheduler-weight.sh"
MAP_PATH="/sys/fs/bpf/endpoint_weights"

_is_ip6() {
	case "$1" in *:*) return 0 ;; esac
	return 1
}

_is_weight_scheduler() {
	local s
	s=$(uci -q get network.globals.mptcp_scheduler 2>/dev/null)
	case "$s" in
		bpf_weight|bpf_burstweight|mptcp_bpf_weight.o|mptcp_bpf_burstweight.o) return 0 ;;
	esac
	return 1
}

# Obtain a Bearer token via POST /token and save it in UCI.
# $1 = full UCI key to store the token (e.g. "omr-metrics.settings.token"
#      or "openmptcprouter.<name>.token")
# Prints the token on stdout; prints nothing on failure.
_login() {
	local uci_token_key="$1" server="$2" serverport="$3" username="$4" password="$5"
	local auth token

	if _is_ip6 "$server"; then
		auth=$(curl -6 --max-time 10 -s -k \
			-H "accept: application/json" \
			-H "Content-Type: application/x-www-form-urlencoded" \
			-X POST -d "username=${username}&password=${password}" \
			"https://[${server}]:${serverport}/token" 2>/dev/null)
	else
		auth=$(curl --max-time 10 -s -k \
			-H "accept: application/json" \
			-H "Content-Type: application/x-www-form-urlencoded" \
			-X POST -d "username=${username}&password=${password}" \
			"https://${server}:${serverport}/token" 2>/dev/null)
	fi

	token=$(echo "$auth" | jsonfilter -q -e '@.access_token' 2>/dev/null)
	if [ -n "$token" ]; then
		uci -q set "${uci_token_key}=${token}"
	fi
	printf '%s' "$token"
}

# GET /metrics/decision from one VPS server.
# Args: servername server serverport predict(true|false) horizon(seconds)
# Prints the raw JSON response on stdout; returns 1 on failure.
_get_decision() {
	local servername="$1" server="$2" serverport="$3" predict="$4" horizon="$5"
	local token username password url http_code tmpfile uci_token_key

	# Prefer global credentials from omr-metrics config if username is set there.
	username=$(uci -q get "omr-metrics.settings.username" 2>/dev/null)
	if [ -n "$username" ]; then
		token=$(uci -q get "omr-metrics.settings.token" 2>/dev/null)
		password=$(uci -q get "omr-metrics.settings.password" 2>/dev/null)
		uci_token_key="omr-metrics.settings.token"
	else
		token=$(uci -q get "openmptcprouter.${servername}.token" 2>/dev/null)
		username=$(uci -q get "openmptcprouter.${servername}.username" 2>/dev/null)
		password=$(uci -q get "openmptcprouter.${servername}.password" 2>/dev/null)
		uci_token_key="openmptcprouter.${servername}.token"
	fi

	[ -z "$token" ] && token=$(_login "$uci_token_key" "$server" "$serverport" "$username" "$password")
	[ -z "$token" ] && return 1

	if _is_ip6 "$server"; then
		url="https://[${server}]:${serverport}/metrics/decision?predict=${predict}&horizon=${horizon}"
	else
		url="https://${server}:${serverport}/metrics/decision?predict=${predict}&horizon=${horizon}"
	fi

	tmpfile=$(mktemp /tmp/omr-decision.XXXXXX) || return 1

	http_code=$(curl --max-time 10 -s -k \
		-o "$tmpfile" -w "%{http_code}" \
		-H "accept: application/json" \
		-H "Authorization: Bearer ${token}" \
		"$url" 2>/dev/null)

	if [ "$http_code" = "401" ]; then
		token=$(_login "$uci_token_key" "$server" "$serverport" "$username" "$password")
		[ -n "$token" ] && http_code=$(curl --max-time 10 -s -k \
			-o "$tmpfile" -w "%{http_code}" \
			-H "accept: application/json" \
			-H "Authorization: Bearer ${token}" \
			"$url" 2>/dev/null)
	fi

	if [ "$http_code" = "200" ]; then
		cat "$tmpfile"
	fi
	rm -f "$tmpfile"
	[ "$http_code" = "200" ]
}

# Fetch model-assigned weights from the first responding VPS and write them
# into UCI multipath_weight / weight.  Returns 0 when at least one weight was
# received; falls through to UCI-only weights otherwise.
_fetch_decision_weights() {
	_is_weight_scheduler || return 1

	local custom_server custom_serverport server_names use_custom_server
	use_custom_server=$(uci -q get "omr-metrics.settings.use_custom_server" 2>/dev/null)
	if [ "${use_custom_server:-0}" = "1" ]; then
		custom_server=$(uci -q get "omr-metrics.settings.server" 2>/dev/null)
		custom_serverport=$(uci -q get "omr-metrics.settings.serverport" 2>/dev/null)
		[ -z "$custom_serverport" ] && custom_serverport="65500"
	fi

	server_names=$(uci -q show openmptcprouter 2>/dev/null | \
		sed -n 's/^openmptcprouter\.\([^.=][^.=]*\)=server$/\1/p')
	[ -z "$custom_server" ] && [ -z "$server_names" ] && return 1

	local iface_list
	iface_list=$(uci show network 2>/dev/null \
		| grep -E "multipath='(on|master|backup)'" \
		| sed "s/network\\.\\(.*\\)\\.multipath.*/\\1/")
	[ -z "$iface_list" ] && return 1

	local _predict _horizon _v
	_v=$(uci -q get omr-metrics.settings.decision_predict 2>/dev/null)
	[ "${_v:-0}" = "1" ] && _predict="true" || _predict="false"
	_horizon=$(uci -q get omr-metrics.settings.decision_horizon 2>/dev/null)
	[ -z "$_horizon" ] && _horizon="300"

	# If a custom omr-metrics server is configured, use it exclusively.
	if [ -n "$custom_server" ]; then
		local response
		response=$(_get_decision "omr_metrics_custom" "$custom_server" "$custom_serverport" "$_predict" "$_horizon") || return 1
		local _changed=0 _iface _weight _cur
		for _iface in $iface_list; do
			_weight=$(jsonfilter -s "$response" -e "@['${_iface}']" 2>/dev/null)
			[ -z "$_weight" ] && continue
			case "$_weight" in
				''|*[!0-9]*) continue ;;
			esac
			[ "$_weight" -lt 1 ] && continue

			_cur=$(uci -q get "network.${_iface}.multipath_weight" 2>/dev/null)
			if [ "$_cur" != "$_weight" ]; then
				uci -q set "network.${_iface}.multipath_weight=$_weight"
				_changed=1
			fi
			_cur=$(uci -q get "network.${_iface}.weight" 2>/dev/null)
			if [ "$_cur" != "$_weight" ]; then
				uci -q set "network.${_iface}.weight=$_weight"
				_changed=1
			fi
		done
		[ "$_changed" -eq 1 ] && uci -q commit network 2>/dev/null || true
		return 0
	fi

	for servername in $server_names; do
		local disabled serverport server_ips
		disabled=$(uci -q get "openmptcprouter.${servername}.disabled" 2>/dev/null)
		[ "$disabled" = "1" ] && continue

		serverport=$(uci -q get "openmptcprouter.${servername}.port" 2>/dev/null)
		[ -z "$serverport" ] && serverport="65500"

		server_ips=$(uci -q show "openmptcprouter.${servername}" 2>/dev/null | \
			sed -n "s/^openmptcprouter\.[^.]*\.ip='\(.*\)'$/\1/p")

		for server in $server_ips; do
			[ -z "$server" ] && continue

			local response
			response=$(_get_decision "$servername" "$server" "$serverport" "$_predict" "$_horizon") || continue

			# Response expected: {"iface": weight, ...}
			local _changed=0 _iface _weight _cur
			for _iface in $iface_list; do
				_weight=$(jsonfilter -s "$response" -e "@['${_iface}']" 2>/dev/null)
				[ -z "$_weight" ] && continue
				case "$_weight" in
					''|*[!0-9]*) continue ;;
				esac
				[ "$_weight" -lt 1 ] && continue

				_cur=$(uci -q get "network.${_iface}.multipath_weight" 2>/dev/null)
				if [ "$_cur" != "$_weight" ]; then
					uci -q set "network.${_iface}.multipath_weight=$_weight"
					_changed=1
				fi
				_cur=$(uci -q get "network.${_iface}.weight" 2>/dev/null)
				if [ "$_cur" != "$_weight" ]; then
					uci -q set "network.${_iface}.weight=$_weight"
					_changed=1
				fi
			done

			[ "$_changed" -eq 1 ] && uci -q commit network 2>/dev/null || true
			return 0
		done
	done

	return 1
}

_apply_weights() {
	_is_weight_scheduler || return 0
	[ -e "$MAP_PATH" ] || return 0
	[ -x "$WEIGHT_SCRIPT" ] || return 0

	local _tmpfile _changed _iface _dev _weight _cur
	_tmpfile=$(mktemp /tmp/omr-weight.XXXXXX 2>/dev/null) || return 1
	_changed=0

	uci show network 2>/dev/null \
		| grep -E "multipath='(on|master|backup)'" \
		| sed "s/network\\.\\(.*\\)\\.multipath.*/\\1/" \
		> "$_tmpfile"

	while IFS= read -r _iface; do
		[ -z "$_iface" ] && continue

		_dev=$(uci -q get "network.${_iface}.device" 2>/dev/null)
		[ -z "$_dev" ] && _dev=$(uci -q get "network.${_iface}.ifname" 2>/dev/null)
		[ -z "$_dev" ] && continue

		# Source of truth: multipath_weight, then weight, default 100
		_weight=$(uci -q get "network.${_iface}.multipath_weight" 2>/dev/null)
		[ -z "$_weight" ] && _weight=$(uci -q get "network.${_iface}.weight" 2>/dev/null)
		[ -z "$_weight" ] && _weight=100

		# Apply to BPF map
		"$WEIGHT_SCRIPT" set "$_dev" "$_weight" >/dev/null 2>&1 || true

		# Sync network.<iface>.weight (ip route semantics)
		_cur=$(uci -q get "network.${_iface}.weight" 2>/dev/null)
		if [ "$_cur" != "$_weight" ]; then
			uci -q set "network.${_iface}.weight=$_weight"
			_changed=1
		fi

		# Sync network.<iface>.multipath_weight
		_cur=$(uci -q get "network.${_iface}.multipath_weight" 2>/dev/null)
		if [ "$_cur" != "$_weight" ]; then
			uci -q set "network.${_iface}.multipath_weight=$_weight"
			_changed=1
		fi
	done < "$_tmpfile"

	rm -f "$_tmpfile"
	[ "$_changed" -eq 1 ] && uci -q commit network 2>/dev/null || true
}

_decision_enabled() {
	local v
	v=$(uci -q get omr-metrics.settings.enable_decision_weights 2>/dev/null)
	[ "${v:-1}" != "0" ]
}

while true; do
	if _decision_enabled; then
		_fetch_decision_weights
	fi
	_apply_weights
	sleep "$POLL_INTERVAL"
done
