diff --git a/net/mptcp/bpf.c b/net/mptcp/bpf.c index e9db856972cbed..02038db59956c8 100644 --- a/net/mptcp/bpf.c +++ b/net/mptcp/bpf.c @@ -218,18 +218,27 @@ __bpf_kfunc_start_defs(); __bpf_kfunc static struct mptcp_sock *bpf_mptcp_sk(struct sock *sk) { + if (!sk || sk->sk_protocol != IPPROTO_MPTCP) + return NULL; + return mptcp_sk(sk); } __bpf_kfunc static struct mptcp_subflow_context * bpf_mptcp_subflow_ctx(const struct sock *sk) { + if (!sk) + return NULL; + return mptcp_subflow_ctx(sk); } __bpf_kfunc static struct sock * bpf_mptcp_subflow_tcp_sock(const struct mptcp_subflow_context *subflow) { + if (!subflow) + return NULL; + return mptcp_subflow_tcp_sock(subflow); } @@ -299,9 +308,9 @@ __bpf_kfunc static bool bpf_mptcp_subflow_queues_empty(struct sock *sk) __bpf_kfunc_end_defs(); BTF_KFUNCS_START(bpf_mptcp_common_kfunc_ids) -BTF_ID_FLAGS(func, bpf_mptcp_sk) -BTF_ID_FLAGS(func, bpf_mptcp_subflow_ctx) -BTF_ID_FLAGS(func, bpf_mptcp_subflow_tcp_sock) +BTF_ID_FLAGS(func, bpf_mptcp_sk, KF_TRUSTED_ARGS | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_mptcp_subflow_ctx, KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_mptcp_subflow_tcp_sock, KF_RET_NULL) BTF_ID_FLAGS(func, bpf_iter_mptcp_subflow_new, KF_ITER_NEW | KF_TRUSTED_ARGS) BTF_ID_FLAGS(func, bpf_iter_mptcp_subflow_next, KF_ITER_NEXT | KF_RET_NULL) BTF_ID_FLAGS(func, bpf_iter_mptcp_subflow_destroy, KF_ITER_DESTROY) @@ -335,7 +344,7 @@ static int __init bpf_mptcp_kfunc_init(void) int ret; ret = register_btf_fmodret_id_set(&bpf_mptcp_fmodret_set); - ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_UNSPEC, + ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_CGROUP_SOCKOPT, &bpf_mptcp_common_kfunc_set); ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &bpf_mptcp_sched_kfunc_set);