Skip to content

Commit

Permalink
Add SUPFILE_DIR as default environment variable
Browse files Browse the repository at this point in the history
This will allow Supfiles to be run different directories and
consistently resolve files relative to the Supfile.

This CL contains refactors the resolution order of environment
variables such that it's a bit easier to reason about.

Fixes pressly#99
  • Loading branch information
stengaard committed Feb 9, 2020
1 parent be6dff4 commit 0788a67
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 58 deletions.
47 changes: 23 additions & 24 deletions cmd/sup/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,21 @@ func cmdUsage(conf *sup.Supfile) {
fmt.Fprintln(w)
}

func options() []sup.SupfileOption {
return []sup.SupfileOption{}

}

// parseArgs parses args and returns network and commands to be run.
// On error, it prints usage and exits.
func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) {
var commands []*sup.Command

// In case of the network.Env needs an initialization
if conf.Env == nil {
conf.Env = make(sup.EnvList, 0)
}

args := flag.Args()
if len(args) < 1 {
networkUsage(conf)
Expand All @@ -130,14 +140,14 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) {
i := strings.Index(env, "=")
if i < 0 {
if len(env) > 0 {
network.Env.Set(env, "")
conf.Env.Set(env, "")
}
continue
}
network.Env.Set(env[:i], env[i+1:])
conf.Env.Set(env[:i], env[i+1:])
}

hosts, err := network.ParseInventory()
hosts, err := network.ParseInventory(conf.Env)
if err != nil {
return nil, nil, err
}
Expand All @@ -155,26 +165,8 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) {
return nil, nil, ErrUsage
}

// In case of the network.Env needs an initialization
if network.Env == nil {
network.Env = make(sup.EnvList, 0)
}

// Add default env variable with current network
network.Env.Set("SUP_NETWORK", args[0])

// Add default nonce
network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339))
if os.Getenv("SUP_TIME") != "" {
network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME"))
}

// Add user
if os.Getenv("SUP_USER") != "" {
network.Env.Set("SUP_USER", os.Getenv("SUP_USER"))
} else {
network.Env.Set("SUP_USER", os.Getenv("USER"))
}
conf.Env.Set("SUP_NETWORK", args[0])

for _, cmd := range args[1:] {
// Target?
Expand Down Expand Up @@ -248,7 +240,14 @@ func main() {
os.Exit(1)
}
}
conf, err := sup.NewSupfile(data)
conf, err := sup.NewSupfile(data,
// SUPFILE_DIR might change as sup invocations are chained.
sup.WithEnv("SUPFILE_DIR", filepath.Dir(supfile)),
// Add default nonce, but inherit from previous invocation.
sup.WithInheritEnv("SUP_TIME", time.Now().UTC().Format(time.RFC3339)),
// Add user, but inherit from previous invocation.
sup.WithInheritEnv("SUP_USER", os.Getenv("USER")),
)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
Expand Down Expand Up @@ -332,7 +331,7 @@ func main() {
}

var vars sup.EnvList
for _, val := range append(conf.Env, network.Env...) {
for _, val := range conf.Env {
vars.Set(val.Key, val.Value)
}
if err := vars.ResolveValues(); err != nil {
Expand Down
8 changes: 7 additions & 1 deletion example/Supfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ env:
networks:
# Groups of hosts
local:
env:
SUP_LOCAL: yessir
hosts:
- localhost

Expand Down Expand Up @@ -53,7 +55,7 @@ commands:
upload:
- src: ./
dst: /tmp/$IMAGE
script: ./scripts/docker-build.sh
script: $SUPFILE_DIR/scripts/docker-build.sh
once: true

pull:
Expand Down Expand Up @@ -126,6 +128,10 @@ commands:
curl -X POST --data-urlencode 'payload={"channel": "#_team_", "text": "['$SUP_NETWORK'] '$SUP_USER' deployed '$NAME'"}' \
https://hooks.slack.com/services/X/Y/Z

env:
desc: Print environment
local: env

bash:
desc: Interactive shell on all hosts
stdin: true
Expand Down
16 changes: 1 addition & 15 deletions localhost.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ import (
"os"
"os/exec"
"os/user"

"github.com/pkg/errors"
)

// Client is a wrapper over the SSH connection/sessions.
// LocalhostClient is a wrapper over the SSH connection/sessions.
type LocalhostClient struct {
cmd *exec.Cmd
user string
Expand Down Expand Up @@ -105,15 +103,3 @@ func (c *LocalhostClient) WriteClose() error {
func (c *LocalhostClient) Signal(sig os.Signal) error {
return c.cmd.Process.Signal(sig)
}

func ResolveLocalPath(cwd, path, env string) (string, error) {
// Check if file exists first. Use bash to resolve $ENV_VARs.
cmd := exec.Command("bash", "-c", env+"echo -n "+path)
cmd.Dir = cwd
resolvedFilename, err := cmd.Output()
if err != nil {
return "", errors.Wrap(err, "resolving path failed")
}

return string(resolvedFilename), nil
}
10 changes: 5 additions & 5 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"golang.org/x/crypto/ssh/agent"
)

// Client is a wrapper over the SSH connection/sessions.
// SSHClient is a wrapper over the SSH connection/sessions.
type SSHClient struct {
conn *ssh.Client
sess *ssh.Session
Expand Down Expand Up @@ -219,16 +219,16 @@ func (c *SSHClient) Wait() error {
}

// DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer.
func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := sc.conn.Dial(net, addr)
func (c *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := c.conn.Dial(net, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
sc, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
return ssh.NewClient(sc, chans, reqs), nil

}

Expand Down
13 changes: 7 additions & 6 deletions sup.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"golang.org/x/crypto/ssh"
)

const VERSION = "0.5"
const VERSION = "0.6"

type Stackup struct {
conf *Supfile
Expand All @@ -30,12 +30,13 @@ func New(conf *Supfile) (*Stackup, error) {
// Run runs set of commands on multiple hosts defined by network sequentially.
// TODO: This megamoth method needs a big refactor and should be split
// to multiple smaller methods.
func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) error {
func (sup *Stackup) Run(network *Network, cliVars EnvList, commands ...*Command) error {
if len(commands) == 0 {
return errors.New("no commands to be run")
}

env := envVars.AsExport()
env := append(sup.conf.Env, network.Env...)
env = append(env, cliVars...)

// Create clients for every host (either SSH or Localhost).
var bastion *SSHClient
Expand All @@ -58,7 +59,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command)
// Localhost client.
if host == "localhost" {
local := &LocalhostClient{
env: env + `export SUP_HOST="` + host + `";`,
env: env.AsExport() + `export SUP_HOST="` + host + `";`,
}
if err := local.Connect(host); err != nil {
errCh <- errors.Wrap(err, "connecting to localhost failed")
Expand All @@ -70,7 +71,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command)

// SSH client.
remote := &SSHClient{
env: env + `export SUP_HOST="` + host + `";`,
env: env.AsExport() + `export SUP_HOST="` + host + `";`,
user: network.User,
color: Colors[i%len(Colors)],
}
Expand Down Expand Up @@ -112,7 +113,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command)
// Run command or run multiple commands defined by target sequentially.
for _, cmd := range commands {
// Translate command into task(s).
tasks, err := sup.createTasks(cmd, clients, env)
tasks, err := sup.createTasks(cmd, clients)
if err != nil {
return errors.Wrap(err, "creating task failed")
}
Expand Down
51 changes: 48 additions & 3 deletions supfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,47 @@ func (e ErrUnsupportedSupfileVersion) Error() string {
return fmt.Sprintf("%v\n\nCheck your Supfile version (available latest version: v0.5)", e.Msg)
}

// SupfileOption modifies a Supfile in some way and maybe returns an error.
type SupfileOption func(*Supfile) error

// WithInheritEnv exports the environment variable env as value, val, in the context of a supfile.
// env will be read from the process runtime and the values here have precedence. This allows users
// to chain Supfile invocation and only have the top-level value set the SUP_* env vars.
func WithInheritEnv(env, val string) SupfileOption {
if envVal := os.Getenv(env); envVal != "" {
val = envVal
}
return WithEnv(env, val)
}

// WithEnv forces the environment variable env to val in the context of a Supfile.
func WithEnv(env, val string) SupfileOption {
return func(s *Supfile) error {
for network := range s.Networks.nets {
n := s.Networks.nets[network]
n.Env.Set(env, val)
s.Networks.nets[network] = n
}
s.Env.Set(env, val)
return nil
}
}

// NewSupfile parses configuration file and returns Supfile or error.
func NewSupfile(data []byte) (*Supfile, error) {
func NewSupfile(data []byte, opts ...SupfileOption) (*Supfile, error) {
var conf Supfile

if err := yaml.Unmarshal(data, &conf); err != nil {
return nil, err
}

for _, opt := range opts {
err := opt(&conf)
if err != nil {
return nil, err
}
}

// API backward compatibility. Will be deprecated in v1.0.
switch conf.Version {
case "":
Expand Down Expand Up @@ -327,16 +360,28 @@ func NewSupfile(data []byte) (*Supfile, error) {
return &conf, nil
}

func (s *Supfile) ResolveLocalPath(cwd, path string) (string, error) {
// Check if file exists first. Use bash to resolve $ENV_VARs.
cmd := exec.Command("bash", "-c", s.Env.AsExport()+" echo -n "+path)
cmd.Dir = cwd
resolvedFilename, err := cmd.Output()
if err != nil {
return "", errors.Wrap(err, "resolving path failed")
}

return string(resolvedFilename), nil
}

// ParseInventory runs the inventory command, if provided, and appends
// the command's output lines to the manually defined list of hosts.
func (n Network) ParseInventory() ([]string, error) {
func (n Network) ParseInventory(ctx EnvList) ([]string, error) {
if n.Inventory == "" {
return nil, nil
}

cmd := exec.Command("/bin/sh", "-c", n.Inventory)
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, n.Env.Slice()...)
cmd.Env = append(cmd.Env, append(ctx.Slice(), n.Env.Slice()...)...)
cmd.Stderr = os.Stderr
output, err := cmd.Output()
if err != nil {
Expand Down
12 changes: 8 additions & 4 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Task struct {
TTY bool
}

func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*Task, error) {
func (sup *Stackup) createTasks(cmd *Command, clients []Client) ([]*Task, error) {
var tasks []*Task

cwd, err := os.Getwd()
Expand All @@ -27,7 +27,7 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*

// Anything to upload?
for _, upload := range cmd.Upload {
uploadFile, err := ResolveLocalPath(cwd, upload.Src, env)
uploadFile, err := sup.conf.ResolveLocalPath(cwd, upload.Src)
if err != nil {
return nil, errors.Wrap(err, "upload: "+upload.Src)
}
Expand Down Expand Up @@ -64,7 +64,11 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*

// Script. Read the file as a multiline input command.
if cmd.Script != "" {
f, err := os.Open(cmd.Script)
script, err := sup.conf.ResolveLocalPath(cwd, cmd.Script)
if err != nil {
return nil, errors.Wrap(err, "can't resolve script path")
}
f, err := os.Open(script)
if err != nil {
return nil, errors.Wrap(err, "can't open script")
}
Expand Down Expand Up @@ -106,7 +110,7 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*
// Local command.
if cmd.Local != "" {
local := &LocalhostClient{
env: env + `export SUP_HOST="localhost";`,
env: sup.conf.Env.AsExport() + `export SUP_HOST="localhost";`,
}
local.Connect("localhost")
task := &Task{
Expand Down

0 comments on commit 0788a67

Please sign in to comment.