Skip to content

Commit

Permalink
grpc/bgp listen in netns
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxime Peim committed Oct 29, 2024
1 parent 06110fa commit c62f854
Show file tree
Hide file tree
Showing 8 changed files with 608 additions and 325 deletions.
2 changes: 1 addition & 1 deletion api/attribute.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion api/capability.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

650 changes: 330 additions & 320 deletions api/gobgp.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions api/gobgp.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ message Global {
GracefulRestart graceful_restart = 10;
ApplyPolicy apply_policy = 11;
string bind_to_device = 12;
string netns = 13;
}

message Confederation {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/spf13/viper v1.16.0
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.2.1
github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.25.0
golang.org/x/text v0.14.0
golang.org/x/time v0.3.0
Expand Down Expand Up @@ -49,7 +50,6 @@ require (
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/net v0.23.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
Expand Down
244 changes: 244 additions & 0 deletions pkg/server/netns_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
// Copyright (C) 2022 Cisco Systems Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build linux

package server

import (
stderrors "errors"
"fmt"
"os"
"path"
"runtime"
"strconv"
"strings"
"syscall"

"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
)

const (
DefaultNetns = "##defaultNetns##"
UnnamedNetns = "##unnamedNetns##"
)

func getNsRunDir() string {
xdgRuntimeDir := os.Getenv("XDG_RUNTIME_DIR")

/// If XDG_RUNTIME_DIR is set, check if the current user owns /var/run. If
// the owner is different, we are most likely running in a user namespace.
// In that case use $XDG_RUNTIME_DIR/netns as runtime dir.
if xdgRuntimeDir != "" {
if s, err := os.Stat("/var/run"); err == nil {
st, ok := s.Sys().(*syscall.Stat_t)
if ok && int(st.Uid) != os.Geteuid() {
return path.Join(xdgRuntimeDir, "netns")
}
}
}

return "/var/run/netns"
}

// GetCurrentThreadNetNSPath copied from containernetworking/plugins/pkg/ns
func GetCurrentThreadNetNSPath() string {
// /proc/self/ns/net returns the namespace of the main thread, not
// of whatever thread this goroutine is running on. Make sure we
// use the thread's net namespace since the thread is switching around
return fmt.Sprintf("/proc/%d/task/%d/ns/net", os.Getpid(), unix.Gettid())
}

/**
* This function was forked from the following repo [0]
* as we depend on pkg/ns, but it doesnot support netns creation
* [0] https://github.com/containernetworking/plugins/blob/main/pkg/testutils/netns_linux.go
*/
func MountNs(ns netns.NsHandle, nsName string) error {
// Creates a new persistent (bind-mounted) network namespace and returns an object
// representing that namespace, without switching to it.

nsRunDir := getNsRunDir()

// Create the directory for mounting network namespaces
// This needs to be a shared mountpoint in case it is mounted in to
// other namespaces (containers)
err := os.MkdirAll(nsRunDir, 0755)
if err != nil {
return err
}

// Remount the namespace directory shared. This will fail if it is not
// already a mountpoint, so bind-mount it on to itself to "upgrade" it
// to a mountpoint.
err = unix.Mount("", nsRunDir, "none", unix.MS_SHARED|unix.MS_REC, "")
if err != nil {
if err != unix.EINVAL {
return fmt.Errorf("mount --make-rshared %s failed: %q", nsRunDir, err)
}

// Recursively remount /var/run/netns on itself. The recursive flag is
// so that any existing netns bindmounts are carried over.
err = unix.Mount(nsRunDir, nsRunDir, "none", unix.MS_BIND|unix.MS_REC, "")
if err != nil {
return fmt.Errorf("mount --rbind %s %s failed: %q", nsRunDir, nsRunDir, err)
}

// Now we can make it shared
err = unix.Mount("", nsRunDir, "none", unix.MS_SHARED|unix.MS_REC, "")
if err != nil {
return fmt.Errorf("mount --make-rshared %s failed: %q", nsRunDir, err)
}

}

// create an empty file at the mount point
nsPath := path.Join(nsRunDir, nsName)
mountPointFd, err := os.Create(nsPath)
if err != nil {
return err
}
mountPointFd.Close()

// Ensure the mount point is cleaned up on errors; if the namespace
// was successfully mounted this will have no effect because the file
// is in-use
defer os.RemoveAll(nsPath)

// bind mount the netns from the current thread (from /proc) onto the
// mount point. This causes the namespace to persist, even when there
// are no threads in the ns.
err = unix.Mount(GetCurrentThreadNetNSPath(), nsPath, "none", unix.MS_BIND, "")
if err != nil {
err = fmt.Errorf("failed to bind mount ns at %s: %v", nsPath, err)
}

if err != nil {
return fmt.Errorf("failed to create namespace: %v", err)
}

return nil
}

func NetNsExec(netnsName string, cb func() error) (err error) {
netnsCleanup, err := NsEnter(netnsName)
defer netnsCleanup()
if err != nil {
return err
}
return cb()
}

// NsEnter switches the goroutine to the given netnsName
// and provides the cleanup function
func NsEnter(netnsName string) (cleanup func(), err error) {
stack := make([]func(), 0)
cleanup = func() {
for i := len(stack) - 1; i >= 0; i-- {
stack[i]()
}
}
if netnsName == "" || netnsName == DefaultNetns {
return cleanup, nil
}
runtime.LockOSThread()
stack = append(stack, runtime.UnlockOSThread)

origns, _ := netns.Get()
stack = append(stack, func() {
err := origns.Close()
if err != nil {
fmt.Printf("Cannot close initial netns fd %s", err)
}
})

var targetns netns.NsHandle
if netnsName == UnnamedNetns {
// We call netns.New() below
} else if strings.HasPrefix(netnsName, "pid:") {
pid, err := strconv.ParseInt((netnsName)[4:], 10, 64)
if err != nil {
return cleanup, err
}
targetns, err = netns.GetFromPid(int(pid))
if err != nil {
return cleanup, fmt.Errorf("Cannot get %s netns from pid: %v", err)
}
} else {
targetns, err = netns.GetFromName(netnsName)
if err != nil {
return cleanup, fmt.Errorf("Cannot get %s netns: %v", netnsName, err)
}
}

stack = append(stack, func() {
err := targetns.Close()
if err != nil {
fmt.Printf("Cannot close target netns fd %s", err)
}
})

if netnsName == UnnamedNetns {
targetns, err = netns.New()
if err != nil {
return cleanup, fmt.Errorf("Cannot create new netns: %v", err)
}
} else {
err = netns.Set(targetns)
if err != nil {
return cleanup, fmt.Errorf("Cannot nsenter %s: %v", netnsName, err)
}
}
stack = append(stack, func() {
if err := netns.Set(origns); err != nil {
fmt.Printf("Cannot nsenter initial netns %s", err)
}
})
return cleanup, nil
}

func EnsureNetnsExists(netnsName string) (err error) {
if netnsName == "" || strings.HasPrefix(netnsName, "pid:") {
return nil
}

runtime.LockOSThread()
defer runtime.UnlockOSThread()

origns, _ := netns.Get()
defer origns.Close()

ns, err := netns.GetFromName(netnsName)
if err != nil {
if stderrors.Is(os.ErrNotExist, err) {
ns, err := netns.New()
if err != nil {
return fmt.Errorf("Could not create netns for %s: %v", netnsName, err)
}
defer ns.Close()
err = MountNs(ns, netnsName)
if err != nil {
return fmt.Errorf("Could not mount netns to %s: %v", netnsName, err)
}
} else {
return fmt.Errorf("Cannot get %s netns: %v", netnsName, err)
}
} else {
ns.Close()
}

return netns.Set(origns)
}
29 changes: 27 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (l *tcpListener) Close() error {
}

// avoid mapped IPv6 address
func newTCPListener(logger log.Logger, address string, port uint32, bindToDev string, ch chan *net.TCPConn) (*tcpListener, error) {
func newTCPListener(logger log.Logger, address string, port uint32, bindToDev string, netNs string, ch chan *net.TCPConn) (*tcpListener, error) {
proto := "tcp4"
family := syscall.AF_INET
if ip := net.ParseIP(address); ip == nil {
Expand All @@ -71,6 +71,14 @@ func newTCPListener(logger log.Logger, address string, port uint32, bindToDev st
}
addr := net.JoinHostPort(address, strconv.Itoa(int(port)))

if netNs != "" {
clean, err := NsEnter(netNs)
if err != nil {
return nil, err
}
defer clean()
}

var lc net.ListenConfig
lc.Control = func(network, address string, c syscall.RawConn) error {
if bindToDev != "" {
Expand Down Expand Up @@ -140,6 +148,7 @@ func newTCPListener(logger log.Logger, address string, port uint32, bindToDev st
type options struct {
grpcAddress string
grpcOption []grpc.ServerOption
netNs string
logger log.Logger
}

Expand All @@ -157,6 +166,12 @@ func GrpcOption(opt []grpc.ServerOption) ServerOption {
}
}

func NetNs(ns string) ServerOption {
return func(o *options) {
o.netNs = ns
}
}

func LoggerOption(logger log.Logger) ServerOption {
return func(o *options) {
o.logger = logger
Expand All @@ -183,6 +198,7 @@ type BgpServer struct {
mrtManager *mrtManager
roaTable *table.ROATable
uuidMap map[string]uuid.UUID
netNs string
logger log.Logger
}

Expand All @@ -206,6 +222,7 @@ func NewBgpServer(opt ...ServerOption) *BgpServer {
uuidMap: make(map[string]uuid.UUID),
roaManager: newROAManager(roaTable, logger),
roaTable: roaTable,
netNs: opts.netNs,
logger: logger,
}
s.bmpManager = newBmpClientManager(s)
Expand All @@ -214,6 +231,14 @@ func NewBgpServer(opt ...ServerOption) *BgpServer {
grpc.EnableTracing = false
s.apiServer = newAPIserver(s, grpc.NewServer(opts.grpcOption...), opts.grpcAddress)
go func() {
if opts.netNs != "" {
clean, err := NsEnter(opts.netNs)
if err != nil {
logger.Fatal("failed to enter network namespace",
log.Fields{"Error": err})
}
defer clean()
}
if err := s.apiServer.serve(); err != nil {
logger.Fatal("failed to listen grpc port",
log.Fields{"Err": err})
Expand Down Expand Up @@ -2383,7 +2408,7 @@ func (s *BgpServer) StartBgp(ctx context.Context, r *api.StartBgpRequest) error
if c.Config.Port > 0 {
acceptCh := make(chan *net.TCPConn, 32)
for _, addr := range c.Config.LocalAddressList {
l, err := newTCPListener(s.logger, addr, uint32(c.Config.Port), g.BindToDevice, acceptCh)
l, err := newTCPListener(s.logger, addr, uint32(c.Config.Port), g.BindToDevice, s.netNs, acceptCh)
if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions tools/pyang_plugins/gobgp.yang
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,9 @@ module gobgp {
leaf-list local-address {
type string;
}
leaf netns {
type string;
}
}

augment "/bgp:bgp/bgp:global/bgp:config" {
Expand Down

0 comments on commit c62f854

Please sign in to comment.