diff --git a/.gitignore b/.gitignore index cd08347..ab209e2 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ _testmain.go dockyard conf/runtime.conf log/* +data/* diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index c3c158f..52d3ee3 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -2,32 +2,97 @@ "ImportPath": "github.com/containerops/dockyard", "GoVersion": "go1.4.2", "Deps": [ + { + "ImportPath": "github.com/Sirupsen/logrus", + "Comment": "v0.8.6-1-g8bca266", + "Rev": "8bca2664072173a3c71db4c28ca8d304079b1787" + }, { "ImportPath": "github.com/Unknwon/com", - "Rev": "188d690b1aea28bf00ba76983aa285b1d322e3dd" + "Rev": "47d7d2b81a44157600669037e11e9ddfbf16745f" }, { "ImportPath": "github.com/Unknwon/macaron", - "Rev": "93de4f3fad97bf246b838f828e2348f46f21f20a" + "Rev": "635c89ac7410dd20df967e37e8f7f173e126cedd" }, { "ImportPath": "github.com/astaxie/beego/config", - "Comment": "v1.4.3-6-gaf71289", - "Rev": "af71289c25f64e35b688376d5115dbf8b93d87ab" + "Comment": "v1.5.0-3-ga89f14d", + "Rev": "a89f14d80dab442927bdc11d641aa1250ab8ad71" }, { "ImportPath": "github.com/astaxie/beego/logs", - "Comment": "v1.4.3-6-gaf71289", - "Rev": "af71289c25f64e35b688376d5115dbf8b93d87ab" + "Comment": "v1.5.0-3-ga89f14d", + "Rev": "a89f14d80dab442927bdc11d641aa1250ab8ad71" }, { "ImportPath": "github.com/codegangsta/cli", - "Comment": "1.2.0-107-g942282e", - "Rev": "942282e931e8286aa802a30b01fa7e16befb50f3" + "Comment": "1.2.0-137-gbca61c4", + "Rev": "bca61c476e3c752594983e4c9bcd5f62fb09f157" + }, + { + "ImportPath": "github.com/containerops/wrench/db", + "Rev": "ce091398c02a0711bd39fd0a8de2c1eb2fafc0c2" + }, + { + "ImportPath": "github.com/containerops/wrench/setting", + "Rev": "ce091398c02a0711bd39fd0a8de2c1eb2fafc0c2" + }, + { + "ImportPath": "github.com/containerops/wrench/utils", + "Rev": "ce091398c02a0711bd39fd0a8de2c1eb2fafc0c2" + }, + { + "ImportPath": "github.com/docker/docker/pkg/archive", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/fileutils", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/ioutils", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/pools", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/promise", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/stdcopy", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/system", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/docker/pkg/units", + "Comment": "v1.4.1-5132-gc19a00d", + "Rev": "c19a00d4cbc52adc7da76f2f5f70bb38190c2271" + }, + { + "ImportPath": "github.com/docker/libtrust", + "Rev": "9cbd2a1374f46905c68a4eb3694a130610adc62a" + }, + { + "ImportPath": "github.com/satori/go.uuid", + "Rev": "6b8e5b55d20d01ad47ecfe98e5171688397c61e9" }, { - "ImportPath": "github.com/macaron-contrib/session", - "Rev": "31e841d95c7302b9ac456c830ea2d6dfcef4f84a" + "ImportPath": "gopkg.in/bsm/ratelimit.v1", + "Rev": "bda20d5067a03094fc6762f7ead53027afac5f28" }, { "ImportPath": "gopkg.in/bufio.v1", @@ -36,13 +101,13 @@ }, { "ImportPath": "gopkg.in/ini.v1", - "Comment": "v0-16-g1772191", - "Rev": "177219109c97e7920c933e21c9b25f874357b237" + "Comment": "v0-26-gcaf3f03", + "Rev": "caf3f03ad90a464ba8fc4ccfece68d72433f9861" }, { - "ImportPath": "gopkg.in/redis.v2", - "Comment": "v2.3.2", - "Rev": "e6179049628164864e6e84e973cfb56335748dea" + "ImportPath": "gopkg.in/redis.v3", + "Comment": "v3.2.3", + "Rev": "fc28d0fa245616be3aa97e162086e71c75c92f6b" } ] } diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/.gitignore b/Godeps/_workspace/src/github.com/Sirupsen/logrus/.gitignore new file mode 100644 index 0000000..66be63a --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/.gitignore @@ -0,0 +1 @@ +logrus diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/.travis.yml b/Godeps/_workspace/src/github.com/Sirupsen/logrus/.travis.yml new file mode 100644 index 0000000..2d8c086 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/.travis.yml @@ -0,0 +1,8 @@ +language: go +go: + - 1.2 + - 1.3 + - 1.4 + - tip +install: + - go get -t ./... diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/CHANGELOG.md b/Godeps/_workspace/src/github.com/Sirupsen/logrus/CHANGELOG.md new file mode 100644 index 0000000..b1fe4b6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/CHANGELOG.md @@ -0,0 +1,41 @@ +# 0.8.6 + +* hooks/raven: allow passing an initialized client + +# 0.8.5 + +* logrus/core: revert #208 + +# 0.8.4 + +* formatter/text: fix data race (#218) + +# 0.8.3 + +* logrus/core: fix entry log level (#208) +* logrus/core: improve performance of text formatter by 40% +* logrus/core: expose `LevelHooks` type +* logrus/core: add support for DragonflyBSD and NetBSD +* formatter/text: print structs more verbosely + +# 0.8.2 + +* logrus: fix more Fatal family functions + +# 0.8.1 + +* logrus: fix not exiting on `Fatalf` and `Fatalln` + +# 0.8.0 + +* logrus: defaults to stderr instead of stdout +* hooks/sentry: add special field for `*http.Request` +* formatter/text: ignore Windows for colors + +# 0.7.3 + +* formatter/\*: allow configuration of timestamp layout + +# 0.7.2 + +* formatter/text: Add configuration option for time format (#158) diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/LICENSE b/Godeps/_workspace/src/github.com/Sirupsen/logrus/LICENSE new file mode 100644 index 0000000..f090cb4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Simon Eskildsen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/README.md b/Godeps/_workspace/src/github.com/Sirupsen/logrus/README.md new file mode 100644 index 0000000..bd9ffb6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/README.md @@ -0,0 +1,356 @@ +# Logrus :walrus: [![Build Status](https://travis-ci.org/Sirupsen/logrus.svg?branch=master)](https://travis-ci.org/Sirupsen/logrus) [![godoc reference](https://godoc.org/github.com/Sirupsen/logrus?status.png)][godoc] + +Logrus is a structured logger for Go (golang), completely API compatible with +the standard library logger. [Godoc][godoc]. **Please note the Logrus API is not +yet stable (pre 1.0). Logrus itself is completely stable and has been used in +many large deployments. The core API is unlikely to change much but please +version control your Logrus to make sure you aren't fetching latest `master` on +every build.** + +Nicely color-coded in development (when a TTY is attached, otherwise just +plain text): + +![Colored](http://i.imgur.com/PY7qMwd.png) + +With `log.Formatter = new(logrus.JSONFormatter)`, for easy parsing by logstash +or Splunk: + +```json +{"animal":"walrus","level":"info","msg":"A group of walrus emerges from the +ocean","size":10,"time":"2014-03-10 19:57:38.562264131 -0400 EDT"} + +{"level":"warning","msg":"The group's number increased tremendously!", +"number":122,"omg":true,"time":"2014-03-10 19:57:38.562471297 -0400 EDT"} + +{"animal":"walrus","level":"info","msg":"A giant walrus appears!", +"size":10,"time":"2014-03-10 19:57:38.562500591 -0400 EDT"} + +{"animal":"walrus","level":"info","msg":"Tremendously sized cow enters the ocean.", +"size":9,"time":"2014-03-10 19:57:38.562527896 -0400 EDT"} + +{"level":"fatal","msg":"The ice breaks!","number":100,"omg":true, +"time":"2014-03-10 19:57:38.562543128 -0400 EDT"} +``` + +With the default `log.Formatter = new(&log.TextFormatter{})` when a TTY is not +attached, the output is compatible with the +[logfmt](http://godoc.org/github.com/kr/logfmt) format: + +```text +time="2015-03-26T01:27:38-04:00" level=debug msg="Started observing beach" animal=walrus number=8 +time="2015-03-26T01:27:38-04:00" level=info msg="A group of walrus emerges from the ocean" animal=walrus size=10 +time="2015-03-26T01:27:38-04:00" level=warning msg="The group's number increased tremendously!" number=122 omg=true +time="2015-03-26T01:27:38-04:00" level=debug msg="Temperature changes" temperature=-4 +time="2015-03-26T01:27:38-04:00" level=panic msg="It's over 9000!" animal=orca size=9009 +time="2015-03-26T01:27:38-04:00" level=fatal msg="The ice breaks!" err=&{0x2082280c0 map[animal:orca size:9009] 2015-03-26 01:27:38.441574009 -0400 EDT panic It's over 9000!} number=100 omg=true +exit status 1 +``` + +#### Example + +The simplest way to use Logrus is simply the package-level exported logger: + +```go +package main + +import ( + log "github.com/Sirupsen/logrus" +) + +func main() { + log.WithFields(log.Fields{ + "animal": "walrus", + }).Info("A walrus appears") +} +``` + +Note that it's completely api-compatible with the stdlib logger, so you can +replace your `log` imports everywhere with `log "github.com/Sirupsen/logrus"` +and you'll now have the flexibility of Logrus. You can customize it all you +want: + +```go +package main + +import ( + "os" + log "github.com/Sirupsen/logrus" + "github.com/Sirupsen/logrus/hooks/airbrake" +) + +func init() { + // Log as JSON instead of the default ASCII formatter. + log.SetFormatter(&log.JSONFormatter{}) + + // Use the Airbrake hook to report errors that have Error severity or above to + // an exception tracker. You can create custom hooks, see the Hooks section. + log.AddHook(airbrake.NewHook("https://example.com", "xyz", "development")) + + // Output to stderr instead of stdout, could also be a file. + log.SetOutput(os.Stderr) + + // Only log the warning severity or above. + log.SetLevel(log.WarnLevel) +} + +func main() { + log.WithFields(log.Fields{ + "animal": "walrus", + "size": 10, + }).Info("A group of walrus emerges from the ocean") + + log.WithFields(log.Fields{ + "omg": true, + "number": 122, + }).Warn("The group's number increased tremendously!") + + log.WithFields(log.Fields{ + "omg": true, + "number": 100, + }).Fatal("The ice breaks!") + + // A common pattern is to re-use fields between logging statements by re-using + // the logrus.Entry returned from WithFields() + contextLogger := log.WithFields(log.Fields{ + "common": "this is a common field", + "other": "I also should be logged always", + }) + + contextLogger.Info("I'll be logged with common and other field") + contextLogger.Info("Me too") +} +``` + +For more advanced usage such as logging to multiple locations from the same +application, you can also create an instance of the `logrus` Logger: + +```go +package main + +import ( + "github.com/Sirupsen/logrus" +) + +// Create a new instance of the logger. You can have any number of instances. +var log = logrus.New() + +func main() { + // The API for setting attributes is a little different than the package level + // exported logger. See Godoc. + log.Out = os.Stderr + + log.WithFields(logrus.Fields{ + "animal": "walrus", + "size": 10, + }).Info("A group of walrus emerges from the ocean") +} +``` + +#### Fields + +Logrus encourages careful, structured logging though logging fields instead of +long, unparseable error messages. For example, instead of: `log.Fatalf("Failed +to send event %s to topic %s with key %d")`, you should log the much more +discoverable: + +```go +log.WithFields(log.Fields{ + "event": event, + "topic": topic, + "key": key, +}).Fatal("Failed to send event") +``` + +We've found this API forces you to think about logging in a way that produces +much more useful logging messages. We've been in countless situations where just +a single added field to a log statement that was already there would've saved us +hours. The `WithFields` call is optional. + +In general, with Logrus using any of the `printf`-family functions should be +seen as a hint you should add a field, however, you can still use the +`printf`-family functions with Logrus. + +#### Hooks + +You can add hooks for logging levels. For example to send errors to an exception +tracking service on `Error`, `Fatal` and `Panic`, info to StatsD or log to +multiple places simultaneously, e.g. syslog. + +Logrus comes with [built-in hooks](hooks/). Add those, or your custom hook, in +`init`: + +```go +import ( + log "github.com/Sirupsen/logrus" + "github.com/Sirupsen/logrus/hooks/airbrake" + logrus_syslog "github.com/Sirupsen/logrus/hooks/syslog" + "log/syslog" +) + +func init() { + log.AddHook(airbrake.NewHook("https://example.com", "xyz", "development")) + + hook, err := logrus_syslog.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "") + if err != nil { + log.Error("Unable to connect to local syslog daemon") + } else { + log.AddHook(hook) + } +} +``` + + +| Hook | Description | +| ----- | ----------- | +| [Airbrake](https://github.com/Sirupsen/logrus/blob/master/hooks/airbrake/airbrake.go) | Send errors to an exception tracking service compatible with the Airbrake API. Uses [`airbrake-go`](https://github.com/tobi/airbrake-go) behind the scenes. | +| [Papertrail](https://github.com/Sirupsen/logrus/blob/master/hooks/papertrail/papertrail.go) | Send errors to the Papertrail hosted logging service via UDP. | +| [Syslog](https://github.com/Sirupsen/logrus/blob/master/hooks/syslog/syslog.go) | Send errors to remote syslog server. Uses standard library `log/syslog` behind the scenes. | +| [BugSnag](https://github.com/Sirupsen/logrus/blob/master/hooks/bugsnag/bugsnag.go) | Send errors to the Bugsnag exception tracking service. | +| [Sentry](https://github.com/Sirupsen/logrus/blob/master/hooks/sentry/sentry.go) | Send errors to the Sentry error logging and aggregation service. | +| [Hiprus](https://github.com/nubo/hiprus) | Send errors to a channel in hipchat. | +| [Logrusly](https://github.com/sebest/logrusly) | Send logs to [Loggly](https://www.loggly.com/) | +| [Slackrus](https://github.com/johntdyer/slackrus) | Hook for Slack chat. | +| [Journalhook](https://github.com/wercker/journalhook) | Hook for logging to `systemd-journald` | +| [Graylog](https://github.com/gemnasium/logrus-hooks/tree/master/graylog) | Hook for logging to [Graylog](http://graylog2.org/) | +| [Raygun](https://github.com/squirkle/logrus-raygun-hook) | Hook for logging to [Raygun.io](http://raygun.io/) | +| [LFShook](https://github.com/rifflock/lfshook) | Hook for logging to the local filesystem | +| [Honeybadger](https://github.com/agonzalezro/logrus_honeybadger) | Hook for sending exceptions to Honeybadger | +| [Mail](https://github.com/zbindenren/logrus_mail) | Hook for sending exceptions via mail | +| [Rollrus](https://github.com/heroku/rollrus) | Hook for sending errors to rollbar | +| [Fluentd](https://github.com/evalphobia/logrus_fluent) | Hook for logging to fluentd | + +#### Level logging + +Logrus has six logging levels: Debug, Info, Warning, Error, Fatal and Panic. + +```go +log.Debug("Useful debugging information.") +log.Info("Something noteworthy happened!") +log.Warn("You should probably take a look at this.") +log.Error("Something failed but I'm not quitting.") +// Calls os.Exit(1) after logging +log.Fatal("Bye.") +// Calls panic() after logging +log.Panic("I'm bailing.") +``` + +You can set the logging level on a `Logger`, then it will only log entries with +that severity or anything above it: + +```go +// Will log anything that is info or above (warn, error, fatal, panic). Default. +log.SetLevel(log.InfoLevel) +``` + +It may be useful to set `log.Level = logrus.DebugLevel` in a debug or verbose +environment if your application has that. + +#### Entries + +Besides the fields added with `WithField` or `WithFields` some fields are +automatically added to all logging events: + +1. `time`. The timestamp when the entry was created. +2. `msg`. The logging message passed to `{Info,Warn,Error,Fatal,Panic}` after + the `AddFields` call. E.g. `Failed to send event.` +3. `level`. The logging level. E.g. `info`. + +#### Environments + +Logrus has no notion of environment. + +If you wish for hooks and formatters to only be used in specific environments, +you should handle that yourself. For example, if your application has a global +variable `Environment`, which is a string representation of the environment you +could do: + +```go +import ( + log "github.com/Sirupsen/logrus" +) + +init() { + // do something here to set environment depending on an environment variable + // or command-line flag + if Environment == "production" { + log.SetFormatter(&logrus.JSONFormatter{}) + } else { + // The TextFormatter is default, you don't actually have to do this. + log.SetFormatter(&log.TextFormatter{}) + } +} +``` + +This configuration is how `logrus` was intended to be used, but JSON in +production is mostly only useful if you do log aggregation with tools like +Splunk or Logstash. + +#### Formatters + +The built-in logging formatters are: + +* `logrus.TextFormatter`. Logs the event in colors if stdout is a tty, otherwise + without colors. + * *Note:* to force colored output when there is no TTY, set the `ForceColors` + field to `true`. To force no colored output even if there is a TTY set the + `DisableColors` field to `true` +* `logrus.JSONFormatter`. Logs fields as JSON. +* `logrus_logstash.LogstashFormatter`. Logs fields as Logstash Events (http://logstash.net). + + ```go + logrus.SetFormatter(&logrus_logstash.LogstashFormatter{Type: “application_name"}) + ``` + +Third party logging formatters: + +* [`zalgo`](https://github.com/aybabtme/logzalgo): invoking the P͉̫o̳̼̊w̖͈̰͎e̬͔̭͂r͚̼̹̲ ̫͓͉̳͈ō̠͕͖̚f̝͍̠ ͕̲̞͖͑Z̖̫̤̫ͪa͉̬͈̗l͖͎g̳̥o̰̥̅!̣͔̲̻͊̄ ̙̘̦̹̦. + +You can define your formatter by implementing the `Formatter` interface, +requiring a `Format` method. `Format` takes an `*Entry`. `entry.Data` is a +`Fields` type (`map[string]interface{}`) with all your fields as well as the +default ones (see Entries section above): + +```go +type MyJSONFormatter struct { +} + +log.SetFormatter(new(MyJSONFormatter)) + +func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) { + // Note this doesn't include Time, Level and Message which are available on + // the Entry. Consult `godoc` on information about those fields or read the + // source of the official loggers. + serialized, err := json.Marshal(entry.Data) + if err != nil { + return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err) + } + return append(serialized, '\n'), nil +} +``` + +#### Logger as an `io.Writer` + +Logrus can be transformed into an `io.Writer`. That writer is the end of an `io.Pipe` and it is your responsibility to close it. + +```go +w := logger.Writer() +defer w.Close() + +srv := http.Server{ + // create a stdlib log.Logger that writes to + // logrus.Logger. + ErrorLog: log.New(w, "", 0), +} +``` + +Each line written to that writer will be printed the usual way, using formatters +and hooks. The level for those entries is `info`. + +#### Rotation + +Log rotation is not provided with Logrus. Log rotation should be done by an +external program (like `logrotate(8)`) that can compress and delete old log +entries. It should not be a feature of the application-level logger. + + +[godoc]: https://godoc.org/github.com/Sirupsen/logrus diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/entry.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/entry.go new file mode 100644 index 0000000..699ea03 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/entry.go @@ -0,0 +1,254 @@ +package logrus + +import ( + "bytes" + "fmt" + "io" + "os" + "time" +) + +// An entry is the final or intermediate Logrus logging entry. It contains all +// the fields passed with WithField{,s}. It's finally logged when Debug, Info, +// Warn, Error, Fatal or Panic is called on it. These objects can be reused and +// passed around as much as you wish to avoid field duplication. +type Entry struct { + Logger *Logger + + // Contains all the fields set by the user. + Data Fields + + // Time at which the log entry was created + Time time.Time + + // Level the log entry was logged at: Debug, Info, Warn, Error, Fatal or Panic + Level Level + + // Message passed to Debug, Info, Warn, Error, Fatal or Panic + Message string +} + +func NewEntry(logger *Logger) *Entry { + return &Entry{ + Logger: logger, + // Default is three fields, give a little extra room + Data: make(Fields, 5), + } +} + +// Returns a reader for the entry, which is a proxy to the formatter. +func (entry *Entry) Reader() (*bytes.Buffer, error) { + serialized, err := entry.Logger.Formatter.Format(entry) + return bytes.NewBuffer(serialized), err +} + +// Returns the string representation from the reader and ultimately the +// formatter. +func (entry *Entry) String() (string, error) { + reader, err := entry.Reader() + if err != nil { + return "", err + } + + return reader.String(), err +} + +// Add a single field to the Entry. +func (entry *Entry) WithField(key string, value interface{}) *Entry { + return entry.WithFields(Fields{key: value}) +} + +// Add a map of fields to the Entry. +func (entry *Entry) WithFields(fields Fields) *Entry { + data := Fields{} + for k, v := range entry.Data { + data[k] = v + } + for k, v := range fields { + data[k] = v + } + return &Entry{Logger: entry.Logger, Data: data} +} + +func (entry *Entry) log(level Level, msg string) { + entry.Time = time.Now() + entry.Level = level + entry.Message = msg + + if err := entry.Logger.Hooks.Fire(level, entry); err != nil { + entry.Logger.mu.Lock() + fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err) + entry.Logger.mu.Unlock() + } + + reader, err := entry.Reader() + if err != nil { + entry.Logger.mu.Lock() + fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err) + entry.Logger.mu.Unlock() + } + + entry.Logger.mu.Lock() + defer entry.Logger.mu.Unlock() + + _, err = io.Copy(entry.Logger.Out, reader) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) + } + + // To avoid Entry#log() returning a value that only would make sense for + // panic() to use in Entry#Panic(), we avoid the allocation by checking + // directly here. + if level <= PanicLevel { + panic(entry) + } +} + +func (entry *Entry) Debug(args ...interface{}) { + if entry.Logger.Level >= DebugLevel { + entry.log(DebugLevel, fmt.Sprint(args...)) + } +} + +func (entry *Entry) Print(args ...interface{}) { + entry.Info(args...) +} + +func (entry *Entry) Info(args ...interface{}) { + if entry.Logger.Level >= InfoLevel { + entry.log(InfoLevel, fmt.Sprint(args...)) + } +} + +func (entry *Entry) Warn(args ...interface{}) { + if entry.Logger.Level >= WarnLevel { + entry.log(WarnLevel, fmt.Sprint(args...)) + } +} + +func (entry *Entry) Warning(args ...interface{}) { + entry.Warn(args...) +} + +func (entry *Entry) Error(args ...interface{}) { + if entry.Logger.Level >= ErrorLevel { + entry.log(ErrorLevel, fmt.Sprint(args...)) + } +} + +func (entry *Entry) Fatal(args ...interface{}) { + if entry.Logger.Level >= FatalLevel { + entry.log(FatalLevel, fmt.Sprint(args...)) + } + os.Exit(1) +} + +func (entry *Entry) Panic(args ...interface{}) { + if entry.Logger.Level >= PanicLevel { + entry.log(PanicLevel, fmt.Sprint(args...)) + } + panic(fmt.Sprint(args...)) +} + +// Entry Printf family functions + +func (entry *Entry) Debugf(format string, args ...interface{}) { + if entry.Logger.Level >= DebugLevel { + entry.Debug(fmt.Sprintf(format, args...)) + } +} + +func (entry *Entry) Infof(format string, args ...interface{}) { + if entry.Logger.Level >= InfoLevel { + entry.Info(fmt.Sprintf(format, args...)) + } +} + +func (entry *Entry) Printf(format string, args ...interface{}) { + entry.Infof(format, args...) +} + +func (entry *Entry) Warnf(format string, args ...interface{}) { + if entry.Logger.Level >= WarnLevel { + entry.Warn(fmt.Sprintf(format, args...)) + } +} + +func (entry *Entry) Warningf(format string, args ...interface{}) { + entry.Warnf(format, args...) +} + +func (entry *Entry) Errorf(format string, args ...interface{}) { + if entry.Logger.Level >= ErrorLevel { + entry.Error(fmt.Sprintf(format, args...)) + } +} + +func (entry *Entry) Fatalf(format string, args ...interface{}) { + if entry.Logger.Level >= FatalLevel { + entry.Fatal(fmt.Sprintf(format, args...)) + } + os.Exit(1) +} + +func (entry *Entry) Panicf(format string, args ...interface{}) { + if entry.Logger.Level >= PanicLevel { + entry.Panic(fmt.Sprintf(format, args...)) + } +} + +// Entry Println family functions + +func (entry *Entry) Debugln(args ...interface{}) { + if entry.Logger.Level >= DebugLevel { + entry.Debug(entry.sprintlnn(args...)) + } +} + +func (entry *Entry) Infoln(args ...interface{}) { + if entry.Logger.Level >= InfoLevel { + entry.Info(entry.sprintlnn(args...)) + } +} + +func (entry *Entry) Println(args ...interface{}) { + entry.Infoln(args...) +} + +func (entry *Entry) Warnln(args ...interface{}) { + if entry.Logger.Level >= WarnLevel { + entry.Warn(entry.sprintlnn(args...)) + } +} + +func (entry *Entry) Warningln(args ...interface{}) { + entry.Warnln(args...) +} + +func (entry *Entry) Errorln(args ...interface{}) { + if entry.Logger.Level >= ErrorLevel { + entry.Error(entry.sprintlnn(args...)) + } +} + +func (entry *Entry) Fatalln(args ...interface{}) { + if entry.Logger.Level >= FatalLevel { + entry.Fatal(entry.sprintlnn(args...)) + } + os.Exit(1) +} + +func (entry *Entry) Panicln(args ...interface{}) { + if entry.Logger.Level >= PanicLevel { + entry.Panic(entry.sprintlnn(args...)) + } +} + +// Sprintlnn => Sprint no newline. This is to get the behavior of how +// fmt.Sprintln where spaces are always added between operands, regardless of +// their type. Instead of vendoring the Sprintln implementation to spare a +// string allocation, we do the simplest thing. +func (entry *Entry) sprintlnn(args ...interface{}) string { + msg := fmt.Sprintln(args...) + return msg[:len(msg)-1] +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/entry_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/entry_test.go new file mode 100644 index 0000000..98717df --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/entry_test.go @@ -0,0 +1,53 @@ +package logrus + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEntryPanicln(t *testing.T) { + errBoom := fmt.Errorf("boom time") + + defer func() { + p := recover() + assert.NotNil(t, p) + + switch pVal := p.(type) { + case *Entry: + assert.Equal(t, "kaboom", pVal.Message) + assert.Equal(t, errBoom, pVal.Data["err"]) + default: + t.Fatalf("want type *Entry, got %T: %#v", pVal, pVal) + } + }() + + logger := New() + logger.Out = &bytes.Buffer{} + entry := NewEntry(logger) + entry.WithField("err", errBoom).Panicln("kaboom") +} + +func TestEntryPanicf(t *testing.T) { + errBoom := fmt.Errorf("boom again") + + defer func() { + p := recover() + assert.NotNil(t, p) + + switch pVal := p.(type) { + case *Entry: + assert.Equal(t, "kaboom true", pVal.Message) + assert.Equal(t, errBoom, pVal.Data["err"]) + default: + t.Fatalf("want type *Entry, got %T: %#v", pVal, pVal) + } + }() + + logger := New() + logger.Out = &bytes.Buffer{} + entry := NewEntry(logger) + entry.WithField("err", errBoom).Panicf("kaboom %v", true) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/examples/basic/basic.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/examples/basic/basic.go new file mode 100644 index 0000000..a1623ec --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/examples/basic/basic.go @@ -0,0 +1,50 @@ +package main + +import ( + "github.com/Sirupsen/logrus" +) + +var log = logrus.New() + +func init() { + log.Formatter = new(logrus.JSONFormatter) + log.Formatter = new(logrus.TextFormatter) // default + log.Level = logrus.DebugLevel +} + +func main() { + defer func() { + err := recover() + if err != nil { + log.WithFields(logrus.Fields{ + "omg": true, + "err": err, + "number": 100, + }).Fatal("The ice breaks!") + } + }() + + log.WithFields(logrus.Fields{ + "animal": "walrus", + "number": 8, + }).Debug("Started observing beach") + + log.WithFields(logrus.Fields{ + "animal": "walrus", + "size": 10, + }).Info("A group of walrus emerges from the ocean") + + log.WithFields(logrus.Fields{ + "omg": true, + "number": 122, + }).Warn("The group's number increased tremendously!") + + log.WithFields(logrus.Fields{ + "temperature": -4, + }).Debug("Temperature changes") + + log.WithFields(logrus.Fields{ + "animal": "orca", + "size": 9009, + }).Panic("It's over 9000!") +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/examples/hook/hook.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/examples/hook/hook.go new file mode 100644 index 0000000..cb5759a --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/examples/hook/hook.go @@ -0,0 +1,30 @@ +package main + +import ( + "github.com/Sirupsen/logrus" + "github.com/Sirupsen/logrus/hooks/airbrake" +) + +var log = logrus.New() + +func init() { + log.Formatter = new(logrus.TextFormatter) // default + log.Hooks.Add(airbrake.NewHook("https://example.com", "xyz", "development")) +} + +func main() { + log.WithFields(logrus.Fields{ + "animal": "walrus", + "size": 10, + }).Info("A group of walrus emerges from the ocean") + + log.WithFields(logrus.Fields{ + "omg": true, + "number": 122, + }).Warn("The group's number increased tremendously!") + + log.WithFields(logrus.Fields{ + "omg": true, + "number": 100, + }).Fatal("The ice breaks!") +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/exported.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/exported.go new file mode 100644 index 0000000..a67e1b8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/exported.go @@ -0,0 +1,188 @@ +package logrus + +import ( + "io" +) + +var ( + // std is the name of the standard logger in stdlib `log` + std = New() +) + +func StandardLogger() *Logger { + return std +} + +// SetOutput sets the standard logger output. +func SetOutput(out io.Writer) { + std.mu.Lock() + defer std.mu.Unlock() + std.Out = out +} + +// SetFormatter sets the standard logger formatter. +func SetFormatter(formatter Formatter) { + std.mu.Lock() + defer std.mu.Unlock() + std.Formatter = formatter +} + +// SetLevel sets the standard logger level. +func SetLevel(level Level) { + std.mu.Lock() + defer std.mu.Unlock() + std.Level = level +} + +// GetLevel returns the standard logger level. +func GetLevel() Level { + std.mu.Lock() + defer std.mu.Unlock() + return std.Level +} + +// AddHook adds a hook to the standard logger hooks. +func AddHook(hook Hook) { + std.mu.Lock() + defer std.mu.Unlock() + std.Hooks.Add(hook) +} + +// WithField creates an entry from the standard logger and adds a field to +// it. If you want multiple fields, use `WithFields`. +// +// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal +// or Panic on the Entry it returns. +func WithField(key string, value interface{}) *Entry { + return std.WithField(key, value) +} + +// WithFields creates an entry from the standard logger and adds multiple +// fields to it. This is simply a helper for `WithField`, invoking it +// once for each field. +// +// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal +// or Panic on the Entry it returns. +func WithFields(fields Fields) *Entry { + return std.WithFields(fields) +} + +// Debug logs a message at level Debug on the standard logger. +func Debug(args ...interface{}) { + std.Debug(args...) +} + +// Print logs a message at level Info on the standard logger. +func Print(args ...interface{}) { + std.Print(args...) +} + +// Info logs a message at level Info on the standard logger. +func Info(args ...interface{}) { + std.Info(args...) +} + +// Warn logs a message at level Warn on the standard logger. +func Warn(args ...interface{}) { + std.Warn(args...) +} + +// Warning logs a message at level Warn on the standard logger. +func Warning(args ...interface{}) { + std.Warning(args...) +} + +// Error logs a message at level Error on the standard logger. +func Error(args ...interface{}) { + std.Error(args...) +} + +// Panic logs a message at level Panic on the standard logger. +func Panic(args ...interface{}) { + std.Panic(args...) +} + +// Fatal logs a message at level Fatal on the standard logger. +func Fatal(args ...interface{}) { + std.Fatal(args...) +} + +// Debugf logs a message at level Debug on the standard logger. +func Debugf(format string, args ...interface{}) { + std.Debugf(format, args...) +} + +// Printf logs a message at level Info on the standard logger. +func Printf(format string, args ...interface{}) { + std.Printf(format, args...) +} + +// Infof logs a message at level Info on the standard logger. +func Infof(format string, args ...interface{}) { + std.Infof(format, args...) +} + +// Warnf logs a message at level Warn on the standard logger. +func Warnf(format string, args ...interface{}) { + std.Warnf(format, args...) +} + +// Warningf logs a message at level Warn on the standard logger. +func Warningf(format string, args ...interface{}) { + std.Warningf(format, args...) +} + +// Errorf logs a message at level Error on the standard logger. +func Errorf(format string, args ...interface{}) { + std.Errorf(format, args...) +} + +// Panicf logs a message at level Panic on the standard logger. +func Panicf(format string, args ...interface{}) { + std.Panicf(format, args...) +} + +// Fatalf logs a message at level Fatal on the standard logger. +func Fatalf(format string, args ...interface{}) { + std.Fatalf(format, args...) +} + +// Debugln logs a message at level Debug on the standard logger. +func Debugln(args ...interface{}) { + std.Debugln(args...) +} + +// Println logs a message at level Info on the standard logger. +func Println(args ...interface{}) { + std.Println(args...) +} + +// Infoln logs a message at level Info on the standard logger. +func Infoln(args ...interface{}) { + std.Infoln(args...) +} + +// Warnln logs a message at level Warn on the standard logger. +func Warnln(args ...interface{}) { + std.Warnln(args...) +} + +// Warningln logs a message at level Warn on the standard logger. +func Warningln(args ...interface{}) { + std.Warningln(args...) +} + +// Errorln logs a message at level Error on the standard logger. +func Errorln(args ...interface{}) { + std.Errorln(args...) +} + +// Panicln logs a message at level Panic on the standard logger. +func Panicln(args ...interface{}) { + std.Panicln(args...) +} + +// Fatalln logs a message at level Fatal on the standard logger. +func Fatalln(args ...interface{}) { + std.Fatalln(args...) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatter.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatter.go new file mode 100644 index 0000000..104d689 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatter.go @@ -0,0 +1,48 @@ +package logrus + +import "time" + +const DefaultTimestampFormat = time.RFC3339 + +// The Formatter interface is used to implement a custom Formatter. It takes an +// `Entry`. It exposes all the fields, including the default ones: +// +// * `entry.Data["msg"]`. The message passed from Info, Warn, Error .. +// * `entry.Data["time"]`. The timestamp. +// * `entry.Data["level"]. The level the entry was logged at. +// +// Any additional fields added with `WithField` or `WithFields` are also in +// `entry.Data`. Format is expected to return an array of bytes which are then +// logged to `logger.Out`. +type Formatter interface { + Format(*Entry) ([]byte, error) +} + +// This is to not silently overwrite `time`, `msg` and `level` fields when +// dumping it. If this code wasn't there doing: +// +// logrus.WithField("level", 1).Info("hello") +// +// Would just silently drop the user provided level. Instead with this code +// it'll logged as: +// +// {"level": "info", "fields.level": 1, "msg": "hello", "time": "..."} +// +// It's not exported because it's still using Data in an opinionated way. It's to +// avoid code duplication between the two default formatters. +func prefixFieldClashes(data Fields) { + _, ok := data["time"] + if ok { + data["fields.time"] = data["time"] + } + + _, ok = data["msg"] + if ok { + data["fields.msg"] = data["msg"] + } + + _, ok = data["level"] + if ok { + data["fields.level"] = data["level"] + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatter_bench_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatter_bench_test.go new file mode 100644 index 0000000..c6d290c --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatter_bench_test.go @@ -0,0 +1,98 @@ +package logrus + +import ( + "fmt" + "testing" + "time" +) + +// smallFields is a small size data set for benchmarking +var smallFields = Fields{ + "foo": "bar", + "baz": "qux", + "one": "two", + "three": "four", +} + +// largeFields is a large size data set for benchmarking +var largeFields = Fields{ + "foo": "bar", + "baz": "qux", + "one": "two", + "three": "four", + "five": "six", + "seven": "eight", + "nine": "ten", + "eleven": "twelve", + "thirteen": "fourteen", + "fifteen": "sixteen", + "seventeen": "eighteen", + "nineteen": "twenty", + "a": "b", + "c": "d", + "e": "f", + "g": "h", + "i": "j", + "k": "l", + "m": "n", + "o": "p", + "q": "r", + "s": "t", + "u": "v", + "w": "x", + "y": "z", + "this": "will", + "make": "thirty", + "entries": "yeah", +} + +var errorFields = Fields{ + "foo": fmt.Errorf("bar"), + "baz": fmt.Errorf("qux"), +} + +func BenchmarkErrorTextFormatter(b *testing.B) { + doBenchmark(b, &TextFormatter{DisableColors: true}, errorFields) +} + +func BenchmarkSmallTextFormatter(b *testing.B) { + doBenchmark(b, &TextFormatter{DisableColors: true}, smallFields) +} + +func BenchmarkLargeTextFormatter(b *testing.B) { + doBenchmark(b, &TextFormatter{DisableColors: true}, largeFields) +} + +func BenchmarkSmallColoredTextFormatter(b *testing.B) { + doBenchmark(b, &TextFormatter{ForceColors: true}, smallFields) +} + +func BenchmarkLargeColoredTextFormatter(b *testing.B) { + doBenchmark(b, &TextFormatter{ForceColors: true}, largeFields) +} + +func BenchmarkSmallJSONFormatter(b *testing.B) { + doBenchmark(b, &JSONFormatter{}, smallFields) +} + +func BenchmarkLargeJSONFormatter(b *testing.B) { + doBenchmark(b, &JSONFormatter{}, largeFields) +} + +func doBenchmark(b *testing.B, formatter Formatter, fields Fields) { + entry := &Entry{ + Time: time.Time{}, + Level: InfoLevel, + Message: "message", + Data: fields, + } + var d []byte + var err error + for i := 0; i < b.N; i++ { + d, err = formatter.Format(entry) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(d))) + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatters/logstash/logstash.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatters/logstash/logstash.go new file mode 100644 index 0000000..8ea93dd --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatters/logstash/logstash.go @@ -0,0 +1,56 @@ +package logstash + +import ( + "encoding/json" + "fmt" + + "github.com/Sirupsen/logrus" +) + +// Formatter generates json in logstash format. +// Logstash site: http://logstash.net/ +type LogstashFormatter struct { + Type string // if not empty use for logstash type field. + + // TimestampFormat sets the format used for timestamps. + TimestampFormat string +} + +func (f *LogstashFormatter) Format(entry *logrus.Entry) ([]byte, error) { + entry.Data["@version"] = 1 + + if f.TimestampFormat == "" { + f.TimestampFormat = logrus.DefaultTimestampFormat + } + + entry.Data["@timestamp"] = entry.Time.Format(f.TimestampFormat) + + // set message field + v, ok := entry.Data["message"] + if ok { + entry.Data["fields.message"] = v + } + entry.Data["message"] = entry.Message + + // set level field + v, ok = entry.Data["level"] + if ok { + entry.Data["fields.level"] = v + } + entry.Data["level"] = entry.Level.String() + + // set type field + if f.Type != "" { + v, ok = entry.Data["type"] + if ok { + entry.Data["fields.type"] = v + } + entry.Data["type"] = f.Type + } + + serialized, err := json.Marshal(entry.Data) + if err != nil { + return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err) + } + return append(serialized, '\n'), nil +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatters/logstash/logstash_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatters/logstash/logstash_test.go new file mode 100644 index 0000000..d8814a0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/formatters/logstash/logstash_test.go @@ -0,0 +1,52 @@ +package logstash + +import ( + "bytes" + "encoding/json" + "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestLogstashFormatter(t *testing.T) { + assert := assert.New(t) + + lf := LogstashFormatter{Type: "abc"} + + fields := logrus.Fields{ + "message": "def", + "level": "ijk", + "type": "lmn", + "one": 1, + "pi": 3.14, + "bool": true, + } + + entry := logrus.WithFields(fields) + entry.Message = "msg" + entry.Level = logrus.InfoLevel + + b, _ := lf.Format(entry) + + var data map[string]interface{} + dec := json.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + dec.Decode(&data) + + // base fields + assert.Equal(json.Number("1"), data["@version"]) + assert.NotEmpty(data["@timestamp"]) + assert.Equal("abc", data["type"]) + assert.Equal("msg", data["message"]) + assert.Equal("info", data["level"]) + + // substituted fields + assert.Equal("def", data["fields.message"]) + assert.Equal("ijk", data["fields.level"]) + assert.Equal("lmn", data["fields.type"]) + + // formats + assert.Equal(json.Number("1"), data["one"]) + assert.Equal(json.Number("3.14"), data["pi"]) + assert.Equal(true, data["bool"]) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hook_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hook_test.go new file mode 100644 index 0000000..13f34cb --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hook_test.go @@ -0,0 +1,122 @@ +package logrus + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type TestHook struct { + Fired bool +} + +func (hook *TestHook) Fire(entry *Entry) error { + hook.Fired = true + return nil +} + +func (hook *TestHook) Levels() []Level { + return []Level{ + DebugLevel, + InfoLevel, + WarnLevel, + ErrorLevel, + FatalLevel, + PanicLevel, + } +} + +func TestHookFires(t *testing.T) { + hook := new(TestHook) + + LogAndAssertJSON(t, func(log *Logger) { + log.Hooks.Add(hook) + assert.Equal(t, hook.Fired, false) + + log.Print("test") + }, func(fields Fields) { + assert.Equal(t, hook.Fired, true) + }) +} + +type ModifyHook struct { +} + +func (hook *ModifyHook) Fire(entry *Entry) error { + entry.Data["wow"] = "whale" + return nil +} + +func (hook *ModifyHook) Levels() []Level { + return []Level{ + DebugLevel, + InfoLevel, + WarnLevel, + ErrorLevel, + FatalLevel, + PanicLevel, + } +} + +func TestHookCanModifyEntry(t *testing.T) { + hook := new(ModifyHook) + + LogAndAssertJSON(t, func(log *Logger) { + log.Hooks.Add(hook) + log.WithField("wow", "elephant").Print("test") + }, func(fields Fields) { + assert.Equal(t, fields["wow"], "whale") + }) +} + +func TestCanFireMultipleHooks(t *testing.T) { + hook1 := new(ModifyHook) + hook2 := new(TestHook) + + LogAndAssertJSON(t, func(log *Logger) { + log.Hooks.Add(hook1) + log.Hooks.Add(hook2) + + log.WithField("wow", "elephant").Print("test") + }, func(fields Fields) { + assert.Equal(t, fields["wow"], "whale") + assert.Equal(t, hook2.Fired, true) + }) +} + +type ErrorHook struct { + Fired bool +} + +func (hook *ErrorHook) Fire(entry *Entry) error { + hook.Fired = true + return nil +} + +func (hook *ErrorHook) Levels() []Level { + return []Level{ + ErrorLevel, + } +} + +func TestErrorHookShouldntFireOnInfo(t *testing.T) { + hook := new(ErrorHook) + + LogAndAssertJSON(t, func(log *Logger) { + log.Hooks.Add(hook) + log.Info("test") + }, func(fields Fields) { + assert.Equal(t, hook.Fired, false) + }) +} + +func TestErrorHookShouldFireOnError(t *testing.T) { + hook := new(ErrorHook) + + LogAndAssertJSON(t, func(log *Logger) { + log.Hooks.Add(hook) + log.Error("test") + }, func(fields Fields) { + assert.Equal(t, hook.Fired, true) + }) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks.go new file mode 100644 index 0000000..3f151cd --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks.go @@ -0,0 +1,34 @@ +package logrus + +// A hook to be fired when logging on the logging levels returned from +// `Levels()` on your implementation of the interface. Note that this is not +// fired in a goroutine or a channel with workers, you should handle such +// functionality yourself if your call is non-blocking and you don't wish for +// the logging calls for levels returned from `Levels()` to block. +type Hook interface { + Levels() []Level + Fire(*Entry) error +} + +// Internal type for storing the hooks on a logger instance. +type LevelHooks map[Level][]Hook + +// Add a hook to an instance of logger. This is called with +// `log.Hooks.Add(new(MyHook))` where `MyHook` implements the `Hook` interface. +func (hooks LevelHooks) Add(hook Hook) { + for _, level := range hook.Levels() { + hooks[level] = append(hooks[level], hook) + } +} + +// Fire all the hooks for the passed level. Used by `entry.log` to fire +// appropriate hooks for a log entry. +func (hooks LevelHooks) Fire(level Level, entry *Entry) error { + for _, hook := range hooks[level] { + if err := hook.Fire(entry); err != nil { + return err + } + } + + return nil +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/airbrake/airbrake.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/airbrake/airbrake.go new file mode 100644 index 0000000..b0502c3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/airbrake/airbrake.go @@ -0,0 +1,54 @@ +package airbrake + +import ( + "errors" + "fmt" + + "github.com/Sirupsen/logrus" + "github.com/tobi/airbrake-go" +) + +// AirbrakeHook to send exceptions to an exception-tracking service compatible +// with the Airbrake API. +type airbrakeHook struct { + APIKey string + Endpoint string + Environment string +} + +func NewHook(endpoint, apiKey, env string) *airbrakeHook { + return &airbrakeHook{ + APIKey: apiKey, + Endpoint: endpoint, + Environment: env, + } +} + +func (hook *airbrakeHook) Fire(entry *logrus.Entry) error { + airbrake.ApiKey = hook.APIKey + airbrake.Endpoint = hook.Endpoint + airbrake.Environment = hook.Environment + + var notifyErr error + err, ok := entry.Data["error"].(error) + if ok { + notifyErr = err + } else { + notifyErr = errors.New(entry.Message) + } + + airErr := airbrake.Notify(notifyErr) + if airErr != nil { + return fmt.Errorf("Failed to send error to Airbrake: %s", airErr) + } + + return nil +} + +func (hook *airbrakeHook) Levels() []logrus.Level { + return []logrus.Level{ + logrus.ErrorLevel, + logrus.FatalLevel, + logrus.PanicLevel, + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/airbrake/airbrake_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/airbrake/airbrake_test.go new file mode 100644 index 0000000..058a91e --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/airbrake/airbrake_test.go @@ -0,0 +1,133 @@ +package airbrake + +import ( + "encoding/xml" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Sirupsen/logrus" +) + +type notice struct { + Error NoticeError `xml:"error"` +} +type NoticeError struct { + Class string `xml:"class"` + Message string `xml:"message"` +} + +type customErr struct { + msg string +} + +func (e *customErr) Error() string { + return e.msg +} + +const ( + testAPIKey = "abcxyz" + testEnv = "development" + expectedClass = "*airbrake.customErr" + expectedMsg = "foo" + unintendedMsg = "Airbrake will not see this string" +) + +var ( + noticeError = make(chan NoticeError, 1) +) + +// TestLogEntryMessageReceived checks if invoking Logrus' log.Error +// method causes an XML payload containing the log entry message is received +// by a HTTP server emulating an Airbrake-compatible endpoint. +func TestLogEntryMessageReceived(t *testing.T) { + log := logrus.New() + ts := startAirbrakeServer(t) + defer ts.Close() + + hook := NewHook(ts.URL, testAPIKey, "production") + log.Hooks.Add(hook) + + log.Error(expectedMsg) + + select { + case received := <-noticeError: + if received.Message != expectedMsg { + t.Errorf("Unexpected message received: %s", received.Message) + } + case <-time.After(time.Second): + t.Error("Timed out; no notice received by Airbrake API") + } +} + +// TestLogEntryMessageReceived confirms that, when passing an error type using +// logrus.Fields, a HTTP server emulating an Airbrake endpoint receives the +// error message returned by the Error() method on the error interface +// rather than the logrus.Entry.Message string. +func TestLogEntryWithErrorReceived(t *testing.T) { + log := logrus.New() + ts := startAirbrakeServer(t) + defer ts.Close() + + hook := NewHook(ts.URL, testAPIKey, "production") + log.Hooks.Add(hook) + + log.WithFields(logrus.Fields{ + "error": &customErr{expectedMsg}, + }).Error(unintendedMsg) + + select { + case received := <-noticeError: + if received.Message != expectedMsg { + t.Errorf("Unexpected message received: %s", received.Message) + } + if received.Class != expectedClass { + t.Errorf("Unexpected error class: %s", received.Class) + } + case <-time.After(time.Second): + t.Error("Timed out; no notice received by Airbrake API") + } +} + +// TestLogEntryWithNonErrorTypeNotReceived confirms that, when passing a +// non-error type using logrus.Fields, a HTTP server emulating an Airbrake +// endpoint receives the logrus.Entry.Message string. +// +// Only error types are supported when setting the 'error' field using +// logrus.WithFields(). +func TestLogEntryWithNonErrorTypeNotReceived(t *testing.T) { + log := logrus.New() + ts := startAirbrakeServer(t) + defer ts.Close() + + hook := NewHook(ts.URL, testAPIKey, "production") + log.Hooks.Add(hook) + + log.WithFields(logrus.Fields{ + "error": expectedMsg, + }).Error(unintendedMsg) + + select { + case received := <-noticeError: + if received.Message != unintendedMsg { + t.Errorf("Unexpected message received: %s", received.Message) + } + case <-time.After(time.Second): + t.Error("Timed out; no notice received by Airbrake API") + } +} + +func startAirbrakeServer(t *testing.T) *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var notice notice + if err := xml.NewDecoder(r.Body).Decode(¬ice); err != nil { + t.Error(err) + } + r.Body.Close() + + noticeError <- notice.Error + })) + + return ts +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/bugsnag/bugsnag.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/bugsnag/bugsnag.go new file mode 100644 index 0000000..d20a0f5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/bugsnag/bugsnag.go @@ -0,0 +1,68 @@ +package logrus_bugsnag + +import ( + "errors" + + "github.com/Sirupsen/logrus" + "github.com/bugsnag/bugsnag-go" +) + +type bugsnagHook struct{} + +// ErrBugsnagUnconfigured is returned if NewBugsnagHook is called before +// bugsnag.Configure. Bugsnag must be configured before the hook. +var ErrBugsnagUnconfigured = errors.New("bugsnag must be configured before installing this logrus hook") + +// ErrBugsnagSendFailed indicates that the hook failed to submit an error to +// bugsnag. The error was successfully generated, but `bugsnag.Notify()` +// failed. +type ErrBugsnagSendFailed struct { + err error +} + +func (e ErrBugsnagSendFailed) Error() string { + return "failed to send error to Bugsnag: " + e.err.Error() +} + +// NewBugsnagHook initializes a logrus hook which sends exceptions to an +// exception-tracking service compatible with the Bugsnag API. Before using +// this hook, you must call bugsnag.Configure(). The returned object should be +// registered with a log via `AddHook()` +// +// Entries that trigger an Error, Fatal or Panic should now include an "error" +// field to send to Bugsnag. +func NewBugsnagHook() (*bugsnagHook, error) { + if bugsnag.Config.APIKey == "" { + return nil, ErrBugsnagUnconfigured + } + return &bugsnagHook{}, nil +} + +// Fire forwards an error to Bugsnag. Given a logrus.Entry, it extracts the +// "error" field (or the Message if the error isn't present) and sends it off. +func (hook *bugsnagHook) Fire(entry *logrus.Entry) error { + var notifyErr error + err, ok := entry.Data["error"].(error) + if ok { + notifyErr = err + } else { + notifyErr = errors.New(entry.Message) + } + + bugsnagErr := bugsnag.Notify(notifyErr) + if bugsnagErr != nil { + return ErrBugsnagSendFailed{bugsnagErr} + } + + return nil +} + +// Levels enumerates the log levels on which the error should be forwarded to +// bugsnag: everything at or above the "Error" level. +func (hook *bugsnagHook) Levels() []logrus.Level { + return []logrus.Level{ + logrus.ErrorLevel, + logrus.FatalLevel, + logrus.PanicLevel, + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/bugsnag/bugsnag_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/bugsnag/bugsnag_test.go new file mode 100644 index 0000000..e9ea298 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/bugsnag/bugsnag_test.go @@ -0,0 +1,64 @@ +package logrus_bugsnag + +import ( + "encoding/json" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Sirupsen/logrus" + "github.com/bugsnag/bugsnag-go" +) + +type notice struct { + Events []struct { + Exceptions []struct { + Message string `json:"message"` + } `json:"exceptions"` + } `json:"events"` +} + +func TestNoticeReceived(t *testing.T) { + msg := make(chan string, 1) + expectedMsg := "foo" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var notice notice + data, _ := ioutil.ReadAll(r.Body) + if err := json.Unmarshal(data, ¬ice); err != nil { + t.Error(err) + } + _ = r.Body.Close() + + msg <- notice.Events[0].Exceptions[0].Message + })) + defer ts.Close() + + hook := &bugsnagHook{} + + bugsnag.Configure(bugsnag.Configuration{ + Endpoint: ts.URL, + ReleaseStage: "production", + APIKey: "12345678901234567890123456789012", + Synchronous: true, + }) + + log := logrus.New() + log.Hooks.Add(hook) + + log.WithFields(logrus.Fields{ + "error": errors.New(expectedMsg), + }).Error("Bugsnag will not see this string") + + select { + case received := <-msg: + if received != expectedMsg { + t.Errorf("Unexpected message received: %s", received) + } + case <-time.After(time.Second): + t.Error("Timed out; no notice received by Bugsnag API") + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/README.md b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/README.md new file mode 100644 index 0000000..ae61e92 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/README.md @@ -0,0 +1,28 @@ +# Papertrail Hook for Logrus :walrus: + +[Papertrail](https://papertrailapp.com) provides hosted log management. Once stored in Papertrail, you can [group](http://help.papertrailapp.com/kb/how-it-works/groups/) your logs on various dimensions, [search](http://help.papertrailapp.com/kb/how-it-works/search-syntax) them, and trigger [alerts](http://help.papertrailapp.com/kb/how-it-works/alerts). + +In most deployments, you'll want to send logs to Papertrail via their [remote_syslog](http://help.papertrailapp.com/kb/configuration/configuring-centralized-logging-from-text-log-files-in-unix/) daemon, which requires no application-specific configuration. This hook is intended for relatively low-volume logging, likely in managed cloud hosting deployments where installing `remote_syslog` is not possible. + +## Usage + +You can find your Papertrail UDP port on your [Papertrail account page](https://papertrailapp.com/account/destinations). Substitute it below for `YOUR_PAPERTRAIL_UDP_PORT`. + +For `YOUR_APP_NAME`, substitute a short string that will readily identify your application or service in the logs. + +```go +import ( + "log/syslog" + "github.com/Sirupsen/logrus" + "github.com/Sirupsen/logrus/hooks/papertrail" +) + +func main() { + log := logrus.New() + hook, err := logrus_papertrail.NewPapertrailHook("logs.papertrailapp.com", YOUR_PAPERTRAIL_UDP_PORT, YOUR_APP_NAME) + + if err == nil { + log.Hooks.Add(hook) + } +} +``` diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/papertrail.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/papertrail.go new file mode 100644 index 0000000..c0f10c1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/papertrail.go @@ -0,0 +1,55 @@ +package logrus_papertrail + +import ( + "fmt" + "net" + "os" + "time" + + "github.com/Sirupsen/logrus" +) + +const ( + format = "Jan 2 15:04:05" +) + +// PapertrailHook to send logs to a logging service compatible with the Papertrail API. +type PapertrailHook struct { + Host string + Port int + AppName string + UDPConn net.Conn +} + +// NewPapertrailHook creates a hook to be added to an instance of logger. +func NewPapertrailHook(host string, port int, appName string) (*PapertrailHook, error) { + conn, err := net.Dial("udp", fmt.Sprintf("%s:%d", host, port)) + return &PapertrailHook{host, port, appName, conn}, err +} + +// Fire is called when a log event is fired. +func (hook *PapertrailHook) Fire(entry *logrus.Entry) error { + date := time.Now().Format(format) + msg, _ := entry.String() + payload := fmt.Sprintf("<22> %s %s: %s", date, hook.AppName, msg) + + bytesWritten, err := hook.UDPConn.Write([]byte(payload)) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to send log line to Papertrail via UDP. Wrote %d bytes before error: %v", bytesWritten, err) + return err + } + + return nil +} + +// Levels returns the available logging levels. +func (hook *PapertrailHook) Levels() []logrus.Level { + return []logrus.Level{ + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, + logrus.WarnLevel, + logrus.InfoLevel, + logrus.DebugLevel, + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/papertrail_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/papertrail_test.go new file mode 100644 index 0000000..96318d0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/papertrail/papertrail_test.go @@ -0,0 +1,26 @@ +package logrus_papertrail + +import ( + "fmt" + "testing" + + "github.com/Sirupsen/logrus" + "github.com/stvp/go-udp-testing" +) + +func TestWritingToUDP(t *testing.T) { + port := 16661 + udp.SetAddr(fmt.Sprintf(":%d", port)) + + hook, err := NewPapertrailHook("localhost", port, "test") + if err != nil { + t.Errorf("Unable to connect to local UDP server.") + } + + log := logrus.New() + log.Hooks.Add(hook) + + udp.ShouldReceive(t, "foo", func() { + log.Info("foo") + }) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/README.md b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/README.md new file mode 100644 index 0000000..31de654 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/README.md @@ -0,0 +1,111 @@ +# Sentry Hook for Logrus :walrus: + +[Sentry](https://getsentry.com) provides both self-hosted and hosted +solutions for exception tracking. +Both client and server are +[open source](https://github.com/getsentry/sentry). + +## Usage + +Every sentry application defined on the server gets a different +[DSN](https://www.getsentry.com/docs/). In the example below replace +`YOUR_DSN` with the one created for your application. + +```go +import ( + "github.com/Sirupsen/logrus" + "github.com/Sirupsen/logrus/hooks/sentry" +) + +func main() { + log := logrus.New() + hook, err := logrus_sentry.NewSentryHook(YOUR_DSN, []logrus.Level{ + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, + }) + + if err == nil { + log.Hooks.Add(hook) + } +} +``` + +If you wish to initialize a SentryHook with tags, you can use the `NewWithTagsSentryHook` constructor to provide default tags: + +```go +tags := map[string]string{ + "site": "example.com", +} +levels := []logrus.Level{ + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, +} +hook, err := logrus_sentry.NewWithTagsSentryHook(YOUR_DSN, tags, levels) + +``` + +If you wish to initialize a SentryHook with an already initialized raven client, you can use +the `NewWithClientSentryHook` constructor: + +```go +import ( + "github.com/Sirupsen/logrus" + "github.com/Sirupsen/logrus/hooks/sentry" + "github.com/getsentry/raven-go" +) + +func main() { + log := logrus.New() + + client, err := raven.New(YOUR_DSN) + if err != nil { + log.Fatal(err) + } + + hook, err := logrus_sentry.NewWithClientSentryHook(client, []logrus.Level{ + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, + }) + + if err == nil { + log.Hooks.Add(hook) + } +} + +hook, err := NewWithClientSentryHook(client, []logrus.Level{ + logrus.ErrorLevel, +}) +``` + +## Special fields + +Some logrus fields have a special meaning in this hook, +these are `server_name`, `logger` and `http_request`. +When logs are sent to sentry these fields are treated differently. +- `server_name` (also known as hostname) is the name of the server which +is logging the event (hostname.example.com) +- `logger` is the part of the application which is logging the event. +In go this usually means setting it to the name of the package. +- `http_request` is the in-coming request(*http.Request). The detailed request data are sent to Sentry. + +## Timeout + +`Timeout` is the time the sentry hook will wait for a response +from the sentry server. + +If this time elapses with no response from +the server an error will be returned. + +If `Timeout` is set to 0 the SentryHook will not wait for a reply +and will assume a correct delivery. + +The SentryHook has a default timeout of `100 milliseconds` when created +with a call to `NewSentryHook`. This can be changed by assigning a value to the `Timeout` field: + +```go +hook, _ := logrus_sentry.NewSentryHook(...) +hook.Timeout = 20*time.Second +``` diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/sentry.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/sentry.go new file mode 100644 index 0000000..cf88098 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/sentry.go @@ -0,0 +1,137 @@ +package logrus_sentry + +import ( + "fmt" + "net/http" + "time" + + "github.com/Sirupsen/logrus" + "github.com/getsentry/raven-go" +) + +var ( + severityMap = map[logrus.Level]raven.Severity{ + logrus.DebugLevel: raven.DEBUG, + logrus.InfoLevel: raven.INFO, + logrus.WarnLevel: raven.WARNING, + logrus.ErrorLevel: raven.ERROR, + logrus.FatalLevel: raven.FATAL, + logrus.PanicLevel: raven.FATAL, + } +) + +func getAndDel(d logrus.Fields, key string) (string, bool) { + var ( + ok bool + v interface{} + val string + ) + if v, ok = d[key]; !ok { + return "", false + } + + if val, ok = v.(string); !ok { + return "", false + } + delete(d, key) + return val, true +} + +func getAndDelRequest(d logrus.Fields, key string) (*http.Request, bool) { + var ( + ok bool + v interface{} + req *http.Request + ) + if v, ok = d[key]; !ok { + return nil, false + } + if req, ok = v.(*http.Request); !ok || req == nil { + return nil, false + } + delete(d, key) + return req, true +} + +// SentryHook delivers logs to a sentry server. +type SentryHook struct { + // Timeout sets the time to wait for a delivery error from the sentry server. + // If this is set to zero the server will not wait for any response and will + // consider the message correctly sent + Timeout time.Duration + + client *raven.Client + levels []logrus.Level +} + +// NewSentryHook creates a hook to be added to an instance of logger +// and initializes the raven client. +// This method sets the timeout to 100 milliseconds. +func NewSentryHook(DSN string, levels []logrus.Level) (*SentryHook, error) { + client, err := raven.New(DSN) + if err != nil { + return nil, err + } + return &SentryHook{100 * time.Millisecond, client, levels}, nil +} + +// NewWithTagsSentryHook creates a hook with tags to be added to an instance +// of logger and initializes the raven client. This method sets the timeout to +// 100 milliseconds. +func NewWithTagsSentryHook(DSN string, tags map[string]string, levels []logrus.Level) (*SentryHook, error) { + client, err := raven.NewWithTags(DSN, tags) + if err != nil { + return nil, err + } + return &SentryHook{100 * time.Millisecond, client, levels}, nil +} + +// NewWithClientSentryHook creates a hook using an initialized raven client. +// This method sets the timeout to 100 milliseconds. +func NewWithClientSentryHook(client *raven.Client, levels []logrus.Level) (*SentryHook, error) { + return &SentryHook{100 * time.Millisecond, client, levels}, nil +} + +// Called when an event should be sent to sentry +// Special fields that sentry uses to give more information to the server +// are extracted from entry.Data (if they are found) +// These fields are: logger, server_name and http_request +func (hook *SentryHook) Fire(entry *logrus.Entry) error { + packet := &raven.Packet{ + Message: entry.Message, + Timestamp: raven.Timestamp(entry.Time), + Level: severityMap[entry.Level], + Platform: "go", + } + + d := entry.Data + + if logger, ok := getAndDel(d, "logger"); ok { + packet.Logger = logger + } + if serverName, ok := getAndDel(d, "server_name"); ok { + packet.ServerName = serverName + } + if req, ok := getAndDelRequest(d, "http_request"); ok { + packet.Interfaces = append(packet.Interfaces, raven.NewHttp(req)) + } + packet.Extra = map[string]interface{}(d) + + _, errCh := hook.client.Capture(packet, nil) + timeout := hook.Timeout + if timeout != 0 { + timeoutCh := time.After(timeout) + select { + case err := <-errCh: + return err + case <-timeoutCh: + return fmt.Errorf("no response from sentry server in %s", timeout) + } + } + return nil +} + +// Levels returns the available logging levels. +func (hook *SentryHook) Levels() []logrus.Level { + return hook.levels +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/sentry_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/sentry_test.go new file mode 100644 index 0000000..4a97bc6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/sentry/sentry_test.go @@ -0,0 +1,154 @@ +package logrus_sentry + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/Sirupsen/logrus" + "github.com/getsentry/raven-go" +) + +const ( + message = "error message" + server_name = "testserver.internal" + logger_name = "test.logger" +) + +func getTestLogger() *logrus.Logger { + l := logrus.New() + l.Out = ioutil.Discard + return l +} + +func WithTestDSN(t *testing.T, tf func(string, <-chan *raven.Packet)) { + pch := make(chan *raven.Packet, 1) + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + defer req.Body.Close() + d := json.NewDecoder(req.Body) + p := &raven.Packet{} + err := d.Decode(p) + if err != nil { + t.Fatal(err.Error()) + } + + pch <- p + })) + defer s.Close() + + fragments := strings.SplitN(s.URL, "://", 2) + dsn := fmt.Sprintf( + "%s://public:secret@%s/sentry/project-id", + fragments[0], + fragments[1], + ) + tf(dsn, pch) +} + +func TestSpecialFields(t *testing.T) { + WithTestDSN(t, func(dsn string, pch <-chan *raven.Packet) { + logger := getTestLogger() + + hook, err := NewSentryHook(dsn, []logrus.Level{ + logrus.ErrorLevel, + }) + + if err != nil { + t.Fatal(err.Error()) + } + logger.Hooks.Add(hook) + + req, _ := http.NewRequest("GET", "url", nil) + logger.WithFields(logrus.Fields{ + "server_name": server_name, + "logger": logger_name, + "http_request": req, + }).Error(message) + + packet := <-pch + if packet.Logger != logger_name { + t.Errorf("logger should have been %s, was %s", logger_name, packet.Logger) + } + + if packet.ServerName != server_name { + t.Errorf("server_name should have been %s, was %s", server_name, packet.ServerName) + } + }) +} + +func TestSentryHandler(t *testing.T) { + WithTestDSN(t, func(dsn string, pch <-chan *raven.Packet) { + logger := getTestLogger() + hook, err := NewSentryHook(dsn, []logrus.Level{ + logrus.ErrorLevel, + }) + if err != nil { + t.Fatal(err.Error()) + } + logger.Hooks.Add(hook) + + logger.Error(message) + packet := <-pch + if packet.Message != message { + t.Errorf("message should have been %s, was %s", message, packet.Message) + } + }) +} + +func TestSentryWithClient(t *testing.T) { + WithTestDSN(t, func(dsn string, pch <-chan *raven.Packet) { + logger := getTestLogger() + + client, _ := raven.New(dsn) + + hook, err := NewWithClientSentryHook(client, []logrus.Level{ + logrus.ErrorLevel, + }) + if err != nil { + t.Fatal(err.Error()) + } + logger.Hooks.Add(hook) + + logger.Error(message) + packet := <-pch + if packet.Message != message { + t.Errorf("message should have been %s, was %s", message, packet.Message) + } + }) +} + +func TestSentryTags(t *testing.T) { + WithTestDSN(t, func(dsn string, pch <-chan *raven.Packet) { + logger := getTestLogger() + tags := map[string]string{ + "site": "test", + } + levels := []logrus.Level{ + logrus.ErrorLevel, + } + + hook, err := NewWithTagsSentryHook(dsn, tags, levels) + if err != nil { + t.Fatal(err.Error()) + } + + logger.Hooks.Add(hook) + + logger.Error(message) + packet := <-pch + expected := raven.Tags{ + raven.Tag{ + Key: "site", + Value: "test", + }, + } + if !reflect.DeepEqual(packet.Tags, expected) { + t.Errorf("message should have been %s, was %s", message, packet.Message) + } + }) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/README.md b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/README.md new file mode 100644 index 0000000..4dbb8e7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/README.md @@ -0,0 +1,20 @@ +# Syslog Hooks for Logrus :walrus: + +## Usage + +```go +import ( + "log/syslog" + "github.com/Sirupsen/logrus" + logrus_syslog "github.com/Sirupsen/logrus/hooks/syslog" +) + +func main() { + log := logrus.New() + hook, err := logrus_syslog.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "") + + if err == nil { + log.Hooks.Add(hook) + } +} +``` diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/syslog.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/syslog.go new file mode 100644 index 0000000..b6fa374 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/syslog.go @@ -0,0 +1,59 @@ +package logrus_syslog + +import ( + "fmt" + "github.com/Sirupsen/logrus" + "log/syslog" + "os" +) + +// SyslogHook to send logs via syslog. +type SyslogHook struct { + Writer *syslog.Writer + SyslogNetwork string + SyslogRaddr string +} + +// Creates a hook to be added to an instance of logger. This is called with +// `hook, err := NewSyslogHook("udp", "localhost:514", syslog.LOG_DEBUG, "")` +// `if err == nil { log.Hooks.Add(hook) }` +func NewSyslogHook(network, raddr string, priority syslog.Priority, tag string) (*SyslogHook, error) { + w, err := syslog.Dial(network, raddr, priority, tag) + return &SyslogHook{w, network, raddr}, err +} + +func (hook *SyslogHook) Fire(entry *logrus.Entry) error { + line, err := entry.String() + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err) + return err + } + + switch entry.Level { + case logrus.PanicLevel: + return hook.Writer.Crit(line) + case logrus.FatalLevel: + return hook.Writer.Crit(line) + case logrus.ErrorLevel: + return hook.Writer.Err(line) + case logrus.WarnLevel: + return hook.Writer.Warning(line) + case logrus.InfoLevel: + return hook.Writer.Info(line) + case logrus.DebugLevel: + return hook.Writer.Debug(line) + default: + return nil + } +} + +func (hook *SyslogHook) Levels() []logrus.Level { + return []logrus.Level{ + logrus.PanicLevel, + logrus.FatalLevel, + logrus.ErrorLevel, + logrus.WarnLevel, + logrus.InfoLevel, + logrus.DebugLevel, + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/syslog_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/syslog_test.go new file mode 100644 index 0000000..42762dc --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/hooks/syslog/syslog_test.go @@ -0,0 +1,26 @@ +package logrus_syslog + +import ( + "github.com/Sirupsen/logrus" + "log/syslog" + "testing" +) + +func TestLocalhostAddAndPrint(t *testing.T) { + log := logrus.New() + hook, err := NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "") + + if err != nil { + t.Errorf("Unable to connect to local syslog.") + } + + log.Hooks.Add(hook) + + for _, level := range hook.Levels() { + if len(log.Hooks[level]) != 1 { + t.Errorf("SyslogHook was not added. The length of log.Hooks[%v]: %v", level, len(log.Hooks[level])) + } + } + + log.Info("Congratulations!") +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/json_formatter.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/json_formatter.go new file mode 100644 index 0000000..2ad6dc5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/json_formatter.go @@ -0,0 +1,41 @@ +package logrus + +import ( + "encoding/json" + "fmt" +) + +type JSONFormatter struct { + // TimestampFormat sets the format used for marshaling timestamps. + TimestampFormat string +} + +func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) { + data := make(Fields, len(entry.Data)+3) + for k, v := range entry.Data { + switch v := v.(type) { + case error: + // Otherwise errors are ignored by `encoding/json` + // https://github.com/Sirupsen/logrus/issues/137 + data[k] = v.Error() + default: + data[k] = v + } + } + prefixFieldClashes(data) + + timestampFormat := f.TimestampFormat + if timestampFormat == "" { + timestampFormat = DefaultTimestampFormat + } + + data["time"] = entry.Time.Format(timestampFormat) + data["msg"] = entry.Message + data["level"] = entry.Level.String() + + serialized, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err) + } + return append(serialized, '\n'), nil +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/json_formatter_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/json_formatter_test.go new file mode 100644 index 0000000..1d70873 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/json_formatter_test.go @@ -0,0 +1,120 @@ +package logrus + +import ( + "encoding/json" + "errors" + + "testing" +) + +func TestErrorNotLost(t *testing.T) { + formatter := &JSONFormatter{} + + b, err := formatter.Format(WithField("error", errors.New("wild walrus"))) + if err != nil { + t.Fatal("Unable to format entry: ", err) + } + + entry := make(map[string]interface{}) + err = json.Unmarshal(b, &entry) + if err != nil { + t.Fatal("Unable to unmarshal formatted entry: ", err) + } + + if entry["error"] != "wild walrus" { + t.Fatal("Error field not set") + } +} + +func TestErrorNotLostOnFieldNotNamedError(t *testing.T) { + formatter := &JSONFormatter{} + + b, err := formatter.Format(WithField("omg", errors.New("wild walrus"))) + if err != nil { + t.Fatal("Unable to format entry: ", err) + } + + entry := make(map[string]interface{}) + err = json.Unmarshal(b, &entry) + if err != nil { + t.Fatal("Unable to unmarshal formatted entry: ", err) + } + + if entry["omg"] != "wild walrus" { + t.Fatal("Error field not set") + } +} + +func TestFieldClashWithTime(t *testing.T) { + formatter := &JSONFormatter{} + + b, err := formatter.Format(WithField("time", "right now!")) + if err != nil { + t.Fatal("Unable to format entry: ", err) + } + + entry := make(map[string]interface{}) + err = json.Unmarshal(b, &entry) + if err != nil { + t.Fatal("Unable to unmarshal formatted entry: ", err) + } + + if entry["fields.time"] != "right now!" { + t.Fatal("fields.time not set to original time field") + } + + if entry["time"] != "0001-01-01T00:00:00Z" { + t.Fatal("time field not set to current time, was: ", entry["time"]) + } +} + +func TestFieldClashWithMsg(t *testing.T) { + formatter := &JSONFormatter{} + + b, err := formatter.Format(WithField("msg", "something")) + if err != nil { + t.Fatal("Unable to format entry: ", err) + } + + entry := make(map[string]interface{}) + err = json.Unmarshal(b, &entry) + if err != nil { + t.Fatal("Unable to unmarshal formatted entry: ", err) + } + + if entry["fields.msg"] != "something" { + t.Fatal("fields.msg not set to original msg field") + } +} + +func TestFieldClashWithLevel(t *testing.T) { + formatter := &JSONFormatter{} + + b, err := formatter.Format(WithField("level", "something")) + if err != nil { + t.Fatal("Unable to format entry: ", err) + } + + entry := make(map[string]interface{}) + err = json.Unmarshal(b, &entry) + if err != nil { + t.Fatal("Unable to unmarshal formatted entry: ", err) + } + + if entry["fields.level"] != "something" { + t.Fatal("fields.level not set to original level field") + } +} + +func TestJSONEntryEndsWithNewline(t *testing.T) { + formatter := &JSONFormatter{} + + b, err := formatter.Format(WithField("level", "something")) + if err != nil { + t.Fatal("Unable to format entry: ", err) + } + + if b[len(b)-1] != '\n' { + t.Fatal("Expected JSON log entry to end with a newline") + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/logger.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/logger.go new file mode 100644 index 0000000..e4974bf --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/logger.go @@ -0,0 +1,206 @@ +package logrus + +import ( + "io" + "os" + "sync" +) + +type Logger struct { + // The logs are `io.Copy`'d to this in a mutex. It's common to set this to a + // file, or leave it default which is `os.Stdout`. You can also set this to + // something more adventorous, such as logging to Kafka. + Out io.Writer + // Hooks for the logger instance. These allow firing events based on logging + // levels and log entries. For example, to send errors to an error tracking + // service, log to StatsD or dump the core on fatal errors. + Hooks LevelHooks + // All log entries pass through the formatter before logged to Out. The + // included formatters are `TextFormatter` and `JSONFormatter` for which + // TextFormatter is the default. In development (when a TTY is attached) it + // logs with colors, but to a file it wouldn't. You can easily implement your + // own that implements the `Formatter` interface, see the `README` or included + // formatters for examples. + Formatter Formatter + // The logging level the logger should log at. This is typically (and defaults + // to) `logrus.Info`, which allows Info(), Warn(), Error() and Fatal() to be + // logged. `logrus.Debug` is useful in + Level Level + // Used to sync writing to the log. + mu sync.Mutex +} + +// Creates a new logger. Configuration should be set by changing `Formatter`, +// `Out` and `Hooks` directly on the default logger instance. You can also just +// instantiate your own: +// +// var log = &Logger{ +// Out: os.Stderr, +// Formatter: new(JSONFormatter), +// Hooks: make(LevelHooks), +// Level: logrus.DebugLevel, +// } +// +// It's recommended to make this a global instance called `log`. +func New() *Logger { + return &Logger{ + Out: os.Stderr, + Formatter: new(TextFormatter), + Hooks: make(LevelHooks), + Level: InfoLevel, + } +} + +// Adds a field to the log entry, note that you it doesn't log until you call +// Debug, Print, Info, Warn, Fatal or Panic. It only creates a log entry. +// Ff you want multiple fields, use `WithFields`. +func (logger *Logger) WithField(key string, value interface{}) *Entry { + return NewEntry(logger).WithField(key, value) +} + +// Adds a struct of fields to the log entry. All it does is call `WithField` for +// each `Field`. +func (logger *Logger) WithFields(fields Fields) *Entry { + return NewEntry(logger).WithFields(fields) +} + +func (logger *Logger) Debugf(format string, args ...interface{}) { + if logger.Level >= DebugLevel { + NewEntry(logger).Debugf(format, args...) + } +} + +func (logger *Logger) Infof(format string, args ...interface{}) { + if logger.Level >= InfoLevel { + NewEntry(logger).Infof(format, args...) + } +} + +func (logger *Logger) Printf(format string, args ...interface{}) { + NewEntry(logger).Printf(format, args...) +} + +func (logger *Logger) Warnf(format string, args ...interface{}) { + if logger.Level >= WarnLevel { + NewEntry(logger).Warnf(format, args...) + } +} + +func (logger *Logger) Warningf(format string, args ...interface{}) { + if logger.Level >= WarnLevel { + NewEntry(logger).Warnf(format, args...) + } +} + +func (logger *Logger) Errorf(format string, args ...interface{}) { + if logger.Level >= ErrorLevel { + NewEntry(logger).Errorf(format, args...) + } +} + +func (logger *Logger) Fatalf(format string, args ...interface{}) { + if logger.Level >= FatalLevel { + NewEntry(logger).Fatalf(format, args...) + } + os.Exit(1) +} + +func (logger *Logger) Panicf(format string, args ...interface{}) { + if logger.Level >= PanicLevel { + NewEntry(logger).Panicf(format, args...) + } +} + +func (logger *Logger) Debug(args ...interface{}) { + if logger.Level >= DebugLevel { + NewEntry(logger).Debug(args...) + } +} + +func (logger *Logger) Info(args ...interface{}) { + if logger.Level >= InfoLevel { + NewEntry(logger).Info(args...) + } +} + +func (logger *Logger) Print(args ...interface{}) { + NewEntry(logger).Info(args...) +} + +func (logger *Logger) Warn(args ...interface{}) { + if logger.Level >= WarnLevel { + NewEntry(logger).Warn(args...) + } +} + +func (logger *Logger) Warning(args ...interface{}) { + if logger.Level >= WarnLevel { + NewEntry(logger).Warn(args...) + } +} + +func (logger *Logger) Error(args ...interface{}) { + if logger.Level >= ErrorLevel { + NewEntry(logger).Error(args...) + } +} + +func (logger *Logger) Fatal(args ...interface{}) { + if logger.Level >= FatalLevel { + NewEntry(logger).Fatal(args...) + } + os.Exit(1) +} + +func (logger *Logger) Panic(args ...interface{}) { + if logger.Level >= PanicLevel { + NewEntry(logger).Panic(args...) + } +} + +func (logger *Logger) Debugln(args ...interface{}) { + if logger.Level >= DebugLevel { + NewEntry(logger).Debugln(args...) + } +} + +func (logger *Logger) Infoln(args ...interface{}) { + if logger.Level >= InfoLevel { + NewEntry(logger).Infoln(args...) + } +} + +func (logger *Logger) Println(args ...interface{}) { + NewEntry(logger).Println(args...) +} + +func (logger *Logger) Warnln(args ...interface{}) { + if logger.Level >= WarnLevel { + NewEntry(logger).Warnln(args...) + } +} + +func (logger *Logger) Warningln(args ...interface{}) { + if logger.Level >= WarnLevel { + NewEntry(logger).Warnln(args...) + } +} + +func (logger *Logger) Errorln(args ...interface{}) { + if logger.Level >= ErrorLevel { + NewEntry(logger).Errorln(args...) + } +} + +func (logger *Logger) Fatalln(args ...interface{}) { + if logger.Level >= FatalLevel { + NewEntry(logger).Fatalln(args...) + } + os.Exit(1) +} + +func (logger *Logger) Panicln(args ...interface{}) { + if logger.Level >= PanicLevel { + NewEntry(logger).Panicln(args...) + } +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/logrus.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/logrus.go new file mode 100644 index 0000000..43ee12e --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/logrus.go @@ -0,0 +1,94 @@ +package logrus + +import ( + "fmt" + "log" +) + +// Fields type, used to pass to `WithFields`. +type Fields map[string]interface{} + +// Level type +type Level uint8 + +// Convert the Level to a string. E.g. PanicLevel becomes "panic". +func (level Level) String() string { + switch level { + case DebugLevel: + return "debug" + case InfoLevel: + return "info" + case WarnLevel: + return "warning" + case ErrorLevel: + return "error" + case FatalLevel: + return "fatal" + case PanicLevel: + return "panic" + } + + return "unknown" +} + +// ParseLevel takes a string level and returns the Logrus log level constant. +func ParseLevel(lvl string) (Level, error) { + switch lvl { + case "panic": + return PanicLevel, nil + case "fatal": + return FatalLevel, nil + case "error": + return ErrorLevel, nil + case "warn", "warning": + return WarnLevel, nil + case "info": + return InfoLevel, nil + case "debug": + return DebugLevel, nil + } + + var l Level + return l, fmt.Errorf("not a valid logrus Level: %q", lvl) +} + +// These are the different logging levels. You can set the logging level to log +// on your instance of logger, obtained with `logrus.New()`. +const ( + // PanicLevel level, highest level of severity. Logs and then calls panic with the + // message passed to Debug, Info, ... + PanicLevel Level = iota + // FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the + // logging level is set to Panic. + FatalLevel + // ErrorLevel level. Logs. Used for errors that should definitely be noted. + // Commonly used for hooks to send errors to an error tracking service. + ErrorLevel + // WarnLevel level. Non-critical entries that deserve eyes. + WarnLevel + // InfoLevel level. General operational entries about what's going on inside the + // application. + InfoLevel + // DebugLevel level. Usually only enabled when debugging. Very verbose logging. + DebugLevel +) + +// Won't compile if StdLogger can't be realized by a log.Logger +var _ StdLogger = &log.Logger{} + +// StdLogger is what your logrus-enabled library should take, that way +// it'll accept a stdlib logger and a logrus logger. There's no standard +// interface, this is the closest we get, unfortunately. +type StdLogger interface { + Print(...interface{}) + Printf(string, ...interface{}) + Println(...interface{}) + + Fatal(...interface{}) + Fatalf(string, ...interface{}) + Fatalln(...interface{}) + + Panic(...interface{}) + Panicf(string, ...interface{}) + Panicln(...interface{}) +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/logrus_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/logrus_test.go new file mode 100644 index 0000000..efaacea --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/logrus_test.go @@ -0,0 +1,301 @@ +package logrus + +import ( + "bytes" + "encoding/json" + "strconv" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func LogAndAssertJSON(t *testing.T, log func(*Logger), assertions func(fields Fields)) { + var buffer bytes.Buffer + var fields Fields + + logger := New() + logger.Out = &buffer + logger.Formatter = new(JSONFormatter) + + log(logger) + + err := json.Unmarshal(buffer.Bytes(), &fields) + assert.Nil(t, err) + + assertions(fields) +} + +func LogAndAssertText(t *testing.T, log func(*Logger), assertions func(fields map[string]string)) { + var buffer bytes.Buffer + + logger := New() + logger.Out = &buffer + logger.Formatter = &TextFormatter{ + DisableColors: true, + } + + log(logger) + + fields := make(map[string]string) + for _, kv := range strings.Split(buffer.String(), " ") { + if !strings.Contains(kv, "=") { + continue + } + kvArr := strings.Split(kv, "=") + key := strings.TrimSpace(kvArr[0]) + val := kvArr[1] + if kvArr[1][0] == '"' { + var err error + val, err = strconv.Unquote(val) + assert.NoError(t, err) + } + fields[key] = val + } + assertions(fields) +} + +func TestPrint(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Print("test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test") + assert.Equal(t, fields["level"], "info") + }) +} + +func TestInfo(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Info("test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test") + assert.Equal(t, fields["level"], "info") + }) +} + +func TestWarn(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Warn("test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test") + assert.Equal(t, fields["level"], "warning") + }) +} + +func TestInfolnShouldAddSpacesBetweenStrings(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Infoln("test", "test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test test") + }) +} + +func TestInfolnShouldAddSpacesBetweenStringAndNonstring(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Infoln("test", 10) + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test 10") + }) +} + +func TestInfolnShouldAddSpacesBetweenTwoNonStrings(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Infoln(10, 10) + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "10 10") + }) +} + +func TestInfoShouldAddSpacesBetweenTwoNonStrings(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Infoln(10, 10) + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "10 10") + }) +} + +func TestInfoShouldNotAddSpacesBetweenStringAndNonstring(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Info("test", 10) + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test10") + }) +} + +func TestInfoShouldNotAddSpacesBetweenStrings(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.Info("test", "test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "testtest") + }) +} + +func TestWithFieldsShouldAllowAssignments(t *testing.T) { + var buffer bytes.Buffer + var fields Fields + + logger := New() + logger.Out = &buffer + logger.Formatter = new(JSONFormatter) + + localLog := logger.WithFields(Fields{ + "key1": "value1", + }) + + localLog.WithField("key2", "value2").Info("test") + err := json.Unmarshal(buffer.Bytes(), &fields) + assert.Nil(t, err) + + assert.Equal(t, "value2", fields["key2"]) + assert.Equal(t, "value1", fields["key1"]) + + buffer = bytes.Buffer{} + fields = Fields{} + localLog.Info("test") + err = json.Unmarshal(buffer.Bytes(), &fields) + assert.Nil(t, err) + + _, ok := fields["key2"] + assert.Equal(t, false, ok) + assert.Equal(t, "value1", fields["key1"]) +} + +func TestUserSuppliedFieldDoesNotOverwriteDefaults(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.WithField("msg", "hello").Info("test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test") + }) +} + +func TestUserSuppliedMsgFieldHasPrefix(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.WithField("msg", "hello").Info("test") + }, func(fields Fields) { + assert.Equal(t, fields["msg"], "test") + assert.Equal(t, fields["fields.msg"], "hello") + }) +} + +func TestUserSuppliedTimeFieldHasPrefix(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.WithField("time", "hello").Info("test") + }, func(fields Fields) { + assert.Equal(t, fields["fields.time"], "hello") + }) +} + +func TestUserSuppliedLevelFieldHasPrefix(t *testing.T) { + LogAndAssertJSON(t, func(log *Logger) { + log.WithField("level", 1).Info("test") + }, func(fields Fields) { + assert.Equal(t, fields["level"], "info") + assert.Equal(t, fields["fields.level"], 1.0) // JSON has floats only + }) +} + +func TestDefaultFieldsAreNotPrefixed(t *testing.T) { + LogAndAssertText(t, func(log *Logger) { + ll := log.WithField("herp", "derp") + ll.Info("hello") + ll.Info("bye") + }, func(fields map[string]string) { + for _, fieldName := range []string{"fields.level", "fields.time", "fields.msg"} { + if _, ok := fields[fieldName]; ok { + t.Fatalf("should not have prefixed %q: %v", fieldName, fields) + } + } + }) +} + +func TestDoubleLoggingDoesntPrefixPreviousFields(t *testing.T) { + + var buffer bytes.Buffer + var fields Fields + + logger := New() + logger.Out = &buffer + logger.Formatter = new(JSONFormatter) + + llog := logger.WithField("context", "eating raw fish") + + llog.Info("looks delicious") + + err := json.Unmarshal(buffer.Bytes(), &fields) + assert.NoError(t, err, "should have decoded first message") + assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields") + assert.Equal(t, fields["msg"], "looks delicious") + assert.Equal(t, fields["context"], "eating raw fish") + + buffer.Reset() + + llog.Warn("omg it is!") + + err = json.Unmarshal(buffer.Bytes(), &fields) + assert.NoError(t, err, "should have decoded second message") + assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields") + assert.Equal(t, fields["msg"], "omg it is!") + assert.Equal(t, fields["context"], "eating raw fish") + assert.Nil(t, fields["fields.msg"], "should not have prefixed previous `msg` entry") + +} + +func TestConvertLevelToString(t *testing.T) { + assert.Equal(t, "debug", DebugLevel.String()) + assert.Equal(t, "info", InfoLevel.String()) + assert.Equal(t, "warning", WarnLevel.String()) + assert.Equal(t, "error", ErrorLevel.String()) + assert.Equal(t, "fatal", FatalLevel.String()) + assert.Equal(t, "panic", PanicLevel.String()) +} + +func TestParseLevel(t *testing.T) { + l, err := ParseLevel("panic") + assert.Nil(t, err) + assert.Equal(t, PanicLevel, l) + + l, err = ParseLevel("fatal") + assert.Nil(t, err) + assert.Equal(t, FatalLevel, l) + + l, err = ParseLevel("error") + assert.Nil(t, err) + assert.Equal(t, ErrorLevel, l) + + l, err = ParseLevel("warn") + assert.Nil(t, err) + assert.Equal(t, WarnLevel, l) + + l, err = ParseLevel("warning") + assert.Nil(t, err) + assert.Equal(t, WarnLevel, l) + + l, err = ParseLevel("info") + assert.Nil(t, err) + assert.Equal(t, InfoLevel, l) + + l, err = ParseLevel("debug") + assert.Nil(t, err) + assert.Equal(t, DebugLevel, l) + + l, err = ParseLevel("invalid") + assert.Equal(t, "not a valid logrus Level: \"invalid\"", err.Error()) +} + +func TestGetSetLevelRace(t *testing.T) { + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + if i%2 == 0 { + SetLevel(InfoLevel) + } else { + GetLevel() + } + }(i) + + } + wg.Wait() +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_bsd.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_bsd.go new file mode 100644 index 0000000..71f8d67 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_bsd.go @@ -0,0 +1,9 @@ +// +build darwin freebsd openbsd netbsd dragonfly + +package logrus + +import "syscall" + +const ioctlReadTermios = syscall.TIOCGETA + +type Termios syscall.Termios diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_linux.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_linux.go new file mode 100644 index 0000000..a2c0b40 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_linux.go @@ -0,0 +1,12 @@ +// Based on ssh/terminal: +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package logrus + +import "syscall" + +const ioctlReadTermios = syscall.TCGETS + +type Termios syscall.Termios diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_notwindows.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_notwindows.go new file mode 100644 index 0000000..4bb5376 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_notwindows.go @@ -0,0 +1,21 @@ +// Based on ssh/terminal: +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux darwin freebsd openbsd netbsd dragonfly + +package logrus + +import ( + "syscall" + "unsafe" +) + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal() bool { + fd := syscall.Stdout + var termios Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_windows.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_windows.go new file mode 100644 index 0000000..2e09f6f --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/terminal_windows.go @@ -0,0 +1,27 @@ +// Based on ssh/terminal: +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package logrus + +import ( + "syscall" + "unsafe" +) + +var kernel32 = syscall.NewLazyDLL("kernel32.dll") + +var ( + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") +) + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal() bool { + fd := syscall.Stdout + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/text_formatter.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/text_formatter.go new file mode 100644 index 0000000..17cc298 --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/text_formatter.go @@ -0,0 +1,159 @@ +package logrus + +import ( + "bytes" + "fmt" + "runtime" + "sort" + "strings" + "time" +) + +const ( + nocolor = 0 + red = 31 + green = 32 + yellow = 33 + blue = 34 + gray = 37 +) + +var ( + baseTimestamp time.Time + isTerminal bool +) + +func init() { + baseTimestamp = time.Now() + isTerminal = IsTerminal() +} + +func miniTS() int { + return int(time.Since(baseTimestamp) / time.Second) +} + +type TextFormatter struct { + // Set to true to bypass checking for a TTY before outputting colors. + ForceColors bool + + // Force disabling colors. + DisableColors bool + + // Disable timestamp logging. useful when output is redirected to logging + // system that already adds timestamps. + DisableTimestamp bool + + // Enable logging the full timestamp when a TTY is attached instead of just + // the time passed since beginning of execution. + FullTimestamp bool + + // TimestampFormat to use for display when a full timestamp is printed + TimestampFormat string + + // The fields are sorted by default for a consistent output. For applications + // that log extremely frequently and don't use the JSON formatter this may not + // be desired. + DisableSorting bool +} + +func (f *TextFormatter) Format(entry *Entry) ([]byte, error) { + var keys []string = make([]string, 0, len(entry.Data)) + for k := range entry.Data { + keys = append(keys, k) + } + + if !f.DisableSorting { + sort.Strings(keys) + } + + b := &bytes.Buffer{} + + prefixFieldClashes(entry.Data) + + isColorTerminal := isTerminal && (runtime.GOOS != "windows") + isColored := (f.ForceColors || isColorTerminal) && !f.DisableColors + + timestampFormat := f.TimestampFormat + if timestampFormat == "" { + timestampFormat = DefaultTimestampFormat + } + if isColored { + f.printColored(b, entry, keys, timestampFormat) + } else { + if !f.DisableTimestamp { + f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat)) + } + f.appendKeyValue(b, "level", entry.Level.String()) + f.appendKeyValue(b, "msg", entry.Message) + for _, key := range keys { + f.appendKeyValue(b, key, entry.Data[key]) + } + } + + b.WriteByte('\n') + return b.Bytes(), nil +} + +func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []string, timestampFormat string) { + var levelColor int + switch entry.Level { + case DebugLevel: + levelColor = gray + case WarnLevel: + levelColor = yellow + case ErrorLevel, FatalLevel, PanicLevel: + levelColor = red + default: + levelColor = blue + } + + levelText := strings.ToUpper(entry.Level.String())[0:4] + + if !f.FullTimestamp { + fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%04d] %-44s ", levelColor, levelText, miniTS(), entry.Message) + } else { + fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %-44s ", levelColor, levelText, entry.Time.Format(timestampFormat), entry.Message) + } + for _, k := range keys { + v := entry.Data[k] + fmt.Fprintf(b, " \x1b[%dm%s\x1b[0m=%+v", levelColor, k, v) + } +} + +func needsQuoting(text string) bool { + for _, ch := range text { + if !((ch >= 'a' && ch <= 'z') || + (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || + ch == '-' || ch == '.') { + return false + } + } + return true +} + +func (f *TextFormatter) appendKeyValue(b *bytes.Buffer, key string, value interface{}) { + + b.WriteString(key) + b.WriteByte('=') + + switch value := value.(type) { + case string: + if needsQuoting(value) { + b.WriteString(value) + } else { + fmt.Fprintf(b, "%q", value) + } + case error: + errmsg := value.Error() + if needsQuoting(errmsg) { + b.WriteString(errmsg) + } else { + fmt.Fprintf(b, "%q", value) + } + default: + fmt.Fprint(b, value) + } + + b.WriteByte(' ') +} diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/text_formatter_test.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/text_formatter_test.go new file mode 100644 index 0000000..e25a44f --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/text_formatter_test.go @@ -0,0 +1,61 @@ +package logrus + +import ( + "bytes" + "errors" + "testing" + "time" +) + +func TestQuoting(t *testing.T) { + tf := &TextFormatter{DisableColors: true} + + checkQuoting := func(q bool, value interface{}) { + b, _ := tf.Format(WithField("test", value)) + idx := bytes.Index(b, ([]byte)("test=")) + cont := bytes.Contains(b[idx+5:], []byte{'"'}) + if cont != q { + if q { + t.Errorf("quoting expected for: %#v", value) + } else { + t.Errorf("quoting not expected for: %#v", value) + } + } + } + + checkQuoting(false, "abcd") + checkQuoting(false, "v1.0") + checkQuoting(false, "1234567890") + checkQuoting(true, "/foobar") + checkQuoting(true, "x y") + checkQuoting(true, "x,y") + checkQuoting(false, errors.New("invalid")) + checkQuoting(true, errors.New("invalid argument")) +} + +func TestTimestampFormat(t *testing.T) { + checkTimeStr := func(format string) { + customFormatter := &TextFormatter{DisableColors: true, TimestampFormat: format} + customStr, _ := customFormatter.Format(WithField("test", "test")) + timeStart := bytes.Index(customStr, ([]byte)("time=")) + timeEnd := bytes.Index(customStr, ([]byte)("level=")) + timeStr := customStr[timeStart+5 : timeEnd-1] + if timeStr[0] == '"' && timeStr[len(timeStr)-1] == '"' { + timeStr = timeStr[1 : len(timeStr)-1] + } + if format == "" { + format = time.RFC3339 + } + _, e := time.Parse(format, (string)(timeStr)) + if e != nil { + t.Errorf("time string \"%s\" did not match provided time format \"%s\": %s", timeStr, format, e) + } + } + + checkTimeStr("2006-01-02T15:04:05.000000000Z07:00") + checkTimeStr("Mon Jan _2 15:04:05 2006") + checkTimeStr("") +} + +// TODO add tests for sorting etc., this requires a parser for the text +// formatter output. diff --git a/Godeps/_workspace/src/github.com/Sirupsen/logrus/writer.go b/Godeps/_workspace/src/github.com/Sirupsen/logrus/writer.go new file mode 100644 index 0000000..1e30b1c --- /dev/null +++ b/Godeps/_workspace/src/github.com/Sirupsen/logrus/writer.go @@ -0,0 +1,31 @@ +package logrus + +import ( + "bufio" + "io" + "runtime" +) + +func (logger *Logger) Writer() *io.PipeWriter { + reader, writer := io.Pipe() + + go logger.writerScanner(reader) + runtime.SetFinalizer(writer, writerFinalizer) + + return writer +} + +func (logger *Logger) writerScanner(reader *io.PipeReader) { + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + logger.Print(scanner.Text()) + } + if err := scanner.Err(); err != nil { + logger.Errorf("Error while reading from Writer: %s", err) + } + reader.Close() +} + +func writerFinalizer(writer *io.PipeWriter) { + writer.Close() +} diff --git a/Godeps/_workspace/src/github.com/Unknwon/com/path.go b/Godeps/_workspace/src/github.com/Unknwon/com/path.go index a501c85..b1e860d 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/com/path.go +++ b/Godeps/_workspace/src/github.com/Unknwon/com/path.go @@ -64,9 +64,9 @@ func GetSrcPath(importPath string) (appPath string, err error) { // it returns error when the variable does not exist. func HomeDir() (home string, err error) { if runtime.GOOS == "windows" { - home = os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") - if home == "" { - home = os.Getenv("USERPROFILE") + home = os.Getenv("USERPROFILE") + if len(home) == 0 { + home = os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") } } else { home = os.Getenv("HOME") diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/README.md b/Godeps/_workspace/src/github.com/Unknwon/macaron/README.md index 8b20162..31a9f35 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/README.md +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/README.md @@ -5,7 +5,7 @@ Macaron [![Build Status](https://drone.io/github.com/Unknwon/macaron/status.png) Package macaron is a high productive and modular design web framework in Go. -##### Current version: 0.5.4 +##### Current version: 0.6.6 ## Getting Started @@ -86,7 +86,6 @@ There are already many [middlewares](https://github.com/macaron-contrib) to simp ## Credits - Basic design of [Martini](https://github.com/go-martini/martini). -- Router layer of [beego](https://github.com/astaxie/beego). - Logo is modified by [@insionng](https://github.com/insionng) based on [Tribal Dragon](http://xtremeyamazaki.deviantart.com/art/Tribal-Dragon-27005087). ## License diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/context.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/context.go index abf7ac9..b012aaf 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/context.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/context.go @@ -176,11 +176,27 @@ func (ctx *Context) Redirect(location string, status ...int) { http.Redirect(ctx.Resp, ctx.Req.Request, location, code) } -// Query querys form parameter. -func (ctx *Context) Query(name string) string { - if ctx.Req.Form == nil { +// Maximum amount of memory to use when parsing a multipart form. +// Set this to whatever value you prefer; default is 10 MB. +var MaxMemory = int64(1024 * 1024 * 10) + +func (ctx *Context) parseForm() { + if ctx.Req.Form != nil { + return + } + + contentType := ctx.Req.Header.Get("Content-Type") + if (ctx.Req.Method == "POST" || ctx.Req.Method == "PUT") && + len(contentType) > 0 && strings.Contains(contentType, "multipart/form-data") { + ctx.Req.ParseMultipartForm(MaxMemory) + } else { ctx.Req.ParseForm() } +} + +// Query querys form parameter. +func (ctx *Context) Query(name string) string { + ctx.parseForm() return ctx.Req.Form.Get(name) } @@ -191,9 +207,7 @@ func (ctx *Context) QueryTrim(name string) string { // QueryStrings returns a list of results by given query name. func (ctx *Context) QueryStrings(name string) []string { - if ctx.Req.Form == nil { - ctx.Req.ParseForm() - } + ctx.parseForm() vals, ok := ctx.Req.Form[name] if !ok { @@ -229,7 +243,7 @@ func (ctx *Context) Params(name string) string { if len(name) == 0 { return "" } - if name[0] != '*' && name[0] != ':' { + if len(name) > 1 && name[0] != ':' { name = ":" + name } return ctx.params[name] diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/context_test.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/context_test.go index c4b4752..9e5a161 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/context_test.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/context_test.go @@ -166,13 +166,15 @@ func Test_Context(t *testing.T) { }) Convey("Get file", func() { - m.Get("/getfile", func(ctx *Context) { + m.Post("/getfile", func(ctx *Context) { + ctx.Query("") ctx.GetFile("hi") }) resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/getfile", nil) + req, err := http.NewRequest("POST", "/getfile", nil) So(err, ShouldBeNil) + req.Header.Set("Content-Type", "multipart/form-data") m.ServeHTTP(resp, req) }) diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip.go index 2e935f3..44bafd0 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip.go @@ -32,10 +32,35 @@ const ( HeaderVary = "Vary" ) +// GzipOptions represents a struct for specifying configuration options for the GZip middleware. +type GzipOptions struct { + // Compression level. Can be DefaultCompression(-1) or any integer value between BestSpeed(1) and BestCompression(9) inclusive. + CompressionLevel int +} + +func isCompressionLevelValid(level int) bool { + return level == gzip.DefaultCompression || + (level >= gzip.BestSpeed && level <= gzip.BestCompression) +} + +func prepareGzipOptions(options []GzipOptions) GzipOptions { + var opt GzipOptions + if len(options) > 0 { + opt = options[0] + } + + if !isCompressionLevelValid(opt.CompressionLevel) { + opt.CompressionLevel = gzip.DefaultCompression + } + return opt +} + // Gziper returns a Handler that adds gzip compression to all requests. // Make sure to include the Gzip middleware above other middleware // that alter the response body (like the render middleware). -func Gziper() Handler { +func Gziper(options ...GzipOptions) Handler { + opt := prepareGzipOptions(options) + return func(ctx *Context) { if !strings.Contains(ctx.Req.Header.Get(HeaderAcceptEncoding), "gzip") { return @@ -45,12 +70,20 @@ func Gziper() Handler { headers.Set(HeaderContentEncoding, "gzip") headers.Set(HeaderVary, HeaderAcceptEncoding) - gz := gzip.NewWriter(ctx.Resp) + // We've made sure compression level is valid in prepareGzipOptions, + // no need to check same error again. + gz, err := gzip.NewWriterLevel(ctx.Resp, opt.CompressionLevel) + if err != nil { + panic(err.Error()) + } defer gz.Close() gzw := gzipResponseWriter{gz, ctx.Resp} ctx.Resp = gzw ctx.MapTo(gzw, (*http.ResponseWriter)(nil)) + if ctx.Render != nil { + ctx.Render.SetResponseWriter(gzw) + } ctx.Next() @@ -68,7 +101,6 @@ func (grw gzipResponseWriter) Write(p []byte) (int, error) { if len(grw.Header().Get(HeaderContentType)) == 0 { grw.Header().Set(HeaderContentType, http.DetectContentType(p)) } - return grw.w.Write(p) } diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip_test.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip_test.go index 565eed6..1edf0e5 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip_test.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/gzip_test.go @@ -29,7 +29,7 @@ func Test_Gzip(t *testing.T) { before := false m := New() - m.Use(Gziper()) + m.Use(Gziper(GzipOptions{-10})) m.Use(func(r http.ResponseWriter) { r.(ResponseWriter).Before(func(rw ResponseWriter) { before = true diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/macaron.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/macaron.go index adbe9e3..5f12ad7 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/macaron.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/macaron.go @@ -29,7 +29,7 @@ import ( "github.com/Unknwon/macaron/inject" ) -const _VERSION = "0.5.4.0318" +const _VERSION = "0.6.6.0728" func Version() string { return _VERSION @@ -83,11 +83,10 @@ func NewWithLogger(out io.Writer) *Macaron { m.Router.m = m m.Map(m.logger) m.Map(defaultReturnHandler()) - m.notFound = func(resp http.ResponseWriter, req *http.Request) { - c := m.createContext(resp, req) - c.handlers = append(c.handlers, http.NotFound) - c.run() - } + m.NotFound(http.NotFound) + m.InternalServerError(func(rw http.ResponseWriter, err error) { + http.Error(rw, err.Error(), 500) + }) return m } diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/render.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/render.go index b0558c9..53735dc 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/render.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/render.go @@ -149,6 +149,7 @@ type ( Render interface { http.ResponseWriter + SetResponseWriter(http.ResponseWriter) RW() http.ResponseWriter JSON(int, interface{}) @@ -321,7 +322,7 @@ func (ts *templateSet) GetDir(name string) string { return ts.dirs[name] } -func prepareOptions(options []RenderOptions) RenderOptions { +func prepareRenderOptions(options []RenderOptions) RenderOptions { var opt RenderOptions if len(options) > 0 { opt = options[0] @@ -401,11 +402,11 @@ func renderHandler(opt RenderOptions, tplSets []string) Handler { // If MACARON_ENV is set to "" or "development" then templates will be recompiled on every request. For more performance, set the // MACARON_ENV environment variable to "production". func Renderer(options ...RenderOptions) Handler { - return renderHandler(prepareOptions(options), []string{}) + return renderHandler(prepareRenderOptions(options), []string{}) } func Renderers(options RenderOptions, tplSets ...string) Handler { - return renderHandler(prepareOptions([]RenderOptions{options}), tplSets) + return renderHandler(prepareRenderOptions([]RenderOptions{options}), tplSets) } type TplRender struct { @@ -417,6 +418,10 @@ type TplRender struct { startTime time.Time } +func (r *TplRender) SetResponseWriter(rw http.ResponseWriter) { + r.ResponseWriter = rw +} + func (r *TplRender) RW() http.ResponseWriter { return r.ResponseWriter } diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/render_test.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/render_test.go index 1631873..7eb9747 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/render_test.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/render_test.go @@ -197,6 +197,7 @@ func Test_Render_HTML(t *testing.T) { Directory: "fixtures/basic", }, "fixtures/basic2")) m.Get("/foobar", func(r Render) { + r.SetResponseWriter(r.RW()) r.HTML(200, "hello", "jeremy") r.SetTemplatePath("", "fixtures/basic2") }) diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler.go index ea1e044..91e9035 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler.go @@ -32,6 +32,11 @@ func canDeref(val reflect.Value) bool { return val.Kind() == reflect.Interface || val.Kind() == reflect.Ptr } +func isError(val reflect.Value) bool { + _, ok := val.Interface().(error) + return ok +} + func isByteSlice(val reflect.Value) bool { return val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 } @@ -39,21 +44,33 @@ func isByteSlice(val reflect.Value) bool { func defaultReturnHandler() ReturnHandler { return func(ctx *Context, vals []reflect.Value) { rv := ctx.GetVal(inject.InterfaceOf((*http.ResponseWriter)(nil))) - res := rv.Interface().(http.ResponseWriter) + resp := rv.Interface().(http.ResponseWriter) var respVal reflect.Value if len(vals) > 1 && vals[0].Kind() == reflect.Int { - res.WriteHeader(int(vals[0].Int())) + resp.WriteHeader(int(vals[0].Int())) respVal = vals[1] } else if len(vals) > 0 { respVal = vals[0] + + if isError(respVal) { + err := respVal.Interface().(error) + if err != nil { + ctx.internalServerError(ctx, err) + } + return + } else if canDeref(respVal) { + if respVal.IsNil() { + return // Ignore nil error + } + } } if canDeref(respVal) { respVal = respVal.Elem() } if isByteSlice(respVal) { - res.Write(respVal.Bytes()) + resp.Write(respVal.Bytes()) } else { - res.Write([]byte(respVal.String())) + resp.Write([]byte(respVal.String())) } } } diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler_test.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler_test.go index 02325b2..1ee6778 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler_test.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/return_handler_test.go @@ -15,6 +15,7 @@ package macaron import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -24,7 +25,7 @@ import ( func Test_Return_Handler(t *testing.T) { Convey("Return with status and body", t, func() { - m := Classic() + m := New() m.Get("/", func() (int, string) { return 418, "i'm a teapot" }) @@ -38,8 +39,40 @@ func Test_Return_Handler(t *testing.T) { So(resp.Body.String(), ShouldEqual, "i'm a teapot") }) + Convey("Return with error", t, func() { + m := New() + m.Get("/", func() error { + return errors.New("what the hell!!!") + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + + So(resp.Code, ShouldEqual, http.StatusInternalServerError) + So(resp.Body.String(), ShouldEqual, "what the hell!!!\n") + + Convey("Return with nil error", func() { + m := New() + m.Get("/", func() error { + return nil + }, func() (int, string) { + return 200, "Awesome" + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + + So(resp.Code, ShouldEqual, http.StatusOK) + So(resp.Body.String(), ShouldEqual, "Awesome") + }) + }) + Convey("Return with pointer", t, func() { - m := Classic() + m := New() m.Get("/", func() *string { str := "hello world" return &str @@ -54,7 +87,7 @@ func Test_Return_Handler(t *testing.T) { }) Convey("Return with byte slice", t, func() { - m := Classic() + m := New() m.Get("/", func() []byte { return []byte("hello world") }) diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/router.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/router.go index d2b3945..2d05669 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/router.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/router.go @@ -18,8 +18,6 @@ import ( "net/http" "strings" "sync" - - "github.com/Unknwon/com" ) var ( @@ -38,22 +36,22 @@ var ( // routeMap represents a thread-safe map for route tree. type routeMap struct { lock sync.RWMutex - routes map[string]map[string]bool + routes map[string]map[string]*Leaf } // NewRouteMap initializes and returns a new routeMap. func NewRouteMap() *routeMap { rm := &routeMap{ - routes: make(map[string]map[string]bool), + routes: make(map[string]map[string]*Leaf), } for m := range _HTTP_METHODS { - rm.routes[m] = make(map[string]bool) + rm.routes[m] = make(map[string]*Leaf) } return rm } -// isExist returns true if a route has been registered. -func (rm *routeMap) isExist(method, pattern string) bool { +// getLeaf returns Leaf object if a route has been registered. +func (rm *routeMap) getLeaf(method, pattern string) *Leaf { rm.lock.RLock() defer rm.lock.RUnlock() @@ -61,11 +59,11 @@ func (rm *routeMap) isExist(method, pattern string) bool { } // add adds new route to route tree map. -func (rm *routeMap) add(method, pattern string) { +func (rm *routeMap) add(method, pattern string, leaf *Leaf) { rm.lock.Lock() defer rm.lock.Unlock() - rm.routes[method][pattern] = true + rm.routes[method][pattern] = leaf } type group struct { @@ -75,34 +73,61 @@ type group struct { // Router represents a Macaron router layer. type Router struct { - m *Macaron - routers map[string]*Tree + m *Macaron + autoHead bool + routers map[string]*Tree *routeMap + namedRoutes map[string]*Leaf - groups []group - notFound http.HandlerFunc + groups []group + notFound http.HandlerFunc + internalServerError func(*Context, error) } func NewRouter() *Router { return &Router{ - routers: make(map[string]*Tree), - routeMap: NewRouteMap(), + routers: make(map[string]*Tree), + routeMap: NewRouteMap(), + namedRoutes: make(map[string]*Leaf), } } +// SetAutoHead sets the value who determines whether add HEAD method automatically +// when GET method is added. Combo router will not be affected by this value. +func (r *Router) SetAutoHead(v bool) { + r.autoHead = v +} + type Params map[string]string // Handle is a function that can be registered to a route to handle HTTP requests. // Like http.HandlerFunc, but has a third parameter for the values of wildcards (variables). type Handle func(http.ResponseWriter, *http.Request, Params) +// Route represents a wrapper of leaf route and upper level router. +type Route struct { + router *Router + leaf *Leaf +} + +// Name sets name of route. +func (r *Route) Name(name string) { + if len(name) == 0 { + panic("route name cannot be empty") + } else if r.router.namedRoutes[name] != nil { + panic("route with given name already exists") + } + r.router.namedRoutes[name] = r.leaf +} + // handle adds new route to the router tree. -func (r *Router) handle(method, pattern string, handle Handle) { +func (r *Router) handle(method, pattern string, handle Handle) *Route { method = strings.ToUpper(method) + var leaf *Leaf // Prevent duplicate routes. - if r.isExist(method, pattern) { - return + if leaf = r.getLeaf(method, pattern); leaf != nil { + return &Route{r, leaf} } // Validate HTTP methods. @@ -123,18 +148,19 @@ func (r *Router) handle(method, pattern string, handle Handle) { // Add to router tree. for m := range methods { if t, ok := r.routers[m]; ok { - t.AddRouter(pattern, handle) + leaf = t.Add(pattern, handle) } else { t := NewTree() - t.AddRouter(pattern, handle) + leaf = t.Add(pattern, handle) r.routers[m] = t } - r.add(m, pattern) + r.add(m, pattern, leaf) } + return &Route{r, leaf} } // Handle registers a new request handle with the given pattern, method and handlers. -func (r *Router) Handle(method string, pattern string, handlers []Handler) { +func (r *Router) Handle(method string, pattern string, handlers []Handler) *Route { if len(r.groups) > 0 { groupPattern := "" h := make([]Handler, 0) @@ -149,7 +175,7 @@ func (r *Router) Handle(method string, pattern string, handlers []Handler) { } validateHandlers(handlers) - r.handle(method, pattern, func(resp http.ResponseWriter, req *http.Request, params Params) { + return r.handle(method, pattern, func(resp http.ResponseWriter, req *http.Request, params Params) { c := r.m.createContext(resp, req) c.params = params c.handlers = make([]Handler, 0, len(r.m.handlers)+len(handlers)) @@ -166,64 +192,70 @@ func (r *Router) Group(pattern string, fn func(), h ...Handler) { } // Get is a shortcut for r.Handle("GET", pattern, handlers) -func (r *Router) Get(pattern string, h ...Handler) { - r.Handle("GET", pattern, h) +func (r *Router) Get(pattern string, h ...Handler) (leaf *Route) { + leaf = r.Handle("GET", pattern, h) + if r.autoHead { + r.Head(pattern, h...) + } + return leaf } // Patch is a shortcut for r.Handle("PATCH", pattern, handlers) -func (r *Router) Patch(pattern string, h ...Handler) { - r.Handle("PATCH", pattern, h) +func (r *Router) Patch(pattern string, h ...Handler) *Route { + return r.Handle("PATCH", pattern, h) } // Post is a shortcut for r.Handle("POST", pattern, handlers) -func (r *Router) Post(pattern string, h ...Handler) { - r.Handle("POST", pattern, h) +func (r *Router) Post(pattern string, h ...Handler) *Route { + return r.Handle("POST", pattern, h) } // Put is a shortcut for r.Handle("PUT", pattern, handlers) -func (r *Router) Put(pattern string, h ...Handler) { - r.Handle("PUT", pattern, h) +func (r *Router) Put(pattern string, h ...Handler) *Route { + return r.Handle("PUT", pattern, h) } // Delete is a shortcut for r.Handle("DELETE", pattern, handlers) -func (r *Router) Delete(pattern string, h ...Handler) { - r.Handle("DELETE", pattern, h) +func (r *Router) Delete(pattern string, h ...Handler) *Route { + return r.Handle("DELETE", pattern, h) } // Options is a shortcut for r.Handle("OPTIONS", pattern, handlers) -func (r *Router) Options(pattern string, h ...Handler) { - r.Handle("OPTIONS", pattern, h) +func (r *Router) Options(pattern string, h ...Handler) *Route { + return r.Handle("OPTIONS", pattern, h) } // Head is a shortcut for r.Handle("HEAD", pattern, handlers) -func (r *Router) Head(pattern string, h ...Handler) { - r.Handle("HEAD", pattern, h) +func (r *Router) Head(pattern string, h ...Handler) *Route { + return r.Handle("HEAD", pattern, h) } // Any is a shortcut for r.Handle("*", pattern, handlers) -func (r *Router) Any(pattern string, h ...Handler) { - r.Handle("*", pattern, h) +func (r *Router) Any(pattern string, h ...Handler) *Route { + return r.Handle("*", pattern, h) } // Route is a shortcut for same handlers but different HTTP methods. // // Example: // m.Route("/", "GET,POST", h) -func (r *Router) Route(pattern, methods string, h ...Handler) { +func (r *Router) Route(pattern, methods string, h ...Handler) (route *Route) { for _, m := range strings.Split(methods, ",") { - r.Handle(strings.TrimSpace(m), pattern, h) + route = r.Handle(strings.TrimSpace(m), pattern, h) } + return route } // Combo returns a combo router. func (r *Router) Combo(pattern string, h ...Handler) *ComboRouter { - return &ComboRouter{r, pattern, h, map[string]bool{}} + return &ComboRouter{r, pattern, h, map[string]bool{}, nil} } // Configurable http.HandlerFunc which is called when no matching route is // found. If it is not set, http.NotFound is used. // Be sure to set 404 response code in your handler. func (r *Router) NotFound(handlers ...Handler) { + validateHandlers(handlers) r.notFound = func(rw http.ResponseWriter, req *http.Request) { c := r.m.createContext(rw, req) c.handlers = append(r.m.handlers, handlers...) @@ -231,16 +263,25 @@ func (r *Router) NotFound(handlers ...Handler) { } } +// Configurable handler which is called when route handler returns +// error. If it is not set, default handler is used. +// Be sure to set 500 response code in your handler. +func (r *Router) InternalServerError(handlers ...Handler) { + validateHandlers(handlers) + r.internalServerError = func(c *Context, err error) { + c.index = 0 + c.handlers = handlers + c.Map(err) + c.run() + } +} + func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if t, ok := r.routers[req.Method]; ok { - h, p := t.Match(req.URL.Path) - if h != nil { - if splat, ok := p[":splat"]; ok { - p["*"] = p[":splat"] // Better name. - splatlist := strings.Split(splat, "/") - for k, v := range splatlist { - p[com.ToStr(k)] = v - } + h, p, ok := t.Match(req.URL.Path) + if ok { + if splat, ok := p["*0"]; ok { + p["*"] = splat // Easy name. } h(rw, req, p) return @@ -250,12 +291,23 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { r.notFound(rw, req) } +// URLFor builds path part of URL by given pair values. +func (r *Router) URLFor(name string, pairs ...string) string { + leaf, ok := r.namedRoutes[name] + if !ok { + panic("route with given name does not exists: " + name) + } + return leaf.URLPath(pairs...) +} + // ComboRouter represents a combo router. type ComboRouter struct { router *Router pattern string handlers []Handler methods map[string]bool // Registered methods. + + lastRoute *Route } func (cr *ComboRouter) checkMethod(name string) { @@ -265,9 +317,9 @@ func (cr *ComboRouter) checkMethod(name string) { cr.methods[name] = true } -func (cr *ComboRouter) route(fn func(string, ...Handler), method string, h ...Handler) *ComboRouter { +func (cr *ComboRouter) route(fn func(string, ...Handler) *Route, method string, h ...Handler) *ComboRouter { cr.checkMethod(method) - fn(cr.pattern, append(cr.handlers, h...)...) + cr.lastRoute = fn(cr.pattern, append(cr.handlers, h...)...) return cr } @@ -298,3 +350,11 @@ func (cr *ComboRouter) Options(h ...Handler) *ComboRouter { func (cr *ComboRouter) Head(h ...Handler) *ComboRouter { return cr.route(cr.router.Head, "HEAD", h...) } + +// Name sets name of ComboRouter route. +func (cr *ComboRouter) Name(name string) { + if cr.lastRoute == nil { + panic("no corresponding route to be named") + } + cr.lastRoute.Name(name) +} diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/router_test.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/router_test.go index f404496..ef9bdab 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/router_test.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/router_test.go @@ -15,6 +15,7 @@ package macaron import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -24,7 +25,7 @@ import ( func Test_Router_Handle(t *testing.T) { Convey("Register all HTTP methods routes", t, func() { - m := Classic() + m := New() m.Get("/get", func() string { return "GET" }) @@ -107,8 +108,35 @@ func Test_Router_Handle(t *testing.T) { So(resp.Body.String(), ShouldEqual, "ROUTE") }) + Convey("Register with or without auto head", t, func() { + Convey("Without auto head", func() { + m := New() + m.Get("/", func() string { + return "GET" + }) + resp := httptest.NewRecorder() + req, err := http.NewRequest("HEAD", "/", nil) + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + So(resp.Code, ShouldEqual, 404) + }) + + Convey("With auto head", func() { + m := New() + m.SetAutoHead(true) + m.Get("/", func() string { + return "GET" + }) + resp := httptest.NewRecorder() + req, err := http.NewRequest("HEAD", "/", nil) + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + So(resp.Code, ShouldEqual, 200) + }) + }) + Convey("Register all HTTP methods routes with combo", t, func() { - m := Classic() + m := New() m.SetURLPrefix("/prefix") m.Use(Renderer()) m.Combo("/", func(ctx *Context) { @@ -151,9 +179,72 @@ func Test_Router_Handle(t *testing.T) { }) } +func Test_Route_Name(t *testing.T) { + Convey("Set route name", t, func() { + m := New() + m.Get("/", func() {}).Name("home") + + defer func() { + So(recover(), ShouldNotBeNil) + }() + m.Get("/", func() {}).Name("home") + }) + + Convey("Set combo router name", t, func() { + m := New() + m.Combo("/").Get(func() {}).Name("home") + + defer func() { + So(recover(), ShouldNotBeNil) + }() + m.Combo("/").Name("home") + }) +} + +func Test_Router_URLFor(t *testing.T) { + Convey("Build URL path", t, func() { + m := New() + m.Get("/user/:id", func() {}).Name("user_id") + m.Get("/user/:id/:name", func() {}).Name("user_id_name") + m.Get("cms_:id_:page.html", func() {}).Name("id_page") + + So(m.URLFor("user_id", "id", "12"), ShouldEqual, "/user/12") + So(m.URLFor("user_id_name", "id", "12", "name", "unknwon"), ShouldEqual, "/user/12/unknwon") + So(m.URLFor("id_page", "id", "12", "page", "profile"), ShouldEqual, "/cms_12_profile.html") + + Convey("Number of pair values does not match", func() { + defer func() { + So(recover(), ShouldNotBeNil) + }() + m.URLFor("user_id", "id") + }) + + Convey("Empty pair value", func() { + defer func() { + So(recover(), ShouldNotBeNil) + }() + m.URLFor("user_id", "", "") + }) + + Convey("Empty route name", func() { + defer func() { + So(recover(), ShouldNotBeNil) + }() + m.Get("/user/:id", func() {}).Name("") + }) + + Convey("Invalid route name", func() { + defer func() { + So(recover(), ShouldNotBeNil) + }() + m.URLFor("404") + }) + }) +} + func Test_Router_Group(t *testing.T) { Convey("Register route group", t, func() { - m := Classic() + m := New() m.Group("/api", func() { m.Group("/v1", func() { m.Get("/list", func() string { @@ -171,7 +262,7 @@ func Test_Router_Group(t *testing.T) { func Test_Router_NotFound(t *testing.T) { Convey("Custom not found handler", t, func() { - m := Classic() + m := New() m.Get("/", func() {}) m.NotFound(func() string { return "Custom not found" @@ -184,9 +275,28 @@ func Test_Router_NotFound(t *testing.T) { }) } +func Test_Router_InternalServerError(t *testing.T) { + Convey("Custom internal server error handler", t, func() { + m := New() + m.Get("/", func() error { + return errors.New("Custom internal server error") + }) + m.InternalServerError(func(rw http.ResponseWriter, err error) { + rw.WriteHeader(500) + rw.Write([]byte(err.Error())) + }) + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + So(err, ShouldBeNil) + m.ServeHTTP(resp, req) + So(resp.Code, ShouldEqual, 500) + So(resp.Body.String(), ShouldEqual, "Custom internal server error") + }) +} + func Test_Router_splat(t *testing.T) { Convey("Register router with glob", t, func() { - m := Classic() + m := New() m.Get("/*", func(ctx *Context) string { return ctx.Params("*") }) diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/tree.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/tree.go index 7bde5ad..82b21d2 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/tree.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/tree.go @@ -1,5 +1,4 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon +// Copyright 2015 Unknwon // // Licensed under the Apache License, Version 2.0 (the "License"): you may // not use this file except in compliance with the License. You may obtain @@ -15,407 +14,366 @@ package macaron -// NOTE: last sync 0c93364 on Dec 19, 2014. - import ( - "path" "regexp" "strings" "github.com/Unknwon/com" ) -type leafInfo struct { - // Names of wildcards that lead to this leaf. - // eg, ["id" "name"] for the wildcard ":id" and ":name". - wildcards []string - // Not nil if the leaf is regexp. - regexps *regexp.Regexp - handle Handle +type patternType int8 + +const ( + _PATTERN_STATIC patternType = iota // /home + _PATTERN_REGEXP // /:id([0-9]+) + _PATTERN_PATH_EXT // /*.* + _PATTERN_HOLDER // /:user + _PATTERN_MATCH_ALL // /* +) + +// Leaf represents a leaf route information. +type Leaf struct { + parent *Tree + + typ patternType + pattern string + rawPattern string // Contains wildcard instead of regexp + wildcards []string + reg *regexp.Regexp + optional bool + + handle Handle } -func (leaf *leafInfo) match(wildcardValues []string) (ok bool, params Params) { - if leaf.regexps == nil { - if len(wildcardValues) == 0 && len(leaf.wildcards) > 0 { - if com.IsSliceContainsStr(leaf.wildcards, ":") { - params = make(map[string]string) - j := 0 - for _, v := range leaf.wildcards { - if v == ":" { - continue - } - params[v] = "" - j += 1 - } - return true, params - } - return false, nil - } else if len(wildcardValues) == 0 { - return true, nil // Static path. - } +var wildcardPattern = regexp.MustCompile(`:[a-zA-Z0-9]+`) - // Match * - if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { - params = make(map[string]string) - params[":splat"] = path.Join(wildcardValues...) - return true, params - } +func isSpecialRegexp(pattern, regStr string, pos []int) bool { + return len(pattern) >= pos[1]+len(regStr) && pattern[pos[1]:pos[1]+len(regStr)] == regStr +} - // Match *.* - if len(leaf.wildcards) == 3 && leaf.wildcards[0] == "." { - params = make(map[string]string) - lastone := wildcardValues[len(wildcardValues)-1] - strs := strings.SplitN(lastone, ".", 2) - if len(strs) == 2 { - params[":ext"] = strs[1] - } else { - params[":ext"] = "" - } - params[":path"] = path.Join(wildcardValues[:len(wildcardValues)-1]...) + "/" + strs[0] - return true, params - } +// getNextWildcard tries to find next wildcard and update pattern with corresponding regexp. +func getNextWildcard(pattern string) (wildcard, _ string) { + pos := wildcardPattern.FindStringIndex(pattern) + if pos == nil { + return "", pattern + } + wildcard = pattern[pos[0]:pos[1]] - // Match :id - params = make(map[string]string) - j := 0 - for _, v := range leaf.wildcards { - if v == ":" { - continue - } - if v == "." { - lastone := wildcardValues[len(wildcardValues)-1] - strs := strings.SplitN(lastone, ".", 2) - if len(strs) == 2 { - params[":ext"] = strs[1] - } else { - params[":ext"] = "" - } - if len(wildcardValues[j:]) == 1 { - params[":path"] = strs[0] - } else { - params[":path"] = path.Join(wildcardValues[j:]...) + "/" + strs[0] - } - return true, params - } - if len(wildcardValues) <= j { - return false, nil - } - params[v] = wildcardValues[j] - j++ - } - if len(params) != len(wildcardValues) { - return false, nil + // Reach last character or no regexp is given. + if len(pattern) == pos[1] { + return wildcard, strings.Replace(pattern, wildcard, `(.+)`, 1) + } else if pattern[pos[1]] != '(' { + switch { + case isSpecialRegexp(pattern, ":int", pos): + pattern = strings.Replace(pattern, ":int", "([0-9]+)", 1) + case isSpecialRegexp(pattern, ":string", pos): + pattern = strings.Replace(pattern, ":string", "([\\w]+)", 1) + default: + return wildcard, strings.Replace(pattern, wildcard, `(.+)`, 1) } - return true, params } - if !leaf.regexps.MatchString(path.Join(wildcardValues...)) { - return false, nil + // Cut out placeholder directly. + return wildcard, pattern[:pos[0]] + pattern[pos[1]:] +} + +func getWildcards(pattern string) (string, []string) { + wildcards := make([]string, 0, 2) + + // Keep getting next wildcard until nothing is left. + var wildcard string + for { + wildcard, pattern = getNextWildcard(pattern) + if len(wildcard) > 0 { + wildcards = append(wildcards, wildcard) + } else { + break + } } - params = make(map[string]string) - matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...)) - for i, match := range matches[1:] { - params[leaf.wildcards[i]] = match + + return pattern, wildcards +} + +// getRawPattern removes all regexp but keeps wildcards for building URL path. +func getRawPattern(rawPattern string) string { + rawPattern = strings.Replace(rawPattern, ":int", "", -1) + rawPattern = strings.Replace(rawPattern, ":string", "", -1) + + for { + startIdx := strings.Index(rawPattern, "(") + if startIdx == -1 { + break + } + + closeIdx := strings.Index(rawPattern, ")") + if closeIdx > -1 { + rawPattern = rawPattern[:startIdx] + rawPattern[closeIdx+1:] + } } - return true, params + return rawPattern } -// Tree represents a router tree for Macaron instance. -type Tree struct { - fixroutes map[string]*Tree - wildcard *Tree - leaves []*leafInfo +func checkPattern(pattern string) (typ patternType, rawPattern string, wildcards []string, reg *regexp.Regexp) { + pattern = strings.TrimLeft(pattern, "?") + rawPattern = getRawPattern(pattern) + + if pattern == "*" { + typ = _PATTERN_MATCH_ALL + } else if pattern == "*.*" { + typ = _PATTERN_PATH_EXT + } else if strings.Contains(pattern, ":") { + typ = _PATTERN_REGEXP + pattern, wildcards = getWildcards(pattern) + if pattern == "(.+)" { + typ = _PATTERN_HOLDER + } else { + reg = regexp.MustCompile(pattern) + } + } + return typ, rawPattern, wildcards, reg } -// NewTree initializes and returns a router tree. -func NewTree() *Tree { - return &Tree{ - fixroutes: make(map[string]*Tree), +func NewLeaf(parent *Tree, pattern string, handle Handle) *Leaf { + typ, rawPattern, wildcards, reg := checkPattern(pattern) + optional := false + if len(pattern) > 0 && pattern[0] == '?' { + optional = true } + return &Leaf{parent, typ, pattern, rawPattern, wildcards, reg, optional, handle} } -// splitPath splites patthen into parts. -// -// Examples: -// "/" -> [] -// "/admin" -> ["admin"] -// "/admin/" -> ["admin"] -// "/admin/users" -> ["admin", "users"] -func splitPath(pattern string) []string { - if len(pattern) == 0 { - return []string{} +// URLPath build path part of URL by given pair values. +func (l *Leaf) URLPath(pairs ...string) string { + if len(pairs)%2 != 0 { + panic("number of pairs does not match") } - elements := strings.Split(pattern, "/") - if elements[0] == "" { - elements = elements[1:] + urlPath := l.rawPattern + parent := l.parent + for parent != nil { + urlPath = parent.rawPattern + "/" + urlPath + parent = parent.parent } - if elements[len(elements)-1] == "" { - elements = elements[:len(elements)-1] + for i := 0; i < len(pairs); i += 2 { + if len(pairs[i]) == 0 { + panic("pair value cannot be empty: " + com.ToStr(i)) + } else if pairs[i][0] != ':' && pairs[i] != "*" && pairs[i] != "*.*" { + pairs[i] = ":" + pairs[i] + } + urlPath = strings.Replace(urlPath, pairs[i], pairs[i+1], 1) } - return elements + return urlPath } -// AddRouter adds a new route to router tree. -func (t *Tree) AddRouter(pattern string, handle Handle) { - t.addSegments(splitPath(pattern), handle, nil, "") +// Tree represents a router tree in Macaron. +type Tree struct { + parent *Tree + + typ patternType + pattern string + rawPattern string + wildcards []string + reg *regexp.Regexp + + subtrees []*Tree + leaves []*Leaf } -// splitSegment splits segment into parts. -// -// Examples: -// "admin" -> false, nil, "" -// ":id" -> true, [:id], "" -// "?:id" -> true, [: :id], "" : meaning can empty -// ":id:int" -> true, [:id], ([0-9]+) -// ":name:string" -> true, [:name], ([\w]+) -// ":id([0-9]+)" -> true, [:id], ([0-9]+) -// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+) -// "cms_:id_:page.html" -> true, [:id :page], cms_(.+)_(.+).html -// "*" -> true, [:splat], "" -// "*.*" -> true,[. :path :ext], "" . meaning separator -func splitSegment(key string) (bool, []string, string) { - if strings.HasPrefix(key, "*") { - if key == "*.*" { - return true, []string{".", ":path", ":ext"}, "" - } else { - return true, []string{":splat"}, "" +func NewSubtree(parent *Tree, pattern string) *Tree { + typ, rawPattern, wildcards, reg := checkPattern(pattern) + return &Tree{parent, typ, pattern, rawPattern, wildcards, reg, make([]*Tree, 0, 5), make([]*Leaf, 0, 5)} +} + +func NewTree() *Tree { + return NewSubtree(nil, "") +} + +func (t *Tree) addLeaf(pattern string, handle Handle) *Leaf { + for i := 0; i < len(t.leaves); i++ { + if t.leaves[i].pattern == pattern { + return t.leaves[i] } } - if strings.ContainsAny(key, ":") { - var paramsNum int - var out []rune - var start bool - var startexp bool - var param []rune - var expt []rune - var skipnum int - params := []string{} - reg := regexp.MustCompile(`[a-zA-Z0-9]+`) - for i, v := range key { - if skipnum > 0 { - skipnum -= 1 - continue - } - if start { - //:id:int and :name:string - if v == ':' { - if len(key) >= i+4 { - if key[i+1:i+4] == "int" { - out = append(out, []rune("([0-9]+)")...) - params = append(params, ":"+string(param)) - start = false - startexp = false - skipnum = 3 - param = make([]rune, 0) - paramsNum += 1 - continue - } - } - if len(key) >= i+7 { - if key[i+1:i+7] == "string" { - out = append(out, []rune(`([\w]+)`)...) - params = append(params, ":"+string(param)) - paramsNum += 1 - start = false - startexp = false - skipnum = 6 - param = make([]rune, 0) - continue - } - } - } - // params only support a-zA-Z0-9 - if reg.MatchString(string(v)) { - param = append(param, v) - continue - } - if v != '(' { - out = append(out, []rune(`(.+)`)...) - params = append(params, ":"+string(param)) - param = make([]rune, 0) - paramsNum += 1 - start = false - startexp = false - } - } - if startexp { - if v != ')' { - expt = append(expt, v) - continue - } - } - if v == ':' { - param = make([]rune, 0) - start = true - } else if v == '(' { - startexp = true - start = false - params = append(params, ":"+string(param)) - paramsNum += 1 - expt = make([]rune, 0) - expt = append(expt, '(') - } else if v == ')' { - startexp = false - expt = append(expt, ')') - out = append(out, expt...) - param = make([]rune, 0) - } else if v == '?' { - params = append(params, ":") - } else { - out = append(out, v) - } + + leaf := NewLeaf(t, pattern, handle) + + // Add exact same leaf to grandparent/parent level without optional. + if leaf.optional { + parent := leaf.parent + if parent.parent != nil { + parent.parent.addLeaf(parent.pattern, handle) + } else { + parent.addLeaf("", handle) // Root tree can add as empty pattern. } - if len(param) > 0 { - if paramsNum > 0 { - out = append(out, []rune(`(.+)`)...) - } - params = append(params, ":"+string(param)) + } + + i := 0 + for ; i < len(t.leaves); i++ { + if leaf.typ < t.leaves[i].typ { + break } - return true, params, string(out) + } + + if i == len(t.leaves) { + t.leaves = append(t.leaves, leaf) } else { - return false, nil, "" + t.leaves = append(t.leaves[:i], append([]*Leaf{leaf}, t.leaves[i:]...)...) } + return leaf } -// addSegments add segments to the router tree. -func (t *Tree) addSegments(segments []string, handle Handle, wildcards []string, reg string) { - // Fixed root route. - if len(segments) == 0 { - if reg != "" { - filterCards := make([]string, 0, len(wildcards)) - for _, v := range wildcards { - if v == ":" || v == "." { - continue - } - filterCards = append(filterCards, v) - } - t.leaves = append(t.leaves, &leafInfo{ - handle: handle, - wildcards: filterCards, - regexps: regexp.MustCompile("^" + reg + "$"), - }) - } else { - t.leaves = append(t.leaves, &leafInfo{ - handle: handle, - wildcards: wildcards, - }) +func (t *Tree) addSubtree(segment, pattern string, handle Handle) *Leaf { + for i := 0; i < len(t.subtrees); i++ { + if t.subtrees[i].pattern == segment { + return t.subtrees[i].addNextSegment(pattern, handle) } - return } - seg := segments[0] - iswild, params, regexpStr := splitSegment(seg) - //for the router /login/*/access match /login/2009/11/access - if !iswild && com.IsSliceContainsStr(wildcards, ":splat") { - iswild = true - regexpStr = seg - } - if seg == "*" && len(wildcards) > 0 && reg == "" { - iswild = true - regexpStr = "(.+)" - } - if iswild { - if t.wildcard == nil { - t.wildcard = NewTree() + subtree := NewSubtree(t, segment) + i := 0 + for ; i < len(t.subtrees); i++ { + if subtree.typ < t.subtrees[i].typ { + break } - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == "." || w == ":" { - continue - } - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "/([^.]+).(.+)" - } else { - for _, w := range params { - if w == "." || w == ":" { - continue - } - regexpStr = "/([^/]+)" + regexpStr - } - } - } - t.wildcard.addSegments(segments[1:], handle, append(wildcards, params...), reg+regexpStr) + } + + if i == len(t.subtrees) { + t.subtrees = append(t.subtrees, subtree) } else { - subTree, ok := t.fixroutes[seg] - if !ok { - subTree = NewTree() - t.fixroutes[seg] = subTree - } - subTree.addSegments(segments[1:], handle, wildcards, reg) + t.subtrees = append(t.subtrees[:i], append([]*Tree{subtree}, t.subtrees[i:]...)...) } + return subtree.addNextSegment(pattern, handle) } -func (t *Tree) match(segments []string, wildcardValues []string) (handle Handle, params Params) { - // Handle leaf nodes. - if len(segments) == 0 { - for _, l := range t.leaves { - if ok, pa := l.match(wildcardValues); ok { - return l.handle, pa +func (t *Tree) addNextSegment(pattern string, handle Handle) *Leaf { + pattern = strings.TrimPrefix(pattern, "/") + + i := strings.Index(pattern, "/") + if i == -1 { + return t.addLeaf(pattern, handle) + } + return t.addSubtree(pattern[:i], pattern[i+1:], handle) +} + +func (t *Tree) Add(pattern string, handle Handle) *Leaf { + pattern = strings.TrimSuffix(pattern, "/") + return t.addNextSegment(pattern, handle) +} + +func (t *Tree) matchLeaf(globLevel int, url string, params Params) (Handle, bool) { + for i := 0; i < len(t.leaves); i++ { + switch t.leaves[i].typ { + case _PATTERN_STATIC: + if t.leaves[i].pattern == url { + return t.leaves[i].handle, true } - } - if t.wildcard != nil { - for _, l := range t.wildcard.leaves { - if ok, pa := l.match(wildcardValues); ok { - return l.handle, pa - } + case _PATTERN_REGEXP: + results := t.leaves[i].reg.FindStringSubmatch(url) + // Number of results and wildcasrd should be exact same. + if len(results)-1 != len(t.leaves[i].wildcards) { + break } + for j := 0; j < len(t.leaves[i].wildcards); j++ { + params[t.leaves[i].wildcards[j]] = results[j+1] + } + return t.leaves[i].handle, true + case _PATTERN_PATH_EXT: + j := strings.LastIndex(url, ".") + if j > -1 { + params[":path"] = url[:j] + params[":ext"] = url[j+1:] + } else { + params[":path"] = url + } + return t.leaves[i].handle, true + case _PATTERN_HOLDER: + params[t.leaves[i].wildcards[0]] = url + return t.leaves[i].handle, true + case _PATTERN_MATCH_ALL: + params["*"] = url + params["*"+com.ToStr(globLevel)] = url + return t.leaves[i].handle, true } - return nil, nil } + return nil, false +} - seg, segs := segments[0], segments[1:] - - subTree, ok := t.fixroutes[seg] - if ok { - handle, params = subTree.match(segs, wildcardValues) - } else if len(segs) == 0 { //.json .xml - if subindex := strings.LastIndex(seg, "."); subindex != -1 { - subTree, ok = t.fixroutes[seg[:subindex]] - if ok { - handle, params = subTree.match(segs, wildcardValues) - if handle != nil { - if params == nil { - params = make(map[string]string) - } - params[":ext"] = seg[subindex+1:] - return handle, params +func (t *Tree) matchSubtree(globLevel int, segment, url string, params Params) (Handle, bool) { + for i := 0; i < len(t.subtrees); i++ { + switch t.subtrees[i].typ { + case _PATTERN_STATIC: + if t.subtrees[i].pattern == segment { + if handle, ok := t.subtrees[i].matchNextSegment(globLevel, url, params); ok { + return handle, true } } + case _PATTERN_REGEXP: + results := t.subtrees[i].reg.FindStringSubmatch(segment) + if len(results)-1 != len(t.subtrees[i].wildcards) { + break + } + + for j := 0; j < len(t.subtrees[i].wildcards); j++ { + params[t.subtrees[i].wildcards[j]] = results[j+1] + } + if handle, ok := t.subtrees[i].matchNextSegment(globLevel, url, params); ok { + return handle, true + } + case _PATTERN_HOLDER: + if handle, ok := t.subtrees[i].matchNextSegment(globLevel+1, url, params); ok { + params[t.subtrees[i].wildcards[0]] = segment + return handle, true + } + case _PATTERN_MATCH_ALL: + if handle, ok := t.subtrees[i].matchNextSegment(globLevel+1, url, params); ok { + params["*"+com.ToStr(globLevel)] = segment + return handle, true + } } } - if handle == nil && t.wildcard != nil { - handle, params = t.wildcard.match(segs, append(wildcardValues, seg)) - } - if handle == nil { - for _, l := range t.leaves { - if ok, pa := l.match(append(wildcardValues, segments...)); ok { - return l.handle, pa + + if len(t.leaves) > 0 { + leaf := t.leaves[len(t.leaves)-1] + if leaf.typ == _PATTERN_PATH_EXT { + url = segment + "/" + url + j := strings.LastIndex(url, ".") + if j > -1 { + params[":path"] = url[:j] + params[":ext"] = url[j+1:] + } else { + params[":path"] = url } + return leaf.handle, true + } else if leaf.typ == _PATTERN_MATCH_ALL { + params["*"] = segment + "/" + url + params["*"+com.ToStr(globLevel)] = segment + "/" + url + return leaf.handle, true } } - return handle, params + return nil, false } -// Match returns Handle and params if any route is matched. -func (t *Tree) Match(pattern string) (Handle, Params) { - if len(pattern) == 0 || pattern[0] != '/' { - return nil, nil +func (t *Tree) matchNextSegment(globLevel int, url string, params Params) (Handle, bool) { + i := strings.Index(url, "/") + if i == -1 { + return t.matchLeaf(globLevel, url, params) } + return t.matchSubtree(globLevel, url[:i], url[i+1:], params) +} + +func (t *Tree) Match(url string) (Handle, Params, bool) { + url = strings.TrimPrefix(url, "/") + url = strings.TrimSuffix(url, "/") + params := make(Params) + handle, ok := t.matchNextSegment(0, url, params) + return handle, params, ok +} - return t.match(splitPath(pattern), nil) +// MatchTest returns true if given URL is matched by given pattern. +func MatchTest(pattern, url string) bool { + t := NewTree() + t.Add(pattern, nil) + _, _, ok := t.Match(url) + return ok } diff --git a/Godeps/_workspace/src/github.com/Unknwon/macaron/tree_test.go b/Godeps/_workspace/src/github.com/Unknwon/macaron/tree_test.go index c814416..0d01db0 100644 --- a/Godeps/_workspace/src/github.com/Unknwon/macaron/tree_test.go +++ b/Godeps/_workspace/src/github.com/Unknwon/macaron/tree_test.go @@ -1,4 +1,4 @@ -// Copyright 2014 Unknwon +// Copyright 2015 Unknwon // // Licensed under the Apache License, Version 2.0 (the "License"): you may // not use this file except in compliance with the License. You may obtain @@ -15,98 +15,223 @@ package macaron import ( - // "net/http" "strings" "testing" . "github.com/smartystreets/goconvey/convey" ) -func Test_splitSegment(t *testing.T) { +func Test_getWildcards(t *testing.T) { type result struct { - Ok bool - Parts []string - Regex string + pattern string + wildcards string } cases := map[string]result{ - "admin": result{false, nil, ""}, - ":id": result{true, []string{":id"}, ""}, - "?:id": result{true, []string{":", ":id"}, ""}, - ":id:int": result{true, []string{":id"}, "([0-9]+)"}, - ":name:string": result{true, []string{":name"}, `([\w]+)`}, - ":id([0-9]+)": result{true, []string{":id"}, "([0-9]+)"}, - ":id([0-9]+)_:name": result{true, []string{":id", ":name"}, "([0-9]+)_(.+)"}, - "cms_:id_:page.html": result{true, []string{":id", ":page"}, "cms_(.+)_(.+).html"}, - "*": result{true, []string{":splat"}, ""}, - "*.*": result{true, []string{".", ":path", ":ext"}, ""}, + "admin": result{"admin", ""}, + ":id": result{"(.+)", ":id"}, + ":id:int": result{"([0-9]+)", ":id"}, + ":id([0-9]+)": result{"([0-9]+)", ":id"}, + ":id([0-9]+)_:name": result{"([0-9]+)_(.+)", ":id :name"}, + "cms_:id_:page.html": result{"cms_(.+)_(.+).html", ":id :page"}, + "cms_:id:int_:page:string.html": result{"cms_([0-9]+)_([\\w]+).html", ":id :page"}, + "*": result{"*", ""}, + "*.*": result{"*.*", ""}, } - Convey("Splits segment into parts", t, func() { + Convey("Get wildcards", t, func() { for key, result := range cases { - ok, parts, regex := splitSegment(key) - So(ok, ShouldEqual, result.Ok) - if result.Parts == nil { - So(parts, ShouldBeNil) - } else { - So(parts, ShouldNotBeNil) - So(strings.Join(parts, " "), ShouldEqual, strings.Join(result.Parts, " ")) - } - So(regex, ShouldEqual, result.Regex) + pattern, wildcards := getWildcards(key) + So(pattern, ShouldEqual, result.pattern) + So(strings.Join(wildcards, " "), ShouldEqual, result.wildcards) } }) } -func Test_Tree_Match(t *testing.T) { - type result struct { - pattern string - reqUrl string - params map[string]string +func Test_getRawPattern(t *testing.T) { + cases := map[string]string{ + "admin": "admin", + ":id": ":id", + ":id:int": ":id", + ":id([0-9]+)": ":id", + ":id([0-9]+)_:name": ":id_:name", + "cms_:id_:page.html": "cms_:id_:page.html", + "cms_:id:int_:page:string.html": "cms_:id_:page.html", + "cms_:id([0-9]+)_:page([\\w]+).html": "cms_:id_:page.html", + "*": "*", + "*.*": "*.*", } + Convey("Get raw pattern", t, func() { + for k, v := range cases { + So(getRawPattern(k), ShouldEqual, v) + } + }) +} - cases := []result{ - {"/:id", "/123", map[string]string{":id": "123"}}, - {"/hello/?:id", "/hello", map[string]string{":id": ""}}, - {"/", "/", nil}, - {"", "", nil}, - {"/customer/login", "/customer/login", nil}, - {"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}, - {"/*", "/customer/123", map[string]string{":splat": "customer/123"}}, - {"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}, - {"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}, - {"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}, - {"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}, - {"/thumbnail/:size/uploads/*", "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", - map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}, - {"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}, - {"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}, - {"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}, - {"/dl/:width:int/:height:int/*.*", "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", - map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}, - {"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}, - {"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}, - {"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}, - {"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}, - {"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}, - {"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}, - {"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}, - {"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}, - {"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}, - {"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}, - {"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}, - {"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}, - } +func Test_Tree_Match(t *testing.T) { + Convey("Match route in tree", t, func() { + Convey("Match static routes", func() { + t := NewTree() + So(t.Add("/", nil), ShouldNotBeNil) + So(t.Add("/user", nil), ShouldNotBeNil) + So(t.Add("/user/unknwon", nil), ShouldNotBeNil) + So(t.Add("/user/unknwon/profile", nil), ShouldNotBeNil) + + So(t.Add("/", nil), ShouldNotBeNil) + + _, _, ok := t.Match("/") + So(ok, ShouldBeTrue) + _, _, ok = t.Match("/user") + So(ok, ShouldBeTrue) + _, _, ok = t.Match("/user/unknwon") + So(ok, ShouldBeTrue) + _, _, ok = t.Match("/user/unknwon/profile") + So(ok, ShouldBeTrue) - Convey("Match routers in tree", t, func() { - for _, c := range cases { + _, _, ok = t.Match("/404") + So(ok, ShouldBeFalse) + }) + + Convey("Match optional routes", func() { t := NewTree() - t.AddRouter(c.pattern, nil) - _, params := t.Match(c.reqUrl) - if params != nil { - for k, v := range c.params { - vv, ok := params[k] - So(ok, ShouldBeTrue) - So(vv, ShouldEqual, v) - } - } - } + So(t.Add("/?:user", nil), ShouldNotBeNil) + So(t.Add("/user/?:name", nil), ShouldNotBeNil) + So(t.Add("/user/list/?:page:int", nil), ShouldNotBeNil) + + _, params, ok := t.Match("/") + So(ok, ShouldBeTrue) + So(params[":user"], ShouldBeEmpty) + _, params, ok = t.Match("/unknwon") + So(ok, ShouldBeTrue) + So(params[":user"], ShouldEqual, "unknwon") + + _, params, ok = t.Match("/user") + So(ok, ShouldBeTrue) + So(params[":name"], ShouldBeEmpty) + _, params, ok = t.Match("/user/unknwon") + So(ok, ShouldBeTrue) + So(params[":name"], ShouldEqual, "unknwon") + + _, params, ok = t.Match("/user/list/") + So(ok, ShouldBeTrue) + So(params[":page"], ShouldBeEmpty) + _, params, ok = t.Match("/user/list/123") + So(ok, ShouldBeTrue) + So(params[":page"], ShouldEqual, "123") + }) + + Convey("Match with regexp", func() { + t := NewTree() + So(t.Add("/v1/:year:int/6/23", nil), ShouldNotBeNil) + So(t.Add("/v2/2015/:month:int/23", nil), ShouldNotBeNil) + So(t.Add("/v3/2015/6/:day:int", nil), ShouldNotBeNil) + + _, params, ok := t.Match("/v1/2015/6/23") + So(ok, ShouldBeTrue) + So(MatchTest("/v1/:year:int/6/23", "/v1/2015/6/23"), ShouldBeTrue) + So(params[":year"], ShouldEqual, "2015") + _, _, ok = t.Match("/v1/year/6/23") + So(ok, ShouldBeFalse) + So(MatchTest("/v1/:year:int/6/23", "/v1/year/6/23"), ShouldBeFalse) + + _, params, ok = t.Match("/v2/2015/6/23") + So(ok, ShouldBeTrue) + So(params[":month"], ShouldEqual, "6") + _, _, ok = t.Match("/v2/2015/month/23") + So(ok, ShouldBeFalse) + + _, params, ok = t.Match("/v3/2015/6/23") + So(ok, ShouldBeTrue) + So(params[":day"], ShouldEqual, "23") + _, _, ok = t.Match("/v2/2015/6/day") + So(ok, ShouldBeFalse) + + So(t.Add("/v1/shop/cms_:id(.+)_:page(.+).html", nil), ShouldNotBeNil) + So(t.Add("/v1/:v/cms/aaa_:id(.+)_:page(.+).html", nil), ShouldNotBeNil) + So(t.Add("/v1/:v/cms_:id(.+)_:page(.+).html", nil), ShouldNotBeNil) + So(t.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page:string.html", nil), ShouldNotBeNil) + + _, params, ok = t.Match("/v1/shop/cms_123_1.html") + So(ok, ShouldBeTrue) + So(params[":id"], ShouldEqual, "123") + So(params[":page"], ShouldEqual, "1") + + _, params, ok = t.Match("/v1/2/cms/aaa_124_2.html") + So(ok, ShouldBeTrue) + So(params[":v"], ShouldEqual, "2") + So(params[":id"], ShouldEqual, "124") + So(params[":page"], ShouldEqual, "2") + + _, params, ok = t.Match("/v1/3/cms_125_3.html") + So(ok, ShouldBeTrue) + So(params[":v"], ShouldEqual, "3") + So(params[":id"], ShouldEqual, "125") + So(params[":page"], ShouldEqual, "3") + + _, params, ok = t.Match("/v1/4_cms/ttt_126_4.html") + So(ok, ShouldBeTrue) + So(params[":v"], ShouldEqual, "4") + So(params[":id"], ShouldEqual, "126") + So(params[":page"], ShouldEqual, "4") + }) + + Convey("Match with path and extension", func() { + t := NewTree() + So(t.Add("/*.*", nil), ShouldNotBeNil) + So(t.Add("/docs/*.*", nil), ShouldNotBeNil) + + _, params, ok := t.Match("/profile.html") + So(ok, ShouldBeTrue) + So(params[":path"], ShouldEqual, "profile") + So(params[":ext"], ShouldEqual, "html") + + _, params, ok = t.Match("/profile") + So(ok, ShouldBeTrue) + So(params[":path"], ShouldEqual, "profile") + So(params[":ext"], ShouldBeEmpty) + + _, params, ok = t.Match("/docs/framework/manual.html") + So(ok, ShouldBeTrue) + So(params[":path"], ShouldEqual, "framework/manual") + So(params[":ext"], ShouldEqual, "html") + + _, params, ok = t.Match("/docs/framework/manual") + So(ok, ShouldBeTrue) + So(params[":path"], ShouldEqual, "framework/manual") + So(params[":ext"], ShouldBeEmpty) + }) + + Convey("Match all", func() { + t := NewTree() + So(t.Add("/*", nil), ShouldNotBeNil) + So(t.Add("/*/123", nil), ShouldNotBeNil) + So(t.Add("/*/123/*", nil), ShouldNotBeNil) + So(t.Add("/*/*/123", nil), ShouldNotBeNil) + + _, params, ok := t.Match("/1/2/3") + So(ok, ShouldBeTrue) + So(params["*0"], ShouldEqual, "1/2/3") + + _, params, ok = t.Match("/4/123") + So(ok, ShouldBeTrue) + So(params["*0"], ShouldEqual, "4") + + _, params, ok = t.Match("/5/123/6") + So(ok, ShouldBeTrue) + So(params["*0"], ShouldEqual, "5") + So(params["*1"], ShouldEqual, "6") + + _, params, ok = t.Match("/7/8/123") + So(ok, ShouldBeTrue) + So(params["*0"], ShouldEqual, "7") + So(params["*1"], ShouldEqual, "8") + }) + + Convey("Complex tests", func() { + t := NewTree() + So(t.Add("/:username/:reponame/commit/*", nil), ShouldNotBeNil) + + _, params, ok := t.Match("/unknwon/com/commit/d855b6c9dea98c619925b7b112f3c4e64b17bfa8") + So(ok, ShouldBeTrue) + So(params["*"], ShouldEqual, "d855b6c9dea98c619925b7b112f3c4e64b17bfa8") + }) }) } diff --git a/Godeps/_workspace/src/github.com/astaxie/beego/config/ini.go b/Godeps/_workspace/src/github.com/astaxie/beego/config/ini.go index 837c9ff..31fe9b5 100644 --- a/Godeps/_workspace/src/github.com/astaxie/beego/config/ini.go +++ b/Godeps/_workspace/src/github.com/astaxie/beego/config/ini.go @@ -300,21 +300,8 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { defer f.Close() buf := bytes.NewBuffer(nil) - for section, dt := range c.data { - // Write section comments. - if v, ok := c.sectionComment[section]; ok { - if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil { - return err - } - } - - if section != DEFAULT_SECTION { - // Write section name. - if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { - return err - } - } - + // Save default section at first place + if dt, ok := c.data[DEFAULT_SECTION]; ok { for key, val := range dt { if key != " " { // Write key comments. @@ -336,6 +323,43 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { return err } } + // Save named sections + for section, dt := range c.data { + if section != DEFAULT_SECTION { + // Write section comments. + if v, ok := c.sectionComment[section]; ok { + if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil { + return err + } + } + + // Write section name. + if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { + return err + } + + for key, val := range dt { + if key != " " { + // Write key comments. + if v, ok := c.keyComment[key]; ok { + if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + } if _, err = buf.WriteTo(f); err != nil { return err diff --git a/Godeps/_workspace/src/github.com/astaxie/beego/logs/es/es.go b/Godeps/_workspace/src/github.com/astaxie/beego/logs/es/es.go new file mode 100644 index 0000000..3a73d4d --- /dev/null +++ b/Godeps/_workspace/src/github.com/astaxie/beego/logs/es/es.go @@ -0,0 +1,76 @@ +package es + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "net/url" + "time" + + "github.com/astaxie/beego/logs" + "github.com/belogik/goes" +) + +func NewES() logs.LoggerInterface { + cw := &esLogger{ + Level: logs.LevelDebug, + } + return cw +} + +type esLogger struct { + *goes.Connection + DSN string `json:"dsn"` + Level int `json:"level"` +} + +// {"dsn":"http://localhost:9200/","level":1} +func (el *esLogger) Init(jsonconfig string) error { + err := json.Unmarshal([]byte(jsonconfig), el) + if err != nil { + return err + } + if el.DSN == "" { + return errors.New("empty dsn") + } else if u, err := url.Parse(el.DSN); err != nil { + return err + } else if u.Path == "" { + return errors.New("missing prefix") + } else if host, port, err := net.SplitHostPort(u.Host); err != nil { + return err + } else { + conn := goes.NewConnection(host, port) + el.Connection = conn + } + return nil +} + +func (el *esLogger) WriteMsg(msg string, level int) error { + if level > el.Level { + return nil + } + t := time.Now() + vals := make(map[string]interface{}) + vals["@timestamp"] = t.Format(time.RFC3339) + vals["@msg"] = msg + d := goes.Document{ + Index: fmt.Sprintf("%04d.%02d.%02d", t.Year(), t.Month(), t.Day()), + Type: "logs", + Fields: vals, + } + _, err := el.Index(d, nil) + return err +} + +func (el *esLogger) Destroy() { + +} + +func (el *esLogger) Flush() { + +} + +func init() { + logs.Register("es", NewES) +} diff --git a/Godeps/_workspace/src/github.com/astaxie/beego/logs/log.go b/Godeps/_workspace/src/github.com/astaxie/beego/logs/log.go index 32e0187..cebbc73 100644 --- a/Godeps/_workspace/src/github.com/astaxie/beego/logs/log.go +++ b/Godeps/_workspace/src/github.com/astaxie/beego/logs/log.go @@ -92,6 +92,7 @@ type BeeLogger struct { level int enableFuncCallDepth bool loggerFuncCallDepth int + asynchronous bool msg chan *logMsg outputs map[string]LoggerInterface } @@ -110,7 +111,11 @@ func NewLogger(channellen int64) *BeeLogger { bl.loggerFuncCallDepth = 2 bl.msg = make(chan *logMsg, channellen) bl.outputs = make(map[string]LoggerInterface) - //bl.SetLogger("console", "") // default output to console + return bl +} + +func (bl *BeeLogger) Async() *BeeLogger { + bl.asynchronous = true go bl.startLogger() return bl } @@ -148,26 +153,30 @@ func (bl *BeeLogger) DelLogger(adaptername string) error { } func (bl *BeeLogger) writerMsg(loglevel int, msg string) error { - if loglevel > bl.level { - return nil - } lm := new(logMsg) lm.level = loglevel if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) - if _, filename := path.Split(file); filename == "log.go" && (line == 97 || line == 83) { - _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth + 1) - } - if ok { - _, filename := path.Split(file) - lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg) - } else { - lm.msg = msg + if !ok { + file = "???" + line = 0 } + _, filename := path.Split(file) + lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg) } else { lm.msg = msg } - bl.msg <- lm + if bl.asynchronous { + bl.msg <- lm + } else { + for name, l := range bl.outputs { + err := l.WriteMsg(lm.msg, lm.level) + if err != nil { + fmt.Println("unable to WriteMsg to adapter:", name, err) + return err + } + } + } return nil } @@ -184,6 +193,11 @@ func (bl *BeeLogger) SetLogFuncCallDepth(d int) { bl.loggerFuncCallDepth = d } +// get log funcCallDepth for wrapper +func (bl *BeeLogger) GetLogFuncCallDepth() int { + return bl.loggerFuncCallDepth +} + // enable log funcCallDepth func (bl *BeeLogger) EnableFuncCallDepth(b bool) { bl.enableFuncCallDepth = b @@ -207,71 +221,104 @@ func (bl *BeeLogger) startLogger() { // Log EMERGENCY level message. func (bl *BeeLogger) Emergency(format string, v ...interface{}) { + if LevelEmergency > bl.level { + return + } msg := fmt.Sprintf("[M] "+format, v...) bl.writerMsg(LevelEmergency, msg) } // Log ALERT level message. func (bl *BeeLogger) Alert(format string, v ...interface{}) { + if LevelAlert > bl.level { + return + } msg := fmt.Sprintf("[A] "+format, v...) bl.writerMsg(LevelAlert, msg) } // Log CRITICAL level message. func (bl *BeeLogger) Critical(format string, v ...interface{}) { + if LevelCritical > bl.level { + return + } msg := fmt.Sprintf("[C] "+format, v...) bl.writerMsg(LevelCritical, msg) } // Log ERROR level message. func (bl *BeeLogger) Error(format string, v ...interface{}) { + if LevelError > bl.level { + return + } msg := fmt.Sprintf("[E] "+format, v...) bl.writerMsg(LevelError, msg) } // Log WARNING level message. func (bl *BeeLogger) Warning(format string, v ...interface{}) { + if LevelWarning > bl.level { + return + } msg := fmt.Sprintf("[W] "+format, v...) bl.writerMsg(LevelWarning, msg) } // Log NOTICE level message. func (bl *BeeLogger) Notice(format string, v ...interface{}) { + if LevelNotice > bl.level { + return + } msg := fmt.Sprintf("[N] "+format, v...) bl.writerMsg(LevelNotice, msg) } // Log INFORMATIONAL level message. func (bl *BeeLogger) Informational(format string, v ...interface{}) { + if LevelInformational > bl.level { + return + } msg := fmt.Sprintf("[I] "+format, v...) bl.writerMsg(LevelInformational, msg) } // Log DEBUG level message. func (bl *BeeLogger) Debug(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } msg := fmt.Sprintf("[D] "+format, v...) bl.writerMsg(LevelDebug, msg) } // Log WARN level message. -// -// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0. +// compatibility alias for Warning() func (bl *BeeLogger) Warn(format string, v ...interface{}) { - bl.Warning(format, v...) + if LevelWarning > bl.level { + return + } + msg := fmt.Sprintf("[W] "+format, v...) + bl.writerMsg(LevelWarning, msg) } // Log INFO level message. -// -// Deprecated: compatibility alias for Informational(), Will be removed in 1.5.0. +// compatibility alias for Informational() func (bl *BeeLogger) Info(format string, v ...interface{}) { - bl.Informational(format, v...) + if LevelInformational > bl.level { + return + } + msg := fmt.Sprintf("[I] "+format, v...) + bl.writerMsg(LevelInformational, msg) } // Log TRACE level message. -// -// Deprecated: compatibility alias for Debug(), Will be removed in 1.5.0. +// compatibility alias for Debug() func (bl *BeeLogger) Trace(format string, v ...interface{}) { - bl.Debug(format, v...) + if LevelDebug > bl.level { + return + } + msg := fmt.Sprintf("[D] "+format, v...) + bl.writerMsg(LevelDebug, msg) } // flush all chan data. diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml b/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml index baf46ab..34d39c8 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml @@ -1,5 +1,12 @@ language: go -go: 1.1 +sudo: false + +go: +- 1.0.3 +- 1.1.2 +- 1.2.2 +- 1.3.3 +- 1.4.2 script: - go vet ./... diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/README.md b/Godeps/_workspace/src/github.com/codegangsta/cli/README.md index cd980fd..85b9cda 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/README.md +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/README.md @@ -158,6 +158,8 @@ app.Action = func(c *cli.Context) { ... ``` +See full list of flags at http://godoc.org/github.com/codegangsta/cli + #### Alternate Names You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g. diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/app.go b/Godeps/_workspace/src/github.com/codegangsta/cli/app.go index 891416d..e7caec9 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/app.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/app.go @@ -9,7 +9,7 @@ import ( ) // App is the main structure of a cli application. It is recomended that -// and app be created with the cli.NewApp() function +// an app be created with the cli.NewApp() function type App struct { // The name of the program. Defaults to os.Args[0] Name string @@ -43,6 +43,8 @@ type App struct { Compiled time.Time // List of all authors who contributed Authors []Author + // Copyright of the binary if any + Copyright string // Name of Author (Note: Use App.Authors, this is deprecated) Author string // Email of Author (Note: Use App.Authors, this is deprecated) @@ -104,17 +106,16 @@ func (a *App) Run(arguments []string) (err error) { nerr := normalizeFlags(a.Flags, set) if nerr != nil { fmt.Fprintln(a.Writer, nerr) - context := NewContext(a, set, set) + context := NewContext(a, set, nil) ShowAppHelp(context) - fmt.Fprintln(a.Writer) return nerr } - context := NewContext(a, set, set) + context := NewContext(a, set, nil) if err != nil { - fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n") - ShowAppHelp(context) + fmt.Fprintln(a.Writer, "Incorrect Usage.") fmt.Fprintln(a.Writer) + ShowAppHelp(context) return err } @@ -132,10 +133,14 @@ func (a *App) Run(arguments []string) (err error) { if a.After != nil { defer func() { - // err is always nil here. - // There is a check to see if it is non-nil - // just few lines before. - err = a.After(context) + afterErr := a.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } }() } @@ -190,21 +195,22 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { set.SetOutput(ioutil.Discard) err = set.Parse(ctx.Args().Tail()) nerr := normalizeFlags(a.Flags, set) - context := NewContext(a, set, ctx.globalSet) + context := NewContext(a, set, ctx) if nerr != nil { fmt.Fprintln(a.Writer, nerr) + fmt.Fprintln(a.Writer) if len(a.Commands) > 0 { ShowSubcommandHelp(context) } else { ShowCommandHelp(ctx, context.Args().First()) } - fmt.Fprintln(a.Writer) return nerr } if err != nil { - fmt.Fprintf(a.Writer, "Incorrect Usage.\n\n") + fmt.Fprintln(a.Writer, "Incorrect Usage.") + fmt.Fprintln(a.Writer) ShowSubcommandHelp(context) return err } @@ -225,10 +231,14 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) { if a.After != nil { defer func() { - // err is always nil here. - // There is a check to see if it is non-nil - // just few lines before. - err = a.After(context) + afterErr := a.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } }() } diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go index ae8bb0f..2d52e88 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go @@ -342,6 +342,38 @@ func TestApp_ParseSliceFlags(t *testing.T) { } } +func TestApp_ParseSliceFlagsWithMissingValue(t *testing.T) { + var parsedIntSlice []int + var parsedStringSlice []string + + app := cli.NewApp() + command := cli.Command{ + Name: "cmd", + Flags: []cli.Flag{ + cli.IntSliceFlag{Name: "a", Usage: "set numbers"}, + cli.StringSliceFlag{Name: "str", Usage: "set strings"}, + }, + Action: func(c *cli.Context) { + parsedIntSlice = c.IntSlice("a") + parsedStringSlice = c.StringSlice("str") + }, + } + app.Commands = []cli.Command{command} + + app.Run([]string{"", "cmd", "my-arg", "-a", "2", "-str", "A"}) + + var expectedIntSlice = []int{2} + var expectedStringSlice = []string{"A"} + + if parsedIntSlice[0] != expectedIntSlice[0] { + t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0]) + } + + if parsedStringSlice[0] != expectedStringSlice[0] { + t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0]) + } +} + func TestApp_DefaultStdout(t *testing.T) { app := cli.NewApp() @@ -595,8 +627,26 @@ func TestAppCommandNotFound(t *testing.T) { expect(t, subcommandRun, false) } +func TestGlobalFlag(t *testing.T) { + var globalFlag string + var globalFlagSet bool + app := cli.NewApp() + app.Flags = []cli.Flag{ + cli.StringFlag{Name: "global, g", Usage: "global"}, + } + app.Action = func(c *cli.Context) { + globalFlag = c.GlobalString("global") + globalFlagSet = c.GlobalIsSet("global") + } + app.Run([]string{"command", "-g", "foo"}) + expect(t, globalFlag, "foo") + expect(t, globalFlagSet, true) + +} + func TestGlobalFlagsInSubcommands(t *testing.T) { subcommandRun := false + parentFlag := false app := cli.NewApp() app.Flags = []cli.Flag{ @@ -606,6 +656,9 @@ func TestGlobalFlagsInSubcommands(t *testing.T) { app.Commands = []cli.Command{ cli.Command{ Name: "foo", + Flags: []cli.Flag{ + cli.BoolFlag{Name: "parent, p", Usage: "Parent flag"}, + }, Subcommands: []cli.Command{ { Name: "bar", @@ -613,15 +666,19 @@ func TestGlobalFlagsInSubcommands(t *testing.T) { if c.GlobalBool("debug") { subcommandRun = true } + if c.GlobalBool("parent") { + parentFlag = true + } }, }, }, }, } - app.Run([]string{"command", "-d", "foo", "bar"}) + app.Run([]string{"command", "-d", "foo", "-p", "bar"}) expect(t, subcommandRun, true) + expect(t, parentFlag, true) } func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) { @@ -677,3 +734,136 @@ func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) { } } } + +func TestApp_Run_SubcommandFullPath(t *testing.T) { + app := cli.NewApp() + buf := new(bytes.Buffer) + app.Writer = buf + + subCmd := cli.Command{ + Name: "bar", + Usage: "does bar things", + } + cmd := cli.Command{ + Name: "foo", + Description: "foo commands", + Subcommands: []cli.Command{subCmd}, + } + app.Commands = []cli.Command{cmd} + + err := app.Run([]string{"command", "foo", "bar", "--help"}) + if err != nil { + t.Error(err) + } + + output := buf.String() + if !strings.Contains(output, "foo bar - does bar things") { + t.Errorf("expected full path to subcommand: %s", output) + } + if !strings.Contains(output, "command foo bar [arguments...]") { + t.Errorf("expected full path to subcommand: %s", output) + } +} + +func TestApp_Run_Help(t *testing.T) { + var helpArguments = [][]string{{"boom", "--help"}, {"boom", "-h"}, {"boom", "help"}} + + for _, args := range helpArguments { + buf := new(bytes.Buffer) + + t.Logf("==> checking with arguments %v", args) + + app := cli.NewApp() + app.Name = "boom" + app.Usage = "make an explosive entrance" + app.Writer = buf + app.Action = func(c *cli.Context) { + buf.WriteString("boom I say!") + } + + err := app.Run(args) + if err != nil { + t.Error(err) + } + + output := buf.String() + t.Logf("output: %q\n", buf.Bytes()) + + if !strings.Contains(output, "boom - make an explosive entrance") { + t.Errorf("want help to contain %q, did not: \n%q", "boom - make an explosive entrance", output) + } + } +} + +func TestApp_Run_Version(t *testing.T) { + var versionArguments = [][]string{{"boom", "--version"}, {"boom", "-v"}} + + for _, args := range versionArguments { + buf := new(bytes.Buffer) + + t.Logf("==> checking with arguments %v", args) + + app := cli.NewApp() + app.Name = "boom" + app.Usage = "make an explosive entrance" + app.Version = "0.1.0" + app.Writer = buf + app.Action = func(c *cli.Context) { + buf.WriteString("boom I say!") + } + + err := app.Run(args) + if err != nil { + t.Error(err) + } + + output := buf.String() + t.Logf("output: %q\n", buf.Bytes()) + + if !strings.Contains(output, "0.1.0") { + t.Errorf("want version to contain %q, did not: \n%q", "0.1.0", output) + } + } +} + +func TestApp_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) { + app := cli.NewApp() + app.Action = func(c *cli.Context) {} + app.Before = func(c *cli.Context) error { return fmt.Errorf("before error") } + app.After = func(c *cli.Context) error { return fmt.Errorf("after error") } + + err := app.Run([]string{"foo"}) + if err == nil { + t.Fatalf("expected to recieve error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +} + +func TestApp_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) { + app := cli.NewApp() + app.Commands = []cli.Command{ + cli.Command{ + Name: "bar", + Before: func(c *cli.Context) error { return fmt.Errorf("before error") }, + After: func(c *cli.Context) error { return fmt.Errorf("after error") }, + }, + } + + err := app.Run([]string{"foo", "bar"}) + if err == nil { + t.Fatalf("expected to recieve error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go b/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go index b742545..31dc912 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go @@ -17,3 +17,24 @@ // app.Run(os.Args) // } package cli + +import ( + "strings" +) + +type MultiError struct { + Errors []error +} + +func NewMultiError(err ...error) MultiError { + return MultiError{Errors: err} +} + +func (m MultiError) Error() string { + errs := make([]string, len(m.Errors)) + for i, err := range m.Errors { + errs[i] = err.Error() + } + + return strings.Join(errs, "\n") +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/command.go b/Godeps/_workspace/src/github.com/codegangsta/cli/command.go index d0bbd0c..54617af 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/command.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/command.go @@ -36,11 +36,21 @@ type Command struct { SkipFlagParsing bool // Boolean to hide built-in help command HideHelp bool + + commandNamePath []string +} + +// Returns the full name of the command. +// For subcommands this ensures that parent commands are part of the command path +func (c Command) FullName() string { + if c.commandNamePath == nil { + return c.Name + } + return strings.Join(c.commandNamePath, " ") } // Invokes the command given the context, parses ctx.Args() to generate command-specific flags func (c Command) Run(ctx *Context) error { - if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil { return c.startApp(ctx) } @@ -91,9 +101,9 @@ func (c Command) Run(ctx *Context) error { } if err != nil { - fmt.Fprint(ctx.App.Writer, "Incorrect Usage.\n\n") - ShowCommandHelp(ctx, c.Name) + fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.") fmt.Fprintln(ctx.App.Writer) + ShowCommandHelp(ctx, c.Name) return err } @@ -102,10 +112,9 @@ func (c Command) Run(ctx *Context) error { fmt.Fprintln(ctx.App.Writer, nerr) fmt.Fprintln(ctx.App.Writer) ShowCommandHelp(ctx, c.Name) - fmt.Fprintln(ctx.App.Writer) return nerr } - context := NewContext(ctx.App, set, ctx.globalSet) + context := NewContext(ctx.App, set, ctx) if checkCommandCompletions(context, c.Name) { return nil @@ -180,5 +189,12 @@ func (c Command) startApp(ctx *Context) error { app.Action = helpSubcommand.Action } + var newCmds []Command + for _, cc := range app.Commands { + cc.commandNamePath = []string{c.Name, cc.Name} + newCmds = append(newCmds, cc) + } + app.Commands = newCmds + return app.RunAsSubcommand(ctx) } diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go index 4125b0c..db81db2 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go @@ -13,7 +13,7 @@ func TestCommandDoNotIgnoreFlags(t *testing.T) { test := []string{"blah", "blah", "-break"} set.Parse(test) - c := cli.NewContext(app, set, set) + c := cli.NewContext(app, set, nil) command := cli.Command{ Name: "test-cmd", @@ -33,7 +33,7 @@ func TestCommandIgnoreFlags(t *testing.T) { test := []string{"blah", "blah"} set.Parse(test) - c := cli.NewContext(app, set, set) + c := cli.NewContext(app, set, nil) command := cli.Command{ Name: "test-cmd", diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/context.go b/Godeps/_workspace/src/github.com/codegangsta/cli/context.go index 37221bd..f541f41 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/context.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/context.go @@ -16,14 +16,14 @@ type Context struct { App *App Command Command flagSet *flag.FlagSet - globalSet *flag.FlagSet setFlags map[string]bool globalSetFlags map[string]bool + parentContext *Context } // Creates a new context. For use in when invoking an App or Command action. -func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context { - return &Context{App: app, flagSet: set, globalSet: globalSet} +func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context { + return &Context{App: app, flagSet: set, parentContext: parentCtx} } // Looks up the value of a local int flag, returns 0 if no int flag exists @@ -73,37 +73,58 @@ func (c *Context) Generic(name string) interface{} { // Looks up the value of a global int flag, returns 0 if no int flag exists func (c *Context) GlobalInt(name string) int { - return lookupInt(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupInt(name, fs) + } + return 0 } // Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists func (c *Context) GlobalDuration(name string) time.Duration { - return lookupDuration(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupDuration(name, fs) + } + return 0 } // Looks up the value of a global bool flag, returns false if no bool flag exists func (c *Context) GlobalBool(name string) bool { - return lookupBool(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupBool(name, fs) + } + return false } // Looks up the value of a global string flag, returns "" if no string flag exists func (c *Context) GlobalString(name string) string { - return lookupString(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupString(name, fs) + } + return "" } // Looks up the value of a global string slice flag, returns nil if no string slice flag exists func (c *Context) GlobalStringSlice(name string) []string { - return lookupStringSlice(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupStringSlice(name, fs) + } + return nil } // Looks up the value of a global int slice flag, returns nil if no int slice flag exists func (c *Context) GlobalIntSlice(name string) []int { - return lookupIntSlice(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupIntSlice(name, fs) + } + return nil } // Looks up the value of a global generic flag, returns nil if no generic flag exists func (c *Context) GlobalGeneric(name string) interface{} { - return lookupGeneric(name, c.globalSet) + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupGeneric(name, fs) + } + return nil } // Returns the number of flags set @@ -126,11 +147,17 @@ func (c *Context) IsSet(name string) bool { func (c *Context) GlobalIsSet(name string) bool { if c.globalSetFlags == nil { c.globalSetFlags = make(map[string]bool) - c.globalSet.Visit(func(f *flag.Flag) { - c.globalSetFlags[f.Name] = true - }) + ctx := c + if ctx.parentContext != nil { + ctx = ctx.parentContext + } + for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext { + ctx.flagSet.Visit(func(f *flag.Flag) { + c.globalSetFlags[f.Name] = true + }) + } } - return c.globalSetFlags[name] == true + return c.globalSetFlags[name] } // Returns a slice of flag names used in this context. @@ -157,6 +184,11 @@ func (c *Context) GlobalFlagNames() (names []string) { return } +// Returns the parent context, if any +func (c *Context) Parent() *Context { + return c.parentContext +} + type Args []string // Returns the command line arguments associated with the context. @@ -201,6 +233,18 @@ func (a Args) Swap(from, to int) error { return nil } +func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet { + if ctx.parentContext != nil { + ctx = ctx.parentContext + } + for ; ctx != nil; ctx = ctx.parentContext { + if f := ctx.flagSet.Lookup(name); f != nil { + return ctx.flagSet + } + } + return nil +} + func lookupInt(name string, set *flag.FlagSet) int { f := set.Lookup(name) if f != nil { diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go index d4a1877..6c27d06 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go @@ -13,8 +13,9 @@ func TestNewContext(t *testing.T) { set.Int("myflag", 12, "doc") globalSet := flag.NewFlagSet("test", 0) globalSet.Int("myflag", 42, "doc") + globalCtx := cli.NewContext(nil, globalSet, nil) command := cli.Command{Name: "mycommand"} - c := cli.NewContext(nil, set, globalSet) + c := cli.NewContext(nil, set, globalCtx) c.Command = command expect(t, c.Int("myflag"), 12) expect(t, c.GlobalInt("myflag"), 42) @@ -24,42 +25,42 @@ func TestNewContext(t *testing.T) { func TestContext_Int(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Int("myflag", 12, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.Int("myflag"), 12) } func TestContext_Duration(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Duration("myflag", time.Duration(12*time.Second), "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.Duration("myflag"), time.Duration(12*time.Second)) } func TestContext_String(t *testing.T) { set := flag.NewFlagSet("test", 0) set.String("myflag", "hello world", "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.String("myflag"), "hello world") } func TestContext_Bool(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.Bool("myflag"), false) } func TestContext_BoolT(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", true, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) expect(t, c.BoolT("myflag"), true) } func TestContext_Args(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") - c := cli.NewContext(nil, set, set) + c := cli.NewContext(nil, set, nil) set.Parse([]string{"--myflag", "bat", "baz"}) expect(t, len(c.Args()), 2) expect(t, c.Bool("myflag"), true) @@ -71,7 +72,8 @@ func TestContext_IsSet(t *testing.T) { set.String("otherflag", "hello world", "doc") globalSet := flag.NewFlagSet("test", 0) globalSet.Bool("myflagGlobal", true, "doc") - c := cli.NewContext(nil, set, globalSet) + globalCtx := cli.NewContext(nil, globalSet, nil) + c := cli.NewContext(nil, set, globalCtx) set.Parse([]string{"--myflag", "bat", "baz"}) globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) expect(t, c.IsSet("myflag"), true) @@ -87,7 +89,8 @@ func TestContext_GlobalIsSet(t *testing.T) { globalSet := flag.NewFlagSet("test", 0) globalSet.Bool("myflagGlobal", true, "doc") globalSet.Bool("myflagGlobalUnset", true, "doc") - c := cli.NewContext(nil, set, globalSet) + globalCtx := cli.NewContext(nil, globalSet, nil) + c := cli.NewContext(nil, set, globalCtx) set.Parse([]string{"--myflag", "bat", "baz"}) globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) expect(t, c.GlobalIsSet("myflag"), false) @@ -104,7 +107,8 @@ func TestContext_NumFlags(t *testing.T) { set.String("otherflag", "hello world", "doc") globalSet := flag.NewFlagSet("test", 0) globalSet.Bool("myflagGlobal", true, "doc") - c := cli.NewContext(nil, set, globalSet) + globalCtx := cli.NewContext(nil, globalSet, nil) + c := cli.NewContext(nil, set, globalCtx) set.Parse([]string{"--myflag", "--otherflag=foo"}) globalSet.Parse([]string{"--myflagGlobal"}) expect(t, c.NumFlags(), 2) diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go b/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go index 2511586..531b091 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go @@ -99,21 +99,27 @@ func (f GenericFlag) getName() string { return f.Name } +// StringSlice is an opaque type for []string to satisfy flag.Value type StringSlice []string +// Set appends the string value to the list of values func (f *StringSlice) Set(value string) error { *f = append(*f, value) return nil } +// String returns a readable representation of this value (for usage defaults) func (f *StringSlice) String() string { return fmt.Sprintf("%s", *f) } +// Value returns the slice of strings set by this flag func (f *StringSlice) Value() []string { return *f } +// StringSlice is a string flag that can be specified multiple times on the +// command-line type StringSliceFlag struct { Name string Value *StringSlice @@ -121,12 +127,14 @@ type StringSliceFlag struct { EnvVar string } +// String returns the usage func (f StringSliceFlag) String() string { firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") pref := prefixFor(firstName) return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) } +// Apply populates the flag given the flag set and environment func (f StringSliceFlag) Apply(set *flag.FlagSet) { if f.EnvVar != "" { for _, envVar := range strings.Split(f.EnvVar, ",") { @@ -144,6 +152,9 @@ func (f StringSliceFlag) Apply(set *flag.FlagSet) { } eachName(f.Name, func(name string) { + if f.Value == nil { + f.Value = &StringSlice{} + } set.Var(f.Value, name, f.Usage) }) } @@ -152,10 +163,11 @@ func (f StringSliceFlag) getName() string { return f.Name } +// StringSlice is an opaque type for []int to satisfy flag.Value type IntSlice []int +// Set parses the value into an integer and appends it to the list of values func (f *IntSlice) Set(value string) error { - tmp, err := strconv.Atoi(value) if err != nil { return err @@ -165,14 +177,18 @@ func (f *IntSlice) Set(value string) error { return nil } +// String returns a readable representation of this value (for usage defaults) func (f *IntSlice) String() string { return fmt.Sprintf("%d", *f) } +// Value returns the slice of ints set by this flag func (f *IntSlice) Value() []int { return *f } +// IntSliceFlag is an int flag that can be specified multiple times on the +// command-line type IntSliceFlag struct { Name string Value *IntSlice @@ -180,12 +196,14 @@ type IntSliceFlag struct { EnvVar string } +// String returns the usage func (f IntSliceFlag) String() string { firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") pref := prefixFor(firstName) return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) } +// Apply populates the flag given the flag set and environment func (f IntSliceFlag) Apply(set *flag.FlagSet) { if f.EnvVar != "" { for _, envVar := range strings.Split(f.EnvVar, ",") { @@ -206,6 +224,9 @@ func (f IntSliceFlag) Apply(set *flag.FlagSet) { } eachName(f.Name, func(name string) { + if f.Value == nil { + f.Value = &IntSlice{} + } set.Var(f.Value, name, f.Usage) }) } @@ -214,16 +235,19 @@ func (f IntSliceFlag) getName() string { return f.Name } +// BoolFlag is a switch that defaults to false type BoolFlag struct { Name string Usage string EnvVar string } +// String returns a readable representation of this value (for usage defaults) func (f BoolFlag) String() string { return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) } +// Apply populates the flag given the flag set and environment func (f BoolFlag) Apply(set *flag.FlagSet) { val := false if f.EnvVar != "" { @@ -248,16 +272,20 @@ func (f BoolFlag) getName() string { return f.Name } +// BoolTFlag this represents a boolean flag that is true by default, but can +// still be set to false by --some-flag=false type BoolTFlag struct { Name string Usage string EnvVar string } +// String returns a readable representation of this value (for usage defaults) func (f BoolTFlag) String() string { return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) } +// Apply populates the flag given the flag set and environment func (f BoolTFlag) Apply(set *flag.FlagSet) { val := true if f.EnvVar != "" { @@ -282,6 +310,7 @@ func (f BoolTFlag) getName() string { return f.Name } +// StringFlag represents a flag that takes as string value type StringFlag struct { Name string Value string @@ -289,6 +318,7 @@ type StringFlag struct { EnvVar string } +// String returns the usage func (f StringFlag) String() string { var fmtString string fmtString = "%s %v\t%v" @@ -302,6 +332,7 @@ func (f StringFlag) String() string { return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage)) } +// Apply populates the flag given the flag set and environment func (f StringFlag) Apply(set *flag.FlagSet) { if f.EnvVar != "" { for _, envVar := range strings.Split(f.EnvVar, ",") { @@ -322,6 +353,8 @@ func (f StringFlag) getName() string { return f.Name } +// IntFlag is a flag that takes an integer +// Errors if the value provided cannot be parsed type IntFlag struct { Name string Value int @@ -329,10 +362,12 @@ type IntFlag struct { EnvVar string } +// String returns the usage func (f IntFlag) String() string { return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) } +// Apply populates the flag given the flag set and environment func (f IntFlag) Apply(set *flag.FlagSet) { if f.EnvVar != "" { for _, envVar := range strings.Split(f.EnvVar, ",") { @@ -356,6 +391,8 @@ func (f IntFlag) getName() string { return f.Name } +// DurationFlag is a flag that takes a duration specified in Go's duration +// format: https://golang.org/pkg/time/#ParseDuration type DurationFlag struct { Name string Value time.Duration @@ -363,10 +400,12 @@ type DurationFlag struct { EnvVar string } +// String returns a readable representation of this value (for usage defaults) func (f DurationFlag) String() string { return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) } +// Apply populates the flag given the flag set and environment func (f DurationFlag) Apply(set *flag.FlagSet) { if f.EnvVar != "" { for _, envVar := range strings.Split(f.EnvVar, ",") { @@ -390,6 +429,8 @@ func (f DurationFlag) getName() string { return f.Name } +// Float64Flag is a flag that takes an float value +// Errors if the value provided cannot be parsed type Float64Flag struct { Name string Value float64 @@ -397,10 +438,12 @@ type Float64Flag struct { EnvVar string } +// String returns the usage func (f Float64Flag) String() string { return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) } +// Apply populates the flag given the flag set and environment func (f Float64Flag) Apply(set *flag.FlagSet) { if f.EnvVar != "" { for _, envVar := range strings.Split(f.EnvVar, ",") { diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/help.go b/Godeps/_workspace/src/github.com/codegangsta/cli/help.go index 1117945..66ef2fb 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/help.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/help.go @@ -15,30 +15,33 @@ var AppHelpTemplate = `NAME: {{.Name}} - {{.Usage}} USAGE: - {{.Name}} {{if .Flags}}[global options] {{end}}command{{if .Flags}} [command options]{{end}} [arguments...] - + {{.Name}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} [arguments...] + {{if .Version}} VERSION: - {{.Version}}{{if len .Authors}} - -AUTHOR(S): - {{range .Authors}}{{ . }}{{end}}{{end}} - + {{.Version}} + {{end}}{{if len .Authors}} +AUTHOR(S): + {{range .Authors}}{{ . }}{{end}} + {{end}}{{if .Commands}} COMMANDS: {{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}} - {{end}}{{if .Flags}} + {{end}}{{end}}{{if .Flags}} GLOBAL OPTIONS: {{range .Flags}}{{.}} - {{end}}{{end}} + {{end}}{{end}}{{if .Copyright }} +COPYRIGHT: + {{.Copyright}} + {{end}} ` // The text template for the command help topic. // cli.go uses text/template to render templates. You can // render custom help text by setting this variable. var CommandHelpTemplate = `NAME: - {{.Name}} - {{.Usage}} + {{.FullName}} - {{.Usage}} USAGE: - command {{.Name}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}} + command {{.FullName}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}} DESCRIPTION: {{.Description}}{{end}}{{if .Flags}} @@ -181,7 +184,7 @@ func printHelp(out io.Writer, templ string, data interface{}) { } func checkVersion(c *Context) bool { - if c.GlobalBool("version") { + if c.GlobalBool("version") || c.GlobalBool("v") || c.Bool("version") || c.Bool("v") { ShowVersion(c) return true } @@ -190,7 +193,7 @@ func checkVersion(c *Context) bool { } func checkHelp(c *Context) bool { - if c.GlobalBool("h") || c.GlobalBool("help") { + if c.GlobalBool("h") || c.GlobalBool("help") || c.Bool("h") || c.Bool("help") { ShowAppHelp(c) return true } diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go index b3c1fda..c85f957 100644 --- a/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go @@ -20,3 +20,19 @@ func Test_ShowAppHelp_NoAuthor(t *testing.T) { t.Errorf("expected\n%snot to include %s", output.String(), "AUTHOR(S):") } } + +func Test_ShowAppHelp_NoVersion(t *testing.T) { + output := new(bytes.Buffer) + app := cli.NewApp() + app.Writer = output + + app.Version = "" + + c := cli.NewContext(app, nil, nil) + + cli.ShowAppHelp(c) + + if bytes.Index(output.Bytes(), []byte("VERSION:")) != -1 { + t.Errorf("expected\n%snot to include %s", output.String(), "VERSION:") + } +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/db/db.go b/Godeps/_workspace/src/github.com/containerops/wrench/db/db.go new file mode 100644 index 0000000..74d9923 --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/db/db.go @@ -0,0 +1,133 @@ +package db + +import ( + "encoding/json" + "fmt" + + "gopkg.in/redis.v3" +) + +const ( + //Dockyard Data Index + GLOBAL_REPOSITORY_INDEX = "GLOBAL_REPOSITORY_INDEX" + GLOBAL_IMAGE_INDEX = "GLOBAL_IMAGE_INDEX" + GLOBAL_TARSUM_INDEX = "GLOBAL_TARSUM_INDEX" + GLOBAL_TAG_INDEX = "GLOBAL_TAG_INDEX" + GLOBAL_COMPOSE_INDEX = "GLOBAL_COMPOSE_INDEX" + GLOBAL_LIBRARY_INDEX = "GLOBAL_LIBRARY_INDEX" + //Sail Data Index + GLOBAL_USER_INDEX = "GLOBAL_USER_INDEX" + GLOBAL_ORGANIZATION_INDEX = "GLOBAL_ORGANIZATION_INDEX" + GLOBAL_TEAM_INDEX = "GLOBAL_TEAM_INDEX" + //Wharf Data Index + GLOBAL_ADMIN_INDEX = "GLOBAL_ADMIN_INDEX" + GLOBAL_LOG_INDEX = "GLOBAL_LOG_INDEX" +) + +/* + [user] : USER-(username) + [organization] : ORG-(org) + [team] : TEAM-(org)-(team) + [repository] : REPO-(namespace)-(repo) + [image] : IMAGE-(imageId) + [tag] : TAG-(namespace)-(repo)-(tag) + [compose] : COMPOSE-(namespace)-(compose) + [admin] : ADMIN-(username) + [log] : LOG-(object) + [lock] : LOCK-(object) +*/ + +var ( + Client *redis.Client +) + +func Key(object string, keys ...string) (result string) { + switch object { + case "USER": + case "user": + result = fmt.Sprintf("USER-%s", keys[0]) + case "ORG": + case "ORGANIZATION": + case "org": + case "organization": + result = fmt.Sprintf("ORG-%s", keys[0]) + case "TEAM": + case "team": + result = fmt.Sprintf("ORG-%s-%s", keys[0], keys[1]) + case "REPO": + case "REPOSITORY": + case "repo": + case "repository": + result = fmt.Sprintf("REPO-%s-%s", keys[0], keys[1]) + case "IMAGE": + case "image": + result = fmt.Sprintf("IMAGE-%s", keys[0]) + case "TARSUM": + case "tarsum": + result = fmt.Sprintf("TARSUM-%s", keys[0]) + case "TAG": + case "tag": + result = fmt.Sprintf("TAG-%s-%s-%s", keys[0], keys[1], keys[2]) + case "COMPOSE": + case "compose": + result = fmt.Sprintf("COMPOSE-%s-%s", keys[0], keys[1]) + case "LIBRARY": + case "library": + result = fmt.Sprintf("LIBRARY-%s", keys[0]) + case "ADMIN": + case "admin": + result = fmt.Sprintf("ADMIN-%s", keys[0]) + case "LOG": + case "log": + result = fmt.Sprintf("LOG-%s", keys[0]) + case "LOCK": + case "lock": + result = fmt.Sprintf("LOCK-%s", keys[0]) + default: + result = "" + } + + return result +} + +func InitDB(addr, passwd string, db int64) error { + Client = redis.NewClient(&redis.Options{ + Addr: addr, + Password: passwd, + DB: db, + }) + + if _, err := Client.Ping().Result(); err != nil { + return err + } else { + return nil + } +} + +func Save(obj interface{}, key string) (err error) { + result, err := json.Marshal(&obj) + + if err != nil { + return err + } + + if _, err := Client.Set(key, string(result), 0).Result(); err != nil { + return err + } + + return nil +} + +func Get(obj interface{}, key string) (err error) { + result, err := Client.Get(key).Result() + + if err != nil { + return err + } + + if err = json.Unmarshal([]byte(result), &obj); err != nil { + return err + } + + return nil +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/db/lock.go b/Godeps/_workspace/src/github.com/containerops/wrench/db/lock.go new file mode 100644 index 0000000..d459c11 --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/db/lock.go @@ -0,0 +1,140 @@ +package db + +import ( + "crypto/rand" + "encoding/base64" + "strconv" + "sync" + "time" + + "gopkg.in/redis.v3" +) + +const luaRefresh = `if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end` +const luaRelease = `if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end` + +type Lock struct { + client *redis.Client + key string + ttl string + opts *LockOptions + + token string + mutex sync.Mutex +} + +// ObtainLock is a shortcut for NewLock().Lock() +func ObtainLock(client *redis.Client, key string, opts *LockOptions) (*Lock, error) { + lock := NewLock(client, key, opts) + if ok, err := lock.Lock(); err != nil || !ok { + return nil, err + } + return lock, nil +} + +// NewLock creates a new distributed lock on key +func NewLock(client *redis.Client, key string, opts *LockOptions) *Lock { + opts = opts.normalize() + ttl := strconv.FormatInt(int64(opts.LockTimeout/time.Millisecond), 10) + return &Lock{client: client, key: key, ttl: ttl, opts: opts} +} + +// IsLocked returns true if a lock is acquired +func (l *Lock) IsLocked() bool { + l.mutex.Lock() + defer l.mutex.Unlock() + + return l.token != "" +} + +// Lock applies the lock, don't forget to defer the Unlock() function to release the lock after usage +func (l *Lock) Lock() (bool, error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + if l.token != "" { + return l.refresh() + } + return l.create() +} + +// Unlock releases the lock +func (l *Lock) Unlock() error { + l.mutex.Lock() + defer l.mutex.Unlock() + + return l.release() +} + +// Helpers +func (l *Lock) create() (bool, error) { + l.reset() + + // Create a random token + token, err := randomToken() + if err != nil { + return false, err + } + + // Calculate the timestamp we are willing to wait for + stop := time.Now().Add(l.opts.WaitTimeout) + for { + // Try to obtain a lock + ok, err := l.obtain(token) + if err != nil { + return false, err + } else if ok { + l.token = token + return true, nil + } + + if time.Now().Add(l.opts.WaitRetry).After(stop) { + break + } + time.Sleep(l.opts.WaitRetry) + } + return false, nil +} + +func (l *Lock) refresh() (bool, error) { + status, err := l.client.Eval(luaRefresh, []string{l.key}, []string{l.token, l.ttl}).Result() + if err != nil { + return false, err + } else if status == int64(1) { + return true, nil + } + return l.create() +} + +func (l *Lock) obtain(token string) (bool, error) { + cmd := redis.NewStringCmd("set", l.key, token, "nx", "px", l.ttl) + l.client.Process(cmd) + + str, err := cmd.Result() + if err == redis.Nil { + err = nil + } + return str == "OK", err +} + +func (l *Lock) release() error { + defer l.reset() + + err := l.client.Eval(luaRelease, []string{l.key}, []string{l.token}).Err() + if err == redis.Nil { + err = nil + } + return err +} + +func (l *Lock) reset() { + l.token = "" +} + +func randomToken() (string, error) { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(buf), nil +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/db/options.go b/Godeps/_workspace/src/github.com/containerops/wrench/db/options.go new file mode 100644 index 0000000..1a30434 --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/db/options.go @@ -0,0 +1,38 @@ +package db + +import ( + "time" +) + +const minWaitRetry = 10 * time.Millisecond + +type LockOptions struct { + // The maximum duration to lock a key for + // Default: 5s + LockTimeout time.Duration + + // The maximum amount of time you are willing to wait to obtain that lock + // Default: 0 = do not wait + WaitTimeout time.Duration + + // In case WaitTimeout is activated, this it the amount of time you are willing + // to wait between retries. + // Default: 100ms, must be at least 10ms + WaitRetry time.Duration +} + +func (o *LockOptions) normalize() *LockOptions { + if o == nil { + o = new(LockOptions) + } + if o.LockTimeout < 1 { + o.LockTimeout = 5 * time.Second + } + if o.WaitTimeout < 0 { + o.WaitTimeout = 0 + } + if o.WaitRetry < minWaitRetry { + o.WaitRetry = minWaitRetry + } + return o +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/setting/setting.go b/Godeps/_workspace/src/github.com/containerops/wrench/setting/setting.go new file mode 100644 index 0000000..522e9e1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/setting/setting.go @@ -0,0 +1,160 @@ +package setting + +import ( + "fmt" + + "github.com/astaxie/beego/config" +) + +const ( + APIVERSION_V1 = iota + APIVERSION_V2 +) + +var ( + conf config.ConfigContainer +) + +var ( + //Global + AppName string + Usage string + Version string + Author string + Email string + RunMode string + ListenMode string + HttpsCertFile string + HttpsKeyFile string + LogPath string + DBURI string + DBPasswd string + DBDB int64 + //Dockyard + BackendDriver string + ImagePath string + Domains string + RegistryVersion string + DistributionVersion string + Standalone string +) + +func SetConfig(path string) error { + var err error + + conf, err = config.NewConfig("ini", path) + if err != nil { + fmt.Errorf("Read %s error: %v", path, err.Error()) + } + + if appname := conf.String("appname"); appname != "" { + AppName = appname + } else if appname == "" { + err = fmt.Errorf("AppName config value is null") + } + + if usage := conf.String("usage"); usage != "" { + Usage = usage + } else if usage == "" { + err = fmt.Errorf("Usage config value is null") + } + + if version := conf.String("version"); version != "" { + Version = version + } else if version == "" { + err = fmt.Errorf("Version config value is null") + } + + if author := conf.String("author"); author != "" { + Author = author + } else if author == "" { + err = fmt.Errorf("Author config value is null") + } + + if email := conf.String("email"); email != "" { + Email = email + } else if email == "" { + err = fmt.Errorf("Email config value is null") + } + + if runmode := conf.String("runmode"); runmode != "" { + RunMode = runmode + } else if runmode == "" { + err = fmt.Errorf("RunMode config value is null") + } + + if listenmode := conf.String("listenmode"); listenmode != "" { + ListenMode = listenmode + } else if listenmode == "" { + err = fmt.Errorf("ListenMode config value is null") + } + + if httpscertfile := conf.String("httpscertfile"); httpscertfile != "" { + HttpsCertFile = httpscertfile + } else if httpscertfile == "" { + err = fmt.Errorf("HttpsCertFile config value is null") + } + + if httpskeyfile := conf.String("httpskeyfile"); httpskeyfile != "" { + HttpsKeyFile = httpskeyfile + } else if httpskeyfile == "" { + err = fmt.Errorf("HttpsKeyFile config value is null") + } + + if logpath := conf.String("log::filepath"); logpath != "" { + LogPath = logpath + } else if logpath == "" { + err = fmt.Errorf("LogPath config value is null") + } + + if dburi := conf.String("db::uri"); dburi != "" { + DBURI = dburi + } else if dburi == "" { + err = fmt.Errorf("DBURI config value is null") + } + + if dbpass := conf.String("db::passwd"); dbpass != "" { + DBPasswd = dbpass + } + + DBDB, err = conf.Int64("db::db") + + //Dockyard + if backenddriver := conf.String("dockyard::driver"); backenddriver != "" { + BackendDriver = backenddriver + } else if backenddriver == "" { + err = fmt.Errorf("Backend driver config value is null") + } + + if imagepath := conf.String("dockyard::path"); imagepath != "" { + ImagePath = imagepath + } else if imagepath == "" { + err = fmt.Errorf("Image path config value is null") + } + + if domains := conf.String("dockyard::domains"); domains != "" { + Domains = domains + } else if domains == "" { + err = fmt.Errorf("Domains value is null") + } + + if registryVersion := conf.String("dockyard::registry"); registryVersion != "" { + RegistryVersion = registryVersion + } else if registryVersion == "" { + err = fmt.Errorf("Registry version value is null") + } + + if distributionVersion := conf.String("dockyard::distribution"); distributionVersion != "" { + DistributionVersion = distributionVersion + } else if distributionVersion == "" { + err = fmt.Errorf("Distribution version value is null") + } + + if standalone := conf.String("dockyard::standalone"); standalone != "" { + Standalone = standalone + } else if standalone == "" { + err = fmt.Errorf("Standalone version value is null") + } + + return err +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/utils/digest.go b/Godeps/_workspace/src/github.com/containerops/wrench/utils/digest.go new file mode 100644 index 0000000..75aaf52 --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/utils/digest.go @@ -0,0 +1,130 @@ +package utils + +import ( + "bytes" + "crypto" + "fmt" + "hash" + "io" + "strings" + + "github.com/docker/libtrust" +) + +//port from distribution in order to match the digest that generated by docker client + +// Algorithm identifies and implementation of a digester by an identifier. +// Note the that this defines both the hash algorithm used and the string +// encoding. +type Algorithm string + +// supported digest types +const ( + SHA256 Algorithm = "sha256" // sha256 with hex encoding + SHA384 Algorithm = "sha384" // sha384 with hex encoding + SHA512 Algorithm = "sha512" // sha512 with hex encoding + TarsumV1SHA256 Algorithm = "tarsum+v1+sha256" // supported tarsum version, verification only + + // Canonical is the primary digest algorithm used with the distribution + // project. Other digests may be used but this one is the primary storage + // digest. + Canonical = SHA256 +) + +var ( + // TODO(stevvooe): Follow the pattern of the standard crypto package for + // registration of digests. Effectively, we are a registerable set and + // common symbol access. + + // algorithms maps values to hash.Hash implementations. Other algorithms + // may be available but they cannot be calculated by the digest package. + algorithms = map[Algorithm]crypto.Hash{ + SHA256: crypto.SHA256, + SHA384: crypto.SHA384, + SHA512: crypto.SHA512, + } +) + +// Available returns true if the digest type is available for use. If this +// returns false, New and Hash will return nil. +func (a Algorithm) Available() bool { + h, ok := algorithms[a] + if !ok { + return false + } + + // check availability of the hash, as well + return h.Available() +} + +func (a Algorithm) New() Digester { + return &digester{ + alg: a, + hash: a.Hash(), + } +} + +func (a Algorithm) Hash() hash.Hash { + if !a.Available() { + return nil + } + + return algorithms[a].New() +} + +type Digester interface { + Hash() hash.Hash // provides direct access to underlying hash instance. + Digest() string +} + +// digester provides a simple digester definition that embeds a hasher. +type digester struct { + alg Algorithm + hash hash.Hash +} + +func (d *digester) Hash() hash.Hash { + return d.hash +} + +func (d *digester) Digest() string { + return string(fmt.Sprintf("%s:%x", d.alg, d.hash.Sum(nil))) +} + +func FromReader(rd io.Reader) (string, error) { + digester := Canonical.New() + + if _, err := io.Copy(digester.Hash(), rd); err != nil { + return "", err + } + + return digester.Digest(), nil +} + +func Payload(data []byte) ([]byte, error) { + jsig, err := libtrust.ParsePrettySignature(data, "signatures") + if err != nil { + return nil, err + } + + // Resolve the payload in the manifest. + return jsig.Payload() +} + +func DigestManifest(data []byte) (string, error) { + p, err := Payload(data) + if err != nil { + if !strings.Contains(err.Error(), "missing signature key") { + return "", err + } + + p = data + } + + digest, err := FromReader(bytes.NewReader(p)) + if err != nil { + return "", err + } + + return digest, err +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/utils/generator.go b/Godeps/_workspace/src/github.com/containerops/wrench/utils/generator.go new file mode 100644 index 0000000..174c8ec --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/utils/generator.go @@ -0,0 +1,3432 @@ +package utils + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "math" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "path" + "path/filepath" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/docker/docker/pkg/archive" + "github.com/docker/docker/pkg/fileutils" + "github.com/docker/docker/pkg/stdcopy" +) + +//****************************************************************// +//auth use type +//****************************************************************// + +// AuthConfiguration represents authentication options to use in the PushImage +// method. It represents the authentication in the Docker index server. +type AuthConfiguration struct { + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + Email string `json:"email,omitempty"` + ServerAddress string `json:"serveraddress,omitempty"` +} + +// AuthConfigurations represents authentication options to use for the +// PushImage method accommodating the new X-Registry-Config header +type AuthConfigurations struct { + Configs map[string]AuthConfiguration `json:"configs"` +} + +// dockerConfig represents a registry authentation configuration from the +// .dockercfg file. +type dockerConfig struct { + Auth string `json:"auth"` + Email string `json:"email"` +} + +//****************************************************************// +//change use type +//****************************************************************// + +// ChangeType is a type for constants indicating the type of change +// in a container +type ChangeType int + +// Change represents a change in a container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#inspect-changes-on-a-container-s-filesystem for more details. +type Change struct { + Path string + Kind ChangeType +} + +//****************************************************************// +//clent use type +//****************************************************************// +// APIVersion is an internal representation of a version of the Remote API. +type APIVersion []int + +// Client is the basic type of this package. It provides methods for +// interaction with the API. +type Client struct { + SkipServerVersionCheck bool + HTTPClient *http.Client + transport *http.Transport + TLSConfig *tls.Config + + endpoint string + endpointURL *url.URL + eventMonitor *eventMonitoringState + requestedAPIVersion APIVersion + serverAPIVersion APIVersion + expectedAPIVersion APIVersion +} + +type doOptions struct { + data interface{} + forceJSON bool +} + +type streamOptions struct { + setRawTerminal bool + rawJSONStream bool + useJSONDecoder bool + headers map[string]string + in io.Reader + stdout io.Writer + stderr io.Writer +} + +type hijackOptions struct { + success chan struct{} + setRawTerminal bool + in io.Reader + stdout io.Writer + stderr io.Writer + data interface{} +} + +type jsonMessage struct { + Status string `json:"status,omitempty"` + Progress string `json:"progress,omitempty"` + Error string `json:"error,omitempty"` + Stream string `json:"stream,omitempty"` +} + +// Error represents failures in the API. It represents a failure from the API. +type Error struct { + Status int + Message string +} + +//****************************************************************// +//Container use type +//****************************************************************// + +// ListContainersOptions specify parameters to the ListContainers function. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#list-containers for more details. +/* +List containers +GET /containers/json +Example request: + GET /containers/json?all=1&before=8dfafdbc3a40&size=1 HTTP/1.1 +Example response: + HTTP/1.1 200 OK + Content-Type: application/json +[ + { + "Id": "8dfafdbc3a40", + "Image": "ubuntu:latest", + "Command": "echo 1", + "Created": 1367854155, + "Status": "Exit 0", + "Ports": [{"PrivatePort": 2222, "PublicPort": 3333, "Type": "tcp"}], + "SizeRw": 12288, + "SizeRootFs": 0 + }, + ...... +] +Query Parameters: + all – 1/True/true or 0/False/false, Show all containers. Only running containers are shown by default (i.e., this defaults to false) + limit – Show limit last created containers, include non-running ones. + since – Show only containers created since Id, include non-running ones. + before – Show only containers created before Id, include non-running ones. + size – 1/True/true or 0/False/false, Show the containers sizes + filters - a JSON encoded value of the filters (a map[string][]string) to process on the containers list. Available filters: + exited=; – containers with exit code of ; + status=(restarting|running|paused|exited) + label=key or key=value of a container label +Status Codes: + 200 – no error + 400 – bad parameter + 500 – server error +*/ +type ListContainersOptions struct { + All bool + Limit int + Since string + Before string + Size bool + Filters map[string][]string + Exited int + Status string + Label string + Key string +} + +// APIPort is a type that represents a port mapping returned by the Docker API +type APIPort struct { + PrivatePort int64 `json:"PrivatePort,omitempty" yaml:"PrivatePort,omitempty"` + PublicPort int64 `json:"PublicPort,omitempty" yaml:"PublicPort,omitempty"` + Type string `json:"Type,omitempty" yaml:"Type,omitempty"` + IP string `json:"IP,omitempty" yaml:"IP,omitempty"` +} + +// APIContainers represents a container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#list-containers for more details. +type APIContainers struct { + ID string `json:"Id" yaml:"Id"` + Image string `json:"Image,omitempty" yaml:"Image,omitempty"` + Command string `json:"Command,omitempty" yaml:"Command,omitempty"` + Created int64 `json:"Created,omitempty" yaml:"Created,omitempty"` + Status string `json:"Status,omitempty" yaml:"Status,omitempty"` + Ports []APIPort `json:"Ports,omitempty" yaml:"Ports,omitempty"` + SizeRw int64 `json:"SizeRw,omitempty" yaml:"SizeRw,omitempty"` + SizeRootFs int64 `json:"SizeRootFs,omitempty" yaml:"SizeRootFs,omitempty"` + Names []string `json:"Names,omitempty" yaml:"Names,omitempty"` //is exist??? +} + +// Port represents the port number and the protocol, in the form +// /. For example: 80/tcp. +type Port string + +// State represents the state of a container. +type State struct { + Running bool `json:"Running,omitempty" yaml:"Running,omitempty"` + Paused bool `json:"Paused,omitempty" yaml:"Paused,omitempty"` + Restarting bool `json:"Restarting,omitempty" yaml:"Restarting,omitempty"` + OOMKilled bool `json:"OOMKilled,omitempty" yaml:"OOMKilled,omitempty"` + Pid int `json:"Pid,omitempty" yaml:"Pid,omitempty"` + ExitCode int `json:"ExitCode,omitempty" yaml:"ExitCode,omitempty"` + Error string `json:"Error,omitempty" yaml:"Error,omitempty"` + StartedAt time.Time `json:"StartedAt,omitempty" yaml:"StartedAt,omitempty"` + FinishedAt time.Time `json:"FinishedAt,omitempty" yaml:"FinishedAt,omitempty"` +} + +// PortBinding represents the host/container port mapping as returned in the +// `docker inspect` json +type PortBinding struct { + HostIP string `json:"HostIP,omitempty" yaml:"HostIP,omitempty"` + HostPort string `json:"HostPort,omitempty" yaml:"HostPort,omitempty"` +} + +// PortMapping represents a deprecated field in the `docker inspect` output, +// and its value as found in NetworkSettings should always be nil +type PortMapping map[string]string + +// NetworkSettings contains network-related information about a container +type NetworkSettings struct { + IPAddress string `json:"IPAddress,omitempty" yaml:"IPAddress,omitempty"` + IPPrefixLen int `json:"IPPrefixLen,omitempty" yaml:"IPPrefixLen,omitempty"` + Gateway string `json:"Gateway,omitempty" yaml:"Gateway,omitempty"` + Bridge string `json:"Bridge,omitempty" yaml:"Bridge,omitempty"` + PortMapping map[string]PortMapping `json:"PortMapping,omitempty" yaml:"PortMapping,omitempty"` + Ports map[Port][]PortBinding `json:"Ports,omitempty" yaml:"Ports,omitempty"` +} + +// Config is the list of configuration options used when creating a container. +// Config does not contain the options that are specific to starting a container on a +// given host. Those are contained in HostConfig +type Config struct { + Hostname string `json:"Hostname,omitempty" yaml:"Hostname,omitempty"` + Domainname string `json:"Domainname,omitempty" yaml:"Domainname,omitempty"` + User string `json:"User,omitempty" yaml:"User,omitempty"` + Memory int64 `json:"Memory,omitempty" yaml:"Memory,omitempty"` + MemorySwap int64 `json:"MemorySwap,omitempty" yaml:"MemorySwap,omitempty"` + CPUShares int64 `json:"CpuShares,omitempty" yaml:"CpuShares,omitempty"` + CPUSet string `json:"Cpuset,omitempty" yaml:"Cpuset,omitempty"` + AttachStdin bool `json:"AttachStdin,omitempty" yaml:"AttachStdin,omitempty"` + AttachStdout bool `json:"AttachStdout,omitempty" yaml:"AttachStdout,omitempty"` + AttachStderr bool `json:"AttachStderr,omitempty" yaml:"AttachStderr,omitempty"` + PortSpecs []string `json:"PortSpecs,omitempty" yaml:"PortSpecs,omitempty"` + ExposedPorts map[Port]struct{} `json:"ExposedPorts,omitempty" yaml:"ExposedPorts,omitempty"` + Tty bool `json:"Tty,omitempty" yaml:"Tty,omitempty"` + OpenStdin bool `json:"OpenStdin,omitempty" yaml:"OpenStdin,omitempty"` + StdinOnce bool `json:"StdinOnce,omitempty" yaml:"StdinOnce,omitempty"` + Env []string `json:"Env,omitempty" yaml:"Env,omitempty"` + Cmd []string `json:"Cmd,omitempty" yaml:"Cmd,omitempty"` + DNS []string `json:"Dns,omitempty" yaml:"Dns,omitempty"` // For Docker API v1.9 and below only + Image string `json:"Image,omitempty" yaml:"Image,omitempty"` + Volumes map[string]struct{} `json:"Volumes,omitempty" yaml:"Volumes,omitempty"` + VolumesFrom string `json:"VolumesFrom,omitempty" yaml:"VolumesFrom,omitempty"` + WorkingDir string `json:"WorkingDir,omitempty" yaml:"WorkingDir,omitempty"` + MacAddress string `json:"MacAddress,omitempty" yaml:"MacAddress,omitempty"` + Entrypoint []string `json:"Entrypoint,omitempty" yaml:"Entrypoint,omitempty"` + NetworkDisabled bool `json:"NetworkDisabled,omitempty" yaml:"NetworkDisabled,omitempty"` + SecurityOpts []string `json:"SecurityOpts,omitempty" yaml:"SecurityOpts,omitempty"` + OnBuild []string `json:"OnBuild,omitempty" yaml:"OnBuild,omitempty"` + Labels map[string]string `json:"Labels,omitempty" yaml:"Labels,omitempty"` +} + +// LogConfig defines the log driver type and the configuration for it. +type LogConfig struct { + Type string `json:"Type,omitempty" yaml:"Type,omitempty"` + Config map[string]string `json:"Config,omitempty" yaml:"Config,omitempty"` +} + +// ULimit defines system-wide resource limitations +// This can help a lot in system administration, e.g. when a user starts too many processes and therefore makes the system unresponsive for other users. +type ULimit struct { + Name string `json:"Name,omitempty" yaml:"Name,omitempty"` + Soft int64 `json:"Soft,omitempty" yaml:"Soft,omitempty"` + Hard int64 `json:"Hard,omitempty" yaml:"Hard,omitempty"` +} + +// SwarmNode containers information about which Swarm node the container is on +type SwarmNode struct { + ID string `json:"ID,omitempty" yaml:"ID,omitempty"` + IP string `json:"IP,omitempty" yaml:"IP,omitempty"` + Addr string `json:"Addr,omitempty" yaml:"Addr,omitempty"` + Name string `json:"Name,omitempty" yaml:"Name,omitempty"` + CPUs int64 `json:"CPUs,omitempty" yaml:"CPUs,omitempty"` + Memory int64 `json:"Memory,omitempty" yaml:"Memory,omitempty"` + Labels map[string]string `json:"Labels,omitempty" yaml:"Labels,omitempty"` +} + +// Container is the type encompasing everything about a container - its config, +// hostconfig, etc. +type Container struct { + ID string `json:"Id" yaml:"Id"` + + Created time.Time `json:"Created,omitempty" yaml:"Created,omitempty"` + + Path string `json:"Path,omitempty" yaml:"Path,omitempty"` + Args []string `json:"Args,omitempty" yaml:"Args,omitempty"` + + Config *Config `json:"Config,omitempty" yaml:"Config,omitempty"` + State State `json:"State,omitempty" yaml:"State,omitempty"` + Image string `json:"Image,omitempty" yaml:"Image,omitempty"` + + Node *SwarmNode `json:"Node,omitempty" yaml:"Node,omitempty"` + + NetworkSettings *NetworkSettings `json:"NetworkSettings,omitempty" yaml:"NetworkSettings,omitempty"` + + SysInitPath string `json:"SysInitPath,omitempty" yaml:"SysInitPath,omitempty"` + ResolvConfPath string `json:"ResolvConfPath,omitempty" yaml:"ResolvConfPath,omitempty"` + HostnamePath string `json:"HostnamePath,omitempty" yaml:"HostnamePath,omitempty"` + HostsPath string `json:"HostsPath,omitempty" yaml:"HostsPath,omitempty"` + Name string `json:"Name,omitempty" yaml:"Name,omitempty"` + Driver string `json:"Driver,omitempty" yaml:"Driver,omitempty"` + + Volumes map[string]string `json:"Volumes,omitempty" yaml:"Volumes,omitempty"` + VolumesRW map[string]bool `json:"VolumesRW,omitempty" yaml:"VolumesRW,omitempty"` + HostConfig *HostConfig `json:"HostConfig,omitempty" yaml:"HostConfig,omitempty"` + ExecIDs []string `json:"ExecIDs,omitempty" yaml:"ExecIDs,omitempty"` + + RestartCount int `json:"RestartCount,omitempty" yaml:"RestartCount,omitempty"` + + AppArmorProfile string `json:"AppArmorProfile,omitempty" yaml:"AppArmorProfile,omitempty"` +} + +// RenameContainerOptions specify parameters to the RenameContainer function. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#rename-a-container for more details. +/* +------------------------------------------------------------------------------------- +Rename a container +POST /containers/(id)/rename +Rename the container id to a new_name +Example request: + POST /containers/e90e34656806/rename?name=new_name HTTP/1.1 +Example response: + HTTP/1.1 204 No Content + +Query Parameters: + name – new name for the container + Status Codes: + 204 – no error + 404 – no such container + 409 - conflict name already assigned + 500 – server error +------------------------------------------------------------------------------------- +*/ + +type RenameContainerOptions struct { + // ID of container to rename + ID string `qs:"-"` + + // New name + Name string `json:"name,omitempty" yaml:"name,omitempty"` +} + +// CreateContainerOptions specify parameters to the CreateContainer function. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#create-a-container for more details. +type CreateContainerOptions struct { + Name string + Config *Config `qs:"-"` + HostConfig *HostConfig +} + +// KeyValuePair is a type for generic key/value pairs as used in the Lxc +// configuration +type KeyValuePair struct { + Key string `json:"Key,omitempty" yaml:"Key,omitempty"` + Value string `json:"Value,omitempty" yaml:"Value,omitempty"` +} + +// RestartPolicy represents the policy for automatically restarting a container. +// +// Possible values are: +// +// - always: the docker daemon will always restart the container +// - on-failure: the docker daemon will restart the container on failures, at +// most MaximumRetryCount times +// - no: the docker daemon will not restart the container automatically +type RestartPolicy struct { + Name string `json:"Name,omitempty" yaml:"Name,omitempty"` + MaximumRetryCount int `json:"MaximumRetryCount,omitempty" yaml:"MaximumRetryCount,omitempty"` +} + +// Device represents a device mapping between the Docker host and the +// container. +type Device struct { + PathOnHost string `json:"PathOnHost,omitempty" yaml:"PathOnHost,omitempty"` + PathInContainer string `json:"PathInContainer,omitempty" yaml:"PathInContainer,omitempty"` + CgroupPermissions string `json:"CgroupPermissions,omitempty" yaml:"CgroupPermissions,omitempty"` +} + +// HostConfig contains the container options related to starting a container on +// a given host +type HostConfig struct { + Binds []string `json:"Binds,omitempty" yaml:"Binds,omitempty"` + CapAdd []string `json:"CapAdd,omitempty" yaml:"CapAdd,omitempty"` + CapDrop []string `json:"CapDrop,omitempty" yaml:"CapDrop,omitempty"` + ContainerIDFile string `json:"ContainerIDFile,omitempty" yaml:"ContainerIDFile,omitempty"` + LxcConf []KeyValuePair `json:"LxcConf,omitempty" yaml:"LxcConf,omitempty"` + Privileged bool `json:"Privileged,omitempty" yaml:"Privileged,omitempty"` + PortBindings map[Port][]PortBinding `json:"PortBindings,omitempty" yaml:"PortBindings,omitempty"` + Links []string `json:"Links,omitempty" yaml:"Links,omitempty"` + PublishAllPorts bool `json:"PublishAllPorts,omitempty" yaml:"PublishAllPorts,omitempty"` + DNS []string `json:"Dns,omitempty" yaml:"Dns,omitempty"` // For Docker API v1.10 and above only + DNSSearch []string `json:"DnsSearch,omitempty" yaml:"DnsSearch,omitempty"` + ExtraHosts []string `json:"ExtraHosts,omitempty" yaml:"ExtraHosts,omitempty"` + VolumesFrom []string `json:"VolumesFrom,omitempty" yaml:"VolumesFrom,omitempty"` + NetworkMode string `json:"NetworkMode,omitempty" yaml:"NetworkMode,omitempty"` + IpcMode string `json:"IpcMode,omitempty" yaml:"IpcMode,omitempty"` + PidMode string `json:"PidMode,omitempty" yaml:"PidMode,omitempty"` + RestartPolicy RestartPolicy `json:"RestartPolicy,omitempty" yaml:"RestartPolicy,omitempty"` + Devices []Device `json:"Devices,omitempty" yaml:"Devices,omitempty"` + LogConfig LogConfig `json:"LogConfig,omitempty" yaml:"LogConfig,omitempty"` + ReadonlyRootfs bool `json:"ReadonlyRootfs,omitempty" yaml:"ReadonlyRootfs,omitempty"` + SecurityOpt []string `json:"SecurityOpt,omitempty" yaml:"SecurityOpt,omitempty"` + CgroupParent string `json:"CgroupParent,omitempty" yaml:"CgroupParent,omitempty"` + Memory int64 `json:"Memory,omitempty" yaml:"Memory,omitempty"` + MemorySwap int64 `json:"MemorySwap,omitempty" yaml:"MemorySwap,omitempty"` + CPUShares int64 `json:"CpuShares,omitempty" yaml:"CpuShares,omitempty"` + CPUSet string `json:"Cpuset,omitempty" yaml:"Cpuset,omitempty"` + CPUQuota int64 `json:"CpuQuota,omitempty" yaml:"CpuQuota,omitempty"` + CPUPeriod int64 `json:"CpuPeriod,omitempty" yaml:"CpuPeriod,omitempty"` + Ulimits []ULimit `json:"Ulimits,omitempty" yaml:"Ulimits,omitempty"` +} + +// TopResult represents the list of processes running in a container, as +// returned by /containers//top. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#list-processes-running-inside-a-container for more details. +type TopResult struct { + Titles []string + Processes [][]string +} + +// Stats represents container statistics, returned by /containers//stats. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#get-container-stats-based-on-resource-usage for more details. +type Stats struct { + Read time.Time `json:"read,omitempty" yaml:"read,omitempty"` + Network struct { + RxDropped uint64 `json:"rx_dropped,omitempty" yaml:"rx_dropped,omitempty"` + RxBytes uint64 `json:"rx_bytes,omitempty" yaml:"rx_bytes,omitempty"` + RxErrors uint64 `json:"rx_errors,omitempty" yaml:"rx_errors,omitempty"` + TxPackets uint64 `json:"tx_packets,omitempty" yaml:"tx_packets,omitempty"` + TxDropped uint64 `json:"tx_dropped,omitempty" yaml:"tx_dropped,omitempty"` + RxPackets uint64 `json:"rx_packets,omitempty" yaml:"rx_packets,omitempty"` + TxErrors uint64 `json:"tx_errors,omitempty" yaml:"tx_errors,omitempty"` + TxBytes uint64 `json:"tx_bytes,omitempty" yaml:"tx_bytes,omitempty"` + } `json:"network,omitempty" yaml:"network,omitempty"` + MemoryStats struct { + Stats struct { + TotalPgmafault uint64 `json:"total_pgmafault,omitempty" yaml:"total_pgmafault,omitempty"` + Cache uint64 `json:"cache,omitempty" yaml:"cache,omitempty"` + MappedFile uint64 `json:"mapped_file,omitempty" yaml:"mapped_file,omitempty"` + TotalInactiveFile uint64 `json:"total_inactive_file,omitempty" yaml:"total_inactive_file,omitempty"` + Pgpgout uint64 `json:"pgpgout,omitempty" yaml:"pgpgout,omitempty"` + Rss uint64 `json:"rss,omitempty" yaml:"rss,omitempty"` + TotalMappedFile uint64 `json:"total_mapped_file,omitempty" yaml:"total_mapped_file,omitempty"` + Writeback uint64 `json:"writeback,omitempty" yaml:"writeback,omitempty"` + Unevictable uint64 `json:"unevictable,omitempty" yaml:"unevictable,omitempty"` + Pgpgin uint64 `json:"pgpgin,omitempty" yaml:"pgpgin,omitempty"` + TotalUnevictable uint64 `json:"total_unevictable,omitempty" yaml:"total_unevictable,omitempty"` + Pgmajfault uint64 `json:"pgmajfault,omitempty" yaml:"pgmajfault,omitempty"` + TotalRss uint64 `json:"total_rss,omitempty" yaml:"total_rss,omitempty"` + TotalRssHuge uint64 `json:"total_rss_huge,omitempty" yaml:"total_rss_huge,omitempty"` + TotalWriteback uint64 `json:"total_writeback,omitempty" yaml:"total_writeback,omitempty"` + TotalInactiveAnon uint64 `json:"total_inactive_anon,omitempty" yaml:"total_inactive_anon,omitempty"` + RssHuge uint64 `json:"rss_huge,omitempty" yaml:"rss_huge,omitempty"` + HierarchicalMemoryLimit uint64 `json:"hierarchical_memory_limit,omitempty" yaml:"hierarchical_memory_limit,omitempty"` + TotalPgfault uint64 `json:"total_pgfault,omitempty" yaml:"total_pgfault,omitempty"` + TotalActiveFile uint64 `json:"total_active_file,omitempty" yaml:"total_active_file,omitempty"` + ActiveAnon uint64 `json:"active_anon,omitempty" yaml:"active_anon,omitempty"` + TotalActiveAnon uint64 `json:"total_active_anon,omitempty" yaml:"total_active_anon,omitempty"` + TotalPgpgout uint64 `json:"total_pgpgout,omitempty" yaml:"total_pgpgout,omitempty"` + TotalCache uint64 `json:"total_cache,omitempty" yaml:"total_cache,omitempty"` + InactiveAnon uint64 `json:"inactive_anon,omitempty" yaml:"inactive_anon,omitempty"` + ActiveFile uint64 `json:"active_file,omitempty" yaml:"active_file,omitempty"` + Pgfault uint64 `json:"pgfault,omitempty" yaml:"pgfault,omitempty"` + InactiveFile uint64 `json:"inactive_file,omitempty" yaml:"inactive_file,omitempty"` + TotalPgpgin uint64 `json:"total_pgpgin,omitempty" yaml:"total_pgpgin,omitempty"` + } `json:"stats,omitempty" yaml:"stats,omitempty"` + MaxUsage uint64 `json:"max_usage,omitempty" yaml:"max_usage,omitempty"` + Usage uint64 `json:"usage,omitempty" yaml:"usage,omitempty"` + Failcnt uint64 `json:"failcnt,omitempty" yaml:"failcnt,omitempty"` + Limit uint64 `json:"limit,omitempty" yaml:"limit,omitempty"` + } `json:"memory_stats,omitempty" yaml:"memory_stats,omitempty"` + BlkioStats struct { + IOServiceBytesRecursive []BlkioStatsEntry `json:"io_service_bytes_recursive,omitempty" yaml:"io_service_bytes_recursive,omitempty"` + IOServicedRecursive []BlkioStatsEntry `json:"io_serviced_recursive,omitempty" yaml:"io_serviced_recursive,omitempty"` + IOQueueRecursive []BlkioStatsEntry `json:"io_queue_recursive,omitempty" yaml:"io_queue_recursive,omitempty"` + IOServiceTimeRecursive []BlkioStatsEntry `json:"io_service_time_recursive,omitempty" yaml:"io_service_time_recursive,omitempty"` + IOWaitTimeRecursive []BlkioStatsEntry `json:"io_wait_time_recursive,omitempty" yaml:"io_wait_time_recursive,omitempty"` + IOMergedRecursive []BlkioStatsEntry `json:"io_merged_recursive,omitempty" yaml:"io_merged_recursive,omitempty"` + IOTimeRecursive []BlkioStatsEntry `json:"io_time_recursive,omitempty" yaml:"io_time_recursive,omitempty"` + SectorsRecursive []BlkioStatsEntry `json:"sectors_recursive,omitempty" yaml:"sectors_recursive,omitempty"` + } `json:"blkio_stats,omitempty" yaml:"blkio_stats,omitempty"` + CPUStats struct { + CPUUsage struct { + PercpuUsage []uint64 `json:"percpu_usage,omitempty" yaml:"percpu_usage,omitempty"` + UsageInUsermode uint64 `json:"usage_in_usermode,omitempty" yaml:"usage_in_usermode,omitempty"` + TotalUsage uint64 `json:"total_usage,omitempty" yaml:"total_usage,omitempty"` + UsageInKernelmode uint64 `json:"usage_in_kernelmode,omitempty" yaml:"usage_in_kernelmode,omitempty"` + } `json:"cpu_usage,omitempty" yaml:"cpu_usage,omitempty"` + SystemCPUUsage uint64 `json:"system_cpu_usage,omitempty" yaml:"system_cpu_usage,omitempty"` + ThrottlingData struct { + Periods uint64 `json:"periods,omitempty"` + ThrottledPeriods uint64 `json:"throttled_periods,omitempty"` + ThrottledTime uint64 `json:"throttled_time,omitempty"` + } `json:"throttling_data,omitempty" yaml:"throttling_data,omitempty"` + } `json:"cpu_stats,omitempty" yaml:"cpu_stats,omitempty"` +} + +// BlkioStatsEntry is a stats entry for blkio_stats +type BlkioStatsEntry struct { + Major uint64 `json:"major,omitempty" yaml:"major,omitempty"` + Minor uint64 `json:"minor,omitempty" yaml:"minor,omitempty"` + Op string `json:"op,omitempty" yaml:"op,omitempty"` + Value uint64 `json:"value,omitempty" yaml:"value,omitempty"` +} + +// StatsOptions specify parameters to the Stats function. +// +// See http://goo.gl/DFMiYD for more details. +type StatsOptions struct { + ID string + Stats chan<- *Stats + Stream bool +} + +// KillContainerOptions represents the set of options that can be used in a +// call to KillContainer. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#kill-a-container for more details. +type KillContainerOptions struct { + // The ID of the container. + ID string `qs:"-"` + + // The signal to send to the container. When omitted, Docker server + // will assume SIGKILL. + Signal Signal +} + +// RemoveContainerOptions encapsulates options to remove a container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#remove-a-container for more details. +type RemoveContainerOptions struct { + // The ID of the container. + ID string `qs:"-"` + + // A flag that indicates whether Docker should remove the volumes + // associated to the container. + RemoveVolumes bool `qs:"v"` + + // A flag that indicates whether Docker should remove the container + // even if it is currently running. + Force bool +} + +// CopyFromContainerOptions is the set of options that can be used when copying +// files or folders from a container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#copy-files-or-folders-from-a-container for more details. +type CopyFromContainerOptions struct { + OutputStream io.Writer `json:"-"` + Container string `json:"-"` + Resource string +} + +// CommitContainerOptions aggregates parameters to the CommitContainer method. +// +// See http://goo.gl/Jn8pe8 for more details. +type CommitContainerOptions struct { + Container string + Repository string `qs:"repo"` + Tag string + Message string `qs:"m"` + Author string + Run *Config `qs:"-"` +} + +// AttachToContainerOptions is the set of options that can be used when +// attaching to a container. +// +// See http://goo.gl/RRAhws for more details. +type AttachToContainerOptions struct { + Container string `qs:"-"` + InputStream io.Reader `qs:"-"` + OutputStream io.Writer `qs:"-"` + ErrorStream io.Writer `qs:"-"` + + // Get container logs, sending it to OutputStream. + Logs bool + + // Stream the response? + Stream bool + + // Attach to stdin, and use InputStream. + Stdin bool + + // Attach to stdout, and use OutputStream. + Stdout bool + + // Attach to stderr, and use ErrorStream. + Stderr bool + + // If set, after a successful connect, a sentinel will be sent and then the + // client will block on receive before continuing. + // + // It must be an unbuffered channel. Using a buffered channel can lead + // to unexpected behavior. + Success chan struct{} + + // Use raw terminal? Usually true when the container contains a TTY. + RawTerminal bool `qs:"-"` +} + +// LogsOptions represents the set of options used when getting logs from a +// container. +// +// See http://goo.gl/rLhKSU for more details. +type LogsOptions struct { + Container string `qs:"-"` + OutputStream io.Writer `qs:"-"` + ErrorStream io.Writer `qs:"-"` + Follow bool + Stdout bool + Stderr bool + Timestamps bool + Tail string + + // Use raw terminal? Usually true when the container contains a TTY. + RawTerminal bool `qs:"-"` +} + +// ExportContainerOptions is the set of parameters to the ExportContainer +// method. +// +// See http://goo.gl/hnzE62 for more details. +type ExportContainerOptions struct { + ID string + OutputStream io.Writer +} + +// NoSuchContainer is the error returned when a given container does not exist. +type NoSuchContainer struct { + ID string + Err error +} + +// ContainerAlreadyRunning is the error returned when a given container is +// already running. +type ContainerAlreadyRunning struct { + ID string +} + +// ContainerNotRunning is the error returned when a given container is not +// running. +type ContainerNotRunning struct { + ID string +} + +//****************************************************************// +//env use type +//****************************************************************// + +// Env represents a list of key-pair represented in the form KEY=VALUE. +type Env []string + +//****************************************************************// +//enent use type +//****************************************************************// + +// APIEvents represents an event returned by the API. +type APIEvents struct { + Status string `json:"Status,omitempty" yaml:"Status,omitempty"` + ID string `json:"ID,omitempty" yaml:"ID,omitempty"` + From string `json:"From,omitempty" yaml:"From,omitempty"` + Time int64 `json:"Time,omitempty" yaml:"Time,omitempty"` +} + +type eventMonitoringState struct { + sync.RWMutex + sync.WaitGroup + enabled bool + lastSeen *int64 + C chan *APIEvents + errC chan error + listeners []chan<- *APIEvents +} + +//****************************************************************// +//exec use type +//****************************************************************// + +// CreateExecOptions specify parameters to the CreateExecContainer function. +// +// See http://goo.gl/8izrzI for more details +type CreateExecOptions struct { + AttachStdin bool `json:"AttachStdin,omitempty" yaml:"AttachStdin,omitempty"` + AttachStdout bool `json:"AttachStdout,omitempty" yaml:"AttachStdout,omitempty"` + AttachStderr bool `json:"AttachStderr,omitempty" yaml:"AttachStderr,omitempty"` + Tty bool `json:"Tty,omitempty" yaml:"Tty,omitempty"` + Cmd []string `json:"Cmd,omitempty" yaml:"Cmd,omitempty"` + Container string `json:"Container,omitempty" yaml:"Container,omitempty"` + User string `json:"User,omitempty" yaml:"User,omitempty"` +} + +// StartExecOptions specify parameters to the StartExecContainer function. +// +// See http://goo.gl/JW8Lxl for more details +type StartExecOptions struct { + Detach bool `json:"Detach,omitempty" yaml:"Detach,omitempty"` + + Tty bool `json:"Tty,omitempty" yaml:"Tty,omitempty"` + + InputStream io.Reader `qs:"-"` + OutputStream io.Writer `qs:"-"` + ErrorStream io.Writer `qs:"-"` + + // Use raw terminal? Usually true when the container contains a TTY. + RawTerminal bool `qs:"-"` + + // If set, after a successful connect, a sentinel will be sent and then the + // client will block on receive before continuing. + // + // It must be an unbuffered channel. Using a buffered channel can lead + // to unexpected behavior. + Success chan struct{} `json:"-"` +} + +// Exec is the type representing a `docker exec` instance and containing the +// instance ID +type Exec struct { + ID string `json:"Id,omitempty" yaml:"Id,omitempty"` +} + +// ExecProcessConfig is a type describing the command associated to a Exec +// instance. It's used in the ExecInspect type. +// +// See http://goo.gl/ypQULN for more details +type ExecProcessConfig struct { + Privileged bool `json:"privileged,omitempty" yaml:"privileged,omitempty"` + User string `json:"user,omitempty" yaml:"user,omitempty"` + Tty bool `json:"tty,omitempty" yaml:"tty,omitempty"` + EntryPoint string `json:"entrypoint,omitempty" yaml:"entrypoint,omitempty"` + Arguments []string `json:"arguments,omitempty" yaml:"arguments,omitempty"` +} + +// ExecInspect is a type with details about a exec instance, including the +// exit code if the command has finished running. It's returned by a api +// call to /exec/(id)/json +// +// See http://goo.gl/ypQULN for more details +type ExecInspect struct { + ID string `json:"ID,omitempty" yaml:"ID,omitempty"` + Running bool `json:"Running,omitempty" yaml:"Running,omitempty"` + ExitCode int `json:"ExitCode,omitempty" yaml:"ExitCode,omitempty"` + OpenStdin bool `json:"OpenStdin,omitempty" yaml:"OpenStdin,omitempty"` + OpenStderr bool `json:"OpenStderr,omitempty" yaml:"OpenStderr,omitempty"` + OpenStdout bool `json:"OpenStdout,omitempty" yaml:"OpenStdout,omitempty"` + ProcessConfig ExecProcessConfig `json:"ProcessConfig,omitempty" yaml:"ProcessConfig,omitempty"` + Container Container `json:"Container,omitempty" yaml:"Container,omitempty"` +} + +// NoSuchExec is the error returned when a given exec instance does not exist. +type NoSuchExec struct { + ID string +} + +//****************************************************************// +//image use type +//****************************************************************// + +// APIImages represent an image returned in the ListImages call. +type APIImages struct { + ID string `json:"Id" yaml:"Id"` + RepoTags []string `json:"RepoTags,omitempty" yaml:"RepoTags,omitempty"` + Created int64 `json:"Created,omitempty" yaml:"Created,omitempty"` + Size int64 `json:"Size,omitempty" yaml:"Size,omitempty"` + VirtualSize int64 `json:"VirtualSize,omitempty" yaml:"VirtualSize,omitempty"` + ParentID string `json:"ParentId,omitempty" yaml:"ParentId,omitempty"` + RepoDigests []string `json:"RepoDigests,omitempty" yaml:"RepoDigests,omitempty"` +} + +// Image is the type representing a docker image and its various properties +type Image struct { + ID string `json:"Id" yaml:"Id"` + Parent string `json:"Parent,omitempty" yaml:"Parent,omitempty"` + Comment string `json:"Comment,omitempty" yaml:"Comment,omitempty"` + Created time.Time `json:"Created,omitempty" yaml:"Created,omitempty"` + Container string `json:"Container,omitempty" yaml:"Container,omitempty"` + ContainerConfig Config `json:"ContainerConfig,omitempty" yaml:"ContainerConfig,omitempty"` + DockerVersion string `json:"DockerVersion,omitempty" yaml:"DockerVersion,omitempty"` + Author string `json:"Author,omitempty" yaml:"Author,omitempty"` + Config *Config `json:"Config,omitempty" yaml:"Config,omitempty"` + Architecture string `json:"Architecture,omitempty" yaml:"Architecture,omitempty"` + Size int64 `json:"Size,omitempty" yaml:"Size,omitempty"` + VirtualSize int64 `json:"VirtualSize,omitempty" yaml:"VirtualSize,omitempty"` +} + +// ImageHistory represent a layer in an image's history returned by the +// ImageHistory call. +type ImageHistory struct { + ID string `json:"Id" yaml:"Id"` + Tags []string `json:"Tags,omitempty" yaml:"Tags,omitempty"` + Created int64 `json:"Created,omitempty" yaml:"Created,omitempty"` + CreatedBy string `json:"CreatedBy,omitempty" yaml:"CreatedBy,omitempty"` + Size int64 `json:"Size,omitempty" yaml:"Size,omitempty"` +} + +// ImagePre012 serves the same purpose as the Image type except that it is for +// earlier versions of the Docker API (pre-012 to be specific) +type ImagePre012 struct { + ID string `json:"id"` + Parent string `json:"parent,omitempty"` + Comment string `json:"comment,omitempty"` + Created time.Time `json:"created"` + Container string `json:"container,omitempty"` + ContainerConfig Config `json:"container_config,omitempty"` + DockerVersion string `json:"docker_version,omitempty"` + Author string `json:"author,omitempty"` + Config *Config `json:"config,omitempty"` + Architecture string `json:"architecture,omitempty"` + Size int64 `json:"size,omitempty"` +} + +// ListImagesOptions specify parameters to the ListImages function. +// +// See http://goo.gl/HRVN1Z for more details. +type ListImagesOptions struct { + All bool + Filters map[string][]string + Digests bool +} + +// RemoveImageOptions present the set of options available for removing an image +// from a registry. +// +// See http://goo.gl/6V48bF for more details. +type RemoveImageOptions struct { + Force bool `qs:"force"` + NoPrune bool `qs:"noprune"` +} + +// PushImageOptions represents options to use in the PushImage method. +// +// See http://goo.gl/pN8A3P for more details. +type PushImageOptions struct { + // Name of the image + Name string + + // Tag of the image + Tag string + + // Registry server to push the image + Registry string + + OutputStream io.Writer `qs:"-"` + RawJSONStream bool `qs:"-"` +} + +// PullImageOptions present the set of options available for pulling an image +// from a registry. +// +// See http://goo.gl/ACyYNS for more details. +type PullImageOptions struct { + Repository string `qs:"fromImage"` + Registry string + Tag string + OutputStream io.Writer `qs:"-"` + RawJSONStream bool `qs:"-"` +} + +// LoadImageOptions represents the options for LoadImage Docker API Call +// +// See http://goo.gl/Y8NNCq for more details. +type LoadImageOptions struct { + InputStream io.Reader +} + +// ExportImageOptions represent the options for ExportImage Docker API call +// +// See http://goo.gl/mi6kvk for more details. +type ExportImageOptions struct { + Name string + OutputStream io.Writer +} + +// ExportImagesOptions represent the options for ExportImages Docker API call +// +// See http://goo.gl/YeZzQK for more details. +type ExportImagesOptions struct { + Names []string + OutputStream io.Writer `qs:"-"` +} + +// ImportImageOptions present the set of informations available for importing +// an image from a source file or the stdin. +// +// See http://goo.gl/PhBKnS for more details. +type ImportImageOptions struct { + Repository string `qs:"repo"` + Source string `qs:"fromSrc"` + Tag string `qs:"tag"` + + InputStream io.Reader `qs:"-"` + OutputStream io.Writer `qs:"-"` + RawJSONStream bool `qs:"-"` +} + +// BuildImageOptions present the set of informations available for building an +// image from a tarfile with a Dockerfile in it. +// +// For more details about the Docker building process, see +// http://goo.gl/tlPXPu. +type BuildImageOptions struct { + Name string `qs:"t"` + Dockerfile string `qs:"dockerfile"` + NoCache bool `qs:"nocache"` + SuppressOutput bool `qs:"q"` + Pull bool `qs:"pull"` + RmTmpContainer bool `qs:"rm"` + ForceRmTmpContainer bool `qs:"forcerm"` + Memory int64 `qs:"memory"` + Memswap int64 `qs:"memswap"` + CPUShares int64 `qs:"cpushares"` + CPUSetCPUs string `qs:"cpusetcpus"` + InputStream io.Reader `qs:"-"` + OutputStream io.Writer `qs:"-"` + RawJSONStream bool `qs:"-"` + Remote string `qs:"remote"` + Auth AuthConfiguration `qs:"-"` // for older docker X-Registry-Auth header + AuthConfigs AuthConfigurations `qs:"-"` // for newer docker X-Registry-Config header + ContextDir string `qs:"-"` +} + +// TagImageOptions present the set of options to tag an image. +// +// See http://goo.gl/5g6qFy for more details. +type TagImageOptions struct { + Repo string + Tag string + Force bool +} + +// APIImageSearch reflect the result of a search on the dockerHub +// +// See http://goo.gl/xI5lLZ for more details. +type APIImageSearch struct { + Description string `json:"description,omitempty" yaml:"description,omitempty"` + IsOfficial bool `json:"is_official,omitempty" yaml:"is_official,omitempty"` + IsAutomated bool `json:"is_automated,omitempty" yaml:"is_automated,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + StarCount int `json:"star_count,omitempty" yaml:"star_count,omitempty"` +} + +//****************************************************************// +//signal use type +//****************************************************************// + +// Signal represents a signal that can be send to the container on +// KillContainer call. +type Signal int + +//****************************************************************// +//tls use type +//****************************************************************// + +type tlsClientCon struct { + *tls.Conn + rawConn net.Conn +} + +//****************************************************************// +//change use const +//****************************************************************// + +const ( + // ChangeModify is the ChangeType for container modifications + ChangeModify ChangeType = iota + + // ChangeAdd is the ChangeType for additions to a container + ChangeAdd + + // ChangeDelete is the ChangeType for deletions from a container + ChangeDelete +) + +//****************************************************************// +//client use const +//****************************************************************// + +const generatorAgent = "go-generator-client" + +//****************************************************************// +//event use const +//****************************************************************// + +const ( + maxMonitorConnRetries = 5 + retryInitialWaitTime = 10. +) + +//****************************************************************// +//signal use const +//****************************************************************// + +// These values represent all signals available on Linux, where containers will +// be running. +const ( + SIGABRT = Signal(0x6) + SIGALRM = Signal(0xe) + SIGBUS = Signal(0x7) + SIGCHLD = Signal(0x11) + SIGCLD = Signal(0x11) + SIGCONT = Signal(0x12) + SIGFPE = Signal(0x8) + SIGHUP = Signal(0x1) + SIGILL = Signal(0x4) + SIGINT = Signal(0x2) + SIGIO = Signal(0x1d) + SIGIOT = Signal(0x6) + SIGKILL = Signal(0x9) + SIGPIPE = Signal(0xd) + SIGPOLL = Signal(0x1d) + SIGPROF = Signal(0x1b) + SIGPWR = Signal(0x1e) + SIGQUIT = Signal(0x3) + SIGSEGV = Signal(0xb) + SIGSTKFLT = Signal(0x10) + SIGSTOP = Signal(0x13) + SIGSYS = Signal(0x1f) + SIGTERM = Signal(0xf) + SIGTRAP = Signal(0x5) + SIGTSTP = Signal(0x14) + SIGTTIN = Signal(0x15) + SIGTTOU = Signal(0x16) + SIGUNUSED = Signal(0x1f) + SIGURG = Signal(0x17) + SIGUSR1 = Signal(0xa) + SIGUSR2 = Signal(0xc) + SIGVTALRM = Signal(0x1a) + SIGWINCH = Signal(0x1c) + SIGXCPU = Signal(0x18) + SIGXFSZ = Signal(0x19) +) + +//****************************************************************// +//client use var +//****************************************************************// + +var ( + // ErrInvalidEndpoint is returned when the endpoint is not a valid HTTP URL. + ErrInvalidEndpoint = errors.New("invalid endpoint") + + // ErrConnectionRefused is returned when the client cannot connect to the given endpoint. + ErrConnectionRefused = errors.New("cannot connect to Docker endpoint") + + apiVersion112, _ = NewAPIVersion("1.12") +) + +//****************************************************************// +//Container use var +//****************************************************************// + +// ErrContainerAlreadyExists is the error returned by CreateContainer when the container already exists. +var ErrContainerAlreadyExists = errors.New("container already exists") + +//****************************************************************// +//event use var +//****************************************************************// + +var ( + // ErrNoListeners is the error returned when no listeners are available + // to receive an event. + ErrNoListeners = errors.New("no listeners present to receive event") + + // ErrListenerAlreadyExists is the error returned when the listerner already + // exists. + ErrListenerAlreadyExists = errors.New("listener already exists for docker events") + + // EOFEvent is sent when the event listener receives an EOF error. + EOFEvent = &APIEvents{ + Status: "EOF", + } +) + +//****************************************************************// +//image use var +//****************************************************************// + +var ( + // ErrNoSuchImage is the error returned when the image does not exist. + ErrNoSuchImage = errors.New("no such image") + + // ErrMissingRepo is the error returned when the remote repository is + // missing. + ErrMissingRepo = errors.New("missing remote repository e.g. 'github.com/user/repo'") + + // ErrMissingOutputStream is the error returned when no output stream + // is provided to some calls, like BuildImage. + ErrMissingOutputStream = errors.New("missing output stream") + + // ErrMultipleContexts is the error returned when both a ContextDir and + // InputStream are provided in BuildImageOptions + ErrMultipleContexts = errors.New("image build may not be provided BOTH context dir and input stream") + + // ErrMustSpecifyNames is the error rreturned when the Names field on + // ExportImagesOptions is nil or empty + ErrMustSpecifyNames = errors.New("must specify at least one name to export") +) + +//****************************************************************// +//auth need func +//****************************************************************// + +// NewAuthConfigurationsFromDockerCfg returns AuthConfigurations from the +// ~/.dockercfg file. +func NewAuthConfigurationsFromDockerCfg() (*AuthConfigurations, error) { + p := path.Join(os.Getenv("HOME"), ".dockercfg") + r, err := os.Open(p) + if err != nil { + return nil, err + } + return NewAuthConfigurations(r) +} + +// NewAuthConfigurations returns AuthConfigurations from a JSON encoded string in the +// same format as the .dockercfg file. +func NewAuthConfigurations(r io.Reader) (*AuthConfigurations, error) { + var auth *AuthConfigurations + var confs map[string]dockerConfig + if err := json.NewDecoder(r).Decode(&confs); err != nil { + return nil, err + } + auth, err := authConfigs(confs) + if err != nil { + return nil, err + } + return auth, nil +} + +// authConfigs converts a dockerConfigs map to a AuthConfigurations object. +func authConfigs(confs map[string]dockerConfig) (*AuthConfigurations, error) { + c := &AuthConfigurations{ + Configs: make(map[string]AuthConfiguration), + } + for reg, conf := range confs { + data, err := base64.StdEncoding.DecodeString(conf.Auth) + if err != nil { + return nil, err + } + userpass := strings.Split(string(data), ":") + c.Configs[reg] = AuthConfiguration{ + Email: conf.Email, + Username: userpass[0], + Password: userpass[1], + ServerAddress: reg, + } + } + return c, nil +} + +// AuthCheck validates the given credentials. It returns nil if successful. +// +// See https://goo.gl/vPoEfJ for more details. +func (c *Client) AuthCheck(conf *AuthConfiguration) error { + if conf == nil { + return fmt.Errorf("conf is nil") + } + body, statusCode, err := c.do("POST", "/auth", doOptions{data: conf}) + if err != nil { + return err + } + if statusCode > 400 { + return fmt.Errorf("auth error (%d): %s", statusCode, body) + } + return nil +} + +//****************************************************************// +//change need func +//****************************************************************// + +func (change *Change) String() string { + var kind string + switch change.Kind { + case ChangeModify: + kind = "C" + case ChangeAdd: + kind = "A" + case ChangeDelete: + kind = "D" + } + return fmt.Sprintf("%s %s", kind, change.Path) +} + +//****************************************************************// +//client need func +//****************************************************************// + +// NewAPIVersion returns an instance of APIVersion for the given string. +// +// The given string must be in the form .., where , +// and are integer numbers. +func NewAPIVersion(input string) (APIVersion, error) { + if !strings.Contains(input, ".") { + return nil, fmt.Errorf("Unable to parse version %q", input) + } + arr := strings.Split(input, ".") + ret := make(APIVersion, len(arr)) + var err error + for i, val := range arr { + ret[i], err = strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("Unable to parse version %q: %q is not an integer", input, val) + } + } + return ret, nil +} + +func (version APIVersion) String() string { + var str string + for i, val := range version { + str += strconv.Itoa(val) + if i < len(version)-1 { + str += "." + } + } + return str +} + +// LessThan is a function for comparing APIVersion structs +func (version APIVersion) LessThan(other APIVersion) bool { + return version.compare(other) < 0 +} + +// LessThanOrEqualTo is a function for comparing APIVersion structs +func (version APIVersion) LessThanOrEqualTo(other APIVersion) bool { + return version.compare(other) <= 0 +} + +// GreaterThan is a function for comparing APIVersion structs +func (version APIVersion) GreaterThan(other APIVersion) bool { + return version.compare(other) > 0 +} + +// GreaterThanOrEqualTo is a function for comparing APIVersion structs +func (version APIVersion) GreaterThanOrEqualTo(other APIVersion) bool { + return version.compare(other) >= 0 +} + +func (version APIVersion) compare(other APIVersion) int { + for i, v := range version { + if i <= len(other)-1 { + otherVersion := other[i] + + if v < otherVersion { + return -1 + } else if v > otherVersion { + return 1 + } + } + } + if len(version) > len(other) { + return 1 + } + if len(version) < len(other) { + return -1 + } + return 0 +} + +// NewClient returns a Client instance ready for communication with the given +// server endpoint. It will use the latest remote API version available in the +// server. +func NewClient(endpoint string) (*Client, error) { + client, err := NewVersionedClient(endpoint, "") + if err != nil { + return nil, err + } + client.SkipServerVersionCheck = true + return client, nil +} + +// NewTLSClient returns a Client instance ready for TLS communications with the givens +// server endpoint, key and certificates . It will use the latest remote API version +// available in the server. +func NewTLSClient(endpoint string, cert, key, ca string) (*Client, error) { + client, err := NewVersionedTLSClient(endpoint, cert, key, ca, "") + if err != nil { + return nil, err + } + client.SkipServerVersionCheck = true + return client, nil +} + +// NewTLSClientFromBytes returns a Client instance ready for TLS communications with the givens +// server endpoint, key and certificates (passed inline to the function as opposed to being +// read from a local file). It will use the latest remote API version available in the server. +func NewTLSClientFromBytes(endpoint string, certPEMBlock, keyPEMBlock, caPEMCert []byte) (*Client, error) { + client, err := NewVersionedTLSClientFromBytes(endpoint, certPEMBlock, keyPEMBlock, caPEMCert, "") + if err != nil { + return nil, err + } + client.SkipServerVersionCheck = true + return client, nil +} + +// NewVersionedClient returns a Client instance ready for communication with +// the given server endpoint, using a specific remote API version. +func NewVersionedClient(endpoint string, apiVersionString string) (*Client, error) { + u, err := parseEndpoint(endpoint, false) + if err != nil { + return nil, err + } + var requestedAPIVersion APIVersion + if strings.Contains(apiVersionString, ".") { + requestedAPIVersion, err = NewAPIVersion(apiVersionString) + if err != nil { + return nil, err + } + } + + tr := &http.Transport{} + return &Client{ + HTTPClient: http.DefaultClient, + transport: tr, + endpoint: endpoint, + endpointURL: u, + eventMonitor: new(eventMonitoringState), + requestedAPIVersion: requestedAPIVersion, + }, nil +} + +// NewVersionnedTLSClient has been DEPRECATED, please use NewVersionedTLSClient. +func NewVersionnedTLSClient(endpoint string, cert, key, ca, apiVersionString string) (*Client, error) { + return NewVersionedTLSClient(endpoint, cert, key, ca, apiVersionString) +} + +// NewVersionedTLSClient returns a Client instance ready for TLS communications with the givens +// server endpoint, key and certificates, using a specific remote API version. +func NewVersionedTLSClient(endpoint string, cert, key, ca, apiVersionString string) (*Client, error) { + certPEMBlock, err := ioutil.ReadFile(cert) + if err != nil { + return nil, err + } + keyPEMBlock, err := ioutil.ReadFile(key) + if err != nil { + return nil, err + } + caPEMCert, err := ioutil.ReadFile(ca) + if err != nil { + return nil, err + } + return NewVersionedTLSClientFromBytes(endpoint, certPEMBlock, keyPEMBlock, caPEMCert, apiVersionString) +} + +// NewVersionedTLSClientFromBytes returns a Client instance ready for TLS communications with the givens +// server endpoint, key and certificates (passed inline to the function as opposed to being +// read from a local file), using a specific remote API version. +func NewVersionedTLSClientFromBytes(endpoint string, certPEMBlock, keyPEMBlock, caPEMCert []byte, apiVersionString string) (*Client, error) { + u, err := parseEndpoint(endpoint, true) + if err != nil { + return nil, err + } + var requestedAPIVersion APIVersion + if strings.Contains(apiVersionString, ".") { + requestedAPIVersion, err = NewAPIVersion(apiVersionString) + if err != nil { + return nil, err + } + } + if certPEMBlock == nil || keyPEMBlock == nil { + return nil, errors.New("Both cert and key are required") + } + tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return nil, err + } + tlsConfig := &tls.Config{Certificates: []tls.Certificate{tlsCert}} + if caPEMCert == nil { + tlsConfig.InsecureSkipVerify = true + } else { + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(caPEMCert) { + return nil, errors.New("Could not add RootCA pem") + } + tlsConfig.RootCAs = caPool + } + tr := &http.Transport{ + TLSClientConfig: tlsConfig, + } + if err != nil { + return nil, err + } + return &Client{ + HTTPClient: &http.Client{Transport: tr}, + transport: tr, + TLSConfig: tlsConfig, + endpoint: endpoint, + endpointURL: u, + eventMonitor: new(eventMonitoringState), + requestedAPIVersion: requestedAPIVersion, + }, nil +} + +func (c *Client) checkAPIVersion() error { + serverAPIVersionString, err := c.getServerAPIVersionString() + if err != nil { + return err + } + c.serverAPIVersion, err = NewAPIVersion(serverAPIVersionString) + if err != nil { + return err + } + if c.requestedAPIVersion == nil { + c.expectedAPIVersion = c.serverAPIVersion + } else { + c.expectedAPIVersion = c.requestedAPIVersion + } + return nil +} + +// Ping pings the docker server +// +// See http://goo.gl/stJENm for more details. +func (c *Client) Ping() error { + path := "/_ping" + body, status, err := c.do("GET", path, doOptions{}) + if err != nil { + return err + } + if status != http.StatusOK { + return newError(status, body) + } + return nil +} + +func (c *Client) getServerAPIVersionString() (version string, err error) { + body, status, err := c.do("GET", "/version", doOptions{}) + if err != nil { + return "", err + } + if status != http.StatusOK { + return "", fmt.Errorf("Received unexpected status %d while trying to retrieve the server version", status) + } + var versionResponse map[string]interface{} + err = json.Unmarshal(body, &versionResponse) + if err != nil { + return "", err + } + if version, ok := (versionResponse["ApiVersion"]).(string); ok { + return version, nil + } + return "", nil +} + +func (c *Client) do(method, path string, doOptions doOptions) ([]byte, int, error) { + var params io.Reader + if doOptions.data != nil || doOptions.forceJSON { + buf, err := json.Marshal(doOptions.data) + if err != nil { + return nil, -1, err + } + params = bytes.NewBuffer(buf) + } + if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { + err := c.checkAPIVersion() + if err != nil { + return nil, -1, err + } + } + req, err := http.NewRequest(method, c.getURL(path), params) + if err != nil { + return nil, -1, err + } + req.Header.Set("User-Agent", generatorAgent) + if doOptions.data != nil { + req.Header.Set("Content-Type", "application/json") + } else if method == "POST" { + req.Header.Set("Content-Type", "plain/text") + } + var resp *http.Response + protocol := c.endpointURL.Scheme + address := c.endpointURL.Path + if protocol == "unix" { + dial, err := net.Dial(protocol, address) + if err != nil { + return nil, -1, err + } + defer dial.Close() + breader := bufio.NewReader(dial) + err = req.Write(dial) + if err != nil { + return nil, -1, err + } + resp, err = http.ReadResponse(breader, req) + } else { + resp, err = c.HTTPClient.Do(req) + } + if err != nil { + if strings.Contains(err.Error(), "connection refused") { + return nil, -1, ErrConnectionRefused + } + return nil, -1, err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, -1, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 400 { + return nil, resp.StatusCode, newError(resp.StatusCode, body) + } + return body, resp.StatusCode, nil +} + +func (c *Client) stream(method, path string, streamOptions streamOptions) error { + if (method == "POST" || method == "PUT") && streamOptions.in == nil { + streamOptions.in = bytes.NewReader(nil) + } + if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { + err := c.checkAPIVersion() + if err != nil { + return err + } + } + req, err := http.NewRequest(method, c.getURL(path), streamOptions.in) + if err != nil { + return err + } + req.Header.Set("User-Agent", generatorAgent) + if method == "POST" { + req.Header.Set("Content-Type", "plain/text") + } + for key, val := range streamOptions.headers { + req.Header.Set(key, val) + } + var resp *http.Response + protocol := c.endpointURL.Scheme + address := c.endpointURL.Path + if streamOptions.stdout == nil { + streamOptions.stdout = ioutil.Discard + } + if streamOptions.stderr == nil { + streamOptions.stderr = ioutil.Discard + } + if protocol == "unix" { + dial, err := net.Dial(protocol, address) + if err != nil { + return err + } + defer dial.Close() + breader := bufio.NewReader(dial) + err = req.Write(dial) + if err != nil { + return err + } + if resp, err = http.ReadResponse(breader, req); err != nil { + if strings.Contains(err.Error(), "connection refused") { + return ErrConnectionRefused + } + return err + } + defer resp.Body.Close() + } else { + if resp, err = c.HTTPClient.Do(req); err != nil { + if strings.Contains(err.Error(), "connection refused") { + return ErrConnectionRefused + } + return err + } + defer resp.Body.Close() + defer c.transport.CancelRequest(req) + } + if resp.StatusCode < 200 || resp.StatusCode >= 400 { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + return newError(resp.StatusCode, body) + } + if streamOptions.useJSONDecoder || resp.Header.Get("Content-Type") == "application/json" { + // if we want to get raw json stream, just copy it back to output + // without decoding it + if streamOptions.rawJSONStream { + _, err = io.Copy(streamOptions.stdout, resp.Body) + return err + } + dec := json.NewDecoder(resp.Body) + for { + var m jsonMessage + if err := dec.Decode(&m); err == io.EOF { + break + } else if err != nil { + return err + } + if m.Stream != "" { + fmt.Fprint(streamOptions.stdout, m.Stream) + } else if m.Progress != "" { + fmt.Fprintf(streamOptions.stdout, "%s %s\r", m.Status, m.Progress) + } else if m.Error != "" { + return errors.New(m.Error) + } + if m.Status != "" { + fmt.Fprintln(streamOptions.stdout, m.Status) + } + } + } else { + if streamOptions.setRawTerminal { + _, err = io.Copy(streamOptions.stdout, resp.Body) + } else { + _, err = stdcopy.StdCopy(streamOptions.stdout, streamOptions.stderr, resp.Body) + } + return err + } + return nil +} + +func (c *Client) hijack(method, path string, hijackOptions hijackOptions) error { + if path != "/version" && !c.SkipServerVersionCheck && c.expectedAPIVersion == nil { + err := c.checkAPIVersion() + if err != nil { + return err + } + } + + var params io.Reader + if hijackOptions.data != nil { + buf, err := json.Marshal(hijackOptions.data) + if err != nil { + return err + } + params = bytes.NewBuffer(buf) + } + + if hijackOptions.stdout == nil { + hijackOptions.stdout = ioutil.Discard + } + if hijackOptions.stderr == nil { + hijackOptions.stderr = ioutil.Discard + } + req, err := http.NewRequest(method, c.getURL(path), params) + if err != nil { + return err + } + req.Header.Set("Content-Type", "plain/text") + protocol := c.endpointURL.Scheme + address := c.endpointURL.Path + if protocol != "unix" { + protocol = "tcp" + address = c.endpointURL.Host + } + var dial net.Conn + if c.TLSConfig != nil && protocol != "unix" { + dial, err = tlsDial(protocol, address, c.TLSConfig) + if err != nil { + return err + } + } else { + dial, err = net.Dial(protocol, address) + if err != nil { + return err + } + } + clientconn := httputil.NewClientConn(dial, nil) + defer clientconn.Close() + clientconn.Do(req) + if hijackOptions.success != nil { + hijackOptions.success <- struct{}{} + <-hijackOptions.success + } + rwc, br := clientconn.Hijack() + defer rwc.Close() + errChanOut := make(chan error, 1) + errChanIn := make(chan error, 1) + exit := make(chan bool) + go func() { + defer close(exit) + defer close(errChanOut) + var err error + if hijackOptions.setRawTerminal { + // When TTY is ON, use regular copy + _, err = io.Copy(hijackOptions.stdout, br) + } else { + _, err = stdcopy.StdCopy(hijackOptions.stdout, hijackOptions.stderr, br) + } + errChanOut <- err + }() + go func() { + if hijackOptions.in != nil { + _, err := io.Copy(rwc, hijackOptions.in) + errChanIn <- err + } + rwc.(interface { + CloseWrite() error + }).CloseWrite() + }() + <-exit + select { + case err = <-errChanIn: + return err + case err = <-errChanOut: + return err + } +} + +func (c *Client) getURL(path string) string { + urlStr := strings.TrimRight(c.endpointURL.String(), "/") + if c.endpointURL.Scheme == "unix" { + urlStr = "" + } + + if c.requestedAPIVersion != nil { + return fmt.Sprintf("%s/v%s%s", urlStr, c.requestedAPIVersion, path) + } + return fmt.Sprintf("%s%s", urlStr, path) +} + +func queryString(opts interface{}) string { + if opts == nil { + return "" + } + value := reflect.ValueOf(opts) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if value.Kind() != reflect.Struct { + return "" + } + items := url.Values(map[string][]string{}) + for i := 0; i < value.NumField(); i++ { + field := value.Type().Field(i) + if field.PkgPath != "" { + continue + } + key := field.Tag.Get("qs") + if key == "" { + key = strings.ToLower(field.Name) + } else if key == "-" { + continue + } + addQueryStringValue(items, key, value.Field(i)) + } + return items.Encode() +} + +func addQueryStringValue(items url.Values, key string, v reflect.Value) { + switch v.Kind() { + case reflect.Bool: + if v.Bool() { + items.Add(key, "1") + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if v.Int() > 0 { + items.Add(key, strconv.FormatInt(v.Int(), 10)) + } + case reflect.Float32, reflect.Float64: + if v.Float() > 0 { + items.Add(key, strconv.FormatFloat(v.Float(), 'f', -1, 64)) + } + case reflect.String: + if v.String() != "" { + items.Add(key, v.String()) + } + case reflect.Ptr: + if !v.IsNil() { + if b, err := json.Marshal(v.Interface()); err == nil { + items.Add(key, string(b)) + } + } + case reflect.Map: + if len(v.MapKeys()) > 0 { + if b, err := json.Marshal(v.Interface()); err == nil { + items.Add(key, string(b)) + } + } + case reflect.Array, reflect.Slice: + vLen := v.Len() + if vLen > 0 { + for i := 0; i < vLen; i++ { + addQueryStringValue(items, key, v.Index(i)) + } + } + } +} + +func newError(status int, body []byte) *Error { + return &Error{Status: status, Message: string(body)} +} + +func (e *Error) Error() string { + return fmt.Sprintf("API error (%d): %s", e.Status, e.Message) +} + +func parseEndpoint(endpoint string, tls bool) (*url.URL, error) { + u, err := url.Parse(endpoint) + if err != nil { + return nil, ErrInvalidEndpoint + } + if tls { + u.Scheme = "https" + } + switch u.Scheme { + case "unix": + return u, nil + case "http", "https", "tcp": + _, port, err := net.SplitHostPort(u.Host) + if err != nil { + if e, ok := err.(*net.AddrError); ok { + if e.Err == "missing port in address" { + return u, nil + } + } + return nil, ErrInvalidEndpoint + } + number, err := strconv.ParseInt(port, 10, 64) + if err == nil && number > 0 && number < 65536 { + if u.Scheme == "tcp" { + if number == 2376 { + u.Scheme = "https" + } else { + u.Scheme = "http" + } + } + return u, nil + } + return nil, ErrInvalidEndpoint + default: + return nil, ErrInvalidEndpoint + } +} + +//****************************************************************// +//container need func +//****************************************************************// + +// ListContainers returns a slice of containers matching the given criteria. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#list-containers for more details. +func (c *Client) ListContainers(opts ListContainersOptions) ([]APIContainers, error) { + path := "/containers/json?" + queryString(opts) + body, _, err := c.do("GET", path, doOptions{}) + if err != nil { + return nil, err + } + var containers []APIContainers + err = json.Unmarshal(body, &containers) + if err != nil { + return nil, err + } + return containers, nil +} + +// Port returns the number of the port. +func (p Port) Port() string { + return strings.Split(string(p), "/")[0] +} + +// Proto returns the name of the protocol. +func (p Port) Proto() string { + parts := strings.Split(string(p), "/") + if len(parts) == 1 { + return "tcp" + } + return parts[1] +} + +// String returns the string representation of a state. +func (s *State) String() string { + if s.Running { + if s.Paused { + return "paused" + } + return fmt.Sprintf("Up %s", time.Now().UTC().Sub(s.StartedAt)) + } + return fmt.Sprintf("Exit %d", s.ExitCode) +} + +// PortMappingAPI translates the port mappings as contained in NetworkSettings +// into the format in which they would appear when returned by the API +func (settings *NetworkSettings) PortMappingAPI() []APIPort { + var mapping []APIPort + for port, bindings := range settings.Ports { + p, _ := parsePort(port.Port()) + if len(bindings) == 0 { + mapping = append(mapping, APIPort{ + PublicPort: int64(p), + Type: port.Proto(), + }) + continue + } + for _, binding := range bindings { + p, _ := parsePort(port.Port()) + h, _ := parsePort(binding.HostPort) + mapping = append(mapping, APIPort{ + PrivatePort: int64(p), + PublicPort: int64(h), + Type: port.Proto(), + IP: binding.HostIP, + }) + } + } + return mapping +} + +func parsePort(rawPort string) (int, error) { + port, err := strconv.ParseUint(rawPort, 10, 16) + if err != nil { + return 0, err + } + return int(port), nil +} + +// RenameContainer updates and existing containers name +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#rename-a-container for more details. +func (c *Client) RenameContainer(opts RenameContainerOptions) error { + _, _, err := c.do("POST", fmt.Sprintf("/containers/"+opts.ID+"/rename?%s", queryString(opts)), doOptions{}) + return err +} + +// InspectContainer returns information about a container by its ID. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#inspect-a-container for more details. +func (c *Client) InspectContainer(id string) (*Container, error) { + path := "/containers/" + id + "/json" + body, status, err := c.do("GET", path, doOptions{}) + if status == http.StatusNotFound { + return nil, &NoSuchContainer{ID: id} + } + if err != nil { + return nil, err + } + var container Container + err = json.Unmarshal(body, &container) + if err != nil { + return nil, err + } + return &container, nil +} + +// ContainerChanges returns changes in the filesystem of the given container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#inspect-changes-on-a-container-s-filesystem for more details. +func (c *Client) ContainerChanges(id string) ([]Change, error) { + path := "/containers/" + id + "/changes" + body, status, err := c.do("GET", path, doOptions{}) + if status == http.StatusNotFound { + return nil, &NoSuchContainer{ID: id} + } + if err != nil { + return nil, err + } + var changes []Change + err = json.Unmarshal(body, &changes) + if err != nil { + return nil, err + } + return changes, nil +} + +// CreateContainer creates a new container, returning the container instance, +// or an error in case of failure. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#create-a-container for more details. +func (c *Client) CreateContainer(opts CreateContainerOptions) (*Container, error) { + path := "/containers/create?" + queryString(opts) + body, status, err := c.do( + "POST", + path, + doOptions{ + data: struct { + *Config + HostConfig *HostConfig `json:"HostConfig,omitempty" yaml:"HostConfig,omitempty"` + }{ + opts.Config, + opts.HostConfig, + }, + }, + ) + + if status == http.StatusNotFound { + return nil, ErrNoSuchImage + } + if status == http.StatusConflict { + return nil, ErrContainerAlreadyExists + } + if err != nil { + return nil, err + } + var container Container + err = json.Unmarshal(body, &container) + if err != nil { + return nil, err + } + + container.Name = opts.Name + + return &container, nil +} + +// AlwaysRestart returns a restart policy that tells the Docker daemon to +// always restart the container. +func AlwaysRestart() RestartPolicy { + return RestartPolicy{Name: "always"} +} + +// RestartOnFailure returns a restart policy that tells the Docker daemon to +// restart the container on failures, trying at most maxRetry times. +func RestartOnFailure(maxRetry int) RestartPolicy { + return RestartPolicy{Name: "on-failure", MaximumRetryCount: maxRetry} +} + +// NeverRestart returns a restart policy that tells the Docker daemon to never +// restart the container on failures. +func NeverRestart() RestartPolicy { + return RestartPolicy{Name: "no"} +} + +// StartContainer starts a container, returning an error in case of failure. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#start-a-container for more details. +func (c *Client) StartContainer(id string, hostConfig *HostConfig) error { + path := "/containers/" + id + "/start" + _, status, err := c.do("POST", path, doOptions{data: hostConfig, forceJSON: true}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: id, Err: err} + } + if status == http.StatusNotModified { + return &ContainerAlreadyRunning{ID: id} + } + if err != nil { + return err + } + return nil +} + +// StopContainer stops a container, killing it after the given timeout (in +// seconds). +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#stop-a-container for more details. +func (c *Client) StopContainer(id string, timeout uint) error { + path := fmt.Sprintf("/containers/%s/stop?t=%d", id, timeout) + _, status, err := c.do("POST", path, doOptions{}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: id} + } + if status == http.StatusNotModified { + return &ContainerNotRunning{ID: id} + } + if err != nil { + return err + } + return nil +} + +// RestartContainer stops a container, killing it after the given timeout (in +// seconds), during the stop process. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#restart-a-container for more details. +func (c *Client) RestartContainer(id string, timeout uint) error { + path := fmt.Sprintf("/containers/%s/restart?t=%d", id, timeout) + _, status, err := c.do("POST", path, doOptions{}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: id} + } + if err != nil { + return err + } + return nil +} + +// PauseContainer pauses the given container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#pause-a-container for more details. +func (c *Client) PauseContainer(id string) error { + path := fmt.Sprintf("/containers/%s/pause", id) + _, status, err := c.do("POST", path, doOptions{}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: id} + } + if err != nil { + return err + } + return nil +} + +// UnpauseContainer unpauses the given container. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#unpause-a-container for more details. +func (c *Client) UnpauseContainer(id string) error { + path := fmt.Sprintf("/containers/%s/unpause", id) + _, status, err := c.do("POST", path, doOptions{}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: id} + } + if err != nil { + return err + } + return nil +} + +// TopContainer returns processes running inside a container +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#list-processes-running-inside-a-container for more details. +func (c *Client) TopContainer(id string, psArgs string) (TopResult, error) { + var args string + var result TopResult + if psArgs != "" { + args = fmt.Sprintf("?ps_args=%s", psArgs) + } + path := fmt.Sprintf("/containers/%s/top%s", id, args) + body, status, err := c.do("GET", path, doOptions{}) + if status == http.StatusNotFound { + return result, &NoSuchContainer{ID: id} + } + if err != nil { + return result, err + } + err = json.Unmarshal(body, &result) + if err != nil { + return result, err + } + return result, nil +} + +// Stats sends container statistics for the given container to the given channel. +// +// This function is blocking, similar to a streaming call for logs, and should be run +// on a separate goroutine from the caller. Note that this function will block until +// the given container is removed, not just exited. When finished, this function +// will close the given channel. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#get-container-stats-based-on-resource-usage for more details. +func (c *Client) Stats(opts StatsOptions) (retErr error) { + errC := make(chan error, 1) + readCloser, writeCloser := io.Pipe() + + defer func() { + close(opts.Stats) + if err := <-errC; err != nil && retErr == nil { + retErr = err + } + if err := readCloser.Close(); err != nil && retErr == nil { + retErr = err + } + }() + + go func() { + err := c.stream("GET", fmt.Sprintf("/containers/%s/stats?stream=%v", opts.ID, opts.Stream), streamOptions{ + rawJSONStream: true, + useJSONDecoder: true, + stdout: writeCloser, + }) + if err != nil { + dockerError, ok := err.(*Error) + if ok { + if dockerError.Status == http.StatusNotFound { + err = &NoSuchContainer{ID: opts.ID} + } + } + } + if closeErr := writeCloser.Close(); closeErr != nil && err == nil { + err = closeErr + } + errC <- err + close(errC) + }() + + decoder := json.NewDecoder(readCloser) + stats := new(Stats) + for err := decoder.Decode(&stats); err != io.EOF; err = decoder.Decode(stats) { + if err != nil { + return err + } + opts.Stats <- stats + stats = new(Stats) + } + return nil +} + +// KillContainer kills a container, returning an error in case of failure. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#kill-a-container for more details. +func (c *Client) KillContainer(opts KillContainerOptions) error { + path := "/containers/" + opts.ID + "/kill" + "?" + queryString(opts) + _, status, err := c.do("POST", path, doOptions{}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: opts.ID} + } + if err != nil { + return err + } + return nil +} + +// RemoveContainer removes a container, returning an error in case of failure. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#remove-a-container for more details. +func (c *Client) RemoveContainer(opts RemoveContainerOptions) error { + path := "/containers/" + opts.ID + "?" + queryString(opts) + _, status, err := c.do("DELETE", path, doOptions{}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: opts.ID} + } + if err != nil { + return err + } + return nil +} + +// CopyFromContainer copy files or folders from a container, using a given +// resource. +// +// See https://docs.docker.com/reference/api/docker_remote_api_v1.19/#copy-files-or-folders-from-a-container for more details. +func (c *Client) CopyFromContainer(opts CopyFromContainerOptions) error { + if opts.Container == "" { + return &NoSuchContainer{ID: opts.Container} + } + url := fmt.Sprintf("/containers/%s/copy", opts.Container) + body, status, err := c.do("POST", url, doOptions{data: opts}) + if status == http.StatusNotFound { + return &NoSuchContainer{ID: opts.Container} + } + if err != nil { + return err + } + _, err = io.Copy(opts.OutputStream, bytes.NewBuffer(body)) + return err +} + +// WaitContainer blocks until the given container stops, return the exit code +// of the container status. +// +// See http://goo.gl/J88DHU for more details. +func (c *Client) WaitContainer(id string) (int, error) { + body, status, err := c.do("POST", "/containers/"+id+"/wait", doOptions{}) + if status == http.StatusNotFound { + return 0, &NoSuchContainer{ID: id} + } + if err != nil { + return 0, err + } + var r struct{ StatusCode int } + err = json.Unmarshal(body, &r) + if err != nil { + return 0, err + } + return r.StatusCode, nil +} + +// CommitContainer creates a new image from a container's changes. +// +// See http://goo.gl/Jn8pe8 for more details. +func (c *Client) CommitContainer(opts CommitContainerOptions) (*Image, error) { + path := "/commit?" + queryString(opts) + body, status, err := c.do("POST", path, doOptions{data: opts.Run}) + if status == http.StatusNotFound { + return nil, &NoSuchContainer{ID: opts.Container} + } + if err != nil { + return nil, err + } + var image Image + err = json.Unmarshal(body, &image) + if err != nil { + return nil, err + } + return &image, nil +} + +// AttachToContainer attaches to a container, using the given options. +// +// See http://goo.gl/RRAhws for more details. +func (c *Client) AttachToContainer(opts AttachToContainerOptions) error { + if opts.Container == "" { + return &NoSuchContainer{ID: opts.Container} + } + path := "/containers/" + opts.Container + "/attach?" + queryString(opts) + return c.hijack("POST", path, hijackOptions{ + success: opts.Success, + setRawTerminal: opts.RawTerminal, + in: opts.InputStream, + stdout: opts.OutputStream, + stderr: opts.ErrorStream, + }) +} + +// Logs gets stdout and stderr logs from the specified container. +// +// See http://goo.gl/rLhKSU for more details. +func (c *Client) Logs(opts LogsOptions) error { + if opts.Container == "" { + return &NoSuchContainer{ID: opts.Container} + } + if opts.Tail == "" { + opts.Tail = "all" + } + path := "/containers/" + opts.Container + "/logs?" + queryString(opts) + return c.stream("GET", path, streamOptions{ + setRawTerminal: opts.RawTerminal, + stdout: opts.OutputStream, + stderr: opts.ErrorStream, + }) +} + +// ResizeContainerTTY resizes the terminal to the given height and width. +func (c *Client) ResizeContainerTTY(id string, height, width int) error { + params := make(url.Values) + params.Set("h", strconv.Itoa(height)) + params.Set("w", strconv.Itoa(width)) + _, _, err := c.do("POST", "/containers/"+id+"/resize?"+params.Encode(), doOptions{}) + return err +} + +// ExportContainer export the contents of container id as tar archive +// and prints the exported contents to stdout. +// +// See http://goo.gl/hnzE62 for more details. +func (c *Client) ExportContainer(opts ExportContainerOptions) error { + if opts.ID == "" { + return &NoSuchContainer{ID: opts.ID} + } + url := fmt.Sprintf("/containers/%s/export", opts.ID) + return c.stream("GET", url, streamOptions{ + setRawTerminal: true, + stdout: opts.OutputStream, + }) +} + +func (err *NoSuchContainer) Error() string { + if err.Err != nil { + return err.Err.Error() + } + return "No such container: " + err.ID +} + +func (err *ContainerAlreadyRunning) Error() string { + return "Container already running: " + err.ID +} +func (err *ContainerNotRunning) Error() string { + return "Container not running: " + err.ID +} + +//****************************************************************// +//env need func +//****************************************************************// + +// Get returns the string value of the given key. +func (env *Env) Get(key string) (value string) { + return env.Map()[key] +} + +// Exists checks whether the given key is defined in the internal Env +// representation. +func (env *Env) Exists(key string) bool { + _, exists := env.Map()[key] + return exists +} + +// GetBool returns a boolean representation of the given key. The key is false +// whenever its value if 0, no, false, none or an empty string. Any other value +// will be interpreted as true. +func (env *Env) GetBool(key string) (value bool) { + s := strings.ToLower(strings.Trim(env.Get(key), " \t")) + if s == "" || s == "0" || s == "no" || s == "false" || s == "none" { + return false + } + return true +} + +// SetBool defines a boolean value to the given key. +func (env *Env) SetBool(key string, value bool) { + if value { + env.Set(key, "1") + } else { + env.Set(key, "0") + } +} + +// GetInt returns the value of the provided key, converted to int. +// +// It the value cannot be represented as an integer, it returns -1. +func (env *Env) GetInt(key string) int { + return int(env.GetInt64(key)) +} + +// SetInt defines an integer value to the given key. +func (env *Env) SetInt(key string, value int) { + env.Set(key, strconv.Itoa(value)) +} + +// GetInt64 returns the value of the provided key, converted to int64. +// +// It the value cannot be represented as an integer, it returns -1. +func (env *Env) GetInt64(key string) int64 { + s := strings.Trim(env.Get(key), " \t") + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return -1 + } + return val +} + +// SetInt64 defines an integer (64-bit wide) value to the given key. +func (env *Env) SetInt64(key string, value int64) { + env.Set(key, strconv.FormatInt(value, 10)) +} + +// GetJSON unmarshals the value of the provided key in the provided iface. +// +// iface is a value that can be provided to the json.Unmarshal function. +func (env *Env) GetJSON(key string, iface interface{}) error { + sval := env.Get(key) + if sval == "" { + return nil + } + return json.Unmarshal([]byte(sval), iface) +} + +// SetJSON marshals the given value to JSON format and stores it using the +// provided key. +func (env *Env) SetJSON(key string, value interface{}) error { + sval, err := json.Marshal(value) + if err != nil { + return err + } + env.Set(key, string(sval)) + return nil +} + +// GetList returns a list of strings matching the provided key. It handles the +// list as a JSON representation of a list of strings. +// +// If the given key matches to a single string, it will return a list +// containing only the value that matches the key. +func (env *Env) GetList(key string) []string { + sval := env.Get(key) + if sval == "" { + return nil + } + var l []string + if err := json.Unmarshal([]byte(sval), &l); err != nil { + l = append(l, sval) + } + return l +} + +// SetList stores the given list in the provided key, after serializing it to +// JSON format. +func (env *Env) SetList(key string, value []string) error { + return env.SetJSON(key, value) +} + +// Set defines the value of a key to the given string. +func (env *Env) Set(key, value string) { + *env = append(*env, key+"="+value) +} + +// Decode decodes `src` as a json dictionary, and adds each decoded key-value +// pair to the environment. +// +// If `src` cannot be decoded as a json dictionary, an error is returned. +func (env *Env) Decode(src io.Reader) error { + m := make(map[string]interface{}) + if err := json.NewDecoder(src).Decode(&m); err != nil { + return err + } + for k, v := range m { + env.SetAuto(k, v) + } + return nil +} + +// SetAuto will try to define the Set* method to call based on the given value. +func (env *Env) SetAuto(key string, value interface{}) { + if fval, ok := value.(float64); ok { + env.SetInt64(key, int64(fval)) + } else if sval, ok := value.(string); ok { + env.Set(key, sval) + } else if val, err := json.Marshal(value); err == nil { + env.Set(key, string(val)) + } else { + env.Set(key, fmt.Sprintf("%v", value)) + } +} + +// Map returns the map representation of the env. +func (env *Env) Map() map[string]string { + if len(*env) == 0 { + return nil + } + m := make(map[string]string) + for _, kv := range *env { + parts := strings.SplitN(kv, "=", 2) + m[parts[0]] = parts[1] + } + return m +} + +//****************************************************************// +//event need func +//****************************************************************// + +// AddEventListener adds a new listener to container events in the Docker API. +// +// The parameter is a channel through which events will be sent. +func (c *Client) AddEventListener(listener chan<- *APIEvents) error { + var err error + if !c.eventMonitor.isEnabled() { + err = c.eventMonitor.enableEventMonitoring(c) + if err != nil { + return err + } + } + err = c.eventMonitor.addListener(listener) + if err != nil { + return err + } + return nil +} + +// RemoveEventListener removes a listener from the monitor. +func (c *Client) RemoveEventListener(listener chan *APIEvents) error { + err := c.eventMonitor.removeListener(listener) + if err != nil { + return err + } + if len(c.eventMonitor.listeners) == 0 { + err = c.eventMonitor.disableEventMonitoring() + if err != nil { + return err + } + } + return nil +} + +func (eventState *eventMonitoringState) addListener(listener chan<- *APIEvents) error { + eventState.Lock() + defer eventState.Unlock() + if listenerExists(listener, &eventState.listeners) { + return ErrListenerAlreadyExists + } + eventState.Add(1) + eventState.listeners = append(eventState.listeners, listener) + return nil +} + +func (eventState *eventMonitoringState) removeListener(listener chan<- *APIEvents) error { + eventState.Lock() + defer eventState.Unlock() + if listenerExists(listener, &eventState.listeners) { + var newListeners []chan<- *APIEvents + for _, l := range eventState.listeners { + if l != listener { + newListeners = append(newListeners, l) + } + } + eventState.listeners = newListeners + eventState.Add(-1) + } + return nil +} + +func (eventState *eventMonitoringState) closeListeners() { + eventState.Lock() + defer eventState.Unlock() + for _, l := range eventState.listeners { + close(l) + eventState.Add(-1) + } + eventState.listeners = nil +} + +func listenerExists(a chan<- *APIEvents, list *[]chan<- *APIEvents) bool { + for _, b := range *list { + if b == a { + return true + } + } + return false +} + +func (eventState *eventMonitoringState) enableEventMonitoring(c *Client) error { + eventState.Lock() + defer eventState.Unlock() + if !eventState.enabled { + eventState.enabled = true + var lastSeenDefault = int64(0) + eventState.lastSeen = &lastSeenDefault + eventState.C = make(chan *APIEvents, 100) + eventState.errC = make(chan error, 1) + go eventState.monitorEvents(c) + } + return nil +} + +func (eventState *eventMonitoringState) disableEventMonitoring() error { + eventState.Wait() + eventState.Lock() + defer eventState.Unlock() + if eventState.enabled { + eventState.enabled = false + close(eventState.C) + close(eventState.errC) + } + return nil +} + +func (eventState *eventMonitoringState) monitorEvents(c *Client) { + var err error + for eventState.noListeners() { + time.Sleep(10 * time.Millisecond) + } + if err = eventState.connectWithRetry(c); err != nil { + eventState.terminate() + } + for eventState.isEnabled() { + timeout := time.After(100 * time.Millisecond) + select { + case ev, ok := <-eventState.C: + if !ok { + return + } + if ev == EOFEvent { + eventState.closeListeners() + eventState.terminate() + return + } + eventState.updateLastSeen(ev) + go eventState.sendEvent(ev) + case err = <-eventState.errC: + if err == ErrNoListeners { + eventState.terminate() + return + } else if err != nil { + defer func() { go eventState.monitorEvents(c) }() + return + } + case <-timeout: + continue + } + } +} + +func (eventState *eventMonitoringState) connectWithRetry(c *Client) error { + var retries int + var err error + for err = c.eventHijack(atomic.LoadInt64(eventState.lastSeen), eventState.C, eventState.errC); err != nil && retries < maxMonitorConnRetries; retries++ { + waitTime := int64(retryInitialWaitTime * math.Pow(2, float64(retries))) + time.Sleep(time.Duration(waitTime) * time.Millisecond) + err = c.eventHijack(atomic.LoadInt64(eventState.lastSeen), eventState.C, eventState.errC) + } + return err +} + +func (eventState *eventMonitoringState) noListeners() bool { + eventState.RLock() + defer eventState.RUnlock() + return len(eventState.listeners) == 0 +} + +func (eventState *eventMonitoringState) isEnabled() bool { + eventState.RLock() + defer eventState.RUnlock() + return eventState.enabled +} + +func (eventState *eventMonitoringState) sendEvent(event *APIEvents) { + eventState.RLock() + defer eventState.RUnlock() + eventState.Add(1) + defer eventState.Done() + if eventState.enabled { + if len(eventState.listeners) == 0 { + eventState.errC <- ErrNoListeners + return + } + + for _, listener := range eventState.listeners { + listener <- event + } + } +} + +func (eventState *eventMonitoringState) updateLastSeen(e *APIEvents) { + eventState.Lock() + defer eventState.Unlock() + if atomic.LoadInt64(eventState.lastSeen) < e.Time { + atomic.StoreInt64(eventState.lastSeen, e.Time) + } +} + +func (eventState *eventMonitoringState) terminate() { + eventState.disableEventMonitoring() +} + +func (c *Client) eventHijack(startTime int64, eventChan chan *APIEvents, errChan chan error) error { + uri := "/events" + if startTime != 0 { + uri += fmt.Sprintf("?since=%d", startTime) + } + protocol := c.endpointURL.Scheme + address := c.endpointURL.Path + if protocol != "unix" { + protocol = "tcp" + address = c.endpointURL.Host + } + var dial net.Conn + var err error + if c.TLSConfig == nil { + dial, err = net.Dial(protocol, address) + } else { + dial, err = tls.Dial(protocol, address, c.TLSConfig) + } + if err != nil { + return err + } + conn := httputil.NewClientConn(dial, nil) + req, err := http.NewRequest("GET", uri, nil) + if err != nil { + return err + } + res, err := conn.Do(req) + if err != nil { + return err + } + go func(res *http.Response, conn *httputil.ClientConn) { + defer conn.Close() + defer res.Body.Close() + decoder := json.NewDecoder(res.Body) + for { + var event APIEvents + if err = decoder.Decode(&event); err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + if c.eventMonitor.isEnabled() { + // Signal that we're exiting. + eventChan <- EOFEvent + } + break + } + errChan <- err + } + if event.Time == 0 { + continue + } + if !c.eventMonitor.isEnabled() { + return + } + eventChan <- &event + } + }(res, conn) + return nil +} + +//****************************************************************// +//exec need func +//****************************************************************// + +// CreateExec sets up an exec instance in a running container `id`, returning the exec +// instance, or an error in case of failure. +// +// See http://goo.gl/8izrzI for more details +func (c *Client) CreateExec(opts CreateExecOptions) (*Exec, error) { + path := fmt.Sprintf("/containers/%s/exec", opts.Container) + body, status, err := c.do("POST", path, doOptions{data: opts}) + if status == http.StatusNotFound { + return nil, &NoSuchContainer{ID: opts.Container} + } + if err != nil { + return nil, err + } + var exec Exec + err = json.Unmarshal(body, &exec) + if err != nil { + return nil, err + } + + return &exec, nil +} + +// StartExec starts a previously set up exec instance id. If opts.Detach is +// true, it returns after starting the exec command. Otherwise, it sets up an +// interactive session with the exec command. +// +// See http://goo.gl/JW8Lxl for more details +func (c *Client) StartExec(id string, opts StartExecOptions) error { + if id == "" { + return &NoSuchExec{ID: id} + } + + path := fmt.Sprintf("/exec/%s/start", id) + + if opts.Detach { + _, status, err := c.do("POST", path, doOptions{data: opts}) + if status == http.StatusNotFound { + return &NoSuchExec{ID: id} + } + if err != nil { + return err + } + return nil + } + + return c.hijack("POST", path, hijackOptions{ + success: opts.Success, + setRawTerminal: opts.RawTerminal, + in: opts.InputStream, + stdout: opts.OutputStream, + stderr: opts.ErrorStream, + data: opts, + }) +} + +// ResizeExecTTY resizes the tty session used by the exec command id. This API +// is valid only if Tty was specified as part of creating and starting the exec +// command. +// +// See http://goo.gl/YDSx1f for more details +func (c *Client) ResizeExecTTY(id string, height, width int) error { + params := make(url.Values) + params.Set("h", strconv.Itoa(height)) + params.Set("w", strconv.Itoa(width)) + + path := fmt.Sprintf("/exec/%s/resize?%s", id, params.Encode()) + _, _, err := c.do("POST", path, doOptions{}) + return err +} + +// InspectExec returns low-level information about the exec command id. +// +// See http://goo.gl/ypQULN for more details +func (c *Client) InspectExec(id string) (*ExecInspect, error) { + path := fmt.Sprintf("/exec/%s/json", id) + body, status, err := c.do("GET", path, doOptions{}) + if status == http.StatusNotFound { + return nil, &NoSuchExec{ID: id} + } + if err != nil { + return nil, err + } + var exec ExecInspect + err = json.Unmarshal(body, &exec) + if err != nil { + return nil, err + } + return &exec, nil +} + +func (err *NoSuchExec) Error() string { + return "No such exec instance: " + err.ID +} + +//****************************************************************// +//image need func +//****************************************************************// + +// ListImages returns the list of available images in the server. +// +// See http://goo.gl/HRVN1Z for more details. +func (c *Client) ListImages(opts ListImagesOptions) ([]APIImages, error) { + path := "/images/json?" + queryString(opts) + body, _, err := c.do("GET", path, doOptions{}) + if err != nil { + return nil, err + } + var images []APIImages + err = json.Unmarshal(body, &images) + if err != nil { + return nil, err + } + return images, nil +} + +// ImageHistory returns the history of the image by its name or ID. +// +// See http://goo.gl/2oJmNs for more details. +func (c *Client) ImageHistory(name string) ([]ImageHistory, error) { + body, status, err := c.do("GET", "/images/"+name+"/history", doOptions{}) + if status == http.StatusNotFound { + return nil, ErrNoSuchImage + } + if err != nil { + return nil, err + } + var history []ImageHistory + err = json.Unmarshal(body, &history) + if err != nil { + return nil, err + } + return history, nil +} + +// RemoveImage removes an image by its name or ID. +// +// See http://goo.gl/znj0wM for more details. +func (c *Client) RemoveImage(name string) error { + _, status, err := c.do("DELETE", "/images/"+name, doOptions{}) + if status == http.StatusNotFound { + return ErrNoSuchImage + } + return err +} + +// RemoveImageExtended removes an image by its name or ID. +// Extra params can be passed, see RemoveImageOptions +// +// See http://goo.gl/znj0wM for more details. +func (c *Client) RemoveImageExtended(name string, opts RemoveImageOptions) error { + uri := fmt.Sprintf("/images/%s?%s", name, queryString(&opts)) + _, status, err := c.do("DELETE", uri, doOptions{}) + if status == http.StatusNotFound { + return ErrNoSuchImage + } + return err +} + +// InspectImage returns an image by its name or ID. +// +// See http://goo.gl/Q112NY for more details. +func (c *Client) InspectImage(name string) (*Image, error) { + body, status, err := c.do("GET", "/images/"+name+"/json", doOptions{}) + if status == http.StatusNotFound { + return nil, ErrNoSuchImage + } + if err != nil { + return nil, err + } + + var image Image + + // if the caller elected to skip checking the server's version, assume it's the latest + if c.SkipServerVersionCheck || c.expectedAPIVersion.GreaterThanOrEqualTo(apiVersion112) { + err = json.Unmarshal(body, &image) + if err != nil { + return nil, err + } + } else { + var imagePre012 ImagePre012 + err = json.Unmarshal(body, &imagePre012) + if err != nil { + return nil, err + } + + image.ID = imagePre012.ID + image.Parent = imagePre012.Parent + image.Comment = imagePre012.Comment + image.Created = imagePre012.Created + image.Container = imagePre012.Container + image.ContainerConfig = imagePre012.ContainerConfig + image.DockerVersion = imagePre012.DockerVersion + image.Author = imagePre012.Author + image.Config = imagePre012.Config + image.Architecture = imagePre012.Architecture + image.Size = imagePre012.Size + } + + return &image, nil +} + +// PushImage pushes an image to a remote registry, logging progress to w. +// +// An empty instance of AuthConfiguration may be used for unauthenticated +// pushes. +// +// See http://goo.gl/pN8A3P for more details. +func (c *Client) PushImage(opts PushImageOptions, auth AuthConfiguration) error { + if opts.Name == "" { + return ErrNoSuchImage + } + headers, err := headersWithAuth(auth) + if err != nil { + return err + } + name := opts.Name + opts.Name = "" + path := "/images/" + name + "/push?" + queryString(&opts) + return c.stream("POST", path, streamOptions{ + setRawTerminal: true, + rawJSONStream: opts.RawJSONStream, + headers: headers, + stdout: opts.OutputStream, + }) +} + +// PullImage pulls an image from a remote registry, logging progress to opts.OutputStream. +// +// See http://goo.gl/ACyYNS for more details. +func (c *Client) PullImage(opts PullImageOptions, auth AuthConfiguration) error { + if opts.Repository == "" { + return ErrNoSuchImage + } + + headers, err := headersWithAuth(auth) + if err != nil { + return err + } + return c.createImage(queryString(&opts), headers, nil, opts.OutputStream, opts.RawJSONStream) +} + +func (c *Client) createImage(qs string, headers map[string]string, in io.Reader, w io.Writer, rawJSONStream bool) error { + path := "/images/create?" + qs + return c.stream("POST", path, streamOptions{ + setRawTerminal: true, + rawJSONStream: rawJSONStream, + headers: headers, + in: in, + stdout: w, + }) +} + +// LoadImage imports a tarball docker image +// +// See http://goo.gl/Y8NNCq for more details. +func (c *Client) LoadImage(opts LoadImageOptions) error { + return c.stream("POST", "/images/load", streamOptions{ + setRawTerminal: true, + in: opts.InputStream, + }) +} + +// ExportImage exports an image (as a tar file) into the stream +// +// See http://goo.gl/mi6kvk for more details. +func (c *Client) ExportImage(opts ExportImageOptions) error { + return c.stream("GET", fmt.Sprintf("/images/%s/get", opts.Name), streamOptions{ + setRawTerminal: true, + stdout: opts.OutputStream, + }) +} + +// ExportImages exports one or more images (as a tar file) into the stream +// +// See http://goo.gl/YeZzQK for more details. +func (c *Client) ExportImages(opts ExportImagesOptions) error { + if opts.Names == nil || len(opts.Names) == 0 { + return ErrMustSpecifyNames + } + return c.stream("GET", "/images/get?"+queryString(&opts), streamOptions{ + setRawTerminal: true, + stdout: opts.OutputStream, + }) +} + +// ImportImage imports an image from a url, a file or stdin +// +// See http://goo.gl/PhBKnS for more details. +func (c *Client) ImportImage(opts ImportImageOptions) error { + if opts.Repository == "" { + return ErrNoSuchImage + } + if opts.Source != "-" { + opts.InputStream = nil + } + if opts.Source != "-" && !isURL(opts.Source) { + f, err := os.Open(opts.Source) + if err != nil { + return err + } + b, err := ioutil.ReadAll(f) + opts.InputStream = bytes.NewBuffer(b) + opts.Source = "-" + } + return c.createImage(queryString(&opts), nil, opts.InputStream, opts.OutputStream, opts.RawJSONStream) +} + +// BuildImage builds an image from a tarball's url or a Dockerfile in the input +// stream. +// +// See http://goo.gl/7nuGXa for more details. +func (c *Client) BuildImage(opts BuildImageOptions) error { + if opts.OutputStream == nil { + return ErrMissingOutputStream + } + headers, err := headersWithAuth(opts.Auth, opts.AuthConfigs) + if err != nil { + return err + } + + if opts.Remote != "" && opts.Name == "" { + opts.Name = opts.Remote + } + if opts.InputStream != nil || opts.ContextDir != "" { + headers["Content-Type"] = "application/tar" + } else if opts.Remote == "" { + return ErrMissingRepo + } + if opts.ContextDir != "" { + if opts.InputStream != nil { + return ErrMultipleContexts + } + var err error + if opts.InputStream, err = createTarStream(opts.ContextDir, opts.Dockerfile); err != nil { + return err + } + } + + return c.stream("POST", fmt.Sprintf("/build?%s", queryString(&opts)), streamOptions{ + setRawTerminal: true, + rawJSONStream: opts.RawJSONStream, + headers: headers, + in: opts.InputStream, + stdout: opts.OutputStream, + }) +} + +// TagImage adds a tag to the image identified by the given name. +// +// See http://goo.gl/5g6qFy for more details. +func (c *Client) TagImage(name string, opts TagImageOptions) error { + if name == "" { + return ErrNoSuchImage + } + _, status, err := c.do("POST", fmt.Sprintf("/images/"+name+"/tag?%s", + queryString(&opts)), doOptions{}) + + if status == http.StatusNotFound { + return ErrNoSuchImage + } + + return err +} + +func isURL(u string) bool { + p, err := url.Parse(u) + if err != nil { + return false + } + return p.Scheme == "http" || p.Scheme == "https" +} + +func headersWithAuth(auths ...interface{}) (map[string]string, error) { + var headers = make(map[string]string) + + for _, auth := range auths { + switch auth.(type) { + case AuthConfiguration: + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(auth); err != nil { + return nil, err + } + headers["X-Registry-Auth"] = base64.URLEncoding.EncodeToString(buf.Bytes()) + case AuthConfigurations: + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(auth); err != nil { + return nil, err + } + headers["X-Registry-Config"] = base64.URLEncoding.EncodeToString(buf.Bytes()) + } + } + + return headers, nil +} + +// SearchImages search the docker hub with a specific given term. +// +// See http://goo.gl/xI5lLZ for more details. +func (c *Client) SearchImages(term string) ([]APIImageSearch, error) { + body, _, err := c.do("GET", "/images/search?term="+term, doOptions{}) + if err != nil { + return nil, err + } + var searchResult []APIImageSearch + err = json.Unmarshal(body, &searchResult) + if err != nil { + return nil, err + } + return searchResult, nil +} + +//****************************************************************// +//misc need func +//****************************************************************// + +// Version returns version information about the docker server. +// +// See http://goo.gl/BOZrF5 for more details. +func (c *Client) Version() (*Env, error) { + body, _, err := c.do("GET", "/version", doOptions{}) + if err != nil { + return nil, err + } + var env Env + if err := env.Decode(bytes.NewReader(body)); err != nil { + return nil, err + } + return &env, nil +} + +// Info returns system-wide information about the Docker server. +// +// See http://goo.gl/wmqZsW for more details. +func (c *Client) Info() (*Env, error) { + body, _, err := c.do("GET", "/info", doOptions{}) + if err != nil { + return nil, err + } + var info Env + err = info.Decode(bytes.NewReader(body)) + if err != nil { + return nil, err + } + return &info, nil +} + +// ParseRepositoryTag gets the name of the repository and returns it splitted +// in two parts: the repository and the tag. +// +// Some examples: +// +// localhost.localdomain:5000/samalba/hipache:latest -> localhost.localdomain:5000/samalba/hipache, latest +// localhost.localdomain:5000/samalba/hipache -> localhost.localdomain:5000/samalba/hipache, "" +func ParseRepositoryTag(repoTag string) (repository string, tag string) { + n := strings.LastIndex(repoTag, ":") + if n < 0 { + return repoTag, "" + } + if tag := repoTag[n+1:]; !strings.Contains(tag, "/") { + return repoTag[:n], tag + } + return repoTag, "" +} + +//****************************************************************// +//tar need func +//****************************************************************// + +func createTarStream(srcPath, dockerfilePath string) (io.ReadCloser, error) { + excludes, err := parseDockerignore(srcPath) + if err != nil { + return nil, err + } + + includes := []string{"."} + + // If .dockerignore mentions .dockerignore or the Dockerfile + // then make sure we send both files over to the daemon + // because Dockerfile is, obviously, needed no matter what, and + // .dockerignore is needed to know if either one needs to be + // removed. The deamon will remove them for us, if needed, after it + // parses the Dockerfile. + // + // https://github.com/docker/docker/issues/8330 + // + forceIncludeFiles := []string{".dockerignore", dockerfilePath} + + for _, includeFile := range forceIncludeFiles { + if includeFile == "" { + continue + } + keepThem, err := fileutils.Matches(includeFile, excludes) + if err != nil { + return nil, fmt.Errorf("cannot match .dockerfile: '%s', error: %s", includeFile, err) + } + if keepThem { + includes = append(includes, includeFile) + } + } + + if err := validateContextDirectory(srcPath, excludes); err != nil { + return nil, err + } + tarOpts := &archive.TarOptions{ + ExcludePatterns: excludes, + IncludeFiles: includes, + Compression: archive.Uncompressed, + NoLchown: true, + } + return archive.TarWithOptions(srcPath, tarOpts) +} + +// validateContextDirectory checks if all the contents of the directory +// can be read and returns an error if some files can't be read. +// Symlinks which point to non-existing files don't trigger an error +func validateContextDirectory(srcPath string, excludes []string) error { + return filepath.Walk(filepath.Join(srcPath, "."), func(filePath string, f os.FileInfo, err error) error { + // skip this directory/file if it's not in the path, it won't get added to the context + if relFilePath, err := filepath.Rel(srcPath, filePath); err != nil { + return err + } else if skip, err := fileutils.Matches(relFilePath, excludes); err != nil { + return err + } else if skip { + if f.IsDir() { + return filepath.SkipDir + } + return nil + } + + if err != nil { + if os.IsPermission(err) { + return fmt.Errorf("can't stat '%s'", filePath) + } + if os.IsNotExist(err) { + return nil + } + return err + } + + // skip checking if symlinks point to non-existing files, such symlinks can be useful + // also skip named pipes, because they hanging on open + if f.Mode()&(os.ModeSymlink|os.ModeNamedPipe) != 0 { + return nil + } + + if !f.IsDir() { + currentFile, err := os.Open(filePath) + if err != nil && os.IsPermission(err) { + return fmt.Errorf("no permission to read from '%s'", filePath) + } + currentFile.Close() + } + return nil + }) +} + +func parseDockerignore(root string) ([]string, error) { + var excludes []string + ignore, err := ioutil.ReadFile(path.Join(root, ".dockerignore")) + if err != nil && !os.IsNotExist(err) { + return excludes, fmt.Errorf("error reading .dockerignore: '%s'", err) + } + excludes = strings.Split(string(ignore), "\n") + + return excludes, nil +} + +//****************************************************************// +//tls need func +//****************************************************************// + +func (c *tlsClientCon) CloseWrite() error { + // Go standard tls.Conn doesn't provide the CloseWrite() method so we do it + // on its underlying connection. + if cwc, ok := c.rawConn.(interface { + CloseWrite() error + }); ok { + return cwc.CloseWrite() + } + return nil +} + +func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- errors.New("") + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := *config + c.ServerName = hostname + config = &c + } + + conn := tls.Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() + return nil, err + } + + // This is Docker difference with standard's crypto/tls package: returned a + // wrapper which holds both the TLS and raw connections. + return &tlsClientCon{conn, rawConn}, nil +} + +func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) { + return tlsDialWithDialer(new(net.Dialer), network, addr, config) +} diff --git a/Godeps/_workspace/src/github.com/containerops/wrench/utils/utils.go b/Godeps/_workspace/src/github.com/containerops/wrench/utils/utils.go new file mode 100644 index 0000000..73214ef --- /dev/null +++ b/Godeps/_workspace/src/github.com/containerops/wrench/utils/utils.go @@ -0,0 +1,111 @@ +package utils + +import ( + "crypto/md5" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "os" + "reflect" + "regexp" + "strings" + "time" +) + +func IsDirExist(path string) bool { + fi, err := os.Stat(path) + + if err != nil { + return os.IsExist(err) + } else { + return fi.IsDir() + } + + panic("not reached") +} + +func IsFileExist(filename string) bool { + _, err := os.Stat(filename) + return err == nil || os.IsExist(err) +} + +func Contain(obj interface{}, target interface{}) (bool, error) { + targetValue := reflect.ValueOf(target) + + switch reflect.TypeOf(target).Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < targetValue.Len(); i++ { + if targetValue.Index(i).Interface() == obj { + return true, nil + } + } + case reflect.Map: + if targetValue.MapIndex(reflect.ValueOf(obj)).IsValid() { + return true, nil + } + } + + return false, errors.New("not in array") +} + +func EncodeBasicAuth(username string, password string) string { + auth := username + ":" + password + msg := []byte(auth) + authorization := make([]byte, base64.StdEncoding.EncodedLen(len(msg))) + base64.StdEncoding.Encode(authorization, msg) + return string(authorization) +} + +func DecodeBasicAuth(authorization string) (username string, password string, err error) { + basic := strings.Split(strings.TrimSpace(authorization), " ") + if len(basic) <= 1 { + return "", "", err + } + + decLen := base64.StdEncoding.DecodedLen(len(basic[1])) + decoded := make([]byte, decLen) + authByte := []byte(basic[1]) + n, err := base64.StdEncoding.Decode(decoded, authByte) + + if err != nil { + return "", "", err + } + if n > decLen { + return "", "", fmt.Errorf("Something went wrong decoding auth config") + } + + arr := strings.SplitN(string(decoded), ":", 2) + if len(arr) != 2 { + return "", "", fmt.Errorf("Invalid auth configuration file") + } + + username = arr[0] + password = strings.Trim(arr[1], "\x00") + + return username, password, nil +} + +func ValidatePassword(password string) error { + if valida, _ := regexp.MatchString("[:alpha:]", password); valida != true { + return fmt.Errorf("No alpha character in the password.") + } + + if valida, _ := regexp.MatchString("[:digit:]", password); valida != true { + return fmt.Errorf("No digital character in the password.") + } + + if len(password) < 5 || len(password) > 30 { + return fmt.Errorf("Password characters length should be between 5 - 30.") + } + + return nil +} + +func MD5(key string) string { + md5String := fmt.Sprintf("%s%d", key, time.Now().Unix()) + h := md5.New() + h.Write([]byte(md5String)) + + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/README.md b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/README.md new file mode 100644 index 0000000..7307d96 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/README.md @@ -0,0 +1 @@ +This code provides helper functions for dealing with archive files. diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive.go new file mode 100644 index 0000000..04e40a9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive.go @@ -0,0 +1,884 @@ +package archive + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/bzip2" + "compress/gzip" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "syscall" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/fileutils" + "github.com/docker/docker/pkg/pools" + "github.com/docker/docker/pkg/promise" + "github.com/docker/docker/pkg/system" +) + +type ( + Archive io.ReadCloser + ArchiveReader io.Reader + Compression int + TarChownOptions struct { + UID, GID int + } + TarOptions struct { + IncludeFiles []string + ExcludePatterns []string + Compression Compression + NoLchown bool + ChownOpts *TarChownOptions + Name string + IncludeSourceDir bool + // When unpacking, specifies whether overwriting a directory with a + // non-directory is allowed and vice versa. + NoOverwriteDirNonDir bool + } + + // Archiver allows the reuse of most utility functions of this package + // with a pluggable Untar function. + Archiver struct { + Untar func(io.Reader, string, *TarOptions) error + } + + // breakoutError is used to differentiate errors related to breaking out + // When testing archive breakout in the unit tests, this error is expected + // in order for the test to pass. + breakoutError error +) + +var ( + ErrNotImplemented = errors.New("Function not implemented") + defaultArchiver = &Archiver{Untar} +) + +const ( + Uncompressed Compression = iota + Bzip2 + Gzip + Xz +) + +func IsArchive(header []byte) bool { + compression := DetectCompression(header) + if compression != Uncompressed { + return true + } + r := tar.NewReader(bytes.NewBuffer(header)) + _, err := r.Next() + return err == nil +} + +func DetectCompression(source []byte) Compression { + for compression, m := range map[Compression][]byte{ + Bzip2: {0x42, 0x5A, 0x68}, + Gzip: {0x1F, 0x8B, 0x08}, + Xz: {0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00}, + } { + if len(source) < len(m) { + logrus.Debugf("Len too short") + continue + } + if bytes.Compare(m, source[:len(m)]) == 0 { + return compression + } + } + return Uncompressed +} + +func xzDecompress(archive io.Reader) (io.ReadCloser, error) { + args := []string{"xz", "-d", "-c", "-q"} + + return CmdStream(exec.Command(args[0], args[1:]...), archive) +} + +func DecompressStream(archive io.Reader) (io.ReadCloser, error) { + p := pools.BufioReader32KPool + buf := p.Get(archive) + bs, err := buf.Peek(10) + if err != nil { + return nil, err + } + + compression := DetectCompression(bs) + switch compression { + case Uncompressed: + readBufWrapper := p.NewReadCloserWrapper(buf, buf) + return readBufWrapper, nil + case Gzip: + gzReader, err := gzip.NewReader(buf) + if err != nil { + return nil, err + } + readBufWrapper := p.NewReadCloserWrapper(buf, gzReader) + return readBufWrapper, nil + case Bzip2: + bz2Reader := bzip2.NewReader(buf) + readBufWrapper := p.NewReadCloserWrapper(buf, bz2Reader) + return readBufWrapper, nil + case Xz: + xzReader, err := xzDecompress(buf) + if err != nil { + return nil, err + } + readBufWrapper := p.NewReadCloserWrapper(buf, xzReader) + return readBufWrapper, nil + default: + return nil, fmt.Errorf("Unsupported compression format %s", (&compression).Extension()) + } +} + +func CompressStream(dest io.WriteCloser, compression Compression) (io.WriteCloser, error) { + p := pools.BufioWriter32KPool + buf := p.Get(dest) + switch compression { + case Uncompressed: + writeBufWrapper := p.NewWriteCloserWrapper(buf, buf) + return writeBufWrapper, nil + case Gzip: + gzWriter := gzip.NewWriter(dest) + writeBufWrapper := p.NewWriteCloserWrapper(buf, gzWriter) + return writeBufWrapper, nil + case Bzip2, Xz: + // archive/bzip2 does not support writing, and there is no xz support at all + // However, this is not a problem as docker only currently generates gzipped tars + return nil, fmt.Errorf("Unsupported compression format %s", (&compression).Extension()) + default: + return nil, fmt.Errorf("Unsupported compression format %s", (&compression).Extension()) + } +} + +func (compression *Compression) Extension() string { + switch *compression { + case Uncompressed: + return "tar" + case Bzip2: + return "tar.bz2" + case Gzip: + return "tar.gz" + case Xz: + return "tar.xz" + } + return "" +} + +type tarAppender struct { + TarWriter *tar.Writer + Buffer *bufio.Writer + + // for hardlink mapping + SeenFiles map[uint64]string +} + +// canonicalTarName provides a platform-independent and consistent posix-style +//path for files and directories to be archived regardless of the platform. +func canonicalTarName(name string, isDir bool) (string, error) { + name, err := CanonicalTarNameForPath(name) + if err != nil { + return "", err + } + + // suffix with '/' for directories + if isDir && !strings.HasSuffix(name, "/") { + name += "/" + } + return name, nil +} + +func (ta *tarAppender) addTarFile(path, name string) error { + fi, err := os.Lstat(path) + if err != nil { + return err + } + + link := "" + if fi.Mode()&os.ModeSymlink != 0 { + if link, err = os.Readlink(path); err != nil { + return err + } + } + + hdr, err := tar.FileInfoHeader(fi, link) + if err != nil { + return err + } + hdr.Mode = int64(chmodTarEntry(os.FileMode(hdr.Mode))) + + name, err = canonicalTarName(name, fi.IsDir()) + if err != nil { + return fmt.Errorf("tar: cannot canonicalize path: %v", err) + } + hdr.Name = name + + nlink, inode, err := setHeaderForSpecialDevice(hdr, ta, name, fi.Sys()) + if err != nil { + return err + } + + // if it's a regular file and has more than 1 link, + // it's hardlinked, so set the type flag accordingly + if fi.Mode().IsRegular() && nlink > 1 { + // a link should have a name that it links too + // and that linked name should be first in the tar archive + if oldpath, ok := ta.SeenFiles[inode]; ok { + hdr.Typeflag = tar.TypeLink + hdr.Linkname = oldpath + hdr.Size = 0 // This Must be here for the writer math to add up! + } else { + ta.SeenFiles[inode] = name + } + } + + capability, _ := system.Lgetxattr(path, "security.capability") + if capability != nil { + hdr.Xattrs = make(map[string]string) + hdr.Xattrs["security.capability"] = string(capability) + } + + if err := ta.TarWriter.WriteHeader(hdr); err != nil { + return err + } + + if hdr.Typeflag == tar.TypeReg { + file, err := os.Open(path) + if err != nil { + return err + } + + ta.Buffer.Reset(ta.TarWriter) + defer ta.Buffer.Reset(nil) + _, err = io.Copy(ta.Buffer, file) + file.Close() + if err != nil { + return err + } + err = ta.Buffer.Flush() + if err != nil { + return err + } + } + + return nil +} + +func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, Lchown bool, chownOpts *TarChownOptions) error { + // hdr.Mode is in linux format, which we can use for sycalls, + // but for os.Foo() calls we need the mode converted to os.FileMode, + // so use hdrInfo.Mode() (they differ for e.g. setuid bits) + hdrInfo := hdr.FileInfo() + + switch hdr.Typeflag { + case tar.TypeDir: + // Create directory unless it exists as a directory already. + // In that case we just want to merge the two + if fi, err := os.Lstat(path); !(err == nil && fi.IsDir()) { + if err := os.Mkdir(path, hdrInfo.Mode()); err != nil { + return err + } + } + + case tar.TypeReg, tar.TypeRegA: + // Source is regular file + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, hdrInfo.Mode()) + if err != nil { + return err + } + if _, err := io.Copy(file, reader); err != nil { + file.Close() + return err + } + file.Close() + + case tar.TypeBlock, tar.TypeChar, tar.TypeFifo: + // Handle this is an OS-specific way + if err := handleTarTypeBlockCharFifo(hdr, path); err != nil { + return err + } + + case tar.TypeLink: + targetPath := filepath.Join(extractDir, hdr.Linkname) + // check for hardlink breakout + if !strings.HasPrefix(targetPath, extractDir) { + return breakoutError(fmt.Errorf("invalid hardlink %q -> %q", targetPath, hdr.Linkname)) + } + if err := os.Link(targetPath, path); err != nil { + return err + } + + case tar.TypeSymlink: + // path -> hdr.Linkname = targetPath + // e.g. /extractDir/path/to/symlink -> ../2/file = /extractDir/path/2/file + targetPath := filepath.Join(filepath.Dir(path), hdr.Linkname) + + // the reason we don't need to check symlinks in the path (with FollowSymlinkInScope) is because + // that symlink would first have to be created, which would be caught earlier, at this very check: + if !strings.HasPrefix(targetPath, extractDir) { + return breakoutError(fmt.Errorf("invalid symlink %q -> %q", path, hdr.Linkname)) + } + if err := os.Symlink(hdr.Linkname, path); err != nil { + return err + } + + case tar.TypeXGlobalHeader: + logrus.Debugf("PAX Global Extended Headers found and ignored") + return nil + + default: + return fmt.Errorf("Unhandled tar header type %d\n", hdr.Typeflag) + } + + // Lchown is not supported on Windows. + if Lchown && runtime.GOOS != "windows" { + if chownOpts == nil { + chownOpts = &TarChownOptions{UID: hdr.Uid, GID: hdr.Gid} + } + if err := os.Lchown(path, chownOpts.UID, chownOpts.GID); err != nil { + return err + } + } + + for key, value := range hdr.Xattrs { + if err := system.Lsetxattr(path, key, []byte(value), 0); err != nil { + return err + } + } + + // There is no LChmod, so ignore mode for symlink. Also, this + // must happen after chown, as that can modify the file mode + if err := handleLChmod(hdr, path, hdrInfo); err != nil { + return err + } + + ts := []syscall.Timespec{timeToTimespec(hdr.AccessTime), timeToTimespec(hdr.ModTime)} + // syscall.UtimesNano doesn't support a NOFOLLOW flag atm + if hdr.Typeflag == tar.TypeLink { + if fi, err := os.Lstat(hdr.Linkname); err == nil && (fi.Mode()&os.ModeSymlink == 0) { + if err := system.UtimesNano(path, ts); err != nil && err != system.ErrNotSupportedPlatform { + return err + } + } + } else if hdr.Typeflag != tar.TypeSymlink { + if err := system.UtimesNano(path, ts); err != nil && err != system.ErrNotSupportedPlatform { + return err + } + } else { + if err := system.LUtimesNano(path, ts); err != nil && err != system.ErrNotSupportedPlatform { + return err + } + } + return nil +} + +// Tar creates an archive from the directory at `path`, and returns it as a +// stream of bytes. +func Tar(path string, compression Compression) (io.ReadCloser, error) { + return TarWithOptions(path, &TarOptions{Compression: compression}) +} + +// TarWithOptions creates an archive from the directory at `path`, only including files whose relative +// paths are included in `options.IncludeFiles` (if non-nil) or not in `options.ExcludePatterns`. +func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) { + + patterns, patDirs, exceptions, err := fileutils.CleanPatterns(options.ExcludePatterns) + + if err != nil { + return nil, err + } + + pipeReader, pipeWriter := io.Pipe() + + compressWriter, err := CompressStream(pipeWriter, options.Compression) + if err != nil { + return nil, err + } + + go func() { + ta := &tarAppender{ + TarWriter: tar.NewWriter(compressWriter), + Buffer: pools.BufioWriter32KPool.Get(nil), + SeenFiles: make(map[uint64]string), + } + + defer func() { + // Make sure to check the error on Close. + if err := ta.TarWriter.Close(); err != nil { + logrus.Debugf("Can't close tar writer: %s", err) + } + if err := compressWriter.Close(); err != nil { + logrus.Debugf("Can't close compress writer: %s", err) + } + if err := pipeWriter.Close(); err != nil { + logrus.Debugf("Can't close pipe writer: %s", err) + } + }() + + // this buffer is needed for the duration of this piped stream + defer pools.BufioWriter32KPool.Put(ta.Buffer) + + // In general we log errors here but ignore them because + // during e.g. a diff operation the container can continue + // mutating the filesystem and we can see transient errors + // from this + + stat, err := os.Lstat(srcPath) + if err != nil { + return + } + + if !stat.IsDir() { + // We can't later join a non-dir with any includes because the + // 'walk' will error if "file/." is stat-ed and "file" is not a + // directory. So, we must split the source path and use the + // basename as the include. + if len(options.IncludeFiles) > 0 { + logrus.Warn("Tar: Can't archive a file with includes") + } + + dir, base := SplitPathDirEntry(srcPath) + srcPath = dir + options.IncludeFiles = []string{base} + } + + if len(options.IncludeFiles) == 0 { + options.IncludeFiles = []string{"."} + } + + seen := make(map[string]bool) + + var renamedRelFilePath string // For when tar.Options.Name is set + for _, include := range options.IncludeFiles { + // We can't use filepath.Join(srcPath, include) because this will + // clean away a trailing "." or "/" which may be important. + walkRoot := strings.Join([]string{srcPath, include}, string(filepath.Separator)) + filepath.Walk(walkRoot, func(filePath string, f os.FileInfo, err error) error { + if err != nil { + logrus.Debugf("Tar: Can't stat file %s to tar: %s", srcPath, err) + return nil + } + + relFilePath, err := filepath.Rel(srcPath, filePath) + if err != nil || (!options.IncludeSourceDir && relFilePath == "." && f.IsDir()) { + // Error getting relative path OR we are looking + // at the source directory path. Skip in both situations. + return nil + } + + if options.IncludeSourceDir && include == "." && relFilePath != "." { + relFilePath = strings.Join([]string{".", relFilePath}, string(filepath.Separator)) + } + + skip := false + + // If "include" is an exact match for the current file + // then even if there's an "excludePatterns" pattern that + // matches it, don't skip it. IOW, assume an explicit 'include' + // is asking for that file no matter what - which is true + // for some files, like .dockerignore and Dockerfile (sometimes) + if include != relFilePath { + skip, err = fileutils.OptimizedMatches(relFilePath, patterns, patDirs) + if err != nil { + logrus.Debugf("Error matching %s: %v", relFilePath, err) + return err + } + } + + if skip { + if !exceptions && f.IsDir() { + return filepath.SkipDir + } + return nil + } + + if seen[relFilePath] { + return nil + } + seen[relFilePath] = true + + // TODO Windows: Verify if this needs to be os.Pathseparator + // Rename the base resource + if options.Name != "" && filePath == srcPath+"/"+filepath.Base(relFilePath) { + renamedRelFilePath = relFilePath + } + // Set this to make sure the items underneath also get renamed + if options.Name != "" { + relFilePath = strings.Replace(relFilePath, renamedRelFilePath, options.Name, 1) + } + + if err := ta.addTarFile(filePath, relFilePath); err != nil { + logrus.Debugf("Can't add file %s to tar: %s", filePath, err) + } + return nil + }) + } + }() + + return pipeReader, nil +} + +func Unpack(decompressedArchive io.Reader, dest string, options *TarOptions) error { + tr := tar.NewReader(decompressedArchive) + trBuf := pools.BufioReader32KPool.Get(nil) + defer pools.BufioReader32KPool.Put(trBuf) + + var dirs []*tar.Header + + // Iterate through the files in the archive. +loop: + for { + hdr, err := tr.Next() + if err == io.EOF { + // end of tar archive + break + } + if err != nil { + return err + } + + // Normalize name, for safety and for a simple is-root check + // This keeps "../" as-is, but normalizes "/../" to "/". Or Windows: + // This keeps "..\" as-is, but normalizes "\..\" to "\". + hdr.Name = filepath.Clean(hdr.Name) + + for _, exclude := range options.ExcludePatterns { + if strings.HasPrefix(hdr.Name, exclude) { + continue loop + } + } + + // After calling filepath.Clean(hdr.Name) above, hdr.Name will now be in + // the filepath format for the OS on which the daemon is running. Hence + // the check for a slash-suffix MUST be done in an OS-agnostic way. + if !strings.HasSuffix(hdr.Name, string(os.PathSeparator)) { + // Not the root directory, ensure that the parent directory exists + parent := filepath.Dir(hdr.Name) + parentPath := filepath.Join(dest, parent) + if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) { + err = system.MkdirAll(parentPath, 0777) + if err != nil { + return err + } + } + } + + path := filepath.Join(dest, hdr.Name) + rel, err := filepath.Rel(dest, path) + if err != nil { + return err + } + if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest)) + } + + // If path exits we almost always just want to remove and replace it + // The only exception is when it is a directory *and* the file from + // the layer is also a directory. Then we want to merge them (i.e. + // just apply the metadata from the layer). + if fi, err := os.Lstat(path); err == nil { + if options.NoOverwriteDirNonDir && fi.IsDir() && hdr.Typeflag != tar.TypeDir { + // If NoOverwriteDirNonDir is true then we cannot replace + // an existing directory with a non-directory from the archive. + return fmt.Errorf("cannot overwrite directory %q with non-directory %q", path, dest) + } + + if options.NoOverwriteDirNonDir && !fi.IsDir() && hdr.Typeflag == tar.TypeDir { + // If NoOverwriteDirNonDir is true then we cannot replace + // an existing non-directory with a directory from the archive. + return fmt.Errorf("cannot overwrite non-directory %q with directory %q", path, dest) + } + + if fi.IsDir() && hdr.Name == "." { + continue + } + + if !(fi.IsDir() && hdr.Typeflag == tar.TypeDir) { + if err := os.RemoveAll(path); err != nil { + return err + } + } + } + trBuf.Reset(tr) + + if err := createTarFile(path, dest, hdr, trBuf, !options.NoLchown, options.ChownOpts); err != nil { + return err + } + + // Directory mtimes must be handled at the end to avoid further + // file creation in them to modify the directory mtime + if hdr.Typeflag == tar.TypeDir { + dirs = append(dirs, hdr) + } + } + + for _, hdr := range dirs { + path := filepath.Join(dest, hdr.Name) + ts := []syscall.Timespec{timeToTimespec(hdr.AccessTime), timeToTimespec(hdr.ModTime)} + if err := syscall.UtimesNano(path, ts); err != nil { + return err + } + } + return nil +} + +// Untar reads a stream of bytes from `archive`, parses it as a tar archive, +// and unpacks it into the directory at `dest`. +// The archive may be compressed with one of the following algorithms: +// identity (uncompressed), gzip, bzip2, xz. +// FIXME: specify behavior when target path exists vs. doesn't exist. +func Untar(archive io.Reader, dest string, options *TarOptions) error { + if archive == nil { + return fmt.Errorf("Empty archive") + } + dest = filepath.Clean(dest) + if options == nil { + options = &TarOptions{} + } + if options.ExcludePatterns == nil { + options.ExcludePatterns = []string{} + } + decompressedArchive, err := DecompressStream(archive) + if err != nil { + return err + } + defer decompressedArchive.Close() + return Unpack(decompressedArchive, dest, options) +} + +func (archiver *Archiver) TarUntar(src, dst string) error { + logrus.Debugf("TarUntar(%s %s)", src, dst) + archive, err := TarWithOptions(src, &TarOptions{Compression: Uncompressed}) + if err != nil { + return err + } + defer archive.Close() + return archiver.Untar(archive, dst, nil) +} + +// TarUntar is a convenience function which calls Tar and Untar, with the output of one piped into the other. +// If either Tar or Untar fails, TarUntar aborts and returns the error. +func TarUntar(src, dst string) error { + return defaultArchiver.TarUntar(src, dst) +} + +func (archiver *Archiver) UntarPath(src, dst string) error { + archive, err := os.Open(src) + if err != nil { + return err + } + defer archive.Close() + if err := archiver.Untar(archive, dst, nil); err != nil { + return err + } + return nil +} + +// UntarPath is a convenience function which looks for an archive +// at filesystem path `src`, and unpacks it at `dst`. +func UntarPath(src, dst string) error { + return defaultArchiver.UntarPath(src, dst) +} + +func (archiver *Archiver) CopyWithTar(src, dst string) error { + srcSt, err := os.Stat(src) + if err != nil { + return err + } + if !srcSt.IsDir() { + return archiver.CopyFileWithTar(src, dst) + } + // Create dst, copy src's content into it + logrus.Debugf("Creating dest directory: %s", dst) + if err := system.MkdirAll(dst, 0755); err != nil && !os.IsExist(err) { + return err + } + logrus.Debugf("Calling TarUntar(%s, %s)", src, dst) + return archiver.TarUntar(src, dst) +} + +// CopyWithTar creates a tar archive of filesystem path `src`, and +// unpacks it at filesystem path `dst`. +// The archive is streamed directly with fixed buffering and no +// intermediary disk IO. +func CopyWithTar(src, dst string) error { + return defaultArchiver.CopyWithTar(src, dst) +} + +func (archiver *Archiver) CopyFileWithTar(src, dst string) (err error) { + logrus.Debugf("CopyFileWithTar(%s, %s)", src, dst) + srcSt, err := os.Stat(src) + if err != nil { + return err + } + + if srcSt.IsDir() { + return fmt.Errorf("Can't copy a directory") + } + + // Clean up the trailing slash. This must be done in an operating + // system specific manner. + if dst[len(dst)-1] == os.PathSeparator { + dst = filepath.Join(dst, filepath.Base(src)) + } + // Create the holding directory if necessary + if err := system.MkdirAll(filepath.Dir(dst), 0700); err != nil && !os.IsExist(err) { + return err + } + + r, w := io.Pipe() + errC := promise.Go(func() error { + defer w.Close() + + srcF, err := os.Open(src) + if err != nil { + return err + } + defer srcF.Close() + + hdr, err := tar.FileInfoHeader(srcSt, "") + if err != nil { + return err + } + hdr.Name = filepath.Base(dst) + hdr.Mode = int64(chmodTarEntry(os.FileMode(hdr.Mode))) + + tw := tar.NewWriter(w) + defer tw.Close() + if err := tw.WriteHeader(hdr); err != nil { + return err + } + if _, err := io.Copy(tw, srcF); err != nil { + return err + } + return nil + }) + defer func() { + if er := <-errC; err != nil { + err = er + } + }() + return archiver.Untar(r, filepath.Dir(dst), nil) +} + +// CopyFileWithTar emulates the behavior of the 'cp' command-line +// for a single file. It copies a regular file from path `src` to +// path `dst`, and preserves all its metadata. +// +// Destination handling is in an operating specific manner depending +// where the daemon is running. If `dst` ends with a trailing slash +// the final destination path will be `dst/base(src)` (Linux) or +// `dst\base(src)` (Windows). +func CopyFileWithTar(src, dst string) (err error) { + return defaultArchiver.CopyFileWithTar(src, dst) +} + +// CmdStream executes a command, and returns its stdout as a stream. +// If the command fails to run or doesn't complete successfully, an error +// will be returned, including anything written on stderr. +func CmdStream(cmd *exec.Cmd, input io.Reader) (io.ReadCloser, error) { + if input != nil { + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + // Write stdin if any + go func() { + io.Copy(stdin, input) + stdin.Close() + }() + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + pipeR, pipeW := io.Pipe() + errChan := make(chan []byte) + // Collect stderr, we will use it in case of an error + go func() { + errText, e := ioutil.ReadAll(stderr) + if e != nil { + errText = []byte("(...couldn't fetch stderr: " + e.Error() + ")") + } + errChan <- errText + }() + // Copy stdout to the returned pipe + go func() { + _, err := io.Copy(pipeW, stdout) + if err != nil { + pipeW.CloseWithError(err) + } + errText := <-errChan + if err := cmd.Wait(); err != nil { + pipeW.CloseWithError(fmt.Errorf("%s: %s", err, errText)) + } else { + pipeW.Close() + } + }() + // Run the command and return the pipe + if err := cmd.Start(); err != nil { + return nil, err + } + return pipeR, nil +} + +// NewTempArchive reads the content of src into a temporary file, and returns the contents +// of that file as an archive. The archive can only be read once - as soon as reading completes, +// the file will be deleted. +func NewTempArchive(src Archive, dir string) (*TempArchive, error) { + f, err := ioutil.TempFile(dir, "") + if err != nil { + return nil, err + } + if _, err := io.Copy(f, src); err != nil { + return nil, err + } + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + st, err := f.Stat() + if err != nil { + return nil, err + } + size := st.Size() + return &TempArchive{File: f, Size: size}, nil +} + +type TempArchive struct { + *os.File + Size int64 // Pre-computed from Stat().Size() as a convenience + read int64 + closed bool +} + +// Close closes the underlying file if it's still open, or does a no-op +// to allow callers to try to close the TempArchive multiple times safely. +func (archive *TempArchive) Close() error { + if archive.closed { + return nil + } + + archive.closed = true + + return archive.File.Close() +} + +func (archive *TempArchive) Read(data []byte) (int, error) { + n, err := archive.File.Read(data) + archive.read += int64(n) + if err != nil || archive.read == archive.Size { + archive.Close() + os.Remove(archive.File.Name()) + } + return n, err +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_test.go new file mode 100644 index 0000000..b93c76c --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_test.go @@ -0,0 +1,1204 @@ +package archive + +import ( + "archive/tar" + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + "syscall" + "testing" + "time" + + "github.com/docker/docker/pkg/system" +) + +func TestIsArchiveNilHeader(t *testing.T) { + out := IsArchive(nil) + if out { + t.Fatalf("isArchive should return false as nil is not a valid archive header") + } +} + +func TestIsArchiveInvalidHeader(t *testing.T) { + header := []byte{0x00, 0x01, 0x02} + out := IsArchive(header) + if out { + t.Fatalf("isArchive should return false as %s is not a valid archive header", header) + } +} + +func TestIsArchiveBzip2(t *testing.T) { + header := []byte{0x42, 0x5A, 0x68} + out := IsArchive(header) + if !out { + t.Fatalf("isArchive should return true as %s is a bz2 header", header) + } +} + +func TestIsArchive7zip(t *testing.T) { + header := []byte{0x50, 0x4b, 0x03, 0x04} + out := IsArchive(header) + if out { + t.Fatalf("isArchive should return false as %s is a 7z header and it is not supported", header) + } +} + +func TestDecompressStreamGzip(t *testing.T) { + cmd := exec.Command("/bin/sh", "-c", "touch /tmp/archive && gzip -f /tmp/archive") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("Fail to create an archive file for test : %s.", output) + } + archive, err := os.Open("/tmp/archive.gz") + _, err = DecompressStream(archive) + if err != nil { + t.Fatalf("Failed to decompress a gzip file.") + } +} + +func TestDecompressStreamBzip2(t *testing.T) { + cmd := exec.Command("/bin/sh", "-c", "touch /tmp/archive && bzip2 -f /tmp/archive") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("Fail to create an archive file for test : %s.", output) + } + archive, err := os.Open("/tmp/archive.bz2") + _, err = DecompressStream(archive) + if err != nil { + t.Fatalf("Failed to decompress a bzip2 file.") + } +} + +func TestDecompressStreamXz(t *testing.T) { + cmd := exec.Command("/bin/sh", "-c", "touch /tmp/archive && xz -f /tmp/archive") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("Fail to create an archive file for test : %s.", output) + } + archive, err := os.Open("/tmp/archive.xz") + _, err = DecompressStream(archive) + if err != nil { + t.Fatalf("Failed to decompress a xz file.") + } +} + +func TestCompressStreamXzUnsuported(t *testing.T) { + dest, err := os.Create("/tmp/dest") + if err != nil { + t.Fatalf("Fail to create the destination file") + } + _, err = CompressStream(dest, Xz) + if err == nil { + t.Fatalf("Should fail as xz is unsupported for compression format.") + } +} + +func TestCompressStreamBzip2Unsupported(t *testing.T) { + dest, err := os.Create("/tmp/dest") + if err != nil { + t.Fatalf("Fail to create the destination file") + } + _, err = CompressStream(dest, Xz) + if err == nil { + t.Fatalf("Should fail as xz is unsupported for compression format.") + } +} + +func TestCompressStreamInvalid(t *testing.T) { + dest, err := os.Create("/tmp/dest") + if err != nil { + t.Fatalf("Fail to create the destination file") + } + _, err = CompressStream(dest, -1) + if err == nil { + t.Fatalf("Should fail as xz is unsupported for compression format.") + } +} + +func TestExtensionInvalid(t *testing.T) { + compression := Compression(-1) + output := compression.Extension() + if output != "" { + t.Fatalf("The extension of an invalid compression should be an empty string.") + } +} + +func TestExtensionUncompressed(t *testing.T) { + compression := Uncompressed + output := compression.Extension() + if output != "tar" { + t.Fatalf("The extension of a uncompressed archive should be 'tar'.") + } +} +func TestExtensionBzip2(t *testing.T) { + compression := Bzip2 + output := compression.Extension() + if output != "tar.bz2" { + t.Fatalf("The extension of a bzip2 archive should be 'tar.bz2'") + } +} +func TestExtensionGzip(t *testing.T) { + compression := Gzip + output := compression.Extension() + if output != "tar.gz" { + t.Fatalf("The extension of a bzip2 archive should be 'tar.gz'") + } +} +func TestExtensionXz(t *testing.T) { + compression := Xz + output := compression.Extension() + if output != "tar.xz" { + t.Fatalf("The extension of a bzip2 archive should be 'tar.xz'") + } +} + +func TestCmdStreamLargeStderr(t *testing.T) { + cmd := exec.Command("/bin/sh", "-c", "dd if=/dev/zero bs=1k count=1000 of=/dev/stderr; echo hello") + out, err := CmdStream(cmd, nil) + if err != nil { + t.Fatalf("Failed to start command: %s", err) + } + errCh := make(chan error) + go func() { + _, err := io.Copy(ioutil.Discard, out) + errCh <- err + }() + select { + case err := <-errCh: + if err != nil { + t.Fatalf("Command should not have failed (err=%.100s...)", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("Command did not complete in 5 seconds; probable deadlock") + } +} + +func TestCmdStreamBad(t *testing.T) { + badCmd := exec.Command("/bin/sh", "-c", "echo hello; echo >&2 error couldn\\'t reverse the phase pulser; exit 1") + out, err := CmdStream(badCmd, nil) + if err != nil { + t.Fatalf("Failed to start command: %s", err) + } + if output, err := ioutil.ReadAll(out); err == nil { + t.Fatalf("Command should have failed") + } else if err.Error() != "exit status 1: error couldn't reverse the phase pulser\n" { + t.Fatalf("Wrong error value (%s)", err) + } else if s := string(output); s != "hello\n" { + t.Fatalf("Command output should be '%s', not '%s'", "hello\\n", output) + } +} + +func TestCmdStreamGood(t *testing.T) { + cmd := exec.Command("/bin/sh", "-c", "echo hello; exit 0") + out, err := CmdStream(cmd, nil) + if err != nil { + t.Fatal(err) + } + if output, err := ioutil.ReadAll(out); err != nil { + t.Fatalf("Command should not have failed (err=%s)", err) + } else if s := string(output); s != "hello\n" { + t.Fatalf("Command output should be '%s', not '%s'", "hello\\n", output) + } +} + +func TestUntarPathWithInvalidDest(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempFolder) + invalidDestFolder := path.Join(tempFolder, "invalidDest") + // Create a src file + srcFile := path.Join(tempFolder, "src") + _, err = os.Create(srcFile) + if err != nil { + t.Fatalf("Fail to create the source file") + } + err = UntarPath(srcFile, invalidDestFolder) + if err == nil { + t.Fatalf("UntarPath with invalid destination path should throw an error.") + } +} + +func TestUntarPathWithInvalidSrc(t *testing.T) { + dest, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatalf("Fail to create the destination file") + } + defer os.RemoveAll(dest) + err = UntarPath("/invalid/path", dest) + if err == nil { + t.Fatalf("UntarPath with invalid src path should throw an error.") + } +} + +func TestUntarPath(t *testing.T) { + tmpFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpFolder) + srcFile := path.Join(tmpFolder, "src") + tarFile := path.Join(tmpFolder, "src.tar") + os.Create(path.Join(tmpFolder, "src")) + cmd := exec.Command("/bin/sh", "-c", "tar cf "+tarFile+" "+srcFile) + _, err = cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + destFolder := path.Join(tmpFolder, "dest") + err = os.MkdirAll(destFolder, 0740) + if err != nil { + t.Fatalf("Fail to create the destination file") + } + err = UntarPath(tarFile, destFolder) + if err != nil { + t.Fatalf("UntarPath shouldn't throw an error, %s.", err) + } + expectedFile := path.Join(destFolder, srcFile) + _, err = os.Stat(expectedFile) + if err != nil { + t.Fatalf("Destination folder should contain the source file but did not.") + } +} + +// Do the same test as above but with the destination as file, it should fail +func TestUntarPathWithDestinationFile(t *testing.T) { + tmpFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpFolder) + srcFile := path.Join(tmpFolder, "src") + tarFile := path.Join(tmpFolder, "src.tar") + os.Create(path.Join(tmpFolder, "src")) + cmd := exec.Command("/bin/sh", "-c", "tar cf "+tarFile+" "+srcFile) + _, err = cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + destFile := path.Join(tmpFolder, "dest") + _, err = os.Create(destFile) + if err != nil { + t.Fatalf("Fail to create the destination file") + } + err = UntarPath(tarFile, destFile) + if err == nil { + t.Fatalf("UntarPath should throw an error if the destination if a file") + } +} + +// Do the same test as above but with the destination folder already exists +// and the destination file is a directory +// It's working, see https://github.com/docker/docker/issues/10040 +func TestUntarPathWithDestinationSrcFileAsFolder(t *testing.T) { + tmpFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpFolder) + srcFile := path.Join(tmpFolder, "src") + tarFile := path.Join(tmpFolder, "src.tar") + os.Create(srcFile) + cmd := exec.Command("/bin/sh", "-c", "tar cf "+tarFile+" "+srcFile) + _, err = cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + destFolder := path.Join(tmpFolder, "dest") + err = os.MkdirAll(destFolder, 0740) + if err != nil { + t.Fatalf("Fail to create the destination folder") + } + // Let's create a folder that will has the same path as the extracted file (from tar) + destSrcFileAsFolder := path.Join(destFolder, srcFile) + err = os.MkdirAll(destSrcFileAsFolder, 0740) + if err != nil { + t.Fatal(err) + } + err = UntarPath(tarFile, destFolder) + if err != nil { + t.Fatalf("UntarPath should throw not throw an error if the extracted file already exists and is a folder") + } +} + +func TestCopyWithTarInvalidSrc(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(nil) + } + destFolder := path.Join(tempFolder, "dest") + invalidSrc := path.Join(tempFolder, "doesnotexists") + err = os.MkdirAll(destFolder, 0740) + if err != nil { + t.Fatal(err) + } + err = CopyWithTar(invalidSrc, destFolder) + if err == nil { + t.Fatalf("archiver.CopyWithTar with invalid src path should throw an error.") + } +} + +func TestCopyWithTarInexistentDestWillCreateIt(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(nil) + } + srcFolder := path.Join(tempFolder, "src") + inexistentDestFolder := path.Join(tempFolder, "doesnotexists") + err = os.MkdirAll(srcFolder, 0740) + if err != nil { + t.Fatal(err) + } + err = CopyWithTar(srcFolder, inexistentDestFolder) + if err != nil { + t.Fatalf("CopyWithTar with an inexistent folder shouldn't fail.") + } + _, err = os.Stat(inexistentDestFolder) + if err != nil { + t.Fatalf("CopyWithTar with an inexistent folder should create it.") + } +} + +// Test CopyWithTar with a file as src +func TestCopyWithTarSrcFile(t *testing.T) { + folder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(folder) + dest := path.Join(folder, "dest") + srcFolder := path.Join(folder, "src") + src := path.Join(folder, path.Join("src", "src")) + err = os.MkdirAll(srcFolder, 0740) + if err != nil { + t.Fatal(err) + } + err = os.MkdirAll(dest, 0740) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(src, []byte("content"), 0777) + err = CopyWithTar(src, dest) + if err != nil { + t.Fatalf("archiver.CopyWithTar shouldn't throw an error, %s.", err) + } + _, err = os.Stat(dest) + // FIXME Check the content + if err != nil { + t.Fatalf("Destination file should be the same as the source.") + } +} + +// Test CopyWithTar with a folder as src +func TestCopyWithTarSrcFolder(t *testing.T) { + folder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(folder) + dest := path.Join(folder, "dest") + src := path.Join(folder, path.Join("src", "folder")) + err = os.MkdirAll(src, 0740) + if err != nil { + t.Fatal(err) + } + err = os.MkdirAll(dest, 0740) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(path.Join(src, "file"), []byte("content"), 0777) + err = CopyWithTar(src, dest) + if err != nil { + t.Fatalf("archiver.CopyWithTar shouldn't throw an error, %s.", err) + } + _, err = os.Stat(dest) + // FIXME Check the content (the file inside) + if err != nil { + t.Fatalf("Destination folder should contain the source file but did not.") + } +} + +func TestCopyFileWithTarInvalidSrc(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempFolder) + destFolder := path.Join(tempFolder, "dest") + err = os.MkdirAll(destFolder, 0740) + if err != nil { + t.Fatal(err) + } + invalidFile := path.Join(tempFolder, "doesnotexists") + err = CopyFileWithTar(invalidFile, destFolder) + if err == nil { + t.Fatalf("archiver.CopyWithTar with invalid src path should throw an error.") + } +} + +func TestCopyFileWithTarInexistentDestWillCreateIt(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(nil) + } + defer os.RemoveAll(tempFolder) + srcFile := path.Join(tempFolder, "src") + inexistentDestFolder := path.Join(tempFolder, "doesnotexists") + _, err = os.Create(srcFile) + if err != nil { + t.Fatal(err) + } + err = CopyFileWithTar(srcFile, inexistentDestFolder) + if err != nil { + t.Fatalf("CopyWithTar with an inexistent folder shouldn't fail.") + } + _, err = os.Stat(inexistentDestFolder) + if err != nil { + t.Fatalf("CopyWithTar with an inexistent folder should create it.") + } + // FIXME Test the src file and content +} + +func TestCopyFileWithTarSrcFolder(t *testing.T) { + folder, err := ioutil.TempDir("", "docker-archive-copyfilewithtar-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(folder) + dest := path.Join(folder, "dest") + src := path.Join(folder, "srcfolder") + err = os.MkdirAll(src, 0740) + if err != nil { + t.Fatal(err) + } + err = os.MkdirAll(dest, 0740) + if err != nil { + t.Fatal(err) + } + err = CopyFileWithTar(src, dest) + if err == nil { + t.Fatalf("CopyFileWithTar should throw an error with a folder.") + } +} + +func TestCopyFileWithTarSrcFile(t *testing.T) { + folder, err := ioutil.TempDir("", "docker-archive-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(folder) + dest := path.Join(folder, "dest") + srcFolder := path.Join(folder, "src") + src := path.Join(folder, path.Join("src", "src")) + err = os.MkdirAll(srcFolder, 0740) + if err != nil { + t.Fatal(err) + } + err = os.MkdirAll(dest, 0740) + if err != nil { + t.Fatal(err) + } + ioutil.WriteFile(src, []byte("content"), 0777) + err = CopyWithTar(src, dest+"/") + if err != nil { + t.Fatalf("archiver.CopyFileWithTar shouldn't throw an error, %s.", err) + } + _, err = os.Stat(dest) + if err != nil { + t.Fatalf("Destination folder should contain the source file but did not.") + } +} + +func TestTarFiles(t *testing.T) { + // try without hardlinks + if err := checkNoChanges(1000, false); err != nil { + t.Fatal(err) + } + // try with hardlinks + if err := checkNoChanges(1000, true); err != nil { + t.Fatal(err) + } +} + +func checkNoChanges(fileNum int, hardlinks bool) error { + srcDir, err := ioutil.TempDir("", "docker-test-srcDir") + if err != nil { + return err + } + defer os.RemoveAll(srcDir) + + destDir, err := ioutil.TempDir("", "docker-test-destDir") + if err != nil { + return err + } + defer os.RemoveAll(destDir) + + _, err = prepareUntarSourceDirectory(fileNum, srcDir, hardlinks) + if err != nil { + return err + } + + err = TarUntar(srcDir, destDir) + if err != nil { + return err + } + + changes, err := ChangesDirs(destDir, srcDir) + if err != nil { + return err + } + if len(changes) > 0 { + return fmt.Errorf("with %d files and %v hardlinks: expected 0 changes, got %d", fileNum, hardlinks, len(changes)) + } + return nil +} + +func tarUntar(t *testing.T, origin string, options *TarOptions) ([]Change, error) { + archive, err := TarWithOptions(origin, options) + if err != nil { + t.Fatal(err) + } + defer archive.Close() + + buf := make([]byte, 10) + if _, err := archive.Read(buf); err != nil { + return nil, err + } + wrap := io.MultiReader(bytes.NewReader(buf), archive) + + detectedCompression := DetectCompression(buf) + compression := options.Compression + if detectedCompression.Extension() != compression.Extension() { + return nil, fmt.Errorf("Wrong compression detected. Actual compression: %s, found %s", compression.Extension(), detectedCompression.Extension()) + } + + tmp, err := ioutil.TempDir("", "docker-test-untar") + if err != nil { + return nil, err + } + defer os.RemoveAll(tmp) + if err := Untar(wrap, tmp, nil); err != nil { + return nil, err + } + if _, err := os.Stat(tmp); err != nil { + return nil, err + } + + return ChangesDirs(origin, tmp) +} + +func TestTarUntar(t *testing.T) { + origin, err := ioutil.TempDir("", "docker-test-untar-origin") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(origin) + if err := ioutil.WriteFile(path.Join(origin, "1"), []byte("hello world"), 0700); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(origin, "2"), []byte("welcome!"), 0700); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(origin, "3"), []byte("will be ignored"), 0700); err != nil { + t.Fatal(err) + } + + for _, c := range []Compression{ + Uncompressed, + Gzip, + } { + changes, err := tarUntar(t, origin, &TarOptions{ + Compression: c, + ExcludePatterns: []string{"3"}, + }) + + if err != nil { + t.Fatalf("Error tar/untar for compression %s: %s", c.Extension(), err) + } + + if len(changes) != 1 || changes[0].Path != "/3" { + t.Fatalf("Unexpected differences after tarUntar: %v", changes) + } + } +} + +func TestTarUntarWithXattr(t *testing.T) { + origin, err := ioutil.TempDir("", "docker-test-untar-origin") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(origin) + if err := ioutil.WriteFile(path.Join(origin, "1"), []byte("hello world"), 0700); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(origin, "2"), []byte("welcome!"), 0700); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(origin, "3"), []byte("will be ignored"), 0700); err != nil { + t.Fatal(err) + } + if err := system.Lsetxattr(path.Join(origin, "2"), "security.capability", []byte{0x00}, 0); err != nil { + t.Fatal(err) + } + + for _, c := range []Compression{ + Uncompressed, + Gzip, + } { + changes, err := tarUntar(t, origin, &TarOptions{ + Compression: c, + ExcludePatterns: []string{"3"}, + }) + + if err != nil { + t.Fatalf("Error tar/untar for compression %s: %s", c.Extension(), err) + } + + if len(changes) != 1 || changes[0].Path != "/3" { + t.Fatalf("Unexpected differences after tarUntar: %v", changes) + } + capability, _ := system.Lgetxattr(path.Join(origin, "2"), "security.capability") + if capability == nil && capability[0] != 0x00 { + t.Fatalf("Untar should have kept the 'security.capability' xattr.") + } + } +} + +func TestTarWithOptions(t *testing.T) { + origin, err := ioutil.TempDir("", "docker-test-untar-origin") + if err != nil { + t.Fatal(err) + } + if _, err := ioutil.TempDir(origin, "folder"); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(origin) + if err := ioutil.WriteFile(path.Join(origin, "1"), []byte("hello world"), 0700); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(origin, "2"), []byte("welcome!"), 0700); err != nil { + t.Fatal(err) + } + + cases := []struct { + opts *TarOptions + numChanges int + }{ + {&TarOptions{IncludeFiles: []string{"1"}}, 2}, + {&TarOptions{ExcludePatterns: []string{"2"}}, 1}, + {&TarOptions{ExcludePatterns: []string{"1", "folder*"}}, 2}, + {&TarOptions{IncludeFiles: []string{"1", "1"}}, 2}, + {&TarOptions{Name: "test", IncludeFiles: []string{"1"}}, 4}, + } + for _, testCase := range cases { + changes, err := tarUntar(t, origin, testCase.opts) + if err != nil { + t.Fatalf("Error tar/untar when testing inclusion/exclusion: %s", err) + } + if len(changes) != testCase.numChanges { + t.Errorf("Expected %d changes, got %d for %+v:", + testCase.numChanges, len(changes), testCase.opts) + } + } +} + +// Some tar archives such as http://haproxy.1wt.eu/download/1.5/src/devel/haproxy-1.5-dev21.tar.gz +// use PAX Global Extended Headers. +// Failing prevents the archives from being uncompressed during ADD +func TestTypeXGlobalHeaderDoesNotFail(t *testing.T) { + hdr := tar.Header{Typeflag: tar.TypeXGlobalHeader} + tmpDir, err := ioutil.TempDir("", "docker-test-archive-pax-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + err = createTarFile(filepath.Join(tmpDir, "pax_global_header"), tmpDir, &hdr, nil, true, nil) + if err != nil { + t.Fatal(err) + } +} + +// Some tar have both GNU specific (huge uid) and Ustar specific (long name) things. +// Not supposed to happen (should use PAX instead of Ustar for long name) but it does and it should still work. +func TestUntarUstarGnuConflict(t *testing.T) { + f, err := os.Open("testdata/broken.tar") + if err != nil { + t.Fatal(err) + } + found := false + tr := tar.NewReader(f) + // Iterate through the files in the archive. + for { + hdr, err := tr.Next() + if err == io.EOF { + // end of tar archive + break + } + if err != nil { + t.Fatal(err) + } + if hdr.Name == "root/.cpanm/work/1395823785.24209/Plack-1.0030/blib/man3/Plack::Middleware::LighttpdScriptNameFix.3pm" { + found = true + break + } + } + if !found { + t.Fatalf("%s not found in the archive", "root/.cpanm/work/1395823785.24209/Plack-1.0030/blib/man3/Plack::Middleware::LighttpdScriptNameFix.3pm") + } +} + +func TestTarWithBlockCharFifo(t *testing.T) { + origin, err := ioutil.TempDir("", "docker-test-tar-hardlink") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(origin) + if err := ioutil.WriteFile(path.Join(origin, "1"), []byte("hello world"), 0700); err != nil { + t.Fatal(err) + } + if err := system.Mknod(path.Join(origin, "2"), syscall.S_IFBLK, int(system.Mkdev(int64(12), int64(5)))); err != nil { + t.Fatal(err) + } + if err := system.Mknod(path.Join(origin, "3"), syscall.S_IFCHR, int(system.Mkdev(int64(12), int64(5)))); err != nil { + t.Fatal(err) + } + if err := system.Mknod(path.Join(origin, "4"), syscall.S_IFIFO, int(system.Mkdev(int64(12), int64(5)))); err != nil { + t.Fatal(err) + } + + dest, err := ioutil.TempDir("", "docker-test-tar-hardlink-dest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dest) + + // we'll do this in two steps to separate failure + fh, err := Tar(origin, Uncompressed) + if err != nil { + t.Fatal(err) + } + + // ensure we can read the whole thing with no error, before writing back out + buf, err := ioutil.ReadAll(fh) + if err != nil { + t.Fatal(err) + } + + bRdr := bytes.NewReader(buf) + err = Untar(bRdr, dest, &TarOptions{Compression: Uncompressed}) + if err != nil { + t.Fatal(err) + } + + changes, err := ChangesDirs(origin, dest) + if err != nil { + t.Fatal(err) + } + if len(changes) > 0 { + t.Fatalf("Tar with special device (block, char, fifo) should keep them (recreate them when untar) : %v", changes) + } +} + +func TestTarWithHardLink(t *testing.T) { + origin, err := ioutil.TempDir("", "docker-test-tar-hardlink") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(origin) + if err := ioutil.WriteFile(path.Join(origin, "1"), []byte("hello world"), 0700); err != nil { + t.Fatal(err) + } + if err := os.Link(path.Join(origin, "1"), path.Join(origin, "2")); err != nil { + t.Fatal(err) + } + + var i1, i2 uint64 + if i1, err = getNlink(path.Join(origin, "1")); err != nil { + t.Fatal(err) + } + // sanity check that we can hardlink + if i1 != 2 { + t.Skipf("skipping since hardlinks don't work here; expected 2 links, got %d", i1) + } + + dest, err := ioutil.TempDir("", "docker-test-tar-hardlink-dest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dest) + + // we'll do this in two steps to separate failure + fh, err := Tar(origin, Uncompressed) + if err != nil { + t.Fatal(err) + } + + // ensure we can read the whole thing with no error, before writing back out + buf, err := ioutil.ReadAll(fh) + if err != nil { + t.Fatal(err) + } + + bRdr := bytes.NewReader(buf) + err = Untar(bRdr, dest, &TarOptions{Compression: Uncompressed}) + if err != nil { + t.Fatal(err) + } + + if i1, err = getInode(path.Join(dest, "1")); err != nil { + t.Fatal(err) + } + if i2, err = getInode(path.Join(dest, "2")); err != nil { + t.Fatal(err) + } + + if i1 != i2 { + t.Errorf("expected matching inodes, but got %d and %d", i1, i2) + } +} + +func getNlink(path string) (uint64, error) { + stat, err := os.Stat(path) + if err != nil { + return 0, err + } + statT, ok := stat.Sys().(*syscall.Stat_t) + if !ok { + return 0, fmt.Errorf("expected type *syscall.Stat_t, got %t", stat.Sys()) + } + // We need this conversion on ARM64 + return uint64(statT.Nlink), nil +} + +func getInode(path string) (uint64, error) { + stat, err := os.Stat(path) + if err != nil { + return 0, err + } + statT, ok := stat.Sys().(*syscall.Stat_t) + if !ok { + return 0, fmt.Errorf("expected type *syscall.Stat_t, got %t", stat.Sys()) + } + return statT.Ino, nil +} + +func prepareUntarSourceDirectory(numberOfFiles int, targetPath string, makeLinks bool) (int, error) { + fileData := []byte("fooo") + for n := 0; n < numberOfFiles; n++ { + fileName := fmt.Sprintf("file-%d", n) + if err := ioutil.WriteFile(path.Join(targetPath, fileName), fileData, 0700); err != nil { + return 0, err + } + if makeLinks { + if err := os.Link(path.Join(targetPath, fileName), path.Join(targetPath, fileName+"-link")); err != nil { + return 0, err + } + } + } + totalSize := numberOfFiles * len(fileData) + return totalSize, nil +} + +func BenchmarkTarUntar(b *testing.B) { + origin, err := ioutil.TempDir("", "docker-test-untar-origin") + if err != nil { + b.Fatal(err) + } + tempDir, err := ioutil.TempDir("", "docker-test-untar-destination") + if err != nil { + b.Fatal(err) + } + target := path.Join(tempDir, "dest") + n, err := prepareUntarSourceDirectory(100, origin, false) + if err != nil { + b.Fatal(err) + } + defer os.RemoveAll(origin) + defer os.RemoveAll(tempDir) + + b.ResetTimer() + b.SetBytes(int64(n)) + for n := 0; n < b.N; n++ { + err := TarUntar(origin, target) + if err != nil { + b.Fatal(err) + } + os.RemoveAll(target) + } +} + +func BenchmarkTarUntarWithLinks(b *testing.B) { + origin, err := ioutil.TempDir("", "docker-test-untar-origin") + if err != nil { + b.Fatal(err) + } + tempDir, err := ioutil.TempDir("", "docker-test-untar-destination") + if err != nil { + b.Fatal(err) + } + target := path.Join(tempDir, "dest") + n, err := prepareUntarSourceDirectory(100, origin, true) + if err != nil { + b.Fatal(err) + } + defer os.RemoveAll(origin) + defer os.RemoveAll(tempDir) + + b.ResetTimer() + b.SetBytes(int64(n)) + for n := 0; n < b.N; n++ { + err := TarUntar(origin, target) + if err != nil { + b.Fatal(err) + } + os.RemoveAll(target) + } +} + +func TestUntarInvalidFilenames(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { + { + Name: "../victim/dotdot", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { + { + // Note the leading slash + Name: "/../victim/slash-dotdot", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("untar", "docker-TestUntarInvalidFilenames", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} + +func TestUntarHardlinkToSymlink(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { + { + Name: "symlink1", + Typeflag: tar.TypeSymlink, + Linkname: "regfile", + Mode: 0644, + }, + { + Name: "symlink2", + Typeflag: tar.TypeLink, + Linkname: "symlink1", + Mode: 0644, + }, + { + Name: "regfile", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("untar", "docker-TestUntarHardlinkToSymlink", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} + +func TestUntarInvalidHardlink(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { // try reading victim/hello (../) + { + Name: "dotdot", + Typeflag: tar.TypeLink, + Linkname: "../victim/hello", + Mode: 0644, + }, + }, + { // try reading victim/hello (/../) + { + Name: "slash-dotdot", + Typeflag: tar.TypeLink, + // Note the leading slash + Linkname: "/../victim/hello", + Mode: 0644, + }, + }, + { // try writing victim/file + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim/file", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { // try reading victim/hello (hardlink, symlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "symlink", + Typeflag: tar.TypeSymlink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // Try reading victim/hello (hardlink, hardlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "hardlink", + Typeflag: tar.TypeLink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // Try removing victim directory (hardlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("untar", "docker-TestUntarInvalidHardlink", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} + +func TestUntarInvalidSymlink(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { // try reading victim/hello (../) + { + Name: "dotdot", + Typeflag: tar.TypeSymlink, + Linkname: "../victim/hello", + Mode: 0644, + }, + }, + { // try reading victim/hello (/../) + { + Name: "slash-dotdot", + Typeflag: tar.TypeSymlink, + // Note the leading slash + Linkname: "/../victim/hello", + Mode: 0644, + }, + }, + { // try writing victim/file + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim/file", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { // try reading victim/hello (symlink, symlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "symlink", + Typeflag: tar.TypeSymlink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // try reading victim/hello (symlink, hardlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "hardlink", + Typeflag: tar.TypeLink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // try removing victim directory (symlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { // try writing to victim/newdir/newfile with a symlink in the path + { + // this header needs to be before the next one, or else there is an error + Name: "dir/loophole", + Typeflag: tar.TypeSymlink, + Linkname: "../../victim", + Mode: 0755, + }, + { + Name: "dir/loophole/newdir/newfile", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("untar", "docker-TestUntarInvalidSymlink", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} + +func TestTempArchiveCloseMultipleTimes(t *testing.T) { + reader := ioutil.NopCloser(strings.NewReader("hello")) + tempArchive, err := NewTempArchive(reader, "") + buf := make([]byte, 10) + n, err := tempArchive.Read(buf) + if n != 5 { + t.Fatalf("Expected to read 5 bytes. Read %d instead", n) + } + for i := 0; i < 3; i++ { + if err = tempArchive.Close(); err != nil { + t.Fatalf("i=%d. Unexpected error closing temp archive: %v", i, err) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_unix.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_unix.go new file mode 100644 index 0000000..9e1dfad --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_unix.go @@ -0,0 +1,89 @@ +// +build !windows + +package archive + +import ( + "archive/tar" + "errors" + "os" + "syscall" + + "github.com/docker/docker/pkg/system" +) + +// CanonicalTarNameForPath returns platform-specific filepath +// to canonical posix-style path for tar archival. p is relative +// path. +func CanonicalTarNameForPath(p string) (string, error) { + return p, nil // already unix-style +} + +// chmodTarEntry is used to adjust the file permissions used in tar header based +// on the platform the archival is done. + +func chmodTarEntry(perm os.FileMode) os.FileMode { + return perm // noop for unix as golang APIs provide perm bits correctly +} + +func setHeaderForSpecialDevice(hdr *tar.Header, ta *tarAppender, name string, stat interface{}) (nlink uint32, inode uint64, err error) { + s, ok := stat.(*syscall.Stat_t) + + if !ok { + err = errors.New("cannot convert stat value to syscall.Stat_t") + return + } + + nlink = uint32(s.Nlink) + inode = uint64(s.Ino) + + // Currently go does not fil in the major/minors + if s.Mode&syscall.S_IFBLK != 0 || + s.Mode&syscall.S_IFCHR != 0 { + hdr.Devmajor = int64(major(uint64(s.Rdev))) + hdr.Devminor = int64(minor(uint64(s.Rdev))) + } + + return +} + +func major(device uint64) uint64 { + return (device >> 8) & 0xfff +} + +func minor(device uint64) uint64 { + return (device & 0xff) | ((device >> 12) & 0xfff00) +} + +// handleTarTypeBlockCharFifo is an OS-specific helper function used by +// createTarFile to handle the following types of header: Block; Char; Fifo +func handleTarTypeBlockCharFifo(hdr *tar.Header, path string) error { + mode := uint32(hdr.Mode & 07777) + switch hdr.Typeflag { + case tar.TypeBlock: + mode |= syscall.S_IFBLK + case tar.TypeChar: + mode |= syscall.S_IFCHR + case tar.TypeFifo: + mode |= syscall.S_IFIFO + } + + if err := system.Mknod(path, mode, int(system.Mkdev(hdr.Devmajor, hdr.Devminor))); err != nil { + return err + } + return nil +} + +func handleLChmod(hdr *tar.Header, path string, hdrInfo os.FileInfo) error { + if hdr.Typeflag == tar.TypeLink { + if fi, err := os.Lstat(hdr.Linkname); err == nil && (fi.Mode()&os.ModeSymlink == 0) { + if err := os.Chmod(path, hdrInfo.Mode()); err != nil { + return err + } + } + } else if hdr.Typeflag != tar.TypeSymlink { + if err := os.Chmod(path, hdrInfo.Mode()); err != nil { + return err + } + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_unix_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_unix_test.go new file mode 100644 index 0000000..18f45c4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_unix_test.go @@ -0,0 +1,60 @@ +// +build !windows + +package archive + +import ( + "os" + "testing" +) + +func TestCanonicalTarNameForPath(t *testing.T) { + cases := []struct{ in, expected string }{ + {"foo", "foo"}, + {"foo/bar", "foo/bar"}, + {"foo/dir/", "foo/dir/"}, + } + for _, v := range cases { + if out, err := CanonicalTarNameForPath(v.in); err != nil { + t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err) + } else if out != v.expected { + t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out) + } + } +} + +func TestCanonicalTarName(t *testing.T) { + cases := []struct { + in string + isDir bool + expected string + }{ + {"foo", false, "foo"}, + {"foo", true, "foo/"}, + {"foo/bar", false, "foo/bar"}, + {"foo/bar", true, "foo/bar/"}, + } + for _, v := range cases { + if out, err := canonicalTarName(v.in, v.isDir); err != nil { + t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err) + } else if out != v.expected { + t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out) + } + } +} + +func TestChmodTarEntry(t *testing.T) { + cases := []struct { + in, expected os.FileMode + }{ + {0000, 0000}, + {0777, 0777}, + {0644, 0644}, + {0755, 0755}, + {0444, 0444}, + } + for _, v := range cases { + if out := chmodTarEntry(v.in); out != v.expected { + t.Fatalf("wrong chmod. expected:%v got:%v", v.expected, out) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_windows.go new file mode 100644 index 0000000..10db4bd --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_windows.go @@ -0,0 +1,50 @@ +// +build windows + +package archive + +import ( + "archive/tar" + "fmt" + "os" + "strings" +) + +// canonicalTarNameForPath returns platform-specific filepath +// to canonical posix-style path for tar archival. p is relative +// path. +func CanonicalTarNameForPath(p string) (string, error) { + // windows: convert windows style relative path with backslashes + // into forward slashes. Since windows does not allow '/' or '\' + // in file names, it is mostly safe to replace however we must + // check just in case + if strings.Contains(p, "/") { + return "", fmt.Errorf("Windows path contains forward slash: %s", p) + } + return strings.Replace(p, string(os.PathSeparator), "/", -1), nil + +} + +// chmodTarEntry is used to adjust the file permissions used in tar header based +// on the platform the archival is done. +func chmodTarEntry(perm os.FileMode) os.FileMode { + perm &= 0755 + // Add the x bit: make everything +x from windows + perm |= 0111 + + return perm +} + +func setHeaderForSpecialDevice(hdr *tar.Header, ta *tarAppender, name string, stat interface{}) (nlink uint32, inode uint64, err error) { + // do nothing. no notion of Rdev, Inode, Nlink in stat on Windows + return +} + +// handleTarTypeBlockCharFifo is an OS-specific helper function used by +// createTarFile to handle the following types of header: Block; Char; Fifo +func handleTarTypeBlockCharFifo(hdr *tar.Header, path string) error { + return nil +} + +func handleLChmod(hdr *tar.Header, path string, hdrInfo os.FileInfo) error { + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_windows_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_windows_test.go new file mode 100644 index 0000000..72bc71e --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/archive_windows_test.go @@ -0,0 +1,65 @@ +// +build windows + +package archive + +import ( + "os" + "testing" +) + +func TestCanonicalTarNameForPath(t *testing.T) { + cases := []struct { + in, expected string + shouldFail bool + }{ + {"foo", "foo", false}, + {"foo/bar", "___", true}, // unix-styled windows path must fail + {`foo\bar`, "foo/bar", false}, + } + for _, v := range cases { + if out, err := CanonicalTarNameForPath(v.in); err != nil && !v.shouldFail { + t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err) + } else if v.shouldFail && err == nil { + t.Fatalf("canonical path call should have failed with error. in=%s out=%s", v.in, out) + } else if !v.shouldFail && out != v.expected { + t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out) + } + } +} + +func TestCanonicalTarName(t *testing.T) { + cases := []struct { + in string + isDir bool + expected string + }{ + {"foo", false, "foo"}, + {"foo", true, "foo/"}, + {`foo\bar`, false, "foo/bar"}, + {`foo\bar`, true, "foo/bar/"}, + } + for _, v := range cases { + if out, err := canonicalTarName(v.in, v.isDir); err != nil { + t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err) + } else if out != v.expected { + t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out) + } + } +} + +func TestChmodTarEntry(t *testing.T) { + cases := []struct { + in, expected os.FileMode + }{ + {0000, 0111}, + {0777, 0755}, + {0644, 0755}, + {0755, 0755}, + {0444, 0555}, + } + for _, v := range cases { + if out := chmodTarEntry(v.in); out != v.expected { + t.Fatalf("wrong chmod. expected:%v got:%v", v.expected, out) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes.go new file mode 100644 index 0000000..689d9a2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes.go @@ -0,0 +1,383 @@ +package archive + +import ( + "archive/tar" + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "sort" + "strings" + "syscall" + "time" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/pools" + "github.com/docker/docker/pkg/system" +) + +type ChangeType int + +const ( + ChangeModify = iota + ChangeAdd + ChangeDelete +) + +type Change struct { + Path string + Kind ChangeType +} + +func (change *Change) String() string { + var kind string + switch change.Kind { + case ChangeModify: + kind = "C" + case ChangeAdd: + kind = "A" + case ChangeDelete: + kind = "D" + } + return fmt.Sprintf("%s %s", kind, change.Path) +} + +// for sort.Sort +type changesByPath []Change + +func (c changesByPath) Less(i, j int) bool { return c[i].Path < c[j].Path } +func (c changesByPath) Len() int { return len(c) } +func (c changesByPath) Swap(i, j int) { c[j], c[i] = c[i], c[j] } + +// Gnu tar and the go tar writer don't have sub-second mtime +// precision, which is problematic when we apply changes via tar +// files, we handle this by comparing for exact times, *or* same +// second count and either a or b having exactly 0 nanoseconds +func sameFsTime(a, b time.Time) bool { + return a == b || + (a.Unix() == b.Unix() && + (a.Nanosecond() == 0 || b.Nanosecond() == 0)) +} + +func sameFsTimeSpec(a, b syscall.Timespec) bool { + return a.Sec == b.Sec && + (a.Nsec == b.Nsec || a.Nsec == 0 || b.Nsec == 0) +} + +// Changes walks the path rw and determines changes for the files in the path, +// with respect to the parent layers +func Changes(layers []string, rw string) ([]Change, error) { + var ( + changes []Change + changedDirs = make(map[string]struct{}) + ) + + err := filepath.Walk(rw, func(path string, f os.FileInfo, err error) error { + if err != nil { + return err + } + + // Rebase path + path, err = filepath.Rel(rw, path) + if err != nil { + return err + } + + // As this runs on the daemon side, file paths are OS specific. + path = filepath.Join(string(os.PathSeparator), path) + + // Skip root + if path == string(os.PathSeparator) { + return nil + } + + // Skip AUFS metadata + if matched, err := filepath.Match(string(os.PathSeparator)+".wh..wh.*", path); err != nil || matched { + return err + } + + change := Change{ + Path: path, + } + + // Find out what kind of modification happened + file := filepath.Base(path) + // If there is a whiteout, then the file was removed + if strings.HasPrefix(file, ".wh.") { + originalFile := file[len(".wh."):] + change.Path = filepath.Join(filepath.Dir(path), originalFile) + change.Kind = ChangeDelete + } else { + // Otherwise, the file was added + change.Kind = ChangeAdd + + // ...Unless it already existed in a top layer, in which case, it's a modification + for _, layer := range layers { + stat, err := os.Stat(filepath.Join(layer, path)) + if err != nil && !os.IsNotExist(err) { + return err + } + if err == nil { + // The file existed in the top layer, so that's a modification + + // However, if it's a directory, maybe it wasn't actually modified. + // If you modify /foo/bar/baz, then /foo will be part of the changed files only because it's the parent of bar + if stat.IsDir() && f.IsDir() { + if f.Size() == stat.Size() && f.Mode() == stat.Mode() && sameFsTime(f.ModTime(), stat.ModTime()) { + // Both directories are the same, don't record the change + return nil + } + } + change.Kind = ChangeModify + break + } + } + } + + // If /foo/bar/file.txt is modified, then /foo/bar must be part of the changed files. + // This block is here to ensure the change is recorded even if the + // modify time, mode and size of the parent directoriy in the rw and ro layers are all equal. + // Check https://github.com/docker/docker/pull/13590 for details. + if f.IsDir() { + changedDirs[path] = struct{}{} + } + if change.Kind == ChangeAdd || change.Kind == ChangeDelete { + parent := filepath.Dir(path) + if _, ok := changedDirs[parent]; !ok && parent != "/" { + changes = append(changes, Change{Path: parent, Kind: ChangeModify}) + changedDirs[parent] = struct{}{} + } + } + + // Record change + changes = append(changes, change) + return nil + }) + if err != nil && !os.IsNotExist(err) { + return nil, err + } + return changes, nil +} + +type FileInfo struct { + parent *FileInfo + name string + stat *system.Stat_t + children map[string]*FileInfo + capability []byte + added bool +} + +func (root *FileInfo) LookUp(path string) *FileInfo { + // As this runs on the daemon side, file paths are OS specific. + parent := root + if path == string(os.PathSeparator) { + return root + } + + pathElements := strings.Split(path, string(os.PathSeparator)) + for _, elem := range pathElements { + if elem != "" { + child := parent.children[elem] + if child == nil { + return nil + } + parent = child + } + } + return parent +} + +func (info *FileInfo) path() string { + if info.parent == nil { + // As this runs on the daemon side, file paths are OS specific. + return string(os.PathSeparator) + } + return filepath.Join(info.parent.path(), info.name) +} + +func (info *FileInfo) addChanges(oldInfo *FileInfo, changes *[]Change) { + + sizeAtEntry := len(*changes) + + if oldInfo == nil { + // add + change := Change{ + Path: info.path(), + Kind: ChangeAdd, + } + *changes = append(*changes, change) + info.added = true + } + + // We make a copy so we can modify it to detect additions + // also, we only recurse on the old dir if the new info is a directory + // otherwise any previous delete/change is considered recursive + oldChildren := make(map[string]*FileInfo) + if oldInfo != nil && info.isDir() { + for k, v := range oldInfo.children { + oldChildren[k] = v + } + } + + for name, newChild := range info.children { + oldChild, _ := oldChildren[name] + if oldChild != nil { + // change? + oldStat := oldChild.stat + newStat := newChild.stat + // Note: We can't compare inode or ctime or blocksize here, because these change + // when copying a file into a container. However, that is not generally a problem + // because any content change will change mtime, and any status change should + // be visible when actually comparing the stat fields. The only time this + // breaks down is if some code intentionally hides a change by setting + // back mtime + if statDifferent(oldStat, newStat) || + bytes.Compare(oldChild.capability, newChild.capability) != 0 { + change := Change{ + Path: newChild.path(), + Kind: ChangeModify, + } + *changes = append(*changes, change) + newChild.added = true + } + + // Remove from copy so we can detect deletions + delete(oldChildren, name) + } + + newChild.addChanges(oldChild, changes) + } + for _, oldChild := range oldChildren { + // delete + change := Change{ + Path: oldChild.path(), + Kind: ChangeDelete, + } + *changes = append(*changes, change) + } + + // If there were changes inside this directory, we need to add it, even if the directory + // itself wasn't changed. This is needed to properly save and restore filesystem permissions. + // As this runs on the daemon side, file paths are OS specific. + if len(*changes) > sizeAtEntry && info.isDir() && !info.added && info.path() != string(os.PathSeparator) { + change := Change{ + Path: info.path(), + Kind: ChangeModify, + } + // Let's insert the directory entry before the recently added entries located inside this dir + *changes = append(*changes, change) // just to resize the slice, will be overwritten + copy((*changes)[sizeAtEntry+1:], (*changes)[sizeAtEntry:]) + (*changes)[sizeAtEntry] = change + } + +} + +func (info *FileInfo) Changes(oldInfo *FileInfo) []Change { + var changes []Change + + info.addChanges(oldInfo, &changes) + + return changes +} + +func newRootFileInfo() *FileInfo { + // As this runs on the daemon side, file paths are OS specific. + root := &FileInfo{ + name: string(os.PathSeparator), + children: make(map[string]*FileInfo), + } + return root +} + +// ChangesDirs compares two directories and generates an array of Change objects describing the changes. +// If oldDir is "", then all files in newDir will be Add-Changes. +func ChangesDirs(newDir, oldDir string) ([]Change, error) { + var ( + oldRoot, newRoot *FileInfo + ) + if oldDir == "" { + emptyDir, err := ioutil.TempDir("", "empty") + if err != nil { + return nil, err + } + defer os.Remove(emptyDir) + oldDir = emptyDir + } + oldRoot, newRoot, err := collectFileInfoForChanges(oldDir, newDir) + if err != nil { + return nil, err + } + + return newRoot.Changes(oldRoot), nil +} + +// ChangesSize calculates the size in bytes of the provided changes, based on newDir. +func ChangesSize(newDir string, changes []Change) int64 { + var size int64 + for _, change := range changes { + if change.Kind == ChangeModify || change.Kind == ChangeAdd { + file := filepath.Join(newDir, change.Path) + fileInfo, _ := os.Lstat(file) + if fileInfo != nil && !fileInfo.IsDir() { + size += fileInfo.Size() + } + } + } + return size +} + +// ExportChanges produces an Archive from the provided changes, relative to dir. +func ExportChanges(dir string, changes []Change) (Archive, error) { + reader, writer := io.Pipe() + go func() { + ta := &tarAppender{ + TarWriter: tar.NewWriter(writer), + Buffer: pools.BufioWriter32KPool.Get(nil), + SeenFiles: make(map[uint64]string), + } + // this buffer is needed for the duration of this piped stream + defer pools.BufioWriter32KPool.Put(ta.Buffer) + + sort.Sort(changesByPath(changes)) + + // In general we log errors here but ignore them because + // during e.g. a diff operation the container can continue + // mutating the filesystem and we can see transient errors + // from this + for _, change := range changes { + if change.Kind == ChangeDelete { + whiteOutDir := filepath.Dir(change.Path) + whiteOutBase := filepath.Base(change.Path) + whiteOut := filepath.Join(whiteOutDir, ".wh."+whiteOutBase) + timestamp := time.Now() + hdr := &tar.Header{ + Name: whiteOut[1:], + Size: 0, + ModTime: timestamp, + AccessTime: timestamp, + ChangeTime: timestamp, + } + if err := ta.TarWriter.WriteHeader(hdr); err != nil { + logrus.Debugf("Can't write whiteout header: %s", err) + } + } else { + path := filepath.Join(dir, change.Path) + if err := ta.addTarFile(path, change.Path[1:]); err != nil { + logrus.Debugf("Can't add file %s to tar: %s", path, err) + } + } + } + + // Make sure to check the error on Close. + if err := ta.TarWriter.Close(); err != nil { + logrus.Debugf("Can't close layer: %s", err) + } + if err := writer.Close(); err != nil { + logrus.Debugf("failed close Changes writer: %s", err) + } + }() + return reader, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_linux.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_linux.go new file mode 100644 index 0000000..dee8b7c --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_linux.go @@ -0,0 +1,285 @@ +package archive + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sort" + "syscall" + "unsafe" + + "github.com/docker/docker/pkg/system" +) + +// walker is used to implement collectFileInfoForChanges on linux. Where this +// method in general returns the entire contents of two directory trees, we +// optimize some FS calls out on linux. In particular, we take advantage of the +// fact that getdents(2) returns the inode of each file in the directory being +// walked, which, when walking two trees in parallel to generate a list of +// changes, can be used to prune subtrees without ever having to lstat(2) them +// directly. Eliminating stat calls in this way can save up to seconds on large +// images. +type walker struct { + dir1 string + dir2 string + root1 *FileInfo + root2 *FileInfo +} + +// collectFileInfoForChanges returns a complete representation of the trees +// rooted at dir1 and dir2, with one important exception: any subtree or +// leaf where the inode and device numbers are an exact match between dir1 +// and dir2 will be pruned from the results. This method is *only* to be used +// to generating a list of changes between the two directories, as it does not +// reflect the full contents. +func collectFileInfoForChanges(dir1, dir2 string) (*FileInfo, *FileInfo, error) { + w := &walker{ + dir1: dir1, + dir2: dir2, + root1: newRootFileInfo(), + root2: newRootFileInfo(), + } + + i1, err := os.Lstat(w.dir1) + if err != nil { + return nil, nil, err + } + i2, err := os.Lstat(w.dir2) + if err != nil { + return nil, nil, err + } + + if err := w.walk("/", i1, i2); err != nil { + return nil, nil, err + } + + return w.root1, w.root2, nil +} + +// Given a FileInfo, its path info, and a reference to the root of the tree +// being constructed, register this file with the tree. +func walkchunk(path string, fi os.FileInfo, dir string, root *FileInfo) error { + if fi == nil { + return nil + } + parent := root.LookUp(filepath.Dir(path)) + if parent == nil { + return fmt.Errorf("collectFileInfoForChanges: Unexpectedly no parent for %s", path) + } + info := &FileInfo{ + name: filepath.Base(path), + children: make(map[string]*FileInfo), + parent: parent, + } + cpath := filepath.Join(dir, path) + stat, err := system.FromStatT(fi.Sys().(*syscall.Stat_t)) + if err != nil { + return err + } + info.stat = stat + info.capability, _ = system.Lgetxattr(cpath, "security.capability") // lgetxattr(2): fs access + parent.children[info.name] = info + return nil +} + +// Walk a subtree rooted at the same path in both trees being iterated. For +// example, /docker/overlay/1234/a/b/c/d and /docker/overlay/8888/a/b/c/d +func (w *walker) walk(path string, i1, i2 os.FileInfo) (err error) { + // Register these nodes with the return trees, unless we're still at the + // (already-created) roots: + if path != "/" { + if err := walkchunk(path, i1, w.dir1, w.root1); err != nil { + return err + } + if err := walkchunk(path, i2, w.dir2, w.root2); err != nil { + return err + } + } + + is1Dir := i1 != nil && i1.IsDir() + is2Dir := i2 != nil && i2.IsDir() + + sameDevice := false + if i1 != nil && i2 != nil { + si1 := i1.Sys().(*syscall.Stat_t) + si2 := i2.Sys().(*syscall.Stat_t) + if si1.Dev == si2.Dev { + sameDevice = true + } + } + + // If these files are both non-existent, or leaves (non-dirs), we are done. + if !is1Dir && !is2Dir { + return nil + } + + // Fetch the names of all the files contained in both directories being walked: + var names1, names2 []nameIno + if is1Dir { + names1, err = readdirnames(filepath.Join(w.dir1, path)) // getdents(2): fs access + if err != nil { + return err + } + } + if is2Dir { + names2, err = readdirnames(filepath.Join(w.dir2, path)) // getdents(2): fs access + if err != nil { + return err + } + } + + // We have lists of the files contained in both parallel directories, sorted + // in the same order. Walk them in parallel, generating a unique merged list + // of all items present in either or both directories. + var names []string + ix1 := 0 + ix2 := 0 + + for { + if ix1 >= len(names1) { + break + } + if ix2 >= len(names2) { + break + } + + ni1 := names1[ix1] + ni2 := names2[ix2] + + switch bytes.Compare([]byte(ni1.name), []byte(ni2.name)) { + case -1: // ni1 < ni2 -- advance ni1 + // we will not encounter ni1 in names2 + names = append(names, ni1.name) + ix1++ + case 0: // ni1 == ni2 + if ni1.ino != ni2.ino || !sameDevice { + names = append(names, ni1.name) + } + ix1++ + ix2++ + case 1: // ni1 > ni2 -- advance ni2 + // we will not encounter ni2 in names1 + names = append(names, ni2.name) + ix2++ + } + } + for ix1 < len(names1) { + names = append(names, names1[ix1].name) + ix1++ + } + for ix2 < len(names2) { + names = append(names, names2[ix2].name) + ix2++ + } + + // For each of the names present in either or both of the directories being + // iterated, stat the name under each root, and recurse the pair of them: + for _, name := range names { + fname := filepath.Join(path, name) + var cInfo1, cInfo2 os.FileInfo + if is1Dir { + cInfo1, err = os.Lstat(filepath.Join(w.dir1, fname)) // lstat(2): fs access + if err != nil && !os.IsNotExist(err) { + return err + } + } + if is2Dir { + cInfo2, err = os.Lstat(filepath.Join(w.dir2, fname)) // lstat(2): fs access + if err != nil && !os.IsNotExist(err) { + return err + } + } + if err = w.walk(fname, cInfo1, cInfo2); err != nil { + return err + } + } + return nil +} + +// {name,inode} pairs used to support the early-pruning logic of the walker type +type nameIno struct { + name string + ino uint64 +} + +type nameInoSlice []nameIno + +func (s nameInoSlice) Len() int { return len(s) } +func (s nameInoSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s nameInoSlice) Less(i, j int) bool { return s[i].name < s[j].name } + +// readdirnames is a hacked-apart version of the Go stdlib code, exposing inode +// numbers further up the stack when reading directory contents. Unlike +// os.Readdirnames, which returns a list of filenames, this function returns a +// list of {filename,inode} pairs. +func readdirnames(dirname string) (names []nameIno, err error) { + var ( + size = 100 + buf = make([]byte, 4096) + nbuf int + bufp int + nb int + ) + + f, err := os.Open(dirname) + if err != nil { + return nil, err + } + defer f.Close() + + names = make([]nameIno, 0, size) // Empty with room to grow. + for { + // Refill the buffer if necessary + if bufp >= nbuf { + bufp = 0 + nbuf, err = syscall.ReadDirent(int(f.Fd()), buf) // getdents on linux + if nbuf < 0 { + nbuf = 0 + } + if err != nil { + return nil, os.NewSyscallError("readdirent", err) + } + if nbuf <= 0 { + break // EOF + } + } + + // Drain the buffer + nb, names = parseDirent(buf[bufp:nbuf], names) + bufp += nb + } + + sl := nameInoSlice(names) + sort.Sort(sl) + return sl, nil +} + +// parseDirent is a minor modification of syscall.ParseDirent (linux version) +// which returns {name,inode} pairs instead of just names. +func parseDirent(buf []byte, names []nameIno) (consumed int, newnames []nameIno) { + origlen := len(buf) + for len(buf) > 0 { + dirent := (*syscall.Dirent)(unsafe.Pointer(&buf[0])) + buf = buf[dirent.Reclen:] + if dirent.Ino == 0 { // File absent in directory. + continue + } + bytes := (*[10000]byte)(unsafe.Pointer(&dirent.Name[0])) + var name = string(bytes[0:clen(bytes[:])]) + if name == "." || name == ".." { // Useless names + continue + } + names = append(names, nameIno{name, dirent.Ino}) + } + return origlen - len(buf), names +} + +func clen(n []byte) int { + for i := 0; i < len(n); i++ { + if n[i] == 0 { + return i + } + } + return len(n) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_other.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_other.go new file mode 100644 index 0000000..da70ed3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_other.go @@ -0,0 +1,97 @@ +// +build !linux + +package archive + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/docker/docker/pkg/system" +) + +func collectFileInfoForChanges(oldDir, newDir string) (*FileInfo, *FileInfo, error) { + var ( + oldRoot, newRoot *FileInfo + err1, err2 error + errs = make(chan error, 2) + ) + go func() { + oldRoot, err1 = collectFileInfo(oldDir) + errs <- err1 + }() + go func() { + newRoot, err2 = collectFileInfo(newDir) + errs <- err2 + }() + + // block until both routines have returned + for i := 0; i < 2; i++ { + if err := <-errs; err != nil { + return nil, nil, err + } + } + + return oldRoot, newRoot, nil +} + +func collectFileInfo(sourceDir string) (*FileInfo, error) { + root := newRootFileInfo() + + err := filepath.Walk(sourceDir, func(path string, f os.FileInfo, err error) error { + if err != nil { + return err + } + + // Rebase path + relPath, err := filepath.Rel(sourceDir, path) + if err != nil { + return err + } + + // As this runs on the daemon side, file paths are OS specific. + relPath = filepath.Join(string(os.PathSeparator), relPath) + + // See https://github.com/golang/go/issues/9168 - bug in filepath.Join. + // Temporary workaround. If the returned path starts with two backslashes, + // trim it down to a single backslash. Only relevant on Windows. + if runtime.GOOS == "windows" { + if strings.HasPrefix(relPath, `\\`) { + relPath = relPath[1:] + } + } + + if relPath == string(os.PathSeparator) { + return nil + } + + parent := root.LookUp(filepath.Dir(relPath)) + if parent == nil { + return fmt.Errorf("collectFileInfo: Unexpectedly no parent for %s", relPath) + } + + info := &FileInfo{ + name: filepath.Base(relPath), + children: make(map[string]*FileInfo), + parent: parent, + } + + s, err := system.Lstat(path) + if err != nil { + return err + } + info.stat = s + + info.capability, _ = system.Lgetxattr(path, "security.capability") + + parent.children[info.name] = info + + return nil + }) + if err != nil { + return nil, err + } + return root, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_posix_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_posix_test.go new file mode 100644 index 0000000..9d528e6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_posix_test.go @@ -0,0 +1,127 @@ +package archive + +import ( + "archive/tar" + "fmt" + "io" + "io/ioutil" + "os" + "path" + "sort" + "testing" +) + +func TestHardLinkOrder(t *testing.T) { + names := []string{"file1.txt", "file2.txt", "file3.txt"} + msg := []byte("Hey y'all") + + // Create dir + src, err := ioutil.TempDir("", "docker-hardlink-test-src-") + if err != nil { + t.Fatal(err) + } + //defer os.RemoveAll(src) + for _, name := range names { + func() { + fh, err := os.Create(path.Join(src, name)) + if err != nil { + t.Fatal(err) + } + defer fh.Close() + if _, err = fh.Write(msg); err != nil { + t.Fatal(err) + } + }() + } + // Create dest, with changes that includes hardlinks + dest, err := ioutil.TempDir("", "docker-hardlink-test-dest-") + if err != nil { + t.Fatal(err) + } + os.RemoveAll(dest) // we just want the name, at first + if err := copyDir(src, dest); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dest) + for _, name := range names { + for i := 0; i < 5; i++ { + if err := os.Link(path.Join(dest, name), path.Join(dest, fmt.Sprintf("%s.link%d", name, i))); err != nil { + t.Fatal(err) + } + } + } + + // get changes + changes, err := ChangesDirs(dest, src) + if err != nil { + t.Fatal(err) + } + + // sort + sort.Sort(changesByPath(changes)) + + // ExportChanges + ar, err := ExportChanges(dest, changes) + if err != nil { + t.Fatal(err) + } + hdrs, err := walkHeaders(ar) + if err != nil { + t.Fatal(err) + } + + // reverse sort + sort.Sort(sort.Reverse(changesByPath(changes))) + // ExportChanges + arRev, err := ExportChanges(dest, changes) + if err != nil { + t.Fatal(err) + } + hdrsRev, err := walkHeaders(arRev) + if err != nil { + t.Fatal(err) + } + + // line up the two sets + sort.Sort(tarHeaders(hdrs)) + sort.Sort(tarHeaders(hdrsRev)) + + // compare Size and LinkName + for i := range hdrs { + if hdrs[i].Name != hdrsRev[i].Name { + t.Errorf("headers - expected name %q; but got %q", hdrs[i].Name, hdrsRev[i].Name) + } + if hdrs[i].Size != hdrsRev[i].Size { + t.Errorf("headers - %q expected size %d; but got %d", hdrs[i].Name, hdrs[i].Size, hdrsRev[i].Size) + } + if hdrs[i].Typeflag != hdrsRev[i].Typeflag { + t.Errorf("headers - %q expected type %d; but got %d", hdrs[i].Name, hdrs[i].Typeflag, hdrsRev[i].Typeflag) + } + if hdrs[i].Linkname != hdrsRev[i].Linkname { + t.Errorf("headers - %q expected linkname %q; but got %q", hdrs[i].Name, hdrs[i].Linkname, hdrsRev[i].Linkname) + } + } + +} + +type tarHeaders []tar.Header + +func (th tarHeaders) Len() int { return len(th) } +func (th tarHeaders) Swap(i, j int) { th[j], th[i] = th[i], th[j] } +func (th tarHeaders) Less(i, j int) bool { return th[i].Name < th[j].Name } + +func walkHeaders(r io.Reader) ([]tar.Header, error) { + t := tar.NewReader(r) + headers := []tar.Header{} + for { + hdr, err := t.Next() + if err != nil { + if err == io.EOF { + break + } + return headers, err + } + headers = append(headers, *hdr) + } + return headers, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_test.go new file mode 100644 index 0000000..509bdb2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_test.go @@ -0,0 +1,495 @@ +package archive + +import ( + "io/ioutil" + "os" + "os/exec" + "path" + "sort" + "testing" + "time" +) + +func max(x, y int) int { + if x >= y { + return x + } + return y +} + +func copyDir(src, dst string) error { + cmd := exec.Command("cp", "-a", src, dst) + if err := cmd.Run(); err != nil { + return err + } + return nil +} + +type FileType uint32 + +const ( + Regular FileType = iota + Dir + Symlink +) + +type FileData struct { + filetype FileType + path string + contents string + permissions os.FileMode +} + +func createSampleDir(t *testing.T, root string) { + files := []FileData{ + {Regular, "file1", "file1\n", 0600}, + {Regular, "file2", "file2\n", 0666}, + {Regular, "file3", "file3\n", 0404}, + {Regular, "file4", "file4\n", 0600}, + {Regular, "file5", "file5\n", 0600}, + {Regular, "file6", "file6\n", 0600}, + {Regular, "file7", "file7\n", 0600}, + {Dir, "dir1", "", 0740}, + {Regular, "dir1/file1-1", "file1-1\n", 01444}, + {Regular, "dir1/file1-2", "file1-2\n", 0666}, + {Dir, "dir2", "", 0700}, + {Regular, "dir2/file2-1", "file2-1\n", 0666}, + {Regular, "dir2/file2-2", "file2-2\n", 0666}, + {Dir, "dir3", "", 0700}, + {Regular, "dir3/file3-1", "file3-1\n", 0666}, + {Regular, "dir3/file3-2", "file3-2\n", 0666}, + {Dir, "dir4", "", 0700}, + {Regular, "dir4/file3-1", "file4-1\n", 0666}, + {Regular, "dir4/file3-2", "file4-2\n", 0666}, + {Symlink, "symlink1", "target1", 0666}, + {Symlink, "symlink2", "target2", 0666}, + } + + now := time.Now() + for _, info := range files { + p := path.Join(root, info.path) + if info.filetype == Dir { + if err := os.MkdirAll(p, info.permissions); err != nil { + t.Fatal(err) + } + } else if info.filetype == Regular { + if err := ioutil.WriteFile(p, []byte(info.contents), info.permissions); err != nil { + t.Fatal(err) + } + } else if info.filetype == Symlink { + if err := os.Symlink(info.contents, p); err != nil { + t.Fatal(err) + } + } + + if info.filetype != Symlink { + // Set a consistent ctime, atime for all files and dirs + if err := os.Chtimes(p, now, now); err != nil { + t.Fatal(err) + } + } + } +} + +func TestChangeString(t *testing.T) { + modifiyChange := Change{"change", ChangeModify} + toString := modifiyChange.String() + if toString != "C change" { + t.Fatalf("String() of a change with ChangeModifiy Kind should have been %s but was %s", "C change", toString) + } + addChange := Change{"change", ChangeAdd} + toString = addChange.String() + if toString != "A change" { + t.Fatalf("String() of a change with ChangeAdd Kind should have been %s but was %s", "A change", toString) + } + deleteChange := Change{"change", ChangeDelete} + toString = deleteChange.String() + if toString != "D change" { + t.Fatalf("String() of a change with ChangeDelete Kind should have been %s but was %s", "D change", toString) + } +} + +func TestChangesWithNoChanges(t *testing.T) { + rwLayer, err := ioutil.TempDir("", "docker-changes-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(rwLayer) + layer, err := ioutil.TempDir("", "docker-changes-test-layer") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(layer) + createSampleDir(t, layer) + changes, err := Changes([]string{layer}, rwLayer) + if err != nil { + t.Fatal(err) + } + if len(changes) != 0 { + t.Fatalf("Changes with no difference should have detect no changes, but detected %d", len(changes)) + } +} + +func TestChangesWithChanges(t *testing.T) { + // Mock the readonly layer + layer, err := ioutil.TempDir("", "docker-changes-test-layer") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(layer) + createSampleDir(t, layer) + os.MkdirAll(path.Join(layer, "dir1/subfolder"), 0740) + + // Mock the RW layer + rwLayer, err := ioutil.TempDir("", "docker-changes-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(rwLayer) + + // Create a folder in RW layer + dir1 := path.Join(rwLayer, "dir1") + os.MkdirAll(dir1, 0740) + deletedFile := path.Join(dir1, ".wh.file1-2") + ioutil.WriteFile(deletedFile, []byte{}, 0600) + modifiedFile := path.Join(dir1, "file1-1") + ioutil.WriteFile(modifiedFile, []byte{0x00}, 01444) + // Let's add a subfolder for a newFile + subfolder := path.Join(dir1, "subfolder") + os.MkdirAll(subfolder, 0740) + newFile := path.Join(subfolder, "newFile") + ioutil.WriteFile(newFile, []byte{}, 0740) + + changes, err := Changes([]string{layer}, rwLayer) + if err != nil { + t.Fatal(err) + } + + expectedChanges := []Change{ + {"/dir1", ChangeModify}, + {"/dir1/file1-1", ChangeModify}, + {"/dir1/file1-2", ChangeDelete}, + {"/dir1/subfolder", ChangeModify}, + {"/dir1/subfolder/newFile", ChangeAdd}, + } + checkChanges(expectedChanges, changes, t) +} + +// See https://github.com/docker/docker/pull/13590 +func TestChangesWithChangesGH13590(t *testing.T) { + baseLayer, err := ioutil.TempDir("", "docker-changes-test.") + defer os.RemoveAll(baseLayer) + + dir3 := path.Join(baseLayer, "dir1/dir2/dir3") + os.MkdirAll(dir3, 07400) + + file := path.Join(dir3, "file.txt") + ioutil.WriteFile(file, []byte("hello"), 0666) + + layer, err := ioutil.TempDir("", "docker-changes-test2.") + defer os.RemoveAll(layer) + + // Test creating a new file + if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil { + t.Fatalf("Cmd failed: %q", err) + } + + os.Remove(path.Join(layer, "dir1/dir2/dir3/file.txt")) + file = path.Join(layer, "dir1/dir2/dir3/file1.txt") + ioutil.WriteFile(file, []byte("bye"), 0666) + + changes, err := Changes([]string{baseLayer}, layer) + if err != nil { + t.Fatal(err) + } + + expectedChanges := []Change{ + {"/dir1/dir2/dir3", ChangeModify}, + {"/dir1/dir2/dir3/file1.txt", ChangeAdd}, + } + checkChanges(expectedChanges, changes, t) + + // Now test changing a file + layer, err = ioutil.TempDir("", "docker-changes-test3.") + defer os.RemoveAll(layer) + + if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil { + t.Fatalf("Cmd failed: %q", err) + } + + file = path.Join(layer, "dir1/dir2/dir3/file.txt") + ioutil.WriteFile(file, []byte("bye"), 0666) + + changes, err = Changes([]string{baseLayer}, layer) + if err != nil { + t.Fatal(err) + } + + expectedChanges = []Change{ + {"/dir1/dir2/dir3/file.txt", ChangeModify}, + } + checkChanges(expectedChanges, changes, t) +} + +// Create an directory, copy it, make sure we report no changes between the two +func TestChangesDirsEmpty(t *testing.T) { + src, err := ioutil.TempDir("", "docker-changes-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(src) + createSampleDir(t, src) + dst := src + "-copy" + if err := copyDir(src, dst); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dst) + changes, err := ChangesDirs(dst, src) + if err != nil { + t.Fatal(err) + } + + if len(changes) != 0 { + t.Fatalf("Reported changes for identical dirs: %v", changes) + } + os.RemoveAll(src) + os.RemoveAll(dst) +} + +func mutateSampleDir(t *testing.T, root string) { + // Remove a regular file + if err := os.RemoveAll(path.Join(root, "file1")); err != nil { + t.Fatal(err) + } + + // Remove a directory + if err := os.RemoveAll(path.Join(root, "dir1")); err != nil { + t.Fatal(err) + } + + // Remove a symlink + if err := os.RemoveAll(path.Join(root, "symlink1")); err != nil { + t.Fatal(err) + } + + // Rewrite a file + if err := ioutil.WriteFile(path.Join(root, "file2"), []byte("fileNN\n"), 0777); err != nil { + t.Fatal(err) + } + + // Replace a file + if err := os.RemoveAll(path.Join(root, "file3")); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(root, "file3"), []byte("fileMM\n"), 0404); err != nil { + t.Fatal(err) + } + + // Touch file + if err := os.Chtimes(path.Join(root, "file4"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } + + // Replace file with dir + if err := os.RemoveAll(path.Join(root, "file5")); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(path.Join(root, "file5"), 0666); err != nil { + t.Fatal(err) + } + + // Create new file + if err := ioutil.WriteFile(path.Join(root, "filenew"), []byte("filenew\n"), 0777); err != nil { + t.Fatal(err) + } + + // Create new dir + if err := os.MkdirAll(path.Join(root, "dirnew"), 0766); err != nil { + t.Fatal(err) + } + + // Create a new symlink + if err := os.Symlink("targetnew", path.Join(root, "symlinknew")); err != nil { + t.Fatal(err) + } + + // Change a symlink + if err := os.RemoveAll(path.Join(root, "symlink2")); err != nil { + t.Fatal(err) + } + if err := os.Symlink("target2change", path.Join(root, "symlink2")); err != nil { + t.Fatal(err) + } + + // Replace dir with file + if err := os.RemoveAll(path.Join(root, "dir2")); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(path.Join(root, "dir2"), []byte("dir2\n"), 0777); err != nil { + t.Fatal(err) + } + + // Touch dir + if err := os.Chtimes(path.Join(root, "dir3"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } +} + +func TestChangesDirsMutated(t *testing.T) { + src, err := ioutil.TempDir("", "docker-changes-test") + if err != nil { + t.Fatal(err) + } + createSampleDir(t, src) + dst := src + "-copy" + if err := copyDir(src, dst); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(src) + defer os.RemoveAll(dst) + + mutateSampleDir(t, dst) + + changes, err := ChangesDirs(dst, src) + if err != nil { + t.Fatal(err) + } + + sort.Sort(changesByPath(changes)) + + expectedChanges := []Change{ + {"/dir1", ChangeDelete}, + {"/dir2", ChangeModify}, + {"/dirnew", ChangeAdd}, + {"/file1", ChangeDelete}, + {"/file2", ChangeModify}, + {"/file3", ChangeModify}, + {"/file4", ChangeModify}, + {"/file5", ChangeModify}, + {"/filenew", ChangeAdd}, + {"/symlink1", ChangeDelete}, + {"/symlink2", ChangeModify}, + {"/symlinknew", ChangeAdd}, + } + + for i := 0; i < max(len(changes), len(expectedChanges)); i++ { + if i >= len(expectedChanges) { + t.Fatalf("unexpected change %s\n", changes[i].String()) + } + if i >= len(changes) { + t.Fatalf("no change for expected change %s\n", expectedChanges[i].String()) + } + if changes[i].Path == expectedChanges[i].Path { + if changes[i] != expectedChanges[i] { + t.Fatalf("Wrong change for %s, expected %s, got %s\n", changes[i].Path, changes[i].String(), expectedChanges[i].String()) + } + } else if changes[i].Path < expectedChanges[i].Path { + t.Fatalf("unexpected change %s\n", changes[i].String()) + } else { + t.Fatalf("no change for expected change %s != %s\n", expectedChanges[i].String(), changes[i].String()) + } + } +} + +func TestApplyLayer(t *testing.T) { + src, err := ioutil.TempDir("", "docker-changes-test") + if err != nil { + t.Fatal(err) + } + createSampleDir(t, src) + defer os.RemoveAll(src) + dst := src + "-copy" + if err := copyDir(src, dst); err != nil { + t.Fatal(err) + } + mutateSampleDir(t, dst) + defer os.RemoveAll(dst) + + changes, err := ChangesDirs(dst, src) + if err != nil { + t.Fatal(err) + } + + layer, err := ExportChanges(dst, changes) + if err != nil { + t.Fatal(err) + } + + layerCopy, err := NewTempArchive(layer, "") + if err != nil { + t.Fatal(err) + } + + if _, err := ApplyLayer(src, layerCopy); err != nil { + t.Fatal(err) + } + + changes2, err := ChangesDirs(src, dst) + if err != nil { + t.Fatal(err) + } + + if len(changes2) != 0 { + t.Fatalf("Unexpected differences after reapplying mutation: %v", changes2) + } +} + +func TestChangesSizeWithNoChanges(t *testing.T) { + size := ChangesSize("/tmp", nil) + if size != 0 { + t.Fatalf("ChangesSizes with no changes should be 0, was %d", size) + } +} + +func TestChangesSizeWithOnlyDeleteChanges(t *testing.T) { + changes := []Change{ + {Path: "deletedPath", Kind: ChangeDelete}, + } + size := ChangesSize("/tmp", changes) + if size != 0 { + t.Fatalf("ChangesSizes with only delete changes should be 0, was %d", size) + } +} + +func TestChangesSize(t *testing.T) { + parentPath, err := ioutil.TempDir("", "docker-changes-test") + defer os.RemoveAll(parentPath) + addition := path.Join(parentPath, "addition") + if err := ioutil.WriteFile(addition, []byte{0x01, 0x01, 0x01}, 0744); err != nil { + t.Fatal(err) + } + modification := path.Join(parentPath, "modification") + if err = ioutil.WriteFile(modification, []byte{0x01, 0x01, 0x01}, 0744); err != nil { + t.Fatal(err) + } + changes := []Change{ + {Path: "addition", Kind: ChangeAdd}, + {Path: "modification", Kind: ChangeModify}, + } + size := ChangesSize(parentPath, changes) + if size != 6 { + t.Fatalf("ChangesSizes with only delete changes should be 0, was %d", size) + } +} + +func checkChanges(expectedChanges, changes []Change, t *testing.T) { + sort.Sort(changesByPath(expectedChanges)) + sort.Sort(changesByPath(changes)) + for i := 0; i < max(len(changes), len(expectedChanges)); i++ { + if i >= len(expectedChanges) { + t.Fatalf("unexpected change %s\n", changes[i].String()) + } + if i >= len(changes) { + t.Fatalf("no change for expected change %s\n", expectedChanges[i].String()) + } + if changes[i].Path == expectedChanges[i].Path { + if changes[i] != expectedChanges[i] { + t.Fatalf("Wrong change for %s, expected %s, got %s\n", changes[i].Path, changes[i].String(), expectedChanges[i].String()) + } + } else if changes[i].Path < expectedChanges[i].Path { + t.Fatalf("unexpected change %s\n", changes[i].String()) + } else { + t.Fatalf("no change for expected change %s != %s\n", expectedChanges[i].String(), changes[i].String()) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_unix.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_unix.go new file mode 100644 index 0000000..d780f16 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_unix.go @@ -0,0 +1,27 @@ +// +build !windows + +package archive + +import ( + "syscall" + + "github.com/docker/docker/pkg/system" +) + +func statDifferent(oldStat *system.Stat_t, newStat *system.Stat_t) bool { + // Don't look at size for dirs, its not a good measure of change + if oldStat.Mode() != newStat.Mode() || + oldStat.Uid() != newStat.Uid() || + oldStat.Gid() != newStat.Gid() || + oldStat.Rdev() != newStat.Rdev() || + // Don't look at size for dirs, its not a good measure of change + (oldStat.Mode()&syscall.S_IFDIR != syscall.S_IFDIR && + (!sameFsTimeSpec(oldStat.Mtim(), newStat.Mtim()) || (oldStat.Size() != newStat.Size()))) { + return true + } + return false +} + +func (info *FileInfo) isDir() bool { + return info.parent == nil || info.stat.Mode()&syscall.S_IFDIR != 0 +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_windows.go new file mode 100644 index 0000000..4809b7a --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/changes_windows.go @@ -0,0 +1,20 @@ +package archive + +import ( + "github.com/docker/docker/pkg/system" +) + +func statDifferent(oldStat *system.Stat_t, newStat *system.Stat_t) bool { + + // Don't look at size for dirs, its not a good measure of change + if oldStat.ModTime() != newStat.ModTime() || + oldStat.Mode() != newStat.Mode() || + oldStat.Size() != newStat.Size() && !oldStat.IsDir() { + return true + } + return false +} + +func (info *FileInfo) isDir() bool { + return info.parent == nil || info.stat.IsDir() +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/copy.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/copy.go new file mode 100644 index 0000000..fee4a02 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/copy.go @@ -0,0 +1,308 @@ +package archive + +import ( + "archive/tar" + "errors" + "io" + "io/ioutil" + "os" + "path" + "path/filepath" + "strings" + + log "github.com/Sirupsen/logrus" +) + +// Errors used or returned by this file. +var ( + ErrNotDirectory = errors.New("not a directory") + ErrDirNotExists = errors.New("no such directory") + ErrCannotCopyDir = errors.New("cannot copy directory") + ErrInvalidCopySource = errors.New("invalid copy source content") +) + +// PreserveTrailingDotOrSeparator returns the given cleaned path (after +// processing using any utility functions from the path or filepath stdlib +// packages) and appends a trailing `/.` or `/` if its corresponding original +// path (from before being processed by utility functions from the path or +// filepath stdlib packages) ends with a trailing `/.` or `/`. If the cleaned +// path already ends in a `.` path segment, then another is not added. If the +// clean path already ends in a path separator, then another is not added. +func PreserveTrailingDotOrSeparator(cleanedPath, originalPath string) string { + if !SpecifiesCurrentDir(cleanedPath) && SpecifiesCurrentDir(originalPath) { + if !HasTrailingPathSeparator(cleanedPath) { + // Add a separator if it doesn't already end with one (a cleaned + // path would only end in a separator if it is the root). + cleanedPath += string(filepath.Separator) + } + cleanedPath += "." + } + + if !HasTrailingPathSeparator(cleanedPath) && HasTrailingPathSeparator(originalPath) { + cleanedPath += string(filepath.Separator) + } + + return cleanedPath +} + +// AssertsDirectory returns whether the given path is +// asserted to be a directory, i.e., the path ends with +// a trailing '/' or `/.`, assuming a path separator of `/`. +func AssertsDirectory(path string) bool { + return HasTrailingPathSeparator(path) || SpecifiesCurrentDir(path) +} + +// HasTrailingPathSeparator returns whether the given +// path ends with the system's path separator character. +func HasTrailingPathSeparator(path string) bool { + return len(path) > 0 && os.IsPathSeparator(path[len(path)-1]) +} + +// SpecifiesCurrentDir returns whether the given path specifies +// a "current directory", i.e., the last path segment is `.`. +func SpecifiesCurrentDir(path string) bool { + return filepath.Base(path) == "." +} + +// SplitPathDirEntry splits the given path between its +// parent directory and its basename in that directory. +func SplitPathDirEntry(localizedPath string) (dir, base string) { + normalizedPath := filepath.ToSlash(localizedPath) + vol := filepath.VolumeName(normalizedPath) + normalizedPath = normalizedPath[len(vol):] + + if normalizedPath == "/" { + // Specifies the root path. + return filepath.FromSlash(vol + normalizedPath), "." + } + + trimmedPath := vol + strings.TrimRight(normalizedPath, "/") + + dir = filepath.FromSlash(path.Dir(trimmedPath)) + base = filepath.FromSlash(path.Base(trimmedPath)) + + return dir, base +} + +// TarResource archives the resource at the given sourcePath into a Tar +// archive. A non-nil error is returned if sourcePath does not exist or is +// asserted to be a directory but exists as another type of file. +// +// This function acts as a convenient wrapper around TarWithOptions, which +// requires a directory as the source path. TarResource accepts either a +// directory or a file path and correctly sets the Tar options. +func TarResource(sourcePath string) (content Archive, err error) { + if _, err = os.Lstat(sourcePath); err != nil { + // Catches the case where the source does not exist or is not a + // directory if asserted to be a directory, as this also causes an + // error. + return + } + + if len(sourcePath) > 1 && HasTrailingPathSeparator(sourcePath) { + // In the case where the source path is a symbolic link AND it ends + // with a path separator, we will want to evaluate the symbolic link. + trimmedPath := sourcePath[:len(sourcePath)-1] + stat, err := os.Lstat(trimmedPath) + if err != nil { + return nil, err + } + + if stat.Mode()&os.ModeSymlink != 0 { + if sourcePath, err = filepath.EvalSymlinks(trimmedPath); err != nil { + return nil, err + } + } + } + + // Separate the source path between it's directory and + // the entry in that directory which we are archiving. + sourceDir, sourceBase := SplitPathDirEntry(sourcePath) + + filter := []string{sourceBase} + + log.Debugf("copying %q from %q", sourceBase, sourceDir) + + return TarWithOptions(sourceDir, &TarOptions{ + Compression: Uncompressed, + IncludeFiles: filter, + IncludeSourceDir: true, + }) +} + +// CopyInfo holds basic info about the source +// or destination path of a copy operation. +type CopyInfo struct { + Path string + Exists bool + IsDir bool +} + +// CopyInfoStatPath stats the given path to create a CopyInfo +// struct representing that resource. If mustExist is true, then +// it is an error if there is no file or directory at the given path. +func CopyInfoStatPath(path string, mustExist bool) (CopyInfo, error) { + pathInfo := CopyInfo{Path: path} + + fileInfo, err := os.Lstat(path) + + if err == nil { + pathInfo.Exists, pathInfo.IsDir = true, fileInfo.IsDir() + } else if os.IsNotExist(err) && !mustExist { + err = nil + } + + return pathInfo, err +} + +// PrepareArchiveCopy prepares the given srcContent archive, which should +// contain the archived resource described by srcInfo, to the destination +// described by dstInfo. Returns the possibly modified content archive along +// with the path to the destination directory which it should be extracted to. +func PrepareArchiveCopy(srcContent ArchiveReader, srcInfo, dstInfo CopyInfo) (dstDir string, content Archive, err error) { + // Separate the destination path between its directory and base + // components in case the source archive contents need to be rebased. + dstDir, dstBase := SplitPathDirEntry(dstInfo.Path) + _, srcBase := SplitPathDirEntry(srcInfo.Path) + + switch { + case dstInfo.Exists && dstInfo.IsDir: + // The destination exists as a directory. No alteration + // to srcContent is needed as its contents can be + // simply extracted to the destination directory. + return dstInfo.Path, ioutil.NopCloser(srcContent), nil + case dstInfo.Exists && srcInfo.IsDir: + // The destination exists as some type of file and the source + // content is a directory. This is an error condition since + // you cannot copy a directory to an existing file location. + return "", nil, ErrCannotCopyDir + case dstInfo.Exists: + // The destination exists as some type of file and the source content + // is also a file. The source content entry will have to be renamed to + // have a basename which matches the destination path's basename. + return dstDir, rebaseArchiveEntries(srcContent, srcBase, dstBase), nil + case srcInfo.IsDir: + // The destination does not exist and the source content is an archive + // of a directory. The archive should be extracted to the parent of + // the destination path instead, and when it is, the directory that is + // created as a result should take the name of the destination path. + // The source content entries will have to be renamed to have a + // basename which matches the destination path's basename. + return dstDir, rebaseArchiveEntries(srcContent, srcBase, dstBase), nil + case AssertsDirectory(dstInfo.Path): + // The destination does not exist and is asserted to be created as a + // directory, but the source content is not a directory. This is an + // error condition since you cannot create a directory from a file + // source. + return "", nil, ErrDirNotExists + default: + // The last remaining case is when the destination does not exist, is + // not asserted to be a directory, and the source content is not an + // archive of a directory. It this case, the destination file will need + // to be created when the archive is extracted and the source content + // entry will have to be renamed to have a basename which matches the + // destination path's basename. + return dstDir, rebaseArchiveEntries(srcContent, srcBase, dstBase), nil + } + +} + +// rebaseArchiveEntries rewrites the given srcContent archive replacing +// an occurance of oldBase with newBase at the beginning of entry names. +func rebaseArchiveEntries(srcContent ArchiveReader, oldBase, newBase string) Archive { + rebased, w := io.Pipe() + + go func() { + srcTar := tar.NewReader(srcContent) + rebasedTar := tar.NewWriter(w) + + for { + hdr, err := srcTar.Next() + if err == io.EOF { + // Signals end of archive. + rebasedTar.Close() + w.Close() + return + } + if err != nil { + w.CloseWithError(err) + return + } + + hdr.Name = strings.Replace(hdr.Name, oldBase, newBase, 1) + + if err = rebasedTar.WriteHeader(hdr); err != nil { + w.CloseWithError(err) + return + } + + if _, err = io.Copy(rebasedTar, srcTar); err != nil { + w.CloseWithError(err) + return + } + } + }() + + return rebased +} + +// CopyResource performs an archive copy from the given source path to the +// given destination path. The source path MUST exist and the destination +// path's parent directory must exist. +func CopyResource(srcPath, dstPath string) error { + var ( + srcInfo CopyInfo + err error + ) + + // Clean the source and destination paths. + srcPath = PreserveTrailingDotOrSeparator(filepath.Clean(srcPath), srcPath) + dstPath = PreserveTrailingDotOrSeparator(filepath.Clean(dstPath), dstPath) + + if srcInfo, err = CopyInfoStatPath(srcPath, true); err != nil { + return err + } + + content, err := TarResource(srcPath) + if err != nil { + return err + } + defer content.Close() + + return CopyTo(content, srcInfo, dstPath) +} + +// CopyTo handles extracting the given content whose +// entries should be sourced from srcInfo to dstPath. +func CopyTo(content ArchiveReader, srcInfo CopyInfo, dstPath string) error { + dstInfo, err := CopyInfoStatPath(dstPath, false) + if err != nil { + return err + } + + if !dstInfo.Exists { + // Ensure destination parent dir exists. + dstParent, _ := SplitPathDirEntry(dstPath) + + dstStat, err := os.Lstat(dstParent) + if err != nil { + return err + } + if !dstStat.IsDir() { + return ErrNotDirectory + } + } + + dstDir, copyArchive, err := PrepareArchiveCopy(content, srcInfo, dstInfo) + if err != nil { + return err + } + defer copyArchive.Close() + + options := &TarOptions{ + NoLchown: true, + NoOverwriteDirNonDir: true, + } + + return Untar(copyArchive, dstDir, options) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/copy_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/copy_test.go new file mode 100644 index 0000000..dd0b323 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/copy_test.go @@ -0,0 +1,637 @@ +package archive + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" +) + +func removeAllPaths(paths ...string) { + for _, path := range paths { + os.RemoveAll(path) + } +} + +func getTestTempDirs(t *testing.T) (tmpDirA, tmpDirB string) { + var err error + + if tmpDirA, err = ioutil.TempDir("", "archive-copy-test"); err != nil { + t.Fatal(err) + } + + if tmpDirB, err = ioutil.TempDir("", "archive-copy-test"); err != nil { + t.Fatal(err) + } + + return +} + +func isNotDir(err error) bool { + return strings.Contains(err.Error(), "not a directory") +} + +func joinTrailingSep(pathElements ...string) string { + joined := filepath.Join(pathElements...) + + return fmt.Sprintf("%s%c", joined, filepath.Separator) +} + +func fileContentsEqual(t *testing.T, filenameA, filenameB string) (err error) { + t.Logf("checking for equal file contents: %q and %q\n", filenameA, filenameB) + + fileA, err := os.Open(filenameA) + if err != nil { + return + } + defer fileA.Close() + + fileB, err := os.Open(filenameB) + if err != nil { + return + } + defer fileB.Close() + + hasher := sha256.New() + + if _, err = io.Copy(hasher, fileA); err != nil { + return + } + + hashA := hasher.Sum(nil) + hasher.Reset() + + if _, err = io.Copy(hasher, fileB); err != nil { + return + } + + hashB := hasher.Sum(nil) + + if !bytes.Equal(hashA, hashB) { + err = fmt.Errorf("file content hashes not equal - expected %s, got %s", hex.EncodeToString(hashA), hex.EncodeToString(hashB)) + } + + return +} + +func dirContentsEqual(t *testing.T, newDir, oldDir string) (err error) { + t.Logf("checking for equal directory contents: %q and %q\n", newDir, oldDir) + + var changes []Change + + if changes, err = ChangesDirs(newDir, oldDir); err != nil { + return + } + + if len(changes) != 0 { + err = fmt.Errorf("expected no changes between directories, but got: %v", changes) + } + + return +} + +func logDirContents(t *testing.T, dirPath string) { + logWalkedPaths := filepath.WalkFunc(func(path string, info os.FileInfo, err error) error { + if err != nil { + t.Errorf("stat error for path %q: %s", path, err) + return nil + } + + if info.IsDir() { + path = joinTrailingSep(path) + } + + t.Logf("\t%s", path) + + return nil + }) + + t.Logf("logging directory contents: %q", dirPath) + + if err := filepath.Walk(dirPath, logWalkedPaths); err != nil { + t.Fatal(err) + } +} + +func testCopyHelper(t *testing.T, srcPath, dstPath string) (err error) { + t.Logf("copying from %q to %q", srcPath, dstPath) + + return CopyResource(srcPath, dstPath) +} + +// Basic assumptions about SRC and DST: +// 1. SRC must exist. +// 2. If SRC ends with a trailing separator, it must be a directory. +// 3. DST parent directory must exist. +// 4. If DST exists as a file, it must not end with a trailing separator. + +// First get these easy error cases out of the way. + +// Test for error when SRC does not exist. +func TestCopyErrSrcNotExists(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + content, err := TarResource(filepath.Join(tmpDirA, "file1")) + if err == nil { + content.Close() + t.Fatal("expected IsNotExist error, but got nil instead") + } + + if !os.IsNotExist(err) { + t.Fatalf("expected IsNotExist error, but got %T: %s", err, err) + } +} + +// Test for error when SRC ends in a trailing +// path separator but it exists as a file. +func TestCopyErrSrcNotDir(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A with some sample files and directories. + createSampleDir(t, tmpDirA) + + content, err := TarResource(joinTrailingSep(tmpDirA, "file1")) + if err == nil { + content.Close() + t.Fatal("expected IsNotDir error, but got nil instead") + } + + if !isNotDir(err) { + t.Fatalf("expected IsNotDir error, but got %T: %s", err, err) + } +} + +// Test for error when SRC is a valid file or directory, +// but the DST parent directory does not exist. +func TestCopyErrDstParentNotExists(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A with some sample files and directories. + createSampleDir(t, tmpDirA) + + srcInfo := CopyInfo{Path: filepath.Join(tmpDirA, "file1"), Exists: true, IsDir: false} + + // Try with a file source. + content, err := TarResource(srcInfo.Path) + if err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + defer content.Close() + + // Copy to a file whose parent does not exist. + if err = CopyTo(content, srcInfo, filepath.Join(tmpDirB, "fakeParentDir", "file1")); err == nil { + t.Fatal("expected IsNotExist error, but got nil instead") + } + + if !os.IsNotExist(err) { + t.Fatalf("expected IsNotExist error, but got %T: %s", err, err) + } + + // Try with a directory source. + srcInfo = CopyInfo{Path: filepath.Join(tmpDirA, "dir1"), Exists: true, IsDir: true} + + content, err = TarResource(srcInfo.Path) + if err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + defer content.Close() + + // Copy to a directory whose parent does not exist. + if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "fakeParentDir", "fakeDstDir")); err == nil { + t.Fatal("expected IsNotExist error, but got nil instead") + } + + if !os.IsNotExist(err) { + t.Fatalf("expected IsNotExist error, but got %T: %s", err, err) + } +} + +// Test for error when DST ends in a trailing +// path separator but exists as a file. +func TestCopyErrDstNotDir(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + // Try with a file source. + srcInfo := CopyInfo{Path: filepath.Join(tmpDirA, "file1"), Exists: true, IsDir: false} + + content, err := TarResource(srcInfo.Path) + if err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + defer content.Close() + + if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "file1")); err == nil { + t.Fatal("expected IsNotDir error, but got nil instead") + } + + if !isNotDir(err) { + t.Fatalf("expected IsNotDir error, but got %T: %s", err, err) + } + + // Try with a directory source. + srcInfo = CopyInfo{Path: filepath.Join(tmpDirA, "dir1"), Exists: true, IsDir: true} + + content, err = TarResource(srcInfo.Path) + if err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + defer content.Close() + + if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "file1")); err == nil { + t.Fatal("expected IsNotDir error, but got nil instead") + } + + if !isNotDir(err) { + t.Fatalf("expected IsNotDir error, but got %T: %s", err, err) + } +} + +// Possibilities are reduced to the remaining 10 cases: +// +// case | srcIsDir | onlyDirContents | dstExists | dstIsDir | dstTrSep | action +// =================================================================================================== +// A | no | - | no | - | no | create file +// B | no | - | no | - | yes | error +// C | no | - | yes | no | - | overwrite file +// D | no | - | yes | yes | - | create file in dst dir +// E | yes | no | no | - | - | create dir, copy contents +// F | yes | no | yes | no | - | error +// G | yes | no | yes | yes | - | copy dir and contents +// H | yes | yes | no | - | - | create dir, copy contents +// I | yes | yes | yes | no | - | error +// J | yes | yes | yes | yes | - | copy dir contents +// + +// A. SRC specifies a file and DST (no trailing path separator) doesn't +// exist. This should create a file with the name DST and copy the +// contents of the source file into it. +func TestCopyCaseA(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A with some sample files and directories. + createSampleDir(t, tmpDirA) + + srcPath := filepath.Join(tmpDirA, "file1") + dstPath := filepath.Join(tmpDirB, "itWorks.txt") + + var err error + + if err = testCopyHelper(t, srcPath, dstPath); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = fileContentsEqual(t, srcPath, dstPath); err != nil { + t.Fatal(err) + } +} + +// B. SRC specifies a file and DST (with trailing path separator) doesn't +// exist. This should cause an error because the copy operation cannot +// create a directory when copying a single file. +func TestCopyCaseB(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A with some sample files and directories. + createSampleDir(t, tmpDirA) + + srcPath := filepath.Join(tmpDirA, "file1") + dstDir := joinTrailingSep(tmpDirB, "testDir") + + var err error + + if err = testCopyHelper(t, srcPath, dstDir); err == nil { + t.Fatal("expected ErrDirNotExists error, but got nil instead") + } + + if err != ErrDirNotExists { + t.Fatalf("expected ErrDirNotExists error, but got %T: %s", err, err) + } +} + +// C. SRC specifies a file and DST exists as a file. This should overwrite +// the file at DST with the contents of the source file. +func TestCopyCaseC(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + srcPath := filepath.Join(tmpDirA, "file1") + dstPath := filepath.Join(tmpDirB, "file2") + + var err error + + // Ensure they start out different. + if err = fileContentsEqual(t, srcPath, dstPath); err == nil { + t.Fatal("expected different file contents") + } + + if err = testCopyHelper(t, srcPath, dstPath); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = fileContentsEqual(t, srcPath, dstPath); err != nil { + t.Fatal(err) + } +} + +// D. SRC specifies a file and DST exists as a directory. This should place +// a copy of the source file inside it using the basename from SRC. Ensure +// this works whether DST has a trailing path separator or not. +func TestCopyCaseD(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + srcPath := filepath.Join(tmpDirA, "file1") + dstDir := filepath.Join(tmpDirB, "dir1") + dstPath := filepath.Join(dstDir, "file1") + + var err error + + // Ensure that dstPath doesn't exist. + if _, err = os.Stat(dstPath); !os.IsNotExist(err) { + t.Fatalf("did not expect dstPath %q to exist", dstPath) + } + + if err = testCopyHelper(t, srcPath, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = fileContentsEqual(t, srcPath, dstPath); err != nil { + t.Fatal(err) + } + + // Now try again but using a trailing path separator for dstDir. + + if err = os.RemoveAll(dstDir); err != nil { + t.Fatalf("unable to remove dstDir: %s", err) + } + + if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil { + t.Fatalf("unable to make dstDir: %s", err) + } + + dstDir = joinTrailingSep(tmpDirB, "dir1") + + if err = testCopyHelper(t, srcPath, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = fileContentsEqual(t, srcPath, dstPath); err != nil { + t.Fatal(err) + } +} + +// E. SRC specifies a directory and DST does not exist. This should create a +// directory at DST and copy the contents of the SRC directory into the DST +// directory. Ensure this works whether DST has a trailing path separator or +// not. +func TestCopyCaseE(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A with some sample files and directories. + createSampleDir(t, tmpDirA) + + srcDir := filepath.Join(tmpDirA, "dir1") + dstDir := filepath.Join(tmpDirB, "testDir") + + var err error + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, dstDir, srcDir); err != nil { + t.Log("dir contents not equal") + logDirContents(t, tmpDirA) + logDirContents(t, tmpDirB) + t.Fatal(err) + } + + // Now try again but using a trailing path separator for dstDir. + + if err = os.RemoveAll(dstDir); err != nil { + t.Fatalf("unable to remove dstDir: %s", err) + } + + dstDir = joinTrailingSep(tmpDirB, "testDir") + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, dstDir, srcDir); err != nil { + t.Fatal(err) + } +} + +// F. SRC specifies a directory and DST exists as a file. This should cause an +// error as it is not possible to overwrite a file with a directory. +func TestCopyCaseF(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + srcDir := filepath.Join(tmpDirA, "dir1") + dstFile := filepath.Join(tmpDirB, "file1") + + var err error + + if err = testCopyHelper(t, srcDir, dstFile); err == nil { + t.Fatal("expected ErrCannotCopyDir error, but got nil instead") + } + + if err != ErrCannotCopyDir { + t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err) + } +} + +// G. SRC specifies a directory and DST exists as a directory. This should copy +// the SRC directory and all its contents to the DST directory. Ensure this +// works whether DST has a trailing path separator or not. +func TestCopyCaseG(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + srcDir := filepath.Join(tmpDirA, "dir1") + dstDir := filepath.Join(tmpDirB, "dir2") + resultDir := filepath.Join(dstDir, "dir1") + + var err error + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, resultDir, srcDir); err != nil { + t.Fatal(err) + } + + // Now try again but using a trailing path separator for dstDir. + + if err = os.RemoveAll(dstDir); err != nil { + t.Fatalf("unable to remove dstDir: %s", err) + } + + if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil { + t.Fatalf("unable to make dstDir: %s", err) + } + + dstDir = joinTrailingSep(tmpDirB, "dir2") + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, resultDir, srcDir); err != nil { + t.Fatal(err) + } +} + +// H. SRC specifies a directory's contents only and DST does not exist. This +// should create a directory at DST and copy the contents of the SRC +// directory (but not the directory itself) into the DST directory. Ensure +// this works whether DST has a trailing path separator or not. +func TestCopyCaseH(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A with some sample files and directories. + createSampleDir(t, tmpDirA) + + srcDir := joinTrailingSep(tmpDirA, "dir1") + "." + dstDir := filepath.Join(tmpDirB, "testDir") + + var err error + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, dstDir, srcDir); err != nil { + t.Log("dir contents not equal") + logDirContents(t, tmpDirA) + logDirContents(t, tmpDirB) + t.Fatal(err) + } + + // Now try again but using a trailing path separator for dstDir. + + if err = os.RemoveAll(dstDir); err != nil { + t.Fatalf("unable to remove dstDir: %s", err) + } + + dstDir = joinTrailingSep(tmpDirB, "testDir") + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, dstDir, srcDir); err != nil { + t.Log("dir contents not equal") + logDirContents(t, tmpDirA) + logDirContents(t, tmpDirB) + t.Fatal(err) + } +} + +// I. SRC specifies a directory's contents only and DST exists as a file. This +// should cause an error as it is not possible to overwrite a file with a +// directory. +func TestCopyCaseI(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + srcDir := joinTrailingSep(tmpDirA, "dir1") + "." + dstFile := filepath.Join(tmpDirB, "file1") + + var err error + + if err = testCopyHelper(t, srcDir, dstFile); err == nil { + t.Fatal("expected ErrCannotCopyDir error, but got nil instead") + } + + if err != ErrCannotCopyDir { + t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err) + } +} + +// J. SRC specifies a directory's contents only and DST exists as a directory. +// This should copy the contents of the SRC directory (but not the directory +// itself) into the DST directory. Ensure this works whether DST has a +// trailing path separator or not. +func TestCopyCaseJ(t *testing.T) { + tmpDirA, tmpDirB := getTestTempDirs(t) + defer removeAllPaths(tmpDirA, tmpDirB) + + // Load A and B with some sample files and directories. + createSampleDir(t, tmpDirA) + createSampleDir(t, tmpDirB) + + srcDir := joinTrailingSep(tmpDirA, "dir1") + "." + dstDir := filepath.Join(tmpDirB, "dir5") + + var err error + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, dstDir, srcDir); err != nil { + t.Fatal(err) + } + + // Now try again but using a trailing path separator for dstDir. + + if err = os.RemoveAll(dstDir); err != nil { + t.Fatalf("unable to remove dstDir: %s", err) + } + + if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil { + t.Fatalf("unable to make dstDir: %s", err) + } + + dstDir = joinTrailingSep(tmpDirB, "dir5") + + if err = testCopyHelper(t, srcDir, dstDir); err != nil { + t.Fatalf("unexpected error %T: %s", err, err) + } + + if err = dirContentsEqual(t, dstDir, srcDir); err != nil { + t.Fatal(err) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/diff.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/diff.go new file mode 100644 index 0000000..aed8542 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/diff.go @@ -0,0 +1,194 @@ +package archive + +import ( + "archive/tar" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" + "syscall" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/pools" + "github.com/docker/docker/pkg/system" +) + +func UnpackLayer(dest string, layer ArchiveReader) (size int64, err error) { + tr := tar.NewReader(layer) + trBuf := pools.BufioReader32KPool.Get(tr) + defer pools.BufioReader32KPool.Put(trBuf) + + var dirs []*tar.Header + + aufsTempdir := "" + aufsHardlinks := make(map[string]*tar.Header) + + // Iterate through the files in the archive. + for { + hdr, err := tr.Next() + if err == io.EOF { + // end of tar archive + break + } + if err != nil { + return 0, err + } + + size += hdr.Size + + // Normalize name, for safety and for a simple is-root check + hdr.Name = filepath.Clean(hdr.Name) + + // Windows does not support filenames with colons in them. Ignore + // these files. This is not a problem though (although it might + // appear that it is). Let's suppose a client is running docker pull. + // The daemon it points to is Windows. Would it make sense for the + // client to be doing a docker pull Ubuntu for example (which has files + // with colons in the name under /usr/share/man/man3)? No, absolutely + // not as it would really only make sense that they were pulling a + // Windows image. However, for development, it is necessary to be able + // to pull Linux images which are in the repository. + // + // TODO Windows. Once the registry is aware of what images are Windows- + // specific or Linux-specific, this warning should be changed to an error + // to cater for the situation where someone does manage to upload a Linux + // image but have it tagged as Windows inadvertantly. + if runtime.GOOS == "windows" { + if strings.Contains(hdr.Name, ":") { + logrus.Warnf("Windows: Ignoring %s (is this a Linux image?)", hdr.Name) + continue + } + } + + // Note as these operations are platform specific, so must the slash be. + if !strings.HasSuffix(hdr.Name, string(os.PathSeparator)) { + // Not the root directory, ensure that the parent directory exists. + // This happened in some tests where an image had a tarfile without any + // parent directories. + parent := filepath.Dir(hdr.Name) + parentPath := filepath.Join(dest, parent) + + if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) { + err = system.MkdirAll(parentPath, 0600) + if err != nil { + return 0, err + } + } + } + + // Skip AUFS metadata dirs + if strings.HasPrefix(hdr.Name, ".wh..wh.") { + // Regular files inside /.wh..wh.plnk can be used as hardlink targets + // We don't want this directory, but we need the files in them so that + // such hardlinks can be resolved. + if strings.HasPrefix(hdr.Name, ".wh..wh.plnk") && hdr.Typeflag == tar.TypeReg { + basename := filepath.Base(hdr.Name) + aufsHardlinks[basename] = hdr + if aufsTempdir == "" { + if aufsTempdir, err = ioutil.TempDir("", "dockerplnk"); err != nil { + return 0, err + } + defer os.RemoveAll(aufsTempdir) + } + if err := createTarFile(filepath.Join(aufsTempdir, basename), dest, hdr, tr, true, nil); err != nil { + return 0, err + } + } + continue + } + path := filepath.Join(dest, hdr.Name) + rel, err := filepath.Rel(dest, path) + if err != nil { + return 0, err + } + + // Note as these operations are platform specific, so must the slash be. + if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return 0, breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest)) + } + base := filepath.Base(path) + + if strings.HasPrefix(base, ".wh.") { + originalBase := base[len(".wh."):] + originalPath := filepath.Join(filepath.Dir(path), originalBase) + if err := os.RemoveAll(originalPath); err != nil { + return 0, err + } + } else { + // If path exits we almost always just want to remove and replace it. + // The only exception is when it is a directory *and* the file from + // the layer is also a directory. Then we want to merge them (i.e. + // just apply the metadata from the layer). + if fi, err := os.Lstat(path); err == nil { + if !(fi.IsDir() && hdr.Typeflag == tar.TypeDir) { + if err := os.RemoveAll(path); err != nil { + return 0, err + } + } + } + + trBuf.Reset(tr) + srcData := io.Reader(trBuf) + srcHdr := hdr + + // Hard links into /.wh..wh.plnk don't work, as we don't extract that directory, so + // we manually retarget these into the temporary files we extracted them into + if hdr.Typeflag == tar.TypeLink && strings.HasPrefix(filepath.Clean(hdr.Linkname), ".wh..wh.plnk") { + linkBasename := filepath.Base(hdr.Linkname) + srcHdr = aufsHardlinks[linkBasename] + if srcHdr == nil { + return 0, fmt.Errorf("Invalid aufs hardlink") + } + tmpFile, err := os.Open(filepath.Join(aufsTempdir, linkBasename)) + if err != nil { + return 0, err + } + defer tmpFile.Close() + srcData = tmpFile + } + + if err := createTarFile(path, dest, srcHdr, srcData, true, nil); err != nil { + return 0, err + } + + // Directory mtimes must be handled at the end to avoid further + // file creation in them to modify the directory mtime + if hdr.Typeflag == tar.TypeDir { + dirs = append(dirs, hdr) + } + } + } + + for _, hdr := range dirs { + path := filepath.Join(dest, hdr.Name) + ts := []syscall.Timespec{timeToTimespec(hdr.AccessTime), timeToTimespec(hdr.ModTime)} + if err := syscall.UtimesNano(path, ts); err != nil { + return 0, err + } + } + + return size, nil +} + +// ApplyLayer parses a diff in the standard layer format from `layer`, and +// applies it to the directory `dest`. Returns the size in bytes of the +// contents of the layer. +func ApplyLayer(dest string, layer ArchiveReader) (int64, error) { + dest = filepath.Clean(dest) + + // We need to be able to set any perms + oldmask, err := system.Umask(0) + if err != nil { + return 0, err + } + defer system.Umask(oldmask) // ignore err, ErrNotSupportedPlatform + + layer, err = DecompressStream(layer) + if err != nil { + return 0, err + } + return UnpackLayer(dest, layer) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/diff_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/diff_test.go new file mode 100644 index 0000000..01ed437 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/diff_test.go @@ -0,0 +1,190 @@ +package archive + +import ( + "archive/tar" + "testing" +) + +func TestApplyLayerInvalidFilenames(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { + { + Name: "../victim/dotdot", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { + { + // Note the leading slash + Name: "/../victim/slash-dotdot", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidFilenames", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} + +func TestApplyLayerInvalidHardlink(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { // try reading victim/hello (../) + { + Name: "dotdot", + Typeflag: tar.TypeLink, + Linkname: "../victim/hello", + Mode: 0644, + }, + }, + { // try reading victim/hello (/../) + { + Name: "slash-dotdot", + Typeflag: tar.TypeLink, + // Note the leading slash + Linkname: "/../victim/hello", + Mode: 0644, + }, + }, + { // try writing victim/file + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim/file", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { // try reading victim/hello (hardlink, symlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "symlink", + Typeflag: tar.TypeSymlink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // Try reading victim/hello (hardlink, hardlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "hardlink", + Typeflag: tar.TypeLink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // Try removing victim directory (hardlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeLink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidHardlink", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} + +func TestApplyLayerInvalidSymlink(t *testing.T) { + for i, headers := range [][]*tar.Header{ + { // try reading victim/hello (../) + { + Name: "dotdot", + Typeflag: tar.TypeSymlink, + Linkname: "../victim/hello", + Mode: 0644, + }, + }, + { // try reading victim/hello (/../) + { + Name: "slash-dotdot", + Typeflag: tar.TypeSymlink, + // Note the leading slash + Linkname: "/../victim/hello", + Mode: 0644, + }, + }, + { // try writing victim/file + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim/file", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + { // try reading victim/hello (symlink, symlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "symlink", + Typeflag: tar.TypeSymlink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // try reading victim/hello (symlink, hardlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "hardlink", + Typeflag: tar.TypeLink, + Linkname: "loophole-victim/hello", + Mode: 0644, + }, + }, + { // try removing victim directory (symlink) + { + Name: "loophole-victim", + Typeflag: tar.TypeSymlink, + Linkname: "../victim", + Mode: 0755, + }, + { + Name: "loophole-victim", + Typeflag: tar.TypeReg, + Mode: 0644, + }, + }, + } { + if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidSymlink", headers); err != nil { + t.Fatalf("i=%d. %v", i, err) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/example_changes.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/example_changes.go new file mode 100644 index 0000000..cedd46a --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/example_changes.go @@ -0,0 +1,97 @@ +// +build ignore + +// Simple tool to create an archive stream from an old and new directory +// +// By default it will stream the comparison of two temporary directories with junk files +package main + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "path" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/archive" +) + +var ( + flDebug = flag.Bool("D", false, "debugging output") + flNewDir = flag.String("newdir", "", "") + flOldDir = flag.String("olddir", "", "") + log = logrus.New() +) + +func main() { + flag.Usage = func() { + fmt.Println("Produce a tar from comparing two directory paths. By default a demo tar is created of around 200 files (including hardlinks)") + fmt.Printf("%s [OPTIONS]\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + log.Out = os.Stderr + if (len(os.Getenv("DEBUG")) > 0) || *flDebug { + logrus.SetLevel(logrus.DebugLevel) + } + var newDir, oldDir string + + if len(*flNewDir) == 0 { + var err error + newDir, err = ioutil.TempDir("", "docker-test-newDir") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(newDir) + if _, err := prepareUntarSourceDirectory(100, newDir, true); err != nil { + log.Fatal(err) + } + } else { + newDir = *flNewDir + } + + if len(*flOldDir) == 0 { + oldDir, err := ioutil.TempDir("", "docker-test-oldDir") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(oldDir) + } else { + oldDir = *flOldDir + } + + changes, err := archive.ChangesDirs(newDir, oldDir) + if err != nil { + log.Fatal(err) + } + + a, err := archive.ExportChanges(newDir, changes) + if err != nil { + log.Fatal(err) + } + defer a.Close() + + i, err := io.Copy(os.Stdout, a) + if err != nil && err != io.EOF { + log.Fatal(err) + } + fmt.Fprintf(os.Stderr, "wrote archive of %d bytes", i) +} + +func prepareUntarSourceDirectory(numberOfFiles int, targetPath string, makeLinks bool) (int, error) { + fileData := []byte("fooo") + for n := 0; n < numberOfFiles; n++ { + fileName := fmt.Sprintf("file-%d", n) + if err := ioutil.WriteFile(path.Join(targetPath, fileName), fileData, 0700); err != nil { + return 0, err + } + if makeLinks { + if err := os.Link(path.Join(targetPath, fileName), path.Join(targetPath, fileName+"-link")); err != nil { + return 0, err + } + } + } + totalSize := numberOfFiles * len(fileData) + return totalSize, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/testdata/broken.tar b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/testdata/broken.tar new file mode 100644 index 0000000..8f10ea6 Binary files /dev/null and b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/testdata/broken.tar differ diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/time_linux.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/time_linux.go new file mode 100644 index 0000000..3448569 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/time_linux.go @@ -0,0 +1,16 @@ +package archive + +import ( + "syscall" + "time" +) + +func timeToTimespec(time time.Time) (ts syscall.Timespec) { + if time.IsZero() { + // Return UTIME_OMIT special value + ts.Sec = 0 + ts.Nsec = ((1 << 30) - 2) + return + } + return syscall.NsecToTimespec(time.UnixNano()) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/time_unsupported.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/time_unsupported.go new file mode 100644 index 0000000..e85aac0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/time_unsupported.go @@ -0,0 +1,16 @@ +// +build !linux + +package archive + +import ( + "syscall" + "time" +) + +func timeToTimespec(time time.Time) (ts syscall.Timespec) { + nsec := int64(0) + if !time.IsZero() { + nsec = time.UnixNano() + } + return syscall.NsecToTimespec(nsec) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/utils_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/utils_test.go new file mode 100644 index 0000000..f5cacea --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/utils_test.go @@ -0,0 +1,166 @@ +package archive + +import ( + "archive/tar" + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "time" +) + +var testUntarFns = map[string]func(string, io.Reader) error{ + "untar": func(dest string, r io.Reader) error { + return Untar(r, dest, nil) + }, + "applylayer": func(dest string, r io.Reader) error { + _, err := ApplyLayer(dest, ArchiveReader(r)) + return err + }, +} + +// testBreakout is a helper function that, within the provided `tmpdir` directory, +// creates a `victim` folder with a generated `hello` file in it. +// `untar` extracts to a directory named `dest`, the tar file created from `headers`. +// +// Here are the tested scenarios: +// - removed `victim` folder (write) +// - removed files from `victim` folder (write) +// - new files in `victim` folder (write) +// - modified files in `victim` folder (write) +// - file in `dest` with same content as `victim/hello` (read) +// +// When using testBreakout make sure you cover one of the scenarios listed above. +func testBreakout(untarFn string, tmpdir string, headers []*tar.Header) error { + tmpdir, err := ioutil.TempDir("", tmpdir) + if err != nil { + return err + } + defer os.RemoveAll(tmpdir) + + dest := filepath.Join(tmpdir, "dest") + if err := os.Mkdir(dest, 0755); err != nil { + return err + } + + victim := filepath.Join(tmpdir, "victim") + if err := os.Mkdir(victim, 0755); err != nil { + return err + } + hello := filepath.Join(victim, "hello") + helloData, err := time.Now().MarshalText() + if err != nil { + return err + } + if err := ioutil.WriteFile(hello, helloData, 0644); err != nil { + return err + } + helloStat, err := os.Stat(hello) + if err != nil { + return err + } + + reader, writer := io.Pipe() + go func() { + t := tar.NewWriter(writer) + for _, hdr := range headers { + t.WriteHeader(hdr) + } + t.Close() + }() + + untar := testUntarFns[untarFn] + if untar == nil { + return fmt.Errorf("could not find untar function %q in testUntarFns", untarFn) + } + if err := untar(dest, reader); err != nil { + if _, ok := err.(breakoutError); !ok { + // If untar returns an error unrelated to an archive breakout, + // then consider this an unexpected error and abort. + return err + } + // Here, untar detected the breakout. + // Let's move on verifying that indeed there was no breakout. + fmt.Printf("breakoutError: %v\n", err) + } + + // Check victim folder + f, err := os.Open(victim) + if err != nil { + // codepath taken if victim folder was removed + return fmt.Errorf("archive breakout: error reading %q: %v", victim, err) + } + defer f.Close() + + // Check contents of victim folder + // + // We are only interested in getting 2 files from the victim folder, because if all is well + // we expect only one result, the `hello` file. If there is a second result, it cannot + // hold the same name `hello` and we assume that a new file got created in the victim folder. + // That is enough to detect an archive breakout. + names, err := f.Readdirnames(2) + if err != nil { + // codepath taken if victim is not a folder + return fmt.Errorf("archive breakout: error reading directory content of %q: %v", victim, err) + } + for _, name := range names { + if name != "hello" { + // codepath taken if new file was created in victim folder + return fmt.Errorf("archive breakout: new file %q", name) + } + } + + // Check victim/hello + f, err = os.Open(hello) + if err != nil { + // codepath taken if read permissions were removed + return fmt.Errorf("archive breakout: could not lstat %q: %v", hello, err) + } + defer f.Close() + b, err := ioutil.ReadAll(f) + if err != nil { + return err + } + fi, err := f.Stat() + if err != nil { + return err + } + if helloStat.IsDir() != fi.IsDir() || + // TODO: cannot check for fi.ModTime() change + helloStat.Mode() != fi.Mode() || + helloStat.Size() != fi.Size() || + !bytes.Equal(helloData, b) { + // codepath taken if hello has been modified + return fmt.Errorf("archive breakout: file %q has been modified. Contents: expected=%q, got=%q. FileInfo: expected=%#v, got=%#v", hello, helloData, b, helloStat, fi) + } + + // Check that nothing in dest/ has the same content as victim/hello. + // Since victim/hello was generated with time.Now(), it is safe to assume + // that any file whose content matches exactly victim/hello, managed somehow + // to access victim/hello. + return filepath.Walk(dest, func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + if err != nil { + // skip directory if error + return filepath.SkipDir + } + // enter directory + return nil + } + if err != nil { + // skip file if error + return nil + } + b, err := ioutil.ReadFile(path) + if err != nil { + // Houston, we have a problem. Aborting (space)walk. + return err + } + if bytes.Equal(helloData, b) { + return fmt.Errorf("archive breakout: file %q has been accessed via %q", hello, path) + } + return nil + }) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/wrap.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/wrap.go new file mode 100644 index 0000000..dfb335c --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/wrap.go @@ -0,0 +1,59 @@ +package archive + +import ( + "archive/tar" + "bytes" + "io/ioutil" +) + +// Generate generates a new archive from the content provided +// as input. +// +// `files` is a sequence of path/content pairs. A new file is +// added to the archive for each pair. +// If the last pair is incomplete, the file is created with an +// empty content. For example: +// +// Generate("foo.txt", "hello world", "emptyfile") +// +// The above call will return an archive with 2 files: +// * ./foo.txt with content "hello world" +// * ./empty with empty content +// +// FIXME: stream content instead of buffering +// FIXME: specify permissions and other archive metadata +func Generate(input ...string) (Archive, error) { + files := parseStringPairs(input...) + buf := new(bytes.Buffer) + tw := tar.NewWriter(buf) + for _, file := range files { + name, content := file[0], file[1] + hdr := &tar.Header{ + Name: name, + Size: int64(len(content)), + } + if err := tw.WriteHeader(hdr); err != nil { + return nil, err + } + if _, err := tw.Write([]byte(content)); err != nil { + return nil, err + } + } + if err := tw.Close(); err != nil { + return nil, err + } + return ioutil.NopCloser(buf), nil +} + +func parseStringPairs(input ...string) (output [][2]string) { + output = make([][2]string, 0, len(input)/2+1) + for i := 0; i < len(input); i += 2 { + var pair [2]string + pair[0] = input[i] + if i+1 < len(input) { + pair[1] = input[i+1] + } + output = append(output, pair) + } + return +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/wrap_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/wrap_test.go new file mode 100644 index 0000000..46ab366 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/archive/wrap_test.go @@ -0,0 +1,98 @@ +package archive + +import ( + "archive/tar" + "bytes" + "io" + "testing" +) + +func TestGenerateEmptyFile(t *testing.T) { + archive, err := Generate("emptyFile") + if err != nil { + t.Fatal(err) + } + if archive == nil { + t.Fatal("The generated archive should not be nil.") + } + + expectedFiles := [][]string{ + {"emptyFile", ""}, + } + + tr := tar.NewReader(archive) + actualFiles := make([][]string, 0, 10) + i := 0 + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + buf := new(bytes.Buffer) + buf.ReadFrom(tr) + content := buf.String() + actualFiles = append(actualFiles, []string{hdr.Name, content}) + i++ + } + if len(actualFiles) != len(expectedFiles) { + t.Fatalf("Number of expected file %d, got %d.", len(expectedFiles), len(actualFiles)) + } + for i := 0; i < len(expectedFiles); i++ { + actual := actualFiles[i] + expected := expectedFiles[i] + if actual[0] != expected[0] { + t.Fatalf("Expected name '%s', Actual name '%s'", expected[0], actual[0]) + } + if actual[1] != expected[1] { + t.Fatalf("Expected content '%s', Actual content '%s'", expected[1], actual[1]) + } + } +} + +func TestGenerateWithContent(t *testing.T) { + archive, err := Generate("file", "content") + if err != nil { + t.Fatal(err) + } + if archive == nil { + t.Fatal("The generated archive should not be nil.") + } + + expectedFiles := [][]string{ + {"file", "content"}, + } + + tr := tar.NewReader(archive) + actualFiles := make([][]string, 0, 10) + i := 0 + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + buf := new(bytes.Buffer) + buf.ReadFrom(tr) + content := buf.String() + actualFiles = append(actualFiles, []string{hdr.Name, content}) + i++ + } + if len(actualFiles) != len(expectedFiles) { + t.Fatalf("Number of expected file %d, got %d.", len(expectedFiles), len(actualFiles)) + } + for i := 0; i < len(expectedFiles); i++ { + actual := actualFiles[i] + expected := expectedFiles[i] + if actual[0] != expected[0] { + t.Fatalf("Expected name '%s', Actual name '%s'", expected[0], actual[0]) + } + if actual[1] != expected[1] { + t.Fatalf("Expected content '%s', Actual content '%s'", expected[1], actual[1]) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/fileutils/fileutils.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/fileutils/fileutils.go new file mode 100644 index 0000000..3eaf7f8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/fileutils/fileutils.go @@ -0,0 +1,196 @@ +package fileutils + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/Sirupsen/logrus" +) + +// exclusion return true if the specified pattern is an exclusion +func exclusion(pattern string) bool { + return pattern[0] == '!' +} + +// empty return true if the specified pattern is empty +func empty(pattern string) bool { + return pattern == "" +} + +// CleanPatterns takes a slice of patterns returns a new +// slice of patterns cleaned with filepath.Clean, stripped +// of any empty patterns and lets the caller know whether the +// slice contains any exception patterns (prefixed with !). +func CleanPatterns(patterns []string) ([]string, [][]string, bool, error) { + // Loop over exclusion patterns and: + // 1. Clean them up. + // 2. Indicate whether we are dealing with any exception rules. + // 3. Error if we see a single exclusion marker on it's own (!). + cleanedPatterns := []string{} + patternDirs := [][]string{} + exceptions := false + for _, pattern := range patterns { + // Eliminate leading and trailing whitespace. + pattern = strings.TrimSpace(pattern) + if empty(pattern) { + continue + } + if exclusion(pattern) { + if len(pattern) == 1 { + return nil, nil, false, errors.New("Illegal exclusion pattern: !") + } + exceptions = true + } + pattern = filepath.Clean(pattern) + cleanedPatterns = append(cleanedPatterns, pattern) + if exclusion(pattern) { + pattern = pattern[1:] + } + patternDirs = append(patternDirs, strings.Split(pattern, "/")) + } + + return cleanedPatterns, patternDirs, exceptions, nil +} + +// Matches returns true if file matches any of the patterns +// and isn't excluded by any of the subsequent patterns. +func Matches(file string, patterns []string) (bool, error) { + file = filepath.Clean(file) + + if file == "." { + // Don't let them exclude everything, kind of silly. + return false, nil + } + + patterns, patDirs, _, err := CleanPatterns(patterns) + if err != nil { + return false, err + } + + return OptimizedMatches(file, patterns, patDirs) +} + +// OptimizedMatches is basically the same as fileutils.Matches() but optimized for archive.go. +// It will assume that the inputs have been preprocessed and therefore the function +// doen't need to do as much error checking and clean-up. This was done to avoid +// repeating these steps on each file being checked during the archive process. +// The more generic fileutils.Matches() can't make these assumptions. +func OptimizedMatches(file string, patterns []string, patDirs [][]string) (bool, error) { + matched := false + parentPath := filepath.Dir(file) + parentPathDirs := strings.Split(parentPath, "/") + + for i, pattern := range patterns { + negative := false + + if exclusion(pattern) { + negative = true + pattern = pattern[1:] + } + + match, err := filepath.Match(pattern, file) + if err != nil { + return false, err + } + + if !match && parentPath != "." { + // Check to see if the pattern matches one of our parent dirs. + if len(patDirs[i]) <= len(parentPathDirs) { + match, _ = filepath.Match(strings.Join(patDirs[i], "/"), + strings.Join(parentPathDirs[:len(patDirs[i])], "/")) + } + } + + if match { + matched = !negative + } + } + + if matched { + logrus.Debugf("Skipping excluded path: %s", file) + } + + return matched, nil +} + +// CopyFile copies from src to dst until either EOF is reached +// on src or an error occurs. It verifies src exists and remove +// the dst if it exists. +func CopyFile(src, dst string) (int64, error) { + cleanSrc := filepath.Clean(src) + cleanDst := filepath.Clean(dst) + if cleanSrc == cleanDst { + return 0, nil + } + sf, err := os.Open(cleanSrc) + if err != nil { + return 0, err + } + defer sf.Close() + if err := os.Remove(cleanDst); err != nil && !os.IsNotExist(err) { + return 0, err + } + df, err := os.Create(cleanDst) + if err != nil { + return 0, err + } + defer df.Close() + return io.Copy(df, sf) +} + +// GetTotalUsedFds Returns the number of used File Descriptors by +// reading it via /proc filesystem. +func GetTotalUsedFds() int { + if fds, err := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())); err != nil { + logrus.Errorf("Error opening /proc/%d/fd: %s", os.Getpid(), err) + } else { + return len(fds) + } + return -1 +} + +// ReadSymlinkedDirectory returns the target directory of a symlink. +// The target of the symbolic link may not be a file. +func ReadSymlinkedDirectory(path string) (string, error) { + var realPath string + var err error + if realPath, err = filepath.Abs(path); err != nil { + return "", fmt.Errorf("unable to get absolute path for %s: %s", path, err) + } + if realPath, err = filepath.EvalSymlinks(realPath); err != nil { + return "", fmt.Errorf("failed to canonicalise path for %s: %s", path, err) + } + realPathInfo, err := os.Stat(realPath) + if err != nil { + return "", fmt.Errorf("failed to stat target '%s' of '%s': %s", realPath, path, err) + } + if !realPathInfo.Mode().IsDir() { + return "", fmt.Errorf("canonical path points to a file '%s'", realPath) + } + return realPath, nil +} + +// CreateIfNotExists creates a file or a directory only if it does not already exist. +func CreateIfNotExists(path string, isDir bool) error { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + if isDir { + return os.MkdirAll(path, 0755) + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + f, err := os.OpenFile(path, os.O_CREATE, 0755) + if err != nil { + return err + } + f.Close() + } + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/fileutils/fileutils_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/fileutils/fileutils_test.go new file mode 100644 index 0000000..b544ffb --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/fileutils/fileutils_test.go @@ -0,0 +1,402 @@ +package fileutils + +import ( + "io/ioutil" + "os" + "path" + "path/filepath" + "testing" +) + +// CopyFile with invalid src +func TestCopyFileWithInvalidSrc(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + defer os.RemoveAll(tempFolder) + if err != nil { + t.Fatal(err) + } + bytes, err := CopyFile("/invalid/file/path", path.Join(tempFolder, "dest")) + if err == nil { + t.Fatal("Should have fail to copy an invalid src file") + } + if bytes != 0 { + t.Fatal("Should have written 0 bytes") + } + +} + +// CopyFile with invalid dest +func TestCopyFileWithInvalidDest(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + defer os.RemoveAll(tempFolder) + if err != nil { + t.Fatal(err) + } + src := path.Join(tempFolder, "file") + err = ioutil.WriteFile(src, []byte("content"), 0740) + if err != nil { + t.Fatal(err) + } + bytes, err := CopyFile(src, path.Join(tempFolder, "/invalid/dest/path")) + if err == nil { + t.Fatal("Should have fail to copy an invalid src file") + } + if bytes != 0 { + t.Fatal("Should have written 0 bytes") + } + +} + +// CopyFile with same src and dest +func TestCopyFileWithSameSrcAndDest(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + defer os.RemoveAll(tempFolder) + if err != nil { + t.Fatal(err) + } + file := path.Join(tempFolder, "file") + err = ioutil.WriteFile(file, []byte("content"), 0740) + if err != nil { + t.Fatal(err) + } + bytes, err := CopyFile(file, file) + if err != nil { + t.Fatal(err) + } + if bytes != 0 { + t.Fatal("Should have written 0 bytes as it is the same file.") + } +} + +// CopyFile with same src and dest but path is different and not clean +func TestCopyFileWithSameSrcAndDestWithPathNameDifferent(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + defer os.RemoveAll(tempFolder) + if err != nil { + t.Fatal(err) + } + testFolder := path.Join(tempFolder, "test") + err = os.MkdirAll(testFolder, 0740) + if err != nil { + t.Fatal(err) + } + file := path.Join(testFolder, "file") + sameFile := testFolder + "/../test/file" + err = ioutil.WriteFile(file, []byte("content"), 0740) + if err != nil { + t.Fatal(err) + } + bytes, err := CopyFile(file, sameFile) + if err != nil { + t.Fatal(err) + } + if bytes != 0 { + t.Fatal("Should have written 0 bytes as it is the same file.") + } +} + +func TestCopyFile(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + defer os.RemoveAll(tempFolder) + if err != nil { + t.Fatal(err) + } + src := path.Join(tempFolder, "src") + dest := path.Join(tempFolder, "dest") + ioutil.WriteFile(src, []byte("content"), 0777) + ioutil.WriteFile(dest, []byte("destContent"), 0777) + bytes, err := CopyFile(src, dest) + if err != nil { + t.Fatal(err) + } + if bytes != 7 { + t.Fatalf("Should have written %d bytes but wrote %d", 7, bytes) + } + actual, err := ioutil.ReadFile(dest) + if err != nil { + t.Fatal(err) + } + if string(actual) != "content" { + t.Fatalf("Dest content was '%s', expected '%s'", string(actual), "content") + } +} + +// Reading a symlink to a directory must return the directory +func TestReadSymlinkedDirectoryExistingDirectory(t *testing.T) { + var err error + if err = os.Mkdir("/tmp/testReadSymlinkToExistingDirectory", 0777); err != nil { + t.Errorf("failed to create directory: %s", err) + } + + if err = os.Symlink("/tmp/testReadSymlinkToExistingDirectory", "/tmp/dirLinkTest"); err != nil { + t.Errorf("failed to create symlink: %s", err) + } + + var path string + if path, err = ReadSymlinkedDirectory("/tmp/dirLinkTest"); err != nil { + t.Fatalf("failed to read symlink to directory: %s", err) + } + + if path != "/tmp/testReadSymlinkToExistingDirectory" { + t.Fatalf("symlink returned unexpected directory: %s", path) + } + + if err = os.Remove("/tmp/testReadSymlinkToExistingDirectory"); err != nil { + t.Errorf("failed to remove temporary directory: %s", err) + } + + if err = os.Remove("/tmp/dirLinkTest"); err != nil { + t.Errorf("failed to remove symlink: %s", err) + } +} + +// Reading a non-existing symlink must fail +func TestReadSymlinkedDirectoryNonExistingSymlink(t *testing.T) { + var path string + var err error + if path, err = ReadSymlinkedDirectory("/tmp/test/foo/Non/ExistingPath"); err == nil { + t.Fatalf("error expected for non-existing symlink") + } + + if path != "" { + t.Fatalf("expected empty path, but '%s' was returned", path) + } +} + +// Reading a symlink to a file must fail +func TestReadSymlinkedDirectoryToFile(t *testing.T) { + var err error + var file *os.File + + if file, err = os.Create("/tmp/testReadSymlinkToFile"); err != nil { + t.Fatalf("failed to create file: %s", err) + } + + file.Close() + + if err = os.Symlink("/tmp/testReadSymlinkToFile", "/tmp/fileLinkTest"); err != nil { + t.Errorf("failed to create symlink: %s", err) + } + + var path string + if path, err = ReadSymlinkedDirectory("/tmp/fileLinkTest"); err == nil { + t.Fatalf("ReadSymlinkedDirectory on a symlink to a file should've failed") + } + + if path != "" { + t.Fatalf("path should've been empty: %s", path) + } + + if err = os.Remove("/tmp/testReadSymlinkToFile"); err != nil { + t.Errorf("failed to remove file: %s", err) + } + + if err = os.Remove("/tmp/fileLinkTest"); err != nil { + t.Errorf("failed to remove symlink: %s", err) + } +} + +func TestWildcardMatches(t *testing.T) { + match, _ := Matches("fileutils.go", []string{"*"}) + if match != true { + t.Errorf("failed to get a wildcard match, got %v", match) + } +} + +// A simple pattern match should return true. +func TestPatternMatches(t *testing.T) { + match, _ := Matches("fileutils.go", []string{"*.go"}) + if match != true { + t.Errorf("failed to get a match, got %v", match) + } +} + +// An exclusion followed by an inclusion should return true. +func TestExclusionPatternMatchesPatternBefore(t *testing.T) { + match, _ := Matches("fileutils.go", []string{"!fileutils.go", "*.go"}) + if match != true { + t.Errorf("failed to get true match on exclusion pattern, got %v", match) + } +} + +// A folder pattern followed by an exception should return false. +func TestPatternMatchesFolderExclusions(t *testing.T) { + match, _ := Matches("docs/README.md", []string{"docs", "!docs/README.md"}) + if match != false { + t.Errorf("failed to get a false match on exclusion pattern, got %v", match) + } +} + +// A folder pattern followed by an exception should return false. +func TestPatternMatchesFolderWithSlashExclusions(t *testing.T) { + match, _ := Matches("docs/README.md", []string{"docs/", "!docs/README.md"}) + if match != false { + t.Errorf("failed to get a false match on exclusion pattern, got %v", match) + } +} + +// A folder pattern followed by an exception should return false. +func TestPatternMatchesFolderWildcardExclusions(t *testing.T) { + match, _ := Matches("docs/README.md", []string{"docs/*", "!docs/README.md"}) + if match != false { + t.Errorf("failed to get a false match on exclusion pattern, got %v", match) + } +} + +// A pattern followed by an exclusion should return false. +func TestExclusionPatternMatchesPatternAfter(t *testing.T) { + match, _ := Matches("fileutils.go", []string{"*.go", "!fileutils.go"}) + if match != false { + t.Errorf("failed to get false match on exclusion pattern, got %v", match) + } +} + +// A filename evaluating to . should return false. +func TestExclusionPatternMatchesWholeDirectory(t *testing.T) { + match, _ := Matches(".", []string{"*.go"}) + if match != false { + t.Errorf("failed to get false match on ., got %v", match) + } +} + +// A single ! pattern should return an error. +func TestSingleExclamationError(t *testing.T) { + _, err := Matches("fileutils.go", []string{"!"}) + if err == nil { + t.Errorf("failed to get an error for a single exclamation point, got %v", err) + } +} + +// A string preceded with a ! should return true from Exclusion. +func TestExclusion(t *testing.T) { + exclusion := exclusion("!") + if !exclusion { + t.Errorf("failed to get true for a single !, got %v", exclusion) + } +} + +// Matches with no patterns +func TestMatchesWithNoPatterns(t *testing.T) { + matches, err := Matches("/any/path/there", []string{}) + if err != nil { + t.Fatal(err) + } + if matches { + t.Fatalf("Should not have match anything") + } +} + +// Matches with malformed patterns +func TestMatchesWithMalformedPatterns(t *testing.T) { + matches, err := Matches("/any/path/there", []string{"["}) + if err == nil { + t.Fatal("Should have failed because of a malformed syntax in the pattern") + } + if matches { + t.Fatalf("Should not have match anything") + } +} + +// An empty string should return true from Empty. +func TestEmpty(t *testing.T) { + empty := empty("") + if !empty { + t.Errorf("failed to get true for an empty string, got %v", empty) + } +} + +func TestCleanPatterns(t *testing.T) { + cleaned, _, _, _ := CleanPatterns([]string{"docs", "config"}) + if len(cleaned) != 2 { + t.Errorf("expected 2 element slice, got %v", len(cleaned)) + } +} + +func TestCleanPatternsStripEmptyPatterns(t *testing.T) { + cleaned, _, _, _ := CleanPatterns([]string{"docs", "config", ""}) + if len(cleaned) != 2 { + t.Errorf("expected 2 element slice, got %v", len(cleaned)) + } +} + +func TestCleanPatternsExceptionFlag(t *testing.T) { + _, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md"}) + if !exceptions { + t.Errorf("expected exceptions to be true, got %v", exceptions) + } +} + +func TestCleanPatternsLeadingSpaceTrimmed(t *testing.T) { + _, _, exceptions, _ := CleanPatterns([]string{"docs", " !docs/README.md"}) + if !exceptions { + t.Errorf("expected exceptions to be true, got %v", exceptions) + } +} + +func TestCleanPatternsTrailingSpaceTrimmed(t *testing.T) { + _, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md "}) + if !exceptions { + t.Errorf("expected exceptions to be true, got %v", exceptions) + } +} + +func TestCleanPatternsErrorSingleException(t *testing.T) { + _, _, _, err := CleanPatterns([]string{"!"}) + if err == nil { + t.Errorf("expected error on single exclamation point, got %v", err) + } +} + +func TestCleanPatternsFolderSplit(t *testing.T) { + _, dirs, _, _ := CleanPatterns([]string{"docs/config/CONFIG.md"}) + if dirs[0][0] != "docs" { + t.Errorf("expected first element in dirs slice to be docs, got %v", dirs[0][1]) + } + if dirs[0][1] != "config" { + t.Errorf("expected first element in dirs slice to be config, got %v", dirs[0][1]) + } +} + +func TestCreateIfNotExistsDir(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempFolder) + + folderToCreate := filepath.Join(tempFolder, "tocreate") + + if err := CreateIfNotExists(folderToCreate, true); err != nil { + t.Fatal(err) + } + fileinfo, err := os.Stat(folderToCreate) + if err != nil { + t.Fatalf("Should have create a folder, got %v", err) + } + + if !fileinfo.IsDir() { + t.Fatalf("Should have been a dir, seems it's not") + } +} + +func TestCreateIfNotExistsFile(t *testing.T) { + tempFolder, err := ioutil.TempDir("", "docker-fileutils-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempFolder) + + fileToCreate := filepath.Join(tempFolder, "file/to/create") + + if err := CreateIfNotExists(fileToCreate, false); err != nil { + t.Fatal(err) + } + fileinfo, err := os.Stat(fileToCreate) + if err != nil { + t.Fatalf("Should have create a file, got %v", err) + } + + if fileinfo.IsDir() { + t.Fatalf("Should have been a file, seems it's not") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/fmt.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/fmt.go new file mode 100644 index 0000000..801132f --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/fmt.go @@ -0,0 +1,14 @@ +package ioutils + +import ( + "fmt" + "io" +) + +// FprintfIfNotEmpty prints the string value if it's not empty +func FprintfIfNotEmpty(w io.Writer, format, value string) (int, error) { + if value != "" { + return fmt.Fprintf(w, format, value) + } + return 0, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/fmt_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/fmt_test.go new file mode 100644 index 0000000..8968863 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/fmt_test.go @@ -0,0 +1,17 @@ +package ioutils + +import "testing" + +func TestFprintfIfNotEmpty(t *testing.T) { + wc := NewWriteCounter(&NopWriter{}) + n, _ := FprintfIfNotEmpty(wc, "foo%s", "") + + if wc.Count != 0 || n != 0 { + t.Errorf("Wrong count: %v vs. %v vs. 0", wc.Count, n) + } + + n, _ = FprintfIfNotEmpty(wc, "foo%s", "bar") + if wc.Count != 6 || n != 6 { + t.Errorf("Wrong count: %v vs. %v vs. 6", wc.Count, n) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/multireader.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/multireader.go new file mode 100644 index 0000000..f231aa9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/multireader.go @@ -0,0 +1,226 @@ +package ioutils + +import ( + "bytes" + "fmt" + "io" + "os" +) + +type pos struct { + idx int + offset int64 +} + +type multiReadSeeker struct { + readers []io.ReadSeeker + pos *pos + posIdx map[io.ReadSeeker]int +} + +func (r *multiReadSeeker) Seek(offset int64, whence int) (int64, error) { + var tmpOffset int64 + switch whence { + case os.SEEK_SET: + for i, rdr := range r.readers { + // get size of the current reader + s, err := rdr.Seek(0, os.SEEK_END) + if err != nil { + return -1, err + } + + if offset > tmpOffset+s { + if i == len(r.readers)-1 { + rdrOffset := s + (offset - tmpOffset) + if _, err := rdr.Seek(rdrOffset, os.SEEK_SET); err != nil { + return -1, err + } + r.pos = &pos{i, rdrOffset} + return offset, nil + } + + tmpOffset += s + continue + } + + rdrOffset := offset - tmpOffset + idx := i + + rdr.Seek(rdrOffset, os.SEEK_SET) + // make sure all following readers are at 0 + for _, rdr := range r.readers[i+1:] { + rdr.Seek(0, os.SEEK_SET) + } + + if rdrOffset == s && i != len(r.readers)-1 { + idx += 1 + rdrOffset = 0 + } + r.pos = &pos{idx, rdrOffset} + return offset, nil + } + case os.SEEK_END: + for _, rdr := range r.readers { + s, err := rdr.Seek(0, os.SEEK_END) + if err != nil { + return -1, err + } + tmpOffset += s + } + r.Seek(tmpOffset+offset, os.SEEK_SET) + return tmpOffset + offset, nil + case os.SEEK_CUR: + if r.pos == nil { + return r.Seek(offset, os.SEEK_SET) + } + // Just return the current offset + if offset == 0 { + return r.getCurOffset() + } + + curOffset, err := r.getCurOffset() + if err != nil { + return -1, err + } + rdr, rdrOffset, err := r.getReaderForOffset(curOffset + offset) + if err != nil { + return -1, err + } + + r.pos = &pos{r.posIdx[rdr], rdrOffset} + return curOffset + offset, nil + default: + return -1, fmt.Errorf("Invalid whence: %d", whence) + } + + return -1, fmt.Errorf("Error seeking for whence: %d, offset: %d", whence, offset) +} + +func (r *multiReadSeeker) getReaderForOffset(offset int64) (io.ReadSeeker, int64, error) { + var rdr io.ReadSeeker + var rdrOffset int64 + + for i, rdr := range r.readers { + offsetTo, err := r.getOffsetToReader(rdr) + if err != nil { + return nil, -1, err + } + if offsetTo > offset { + rdr = r.readers[i-1] + rdrOffset = offsetTo - offset + break + } + + if rdr == r.readers[len(r.readers)-1] { + rdrOffset = offsetTo + offset + break + } + } + + return rdr, rdrOffset, nil +} + +func (r *multiReadSeeker) getCurOffset() (int64, error) { + var totalSize int64 + for _, rdr := range r.readers[:r.pos.idx+1] { + if r.posIdx[rdr] == r.pos.idx { + totalSize += r.pos.offset + break + } + + size, err := getReadSeekerSize(rdr) + if err != nil { + return -1, fmt.Errorf("error getting seeker size: %v", err) + } + totalSize += size + } + return totalSize, nil +} + +func (r *multiReadSeeker) getOffsetToReader(rdr io.ReadSeeker) (int64, error) { + var offset int64 + for _, r := range r.readers { + if r == rdr { + break + } + + size, err := getReadSeekerSize(rdr) + if err != nil { + return -1, err + } + offset += size + } + return offset, nil +} + +func (r *multiReadSeeker) Read(b []byte) (int, error) { + if r.pos == nil { + r.pos = &pos{0, 0} + } + + bCap := int64(cap(b)) + buf := bytes.NewBuffer(nil) + var rdr io.ReadSeeker + + for _, rdr = range r.readers[r.pos.idx:] { + readBytes, err := io.CopyN(buf, rdr, bCap) + if err != nil && err != io.EOF { + return -1, err + } + bCap -= readBytes + + if bCap == 0 { + break + } + } + + rdrPos, err := rdr.Seek(0, os.SEEK_CUR) + if err != nil { + return -1, err + } + r.pos = &pos{r.posIdx[rdr], rdrPos} + return buf.Read(b) +} + +func getReadSeekerSize(rdr io.ReadSeeker) (int64, error) { + // save the current position + pos, err := rdr.Seek(0, os.SEEK_CUR) + if err != nil { + return -1, err + } + + // get the size + size, err := rdr.Seek(0, os.SEEK_END) + if err != nil { + return -1, err + } + + // reset the position + if _, err := rdr.Seek(pos, os.SEEK_SET); err != nil { + return -1, err + } + return size, nil +} + +// MultiReadSeeker returns a ReadSeeker that's the logical concatenation of the provided +// input readseekers. After calling this method the initial position is set to the +// beginning of the first ReadSeeker. At the end of a ReadSeeker, Read always advances +// to the beginning of the next ReadSeeker and returns EOF at the end of the last ReadSeeker. +// Seek can be used over the sum of lengths of all readseekers. +// +// When a MultiReadSeeker is used, no Read and Seek operations should be made on +// its ReadSeeker components. Also, users should make no assumption on the state +// of individual readseekers while the MultiReadSeeker is used. +func MultiReadSeeker(readers ...io.ReadSeeker) io.ReadSeeker { + if len(readers) == 1 { + return readers[0] + } + idx := make(map[io.ReadSeeker]int) + for i, rdr := range readers { + idx[rdr] = i + } + return &multiReadSeeker{ + readers: readers, + posIdx: idx, + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/multireader_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/multireader_test.go new file mode 100644 index 0000000..de495b5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/multireader_test.go @@ -0,0 +1,149 @@ +package ioutils + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestMultiReadSeekerReadAll(t *testing.T) { + str := "hello world" + s1 := strings.NewReader(str + " 1") + s2 := strings.NewReader(str + " 2") + s3 := strings.NewReader(str + " 3") + mr := MultiReadSeeker(s1, s2, s3) + + expectedSize := int64(s1.Len() + s2.Len() + s3.Len()) + + b, err := ioutil.ReadAll(mr) + if err != nil { + t.Fatal(err) + } + + expected := "hello world 1hello world 2hello world 3" + if string(b) != expected { + t.Fatalf("ReadAll failed, got: %q, expected %q", string(b), expected) + } + + size, err := mr.Seek(0, os.SEEK_END) + if err != nil { + t.Fatal(err) + } + if size != expectedSize { + t.Fatalf("reader size does not match, got %d, expected %d", size, expectedSize) + } + + // Reset the position and read again + pos, err := mr.Seek(0, os.SEEK_SET) + if err != nil { + t.Fatal(err) + } + if pos != 0 { + t.Fatalf("expected position to be set to 0, got %d", pos) + } + + b, err = ioutil.ReadAll(mr) + if err != nil { + t.Fatal(err) + } + + if string(b) != expected { + t.Fatalf("ReadAll failed, got: %q, expected %q", string(b), expected) + } +} + +func TestMultiReadSeekerReadEach(t *testing.T) { + str := "hello world" + s1 := strings.NewReader(str + " 1") + s2 := strings.NewReader(str + " 2") + s3 := strings.NewReader(str + " 3") + mr := MultiReadSeeker(s1, s2, s3) + + var totalBytes int64 + for i, s := range []*strings.Reader{s1, s2, s3} { + sLen := int64(s.Len()) + buf := make([]byte, s.Len()) + expected := []byte(fmt.Sprintf("%s %d", str, i+1)) + + if _, err := mr.Read(buf); err != nil && err != io.EOF { + t.Fatal(err) + } + + if !bytes.Equal(buf, expected) { + t.Fatalf("expected %q to be %q", string(buf), string(expected)) + } + + pos, err := mr.Seek(0, os.SEEK_CUR) + if err != nil { + t.Fatalf("iteration: %d, error: %v", i+1, err) + } + + // check that the total bytes read is the current position of the seeker + totalBytes += sLen + if pos != totalBytes { + t.Fatalf("expected current position to be: %d, got: %d, iteration: %d", totalBytes, pos, i+1) + } + + // This tests not only that SEEK_SET and SEEK_CUR give the same values, but that the next iteration is in the expected position as well + newPos, err := mr.Seek(pos, os.SEEK_SET) + if err != nil { + t.Fatal(err) + } + if newPos != pos { + t.Fatalf("expected to get same position when calling SEEK_SET with value from SEEK_CUR, cur: %d, set: %d", pos, newPos) + } + } +} + +func TestMultiReadSeekerReadSpanningChunks(t *testing.T) { + str := "hello world" + s1 := strings.NewReader(str + " 1") + s2 := strings.NewReader(str + " 2") + s3 := strings.NewReader(str + " 3") + mr := MultiReadSeeker(s1, s2, s3) + + buf := make([]byte, s1.Len()+3) + _, err := mr.Read(buf) + if err != nil { + t.Fatal(err) + } + + // expected is the contents of s1 + 3 bytes from s2, ie, the `hel` at the end of this string + expected := "hello world 1hel" + if string(buf) != expected { + t.Fatalf("expected %s to be %s", string(buf), expected) + } +} + +func TestMultiReadSeekerNegativeSeek(t *testing.T) { + str := "hello world" + s1 := strings.NewReader(str + " 1") + s2 := strings.NewReader(str + " 2") + s3 := strings.NewReader(str + " 3") + mr := MultiReadSeeker(s1, s2, s3) + + s1Len := s1.Len() + s2Len := s2.Len() + s3Len := s3.Len() + + s, err := mr.Seek(int64(-1*s3.Len()), os.SEEK_END) + if err != nil { + t.Fatal(err) + } + if s != int64(s1Len+s2Len) { + t.Fatalf("expected %d to be %d", s, s1.Len()+s2.Len()) + } + + buf := make([]byte, s3Len) + if _, err := mr.Read(buf); err != nil && err != io.EOF { + t.Fatal(err) + } + expected := fmt.Sprintf("%s %d", str, 3) + if string(buf) != fmt.Sprintf("%s %d", str, 3) { + t.Fatalf("expected %q to be %q", string(buf), expected) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/readers.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/readers.go new file mode 100644 index 0000000..ff09baa --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/readers.go @@ -0,0 +1,254 @@ +package ioutils + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "io" + "math/big" + "sync" + "time" +) + +type readCloserWrapper struct { + io.Reader + closer func() error +} + +func (r *readCloserWrapper) Close() error { + return r.closer() +} + +func NewReadCloserWrapper(r io.Reader, closer func() error) io.ReadCloser { + return &readCloserWrapper{ + Reader: r, + closer: closer, + } +} + +type readerErrWrapper struct { + reader io.Reader + closer func() +} + +func (r *readerErrWrapper) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + if err != nil { + r.closer() + } + return n, err +} + +func NewReaderErrWrapper(r io.Reader, closer func()) io.Reader { + return &readerErrWrapper{ + reader: r, + closer: closer, + } +} + +// bufReader allows the underlying reader to continue to produce +// output by pre-emptively reading from the wrapped reader. +// This is achieved by buffering this data in bufReader's +// expanding buffer. +type bufReader struct { + sync.Mutex + buf *bytes.Buffer + reader io.Reader + err error + wait sync.Cond + drainBuf []byte + reuseBuf []byte + maxReuse int64 + resetTimeout time.Duration + bufLenResetThreshold int64 + maxReadDataReset int64 +} + +func NewBufReader(r io.Reader) *bufReader { + var timeout int + if randVal, err := rand.Int(rand.Reader, big.NewInt(120)); err == nil { + timeout = int(randVal.Int64()) + 180 + } else { + timeout = 300 + } + reader := &bufReader{ + buf: &bytes.Buffer{}, + drainBuf: make([]byte, 1024), + reuseBuf: make([]byte, 4096), + maxReuse: 1000, + resetTimeout: time.Second * time.Duration(timeout), + bufLenResetThreshold: 100 * 1024, + maxReadDataReset: 10 * 1024 * 1024, + reader: r, + } + reader.wait.L = &reader.Mutex + go reader.drain() + return reader +} + +func NewBufReaderWithDrainbufAndBuffer(r io.Reader, drainBuffer []byte, buffer *bytes.Buffer) *bufReader { + reader := &bufReader{ + buf: buffer, + drainBuf: drainBuffer, + reader: r, + } + reader.wait.L = &reader.Mutex + go reader.drain() + return reader +} + +func (r *bufReader) drain() { + var ( + duration time.Duration + lastReset time.Time + now time.Time + reset bool + bufLen int64 + dataSinceReset int64 + maxBufLen int64 + reuseBufLen int64 + reuseCount int64 + ) + reuseBufLen = int64(len(r.reuseBuf)) + lastReset = time.Now() + for { + n, err := r.reader.Read(r.drainBuf) + dataSinceReset += int64(n) + r.Lock() + bufLen = int64(r.buf.Len()) + if bufLen > maxBufLen { + maxBufLen = bufLen + } + + // Avoid unbounded growth of the buffer over time. + // This has been discovered to be the only non-intrusive + // solution to the unbounded growth of the buffer. + // Alternative solutions such as compression, multiple + // buffers, channels and other similar pieces of code + // were reducing throughput, overall Docker performance + // or simply crashed Docker. + // This solution releases the buffer when specific + // conditions are met to avoid the continuous resizing + // of the buffer for long lived containers. + // + // Move data to the front of the buffer if it's + // smaller than what reuseBuf can store + if bufLen > 0 && reuseBufLen >= bufLen { + n, _ := r.buf.Read(r.reuseBuf) + r.buf.Write(r.reuseBuf[0:n]) + // Take action if the buffer has been reused too many + // times and if there's data in the buffer. + // The timeout is also used as means to avoid doing + // these operations more often or less often than + // required. + // The various conditions try to detect heavy activity + // in the buffer which might be indicators of heavy + // growth of the buffer. + } else if reuseCount >= r.maxReuse && bufLen > 0 { + now = time.Now() + duration = now.Sub(lastReset) + timeoutReached := duration >= r.resetTimeout + + // The timeout has been reached and the + // buffered data couldn't be moved to the front + // of the buffer, so the buffer gets reset. + if timeoutReached && bufLen > reuseBufLen { + reset = true + } + // The amount of buffered data is too high now, + // reset the buffer. + if timeoutReached && maxBufLen >= r.bufLenResetThreshold { + reset = true + } + // Reset the buffer if a certain amount of + // data has gone through the buffer since the + // last reset. + if timeoutReached && dataSinceReset >= r.maxReadDataReset { + reset = true + } + // The buffered data is moved to a fresh buffer, + // swap the old buffer with the new one and + // reset all counters. + if reset { + newbuf := &bytes.Buffer{} + newbuf.ReadFrom(r.buf) + r.buf = newbuf + lastReset = now + reset = false + dataSinceReset = 0 + maxBufLen = 0 + reuseCount = 0 + } + } + if err != nil { + r.err = err + } else { + r.buf.Write(r.drainBuf[0:n]) + } + reuseCount++ + r.wait.Signal() + r.Unlock() + callSchedulerIfNecessary() + if err != nil { + break + } + } +} + +func (r *bufReader) Read(p []byte) (n int, err error) { + r.Lock() + defer r.Unlock() + for { + n, err = r.buf.Read(p) + if n > 0 { + return n, err + } + if r.err != nil { + return 0, r.err + } + r.wait.Wait() + } +} + +func (r *bufReader) Close() error { + closer, ok := r.reader.(io.ReadCloser) + if !ok { + return nil + } + return closer.Close() +} + +func HashData(src io.Reader) (string, error) { + h := sha256.New() + if _, err := io.Copy(h, src); err != nil { + return "", err + } + return "sha256:" + hex.EncodeToString(h.Sum(nil)), nil +} + +type OnEOFReader struct { + Rc io.ReadCloser + Fn func() +} + +func (r *OnEOFReader) Read(p []byte) (n int, err error) { + n, err = r.Rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *OnEOFReader) Close() error { + err := r.Rc.Close() + r.runFunc() + return err +} + +func (r *OnEOFReader) runFunc() { + if fn := r.Fn; fn != nil { + fn() + r.Fn = nil + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/readers_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/readers_test.go new file mode 100644 index 0000000..0a39b6e --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/readers_test.go @@ -0,0 +1,216 @@ +package ioutils + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "testing" +) + +// Implement io.Reader +type errorReader struct{} + +func (r *errorReader) Read(p []byte) (int, error) { + return 0, fmt.Errorf("Error reader always fail.") +} + +func TestReadCloserWrapperClose(t *testing.T) { + reader := strings.NewReader("A string reader") + wrapper := NewReadCloserWrapper(reader, func() error { + return fmt.Errorf("This will be called when closing") + }) + err := wrapper.Close() + if err == nil || !strings.Contains(err.Error(), "This will be called when closing") { + t.Fatalf("readCloserWrapper should have call the anonymous func and thus, fail.") + } +} + +func TestReaderErrWrapperReadOnError(t *testing.T) { + called := false + reader := &errorReader{} + wrapper := NewReaderErrWrapper(reader, func() { + called = true + }) + _, err := wrapper.Read([]byte{}) + if err == nil || !strings.Contains(err.Error(), "Error reader always fail.") { + t.Fatalf("readErrWrapper should returned an error") + } + if !called { + t.Fatalf("readErrWrapper should have call the anonymous function on failure") + } +} + +func TestReaderErrWrapperRead(t *testing.T) { + reader := strings.NewReader("a string reader.") + wrapper := NewReaderErrWrapper(reader, func() { + t.Fatalf("readErrWrapper should not have called the anonymous function") + }) + // Read 20 byte (should be ok with the string above) + num, err := wrapper.Read(make([]byte, 20)) + if err != nil { + t.Fatal(err) + } + if num != 16 { + t.Fatalf("readerErrWrapper should have read 16 byte, but read %d", num) + } +} + +func TestNewBufReaderWithDrainbufAndBuffer(t *testing.T) { + reader, writer := io.Pipe() + + drainBuffer := make([]byte, 1024) + buffer := bytes.Buffer{} + bufreader := NewBufReaderWithDrainbufAndBuffer(reader, drainBuffer, &buffer) + + // Write everything down to a Pipe + // Usually, a pipe should block but because of the buffered reader, + // the writes will go through + done := make(chan bool) + go func() { + writer.Write([]byte("hello world")) + writer.Close() + done <- true + }() + + // Drain the reader *after* everything has been written, just to verify + // it is indeed buffering + <-done + + output, err := ioutil.ReadAll(bufreader) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(output, []byte("hello world")) { + t.Error(string(output)) + } +} + +func TestBufReader(t *testing.T) { + reader, writer := io.Pipe() + bufreader := NewBufReader(reader) + + // Write everything down to a Pipe + // Usually, a pipe should block but because of the buffered reader, + // the writes will go through + done := make(chan bool) + go func() { + writer.Write([]byte("hello world")) + writer.Close() + done <- true + }() + + // Drain the reader *after* everything has been written, just to verify + // it is indeed buffering + <-done + output, err := ioutil.ReadAll(bufreader) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(output, []byte("hello world")) { + t.Error(string(output)) + } +} + +func TestBufReaderCloseWithNonReaderCloser(t *testing.T) { + reader := strings.NewReader("buffer") + bufreader := NewBufReader(reader) + + if err := bufreader.Close(); err != nil { + t.Fatal(err) + } + +} + +// implements io.ReadCloser +type simpleReaderCloser struct{} + +func (r *simpleReaderCloser) Read(p []byte) (n int, err error) { + return 0, nil +} + +func (r *simpleReaderCloser) Close() error { + return nil +} + +func TestBufReaderCloseWithReaderCloser(t *testing.T) { + reader := &simpleReaderCloser{} + bufreader := NewBufReader(reader) + + err := bufreader.Close() + if err != nil { + t.Fatal(err) + } + +} + +func TestHashData(t *testing.T) { + reader := strings.NewReader("hash-me") + actual, err := HashData(reader) + if err != nil { + t.Fatal(err) + } + expected := "sha256:4d11186aed035cc624d553e10db358492c84a7cd6b9670d92123c144930450aa" + if actual != expected { + t.Fatalf("Expecting %s, got %s", expected, actual) + } +} + +type repeatedReader struct { + readCount int + maxReads int + data []byte +} + +func newRepeatedReader(max int, data []byte) *repeatedReader { + return &repeatedReader{0, max, data} +} + +func (r *repeatedReader) Read(p []byte) (int, error) { + if r.readCount >= r.maxReads { + return 0, io.EOF + } + r.readCount++ + n := copy(p, r.data) + return n, nil +} + +func testWithData(data []byte, reads int) { + reader := newRepeatedReader(reads, data) + bufReader := NewBufReader(reader) + io.Copy(ioutil.Discard, bufReader) +} + +func Benchmark1M10BytesReads(b *testing.B) { + reads := 1000000 + readSize := int64(10) + data := make([]byte, readSize) + b.SetBytes(readSize * int64(reads)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + testWithData(data, reads) + } +} + +func Benchmark1M1024BytesReads(b *testing.B) { + reads := 1000000 + readSize := int64(1024) + data := make([]byte, readSize) + b.SetBytes(readSize * int64(reads)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + testWithData(data, reads) + } +} + +func Benchmark10k32KBytesReads(b *testing.B) { + reads := 10000 + readSize := int64(32 * 1024) + data := make([]byte, readSize) + b.SetBytes(readSize * int64(reads)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + testWithData(data, reads) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/scheduler.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/scheduler.go new file mode 100644 index 0000000..3c88f29 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/scheduler.go @@ -0,0 +1,6 @@ +// +build !gccgo + +package ioutils + +func callSchedulerIfNecessary() { +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/scheduler_gccgo.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/scheduler_gccgo.go new file mode 100644 index 0000000..c11d02b --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/scheduler_gccgo.go @@ -0,0 +1,13 @@ +// +build gccgo + +package ioutils + +import ( + "runtime" +) + +func callSchedulerIfNecessary() { + //allow or force Go scheduler to switch context, without explicitly + //forcing this will make it hang when using gccgo implementation + runtime.Gosched() +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writeflusher.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writeflusher.go new file mode 100644 index 0000000..2509547 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writeflusher.go @@ -0,0 +1,47 @@ +package ioutils + +import ( + "io" + "net/http" + "sync" +) + +type WriteFlusher struct { + sync.Mutex + w io.Writer + flusher http.Flusher + flushed bool +} + +func (wf *WriteFlusher) Write(b []byte) (n int, err error) { + wf.Lock() + defer wf.Unlock() + n, err = wf.w.Write(b) + wf.flushed = true + wf.flusher.Flush() + return n, err +} + +// Flush the stream immediately. +func (wf *WriteFlusher) Flush() { + wf.Lock() + defer wf.Unlock() + wf.flushed = true + wf.flusher.Flush() +} + +func (wf *WriteFlusher) Flushed() bool { + wf.Lock() + defer wf.Unlock() + return wf.flushed +} + +func NewWriteFlusher(w io.Writer) *WriteFlusher { + var flusher http.Flusher + if f, ok := w.(http.Flusher); ok { + flusher = f + } else { + flusher = &NopFlusher{} + } + return &WriteFlusher{w: w, flusher: flusher} +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writers.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writers.go new file mode 100644 index 0000000..43fdc44 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writers.go @@ -0,0 +1,60 @@ +package ioutils + +import "io" + +type NopWriter struct{} + +func (*NopWriter) Write(buf []byte) (int, error) { + return len(buf), nil +} + +type nopWriteCloser struct { + io.Writer +} + +func (w *nopWriteCloser) Close() error { return nil } + +func NopWriteCloser(w io.Writer) io.WriteCloser { + return &nopWriteCloser{w} +} + +type NopFlusher struct{} + +func (f *NopFlusher) Flush() {} + +type writeCloserWrapper struct { + io.Writer + closer func() error +} + +func (r *writeCloserWrapper) Close() error { + return r.closer() +} + +func NewWriteCloserWrapper(r io.Writer, closer func() error) io.WriteCloser { + return &writeCloserWrapper{ + Writer: r, + closer: closer, + } +} + +// Wrap a concrete io.Writer and hold a count of the number +// of bytes written to the writer during a "session". +// This can be convenient when write return is masked +// (e.g., json.Encoder.Encode()) +type WriteCounter struct { + Count int64 + Writer io.Writer +} + +func NewWriteCounter(w io.Writer) *WriteCounter { + return &WriteCounter{ + Writer: w, + } +} + +func (wc *WriteCounter) Write(p []byte) (count int, err error) { + count, err = wc.Writer.Write(p) + wc.Count += int64(count) + return +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writers_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writers_test.go new file mode 100644 index 0000000..564b1cd --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/ioutils/writers_test.go @@ -0,0 +1,65 @@ +package ioutils + +import ( + "bytes" + "strings" + "testing" +) + +func TestWriteCloserWrapperClose(t *testing.T) { + called := false + writer := bytes.NewBuffer([]byte{}) + wrapper := NewWriteCloserWrapper(writer, func() error { + called = true + return nil + }) + if err := wrapper.Close(); err != nil { + t.Fatal(err) + } + if !called { + t.Fatalf("writeCloserWrapper should have call the anonymous function.") + } +} + +func TestNopWriteCloser(t *testing.T) { + writer := bytes.NewBuffer([]byte{}) + wrapper := NopWriteCloser(writer) + if err := wrapper.Close(); err != nil { + t.Fatal("NopWriteCloser always return nil on Close.") + } + +} + +func TestNopWriter(t *testing.T) { + nw := &NopWriter{} + l, err := nw.Write([]byte{'c'}) + if err != nil { + t.Fatal(err) + } + if l != 1 { + t.Fatalf("Expected 1 got %d", l) + } +} + +func TestWriteCounter(t *testing.T) { + dummy1 := "This is a dummy string." + dummy2 := "This is another dummy string." + totalLength := int64(len(dummy1) + len(dummy2)) + + reader1 := strings.NewReader(dummy1) + reader2 := strings.NewReader(dummy2) + + var buffer bytes.Buffer + wc := NewWriteCounter(&buffer) + + reader1.WriteTo(wc) + reader2.WriteTo(wc) + + if wc.Count != totalLength { + t.Errorf("Wrong count: %d vs. %d", wc.Count, totalLength) + } + + if buffer.String() != dummy1+dummy2 { + t.Error("Wrong message written") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/pools/pools.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/pools/pools.go new file mode 100644 index 0000000..76e84f9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/pools/pools.go @@ -0,0 +1,119 @@ +// Package pools provides a collection of pools which provide various +// data types with buffers. These can be used to lower the number of +// memory allocations and reuse buffers. +// +// New pools should be added to this package to allow them to be +// shared across packages. +// +// Utility functions which operate on pools should be added to this +// package to allow them to be reused. +package pools + +import ( + "bufio" + "io" + "sync" + + "github.com/docker/docker/pkg/ioutils" +) + +var ( + // BufioReader32KPool is a pool which returns bufio.Reader with a 32K buffer. + BufioReader32KPool *BufioReaderPool + // BufioWriter32KPool is a pool which returns bufio.Writer with a 32K buffer. + BufioWriter32KPool *BufioWriterPool +) + +const buffer32K = 32 * 1024 + +// BufioReaderPool is a bufio reader that uses sync.Pool. +type BufioReaderPool struct { + pool sync.Pool +} + +func init() { + BufioReader32KPool = newBufioReaderPoolWithSize(buffer32K) + BufioWriter32KPool = newBufioWriterPoolWithSize(buffer32K) +} + +// newBufioReaderPoolWithSize is unexported because new pools should be +// added here to be shared where required. +func newBufioReaderPoolWithSize(size int) *BufioReaderPool { + pool := sync.Pool{ + New: func() interface{} { return bufio.NewReaderSize(nil, size) }, + } + return &BufioReaderPool{pool: pool} +} + +// Get returns a bufio.Reader which reads from r. The buffer size is that of the pool. +func (bufPool *BufioReaderPool) Get(r io.Reader) *bufio.Reader { + buf := bufPool.pool.Get().(*bufio.Reader) + buf.Reset(r) + return buf +} + +// Put puts the bufio.Reader back into the pool. +func (bufPool *BufioReaderPool) Put(b *bufio.Reader) { + b.Reset(nil) + bufPool.pool.Put(b) +} + +// Copy is a convenience wrapper which uses a buffer to avoid allocation in io.Copy. +func Copy(dst io.Writer, src io.Reader) (written int64, err error) { + buf := BufioReader32KPool.Get(src) + written, err = io.Copy(dst, buf) + BufioReader32KPool.Put(buf) + return +} + +// NewReadCloserWrapper returns a wrapper which puts the bufio.Reader back +// into the pool and closes the reader if it's an io.ReadCloser. +func (bufPool *BufioReaderPool) NewReadCloserWrapper(buf *bufio.Reader, r io.Reader) io.ReadCloser { + return ioutils.NewReadCloserWrapper(r, func() error { + if readCloser, ok := r.(io.ReadCloser); ok { + readCloser.Close() + } + bufPool.Put(buf) + return nil + }) +} + +// BufioWriterPool is a bufio writer that uses sync.Pool. +type BufioWriterPool struct { + pool sync.Pool +} + +// newBufioWriterPoolWithSize is unexported because new pools should be +// added here to be shared where required. +func newBufioWriterPoolWithSize(size int) *BufioWriterPool { + pool := sync.Pool{ + New: func() interface{} { return bufio.NewWriterSize(nil, size) }, + } + return &BufioWriterPool{pool: pool} +} + +// Get returns a bufio.Writer which writes to w. The buffer size is that of the pool. +func (bufPool *BufioWriterPool) Get(w io.Writer) *bufio.Writer { + buf := bufPool.pool.Get().(*bufio.Writer) + buf.Reset(w) + return buf +} + +// Put puts the bufio.Writer back into the pool. +func (bufPool *BufioWriterPool) Put(b *bufio.Writer) { + b.Reset(nil) + bufPool.pool.Put(b) +} + +// NewWriteCloserWrapper returns a wrapper which puts the bufio.Writer back +// into the pool and closes the writer if it's an io.Writecloser. +func (bufPool *BufioWriterPool) NewWriteCloserWrapper(buf *bufio.Writer, w io.Writer) io.WriteCloser { + return ioutils.NewWriteCloserWrapper(w, func() error { + buf.Flush() + if writeCloser, ok := w.(io.WriteCloser); ok { + writeCloser.Close() + } + bufPool.Put(buf) + return nil + }) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/pools/pools_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/pools/pools_test.go new file mode 100644 index 0000000..7868980 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/pools/pools_test.go @@ -0,0 +1,162 @@ +package pools + +import ( + "bufio" + "bytes" + "io" + "strings" + "testing" +) + +func TestBufioReaderPoolGetWithNoReaderShouldCreateOne(t *testing.T) { + reader := BufioReader32KPool.Get(nil) + if reader == nil { + t.Fatalf("BufioReaderPool should have create a bufio.Reader but did not.") + } +} + +func TestBufioReaderPoolPutAndGet(t *testing.T) { + sr := bufio.NewReader(strings.NewReader("foobar")) + reader := BufioReader32KPool.Get(sr) + if reader == nil { + t.Fatalf("BufioReaderPool should not return a nil reader.") + } + // verify the first 3 byte + buf1 := make([]byte, 3) + _, err := reader.Read(buf1) + if err != nil { + t.Fatal(err) + } + if actual := string(buf1); actual != "foo" { + t.Fatalf("The first letter should have been 'foo' but was %v", actual) + } + BufioReader32KPool.Put(reader) + // Try to read the next 3 bytes + _, err = sr.Read(make([]byte, 3)) + if err == nil || err != io.EOF { + t.Fatalf("The buffer should have been empty, issue an EOF error.") + } +} + +type simpleReaderCloser struct { + io.Reader + closed bool +} + +func (r *simpleReaderCloser) Close() error { + r.closed = true + return nil +} + +func TestNewReadCloserWrapperWithAReadCloser(t *testing.T) { + br := bufio.NewReader(strings.NewReader("")) + sr := &simpleReaderCloser{ + Reader: strings.NewReader("foobar"), + closed: false, + } + reader := BufioReader32KPool.NewReadCloserWrapper(br, sr) + if reader == nil { + t.Fatalf("NewReadCloserWrapper should not return a nil reader.") + } + // Verify the content of reader + buf := make([]byte, 3) + _, err := reader.Read(buf) + if err != nil { + t.Fatal(err) + } + if actual := string(buf); actual != "foo" { + t.Fatalf("The first 3 letter should have been 'foo' but were %v", actual) + } + reader.Close() + // Read 3 more bytes "bar" + _, err = reader.Read(buf) + if err != nil { + t.Fatal(err) + } + if actual := string(buf); actual != "bar" { + t.Fatalf("The first 3 letter should have been 'bar' but were %v", actual) + } + if !sr.closed { + t.Fatalf("The ReaderCloser should have been closed, it is not.") + } +} + +func TestBufioWriterPoolGetWithNoReaderShouldCreateOne(t *testing.T) { + writer := BufioWriter32KPool.Get(nil) + if writer == nil { + t.Fatalf("BufioWriterPool should have create a bufio.Writer but did not.") + } +} + +func TestBufioWriterPoolPutAndGet(t *testing.T) { + buf := new(bytes.Buffer) + bw := bufio.NewWriter(buf) + writer := BufioWriter32KPool.Get(bw) + if writer == nil { + t.Fatalf("BufioReaderPool should not return a nil writer.") + } + written, err := writer.Write([]byte("foobar")) + if err != nil { + t.Fatal(err) + } + if written != 6 { + t.Fatalf("Should have written 6 bytes, but wrote %v bytes", written) + } + // Make sure we Flush all the way ? + writer.Flush() + bw.Flush() + if len(buf.Bytes()) != 6 { + t.Fatalf("The buffer should contain 6 bytes ('foobar') but contains %v ('%v')", buf.Bytes(), string(buf.Bytes())) + } + // Reset the buffer + buf.Reset() + BufioWriter32KPool.Put(writer) + // Try to write something + written, err = writer.Write([]byte("barfoo")) + if err != nil { + t.Fatal(err) + } + // If we now try to flush it, it should panic (the writer is nil) + // recover it + defer func() { + if r := recover(); r == nil { + t.Fatal("Trying to flush the writter should have 'paniced', did not.") + } + }() + writer.Flush() +} + +type simpleWriterCloser struct { + io.Writer + closed bool +} + +func (r *simpleWriterCloser) Close() error { + r.closed = true + return nil +} + +func TestNewWriteCloserWrapperWithAWriteCloser(t *testing.T) { + buf := new(bytes.Buffer) + bw := bufio.NewWriter(buf) + sw := &simpleWriterCloser{ + Writer: new(bytes.Buffer), + closed: false, + } + bw.Flush() + writer := BufioWriter32KPool.NewWriteCloserWrapper(bw, sw) + if writer == nil { + t.Fatalf("BufioReaderPool should not return a nil writer.") + } + written, err := writer.Write([]byte("foobar")) + if err != nil { + t.Fatal(err) + } + if written != 6 { + t.Fatalf("Should have written 6 bytes, but wrote %v bytes", written) + } + writer.Close() + if !sw.closed { + t.Fatalf("The ReaderCloser should have been closed, it is not.") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/promise/promise.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/promise/promise.go new file mode 100644 index 0000000..dd52b90 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/promise/promise.go @@ -0,0 +1,11 @@ +package promise + +// Go is a basic promise implementation: it wraps calls a function in a goroutine, +// and returns a channel which will later return the function's return value. +func Go(f func() error) chan error { + ch := make(chan error, 1) + go func() { + ch <- f() + }() + return ch +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/stdcopy/stdcopy.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/stdcopy/stdcopy.go new file mode 100644 index 0000000..684b4d4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/stdcopy/stdcopy.go @@ -0,0 +1,168 @@ +package stdcopy + +import ( + "encoding/binary" + "errors" + "io" + + "github.com/Sirupsen/logrus" +) + +const ( + StdWriterPrefixLen = 8 + StdWriterFdIndex = 0 + StdWriterSizeIndex = 4 +) + +type StdType [StdWriterPrefixLen]byte + +var ( + Stdin StdType = StdType{0: 0} + Stdout StdType = StdType{0: 1} + Stderr StdType = StdType{0: 2} +) + +type StdWriter struct { + io.Writer + prefix StdType + sizeBuf []byte +} + +func (w *StdWriter) Write(buf []byte) (n int, err error) { + var n1, n2 int + if w == nil || w.Writer == nil { + return 0, errors.New("Writer not instantiated") + } + binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf))) + n1, err = w.Writer.Write(w.prefix[:]) + if err != nil { + n = n1 - StdWriterPrefixLen + } else { + n2, err = w.Writer.Write(buf) + n = n1 + n2 - StdWriterPrefixLen + } + if n < 0 { + n = 0 + } + return +} + +// NewStdWriter instantiates a new Writer. +// Everything written to it will be encapsulated using a custom format, +// and written to the underlying `w` stream. +// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection. +// `t` indicates the id of the stream to encapsulate. +// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr. +func NewStdWriter(w io.Writer, t StdType) *StdWriter { + return &StdWriter{ + Writer: w, + prefix: t, + sizeBuf: make([]byte, 4), + } +} + +var ErrInvalidStdHeader = errors.New("Unrecognized input header") + +// StdCopy is a modified version of io.Copy. +// +// StdCopy will demultiplex `src`, assuming that it contains two streams, +// previously multiplexed together using a StdWriter instance. +// As it reads from `src`, StdCopy will write to `dstout` and `dsterr`. +// +// StdCopy will read until it hits EOF on `src`. It will then return a nil error. +// In other words: if `err` is non nil, it indicates a real underlying error. +// +// `written` will hold the total number of bytes written to `dstout` and `dsterr`. +func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) { + var ( + buf = make([]byte, 32*1024+StdWriterPrefixLen+1) + bufLen = len(buf) + nr, nw int + er, ew error + out io.Writer + frameSize int + ) + + for { + // Make sure we have at least a full header + for nr < StdWriterPrefixLen { + var nr2 int + nr2, er = src.Read(buf[nr:]) + nr += nr2 + if er == io.EOF { + if nr < StdWriterPrefixLen { + logrus.Debugf("Corrupted prefix: %v", buf[:nr]) + return written, nil + } + break + } + if er != nil { + logrus.Debugf("Error reading header: %s", er) + return 0, er + } + } + + // Check the first byte to know where to write + switch buf[StdWriterFdIndex] { + case 0: + fallthrough + case 1: + // Write on stdout + out = dstout + case 2: + // Write on stderr + out = dsterr + default: + logrus.Debugf("Error selecting output fd: (%d)", buf[StdWriterFdIndex]) + return 0, ErrInvalidStdHeader + } + + // Retrieve the size of the frame + frameSize = int(binary.BigEndian.Uint32(buf[StdWriterSizeIndex : StdWriterSizeIndex+4])) + logrus.Debugf("framesize: %d", frameSize) + + // Check if the buffer is big enough to read the frame. + // Extend it if necessary. + if frameSize+StdWriterPrefixLen > bufLen { + logrus.Debugf("Extending buffer cap by %d (was %d)", frameSize+StdWriterPrefixLen-bufLen+1, len(buf)) + buf = append(buf, make([]byte, frameSize+StdWriterPrefixLen-bufLen+1)...) + bufLen = len(buf) + } + + // While the amount of bytes read is less than the size of the frame + header, we keep reading + for nr < frameSize+StdWriterPrefixLen { + var nr2 int + nr2, er = src.Read(buf[nr:]) + nr += nr2 + if er == io.EOF { + if nr < frameSize+StdWriterPrefixLen { + logrus.Debugf("Corrupted frame: %v", buf[StdWriterPrefixLen:nr]) + return written, nil + } + break + } + if er != nil { + logrus.Debugf("Error reading frame: %s", er) + return 0, er + } + } + + // Write the retrieved frame (without header) + nw, ew = out.Write(buf[StdWriterPrefixLen : frameSize+StdWriterPrefixLen]) + if ew != nil { + logrus.Debugf("Error writing frame: %s", ew) + return 0, ew + } + // If the frame has not been fully written: error + if nw != frameSize { + logrus.Debugf("Error Short Write: (%d on %d)", nw, frameSize) + return 0, io.ErrShortWrite + } + written += int64(nw) + + // Move the rest of the buffer to the beginning + copy(buf, buf[frameSize+StdWriterPrefixLen:]) + // Move the index + nr -= frameSize + StdWriterPrefixLen + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/stdcopy/stdcopy_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/stdcopy/stdcopy_test.go new file mode 100644 index 0000000..a9fd73a --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/stdcopy/stdcopy_test.go @@ -0,0 +1,85 @@ +package stdcopy + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" +) + +func TestNewStdWriter(t *testing.T) { + writer := NewStdWriter(ioutil.Discard, Stdout) + if writer == nil { + t.Fatalf("NewStdWriter with an invalid StdType should not return nil.") + } +} + +func TestWriteWithUnitializedStdWriter(t *testing.T) { + writer := StdWriter{ + Writer: nil, + prefix: Stdout, + sizeBuf: make([]byte, 4), + } + n, err := writer.Write([]byte("Something here")) + if n != 0 || err == nil { + t.Fatalf("Should fail when given an uncomplete or uninitialized StdWriter") + } +} + +func TestWriteWithNilBytes(t *testing.T) { + writer := NewStdWriter(ioutil.Discard, Stdout) + n, err := writer.Write(nil) + if err != nil { + t.Fatalf("Shouldn't have fail when given no data") + } + if n > 0 { + t.Fatalf("Write should have written 0 byte, but has written %d", n) + } +} + +func TestWrite(t *testing.T) { + writer := NewStdWriter(ioutil.Discard, Stdout) + data := []byte("Test StdWrite.Write") + n, err := writer.Write(data) + if err != nil { + t.Fatalf("Error while writing with StdWrite") + } + if n != len(data) { + t.Fatalf("Write should have writen %d byte but wrote %d.", len(data), n) + } +} + +func TestStdCopyWithInvalidInputHeader(t *testing.T) { + dstOut := NewStdWriter(ioutil.Discard, Stdout) + dstErr := NewStdWriter(ioutil.Discard, Stderr) + src := strings.NewReader("Invalid input") + _, err := StdCopy(dstOut, dstErr, src) + if err == nil { + t.Fatal("StdCopy with invalid input header should fail.") + } +} + +func TestStdCopyWithCorruptedPrefix(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + src := bytes.NewReader(data) + written, err := StdCopy(nil, nil, src) + if err != nil { + t.Fatalf("StdCopy should not return an error with corrupted prefix.") + } + if written != 0 { + t.Fatalf("StdCopy should have written 0, but has written %d", written) + } +} + +func BenchmarkWrite(b *testing.B) { + w := NewStdWriter(ioutil.Discard, Stdout) + data := []byte("Test line for testing stdwriter performance\n") + data = bytes.Repeat(data, 100) + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := w.Write(data); err != nil { + b.Fatal(err) + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/errors.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/errors.go new file mode 100644 index 0000000..6304518 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/errors.go @@ -0,0 +1,9 @@ +package system + +import ( + "errors" +) + +var ( + ErrNotSupportedPlatform = errors.New("platform and architecture is not supported") +) diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/events_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/events_windows.go new file mode 100644 index 0000000..23f7c61 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/events_windows.go @@ -0,0 +1,83 @@ +package system + +// This file implements syscalls for Win32 events which are not implemented +// in golang. + +import ( + "syscall" + "unsafe" +) + +const ( + EVENT_ALL_ACCESS = 0x1F0003 + EVENT_MODIFY_STATUS = 0x0002 +) + +var ( + procCreateEvent = modkernel32.NewProc("CreateEventW") + procOpenEvent = modkernel32.NewProc("OpenEventW") + procSetEvent = modkernel32.NewProc("SetEvent") + procResetEvent = modkernel32.NewProc("ResetEvent") + procPulseEvent = modkernel32.NewProc("PulseEvent") +) + +func CreateEvent(eventAttributes *syscall.SecurityAttributes, manualReset bool, initialState bool, name string) (handle syscall.Handle, err error) { + namep, _ := syscall.UTF16PtrFromString(name) + var _p1 uint32 = 0 + if manualReset { + _p1 = 1 + } + var _p2 uint32 = 0 + if initialState { + _p2 = 1 + } + r0, _, e1 := procCreateEvent.Call(uintptr(unsafe.Pointer(eventAttributes)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(namep))) + use(unsafe.Pointer(namep)) + handle = syscall.Handle(r0) + if handle == syscall.InvalidHandle { + err = e1 + } + return +} + +func OpenEvent(desiredAccess uint32, inheritHandle bool, name string) (handle syscall.Handle, err error) { + namep, _ := syscall.UTF16PtrFromString(name) + var _p1 uint32 = 0 + if inheritHandle { + _p1 = 1 + } + r0, _, e1 := procOpenEvent.Call(uintptr(desiredAccess), uintptr(_p1), uintptr(unsafe.Pointer(namep))) + use(unsafe.Pointer(namep)) + handle = syscall.Handle(r0) + if handle == syscall.InvalidHandle { + err = e1 + } + return +} + +func SetEvent(handle syscall.Handle) (err error) { + return setResetPulse(handle, procSetEvent) +} + +func ResetEvent(handle syscall.Handle) (err error) { + return setResetPulse(handle, procResetEvent) +} + +func PulseEvent(handle syscall.Handle) (err error) { + return setResetPulse(handle, procPulseEvent) +} + +func setResetPulse(handle syscall.Handle, proc *syscall.LazyProc) (err error) { + r0, _, _ := proc.Call(uintptr(handle)) + if r0 != 0 { + err = syscall.Errno(r0) + } + return +} + +var temp unsafe.Pointer + +// use ensures a variable is kept alive without the GC freeing while still needed +func use(p unsafe.Pointer) { + temp = p +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/filesys.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/filesys.go new file mode 100644 index 0000000..e1f70e8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/filesys.go @@ -0,0 +1,11 @@ +// +build !windows + +package system + +import ( + "os" +) + +func MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/filesys_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/filesys_windows.go new file mode 100644 index 0000000..90b5006 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/filesys_windows.go @@ -0,0 +1,64 @@ +// +build windows + +package system + +import ( + "os" + "regexp" + "syscall" +) + +// MkdirAll implementation that is volume path aware for Windows. +func MkdirAll(path string, perm os.FileMode) error { + if re := regexp.MustCompile(`^\\\\\?\\Volume{[a-z0-9-]+}$`); re.MatchString(path) { + return nil + } + + // The rest of this method is copied from os.MkdirAll and should be kept + // as-is to ensure compatibility. + + // Fast path: if we can tell whether path is a directory or file, stop with success or error. + dir, err := os.Stat(path) + if err == nil { + if dir.IsDir() { + return nil + } + return &os.PathError{ + Op: "mkdir", + Path: path, + Err: syscall.ENOTDIR, + } + } + + // Slow path: make sure parent exists and then call Mkdir for path. + i := len(path) + for i > 0 && os.IsPathSeparator(path[i-1]) { // Skip trailing path separator. + i-- + } + + j := i + for j > 0 && !os.IsPathSeparator(path[j-1]) { // Scan backward over element. + j-- + } + + if j > 1 { + // Create parent + err = MkdirAll(path[0:j-1], perm) + if err != nil { + return err + } + } + + // Parent now exists; invoke Mkdir and use its result. + err = os.Mkdir(path, perm) + if err != nil { + // Handle arguments like "foo/." by + // double-checking that directory doesn't exist. + dir, err1 := os.Lstat(path) + if err1 == nil && dir.IsDir() { + return nil + } + return err + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat.go new file mode 100644 index 0000000..d0e43b3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat.go @@ -0,0 +1,19 @@ +// +build !windows + +package system + +import ( + "syscall" +) + +// Lstat takes a path to a file and returns +// a system.Stat_t type pertaining to that file. +// +// Throws an error if the file does not exist +func Lstat(path string) (*Stat_t, error) { + s := &syscall.Stat_t{} + if err := syscall.Lstat(path, s); err != nil { + return nil, err + } + return fromStatT(s) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat_test.go new file mode 100644 index 0000000..6bac492 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat_test.go @@ -0,0 +1,28 @@ +package system + +import ( + "os" + "testing" +) + +// TestLstat tests Lstat for existing and non existing files +func TestLstat(t *testing.T) { + file, invalid, _, dir := prepareFiles(t) + defer os.RemoveAll(dir) + + statFile, err := Lstat(file) + if err != nil { + t.Fatal(err) + } + if statFile == nil { + t.Fatal("returned empty stat for existing file") + } + + statInvalid, err := Lstat(invalid) + if err == nil { + t.Fatal("did not return error for non-existing file") + } + if statInvalid != nil { + t.Fatal("returned non-nil stat for non-existing file") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat_windows.go new file mode 100644 index 0000000..eee1be2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/lstat_windows.go @@ -0,0 +1,29 @@ +// +build windows + +package system + +import ( + "os" +) + +// Some explanation for my own sanity, and hopefully maintainers in the +// future. +// +// Lstat calls os.Lstat to get a fileinfo interface back. +// This is then copied into our own locally defined structure. +// Note the Linux version uses fromStatT to do the copy back, +// but that not strictly necessary when already in an OS specific module. + +func Lstat(path string) (*Stat_t, error) { + fi, err := os.Lstat(path) + if err != nil { + return nil, err + } + + return &Stat_t{ + name: fi.Name(), + size: fi.Size(), + mode: fi.Mode(), + modTime: fi.ModTime(), + isDir: fi.IsDir()}, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo.go new file mode 100644 index 0000000..3b6e947 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo.go @@ -0,0 +1,17 @@ +package system + +// MemInfo contains memory statistics of the host system. +type MemInfo struct { + // Total usable RAM (i.e. physical RAM minus a few reserved bits and the + // kernel binary code). + MemTotal int64 + + // Amount of free memory. + MemFree int64 + + // Total amount of swap space available. + SwapTotal int64 + + // Amount of swap space that is currently unused. + SwapFree int64 +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_linux.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_linux.go new file mode 100644 index 0000000..e2ca140 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_linux.go @@ -0,0 +1,71 @@ +package system + +import ( + "bufio" + "errors" + "io" + "os" + "strconv" + "strings" + + "github.com/docker/docker/pkg/units" +) + +var ( + ErrMalformed = errors.New("malformed file") +) + +// ReadMemInfo retrieves memory statistics of the host system and returns a +// MemInfo type. +func ReadMemInfo() (*MemInfo, error) { + file, err := os.Open("/proc/meminfo") + if err != nil { + return nil, err + } + defer file.Close() + return parseMemInfo(file) +} + +// parseMemInfo parses the /proc/meminfo file into +// a MemInfo object given a io.Reader to the file. +// +// Throws error if there are problems reading from the file +func parseMemInfo(reader io.Reader) (*MemInfo, error) { + meminfo := &MemInfo{} + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + // Expected format: ["MemTotal:", "1234", "kB"] + parts := strings.Fields(scanner.Text()) + + // Sanity checks: Skip malformed entries. + if len(parts) < 3 || parts[2] != "kB" { + continue + } + + // Convert to bytes. + size, err := strconv.Atoi(parts[1]) + if err != nil { + continue + } + bytes := int64(size) * units.KiB + + switch parts[0] { + case "MemTotal:": + meminfo.MemTotal = bytes + case "MemFree:": + meminfo.MemFree = bytes + case "SwapTotal:": + meminfo.SwapTotal = bytes + case "SwapFree:": + meminfo.SwapFree = bytes + } + + } + + // Handle errors that may have occurred during the reading of the file. + if err := scanner.Err(); err != nil { + return nil, err + } + + return meminfo, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_linux_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_linux_test.go new file mode 100644 index 0000000..10ddf79 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_linux_test.go @@ -0,0 +1,38 @@ +package system + +import ( + "strings" + "testing" + + "github.com/docker/docker/pkg/units" +) + +// TestMemInfo tests parseMemInfo with a static meminfo string +func TestMemInfo(t *testing.T) { + const input = ` + MemTotal: 1 kB + MemFree: 2 kB + SwapTotal: 3 kB + SwapFree: 4 kB + Malformed1: + Malformed2: 1 + Malformed3: 2 MB + Malformed4: X kB + ` + meminfo, err := parseMemInfo(strings.NewReader(input)) + if err != nil { + t.Fatal(err) + } + if meminfo.MemTotal != 1*units.KiB { + t.Fatalf("Unexpected MemTotal: %d", meminfo.MemTotal) + } + if meminfo.MemFree != 2*units.KiB { + t.Fatalf("Unexpected MemFree: %d", meminfo.MemFree) + } + if meminfo.SwapTotal != 3*units.KiB { + t.Fatalf("Unexpected SwapTotal: %d", meminfo.SwapTotal) + } + if meminfo.SwapFree != 4*units.KiB { + t.Fatalf("Unexpected SwapFree: %d", meminfo.SwapFree) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_unsupported.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_unsupported.go new file mode 100644 index 0000000..604d338 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_unsupported.go @@ -0,0 +1,7 @@ +// +build !linux,!windows + +package system + +func ReadMemInfo() (*MemInfo, error) { + return nil, ErrNotSupportedPlatform +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_windows.go new file mode 100644 index 0000000..d466425 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/meminfo_windows.go @@ -0,0 +1,44 @@ +package system + +import ( + "syscall" + "unsafe" +) + +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + + procGlobalMemoryStatusEx = modkernel32.NewProc("GlobalMemoryStatusEx") +) + +// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366589(v=vs.85).aspx +// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366770(v=vs.85).aspx +type memorystatusex struct { + dwLength uint32 + dwMemoryLoad uint32 + ullTotalPhys uint64 + ullAvailPhys uint64 + ullTotalPageFile uint64 + ullAvailPageFile uint64 + ullTotalVirtual uint64 + ullAvailVirtual uint64 + ullAvailExtendedVirtual uint64 +} + +// ReadMemInfo retrieves memory statistics of the host system and returns a +// MemInfo type. +func ReadMemInfo() (*MemInfo, error) { + msi := &memorystatusex{ + dwLength: 64, + } + r1, _, _ := procGlobalMemoryStatusEx.Call(uintptr(unsafe.Pointer(msi))) + if r1 == 0 { + return &MemInfo{}, nil + } + return &MemInfo{ + MemTotal: int64(msi.ullTotalPhys), + MemFree: int64(msi.ullAvailPhys), + SwapTotal: int64(msi.ullTotalPageFile), + SwapFree: int64(msi.ullAvailPageFile), + }, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/mknod.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/mknod.go new file mode 100644 index 0000000..26617eb --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/mknod.go @@ -0,0 +1,20 @@ +// +build !windows + +package system + +import ( + "syscall" +) + +// Mknod creates a filesystem node (file, device special file or named pipe) named path +// with attributes specified by mode and dev +func Mknod(path string, mode uint32, dev int) error { + return syscall.Mknod(path, mode, dev) +} + +// Linux device nodes are a bit weird due to backwards compat with 16 bit device nodes. +// They are, from low to high: the lower 8 bits of the minor, then 12 bits of the major, +// then the top 12 bits of the minor +func Mkdev(major int64, minor int64) uint32 { + return uint32(((minor & 0xfff00) << 12) | ((major & 0xfff) << 8) | (minor & 0xff)) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/mknod_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/mknod_windows.go new file mode 100644 index 0000000..1811542 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/mknod_windows.go @@ -0,0 +1,11 @@ +// +build windows + +package system + +func Mknod(path string, mode uint32, dev int) error { + return ErrNotSupportedPlatform +} + +func Mkdev(major int64, minor int64) uint32 { + panic("Mkdev not implemented on Windows.") +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat.go new file mode 100644 index 0000000..e2ecfe5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat.go @@ -0,0 +1,46 @@ +// +build !windows + +package system + +import ( + "syscall" +) + +// Stat_t type contains status of a file. It contains metadata +// like permission, owner, group, size, etc about a file +type Stat_t struct { + mode uint32 + uid uint32 + gid uint32 + rdev uint64 + size int64 + mtim syscall.Timespec +} + +func (s Stat_t) Mode() uint32 { + return s.mode +} + +func (s Stat_t) Uid() uint32 { + return s.uid +} + +func (s Stat_t) Gid() uint32 { + return s.gid +} + +func (s Stat_t) Rdev() uint64 { + return s.rdev +} + +func (s Stat_t) Size() int64 { + return s.size +} + +func (s Stat_t) Mtim() syscall.Timespec { + return s.mtim +} + +func (s Stat_t) GetLastModification() syscall.Timespec { + return s.Mtim() +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_linux.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_linux.go new file mode 100644 index 0000000..80262d9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_linux.go @@ -0,0 +1,33 @@ +package system + +import ( + "syscall" +) + +// fromStatT converts a syscall.Stat_t type to a system.Stat_t type +func fromStatT(s *syscall.Stat_t) (*Stat_t, error) { + return &Stat_t{size: s.Size, + mode: s.Mode, + uid: s.Uid, + gid: s.Gid, + rdev: s.Rdev, + mtim: s.Mtim}, nil +} + +// FromStatT exists only on linux, and loads a system.Stat_t from a +// syscal.Stat_t. +func FromStatT(s *syscall.Stat_t) (*Stat_t, error) { + return fromStatT(s) +} + +// Stat takes a path to a file and returns +// a system.Stat_t type pertaining to that file. +// +// Throws an error if the file does not exist +func Stat(path string) (*Stat_t, error) { + s := &syscall.Stat_t{} + if err := syscall.Stat(path, s); err != nil { + return nil, err + } + return fromStatT(s) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_test.go new file mode 100644 index 0000000..4534129 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_test.go @@ -0,0 +1,37 @@ +package system + +import ( + "os" + "syscall" + "testing" +) + +// TestFromStatT tests fromStatT for a tempfile +func TestFromStatT(t *testing.T) { + file, _, _, dir := prepareFiles(t) + defer os.RemoveAll(dir) + + stat := &syscall.Stat_t{} + err := syscall.Lstat(file, stat) + + s, err := fromStatT(stat) + if err != nil { + t.Fatal(err) + } + + if stat.Mode != s.Mode() { + t.Fatal("got invalid mode") + } + if stat.Uid != s.Uid() { + t.Fatal("got invalid uid") + } + if stat.Gid != s.Gid() { + t.Fatal("got invalid gid") + } + if stat.Rdev != s.Rdev() { + t.Fatal("got invalid rdev") + } + if stat.Mtim != s.Mtim() { + t.Fatal("got invalid mtim") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_unsupported.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_unsupported.go new file mode 100644 index 0000000..7e0d034 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_unsupported.go @@ -0,0 +1,17 @@ +// +build !linux,!windows + +package system + +import ( + "syscall" +) + +// fromStatT creates a system.Stat_t type from a syscall.Stat_t type +func fromStatT(s *syscall.Stat_t) (*Stat_t, error) { + return &Stat_t{size: s.Size, + mode: uint32(s.Mode), + uid: s.Uid, + gid: s.Gid, + rdev: uint64(s.Rdev), + mtim: s.Mtimespec}, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_windows.go new file mode 100644 index 0000000..b1fd39e --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/stat_windows.go @@ -0,0 +1,36 @@ +// +build windows + +package system + +import ( + "os" + "time" +) + +type Stat_t struct { + name string + size int64 + mode os.FileMode + modTime time.Time + isDir bool +} + +func (s Stat_t) Name() string { + return s.name +} + +func (s Stat_t) Size() int64 { + return s.size +} + +func (s Stat_t) Mode() os.FileMode { + return s.mode +} + +func (s Stat_t) ModTime() time.Time { + return s.modTime +} + +func (s Stat_t) IsDir() bool { + return s.isDir +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/umask.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/umask.go new file mode 100644 index 0000000..fddbecd --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/umask.go @@ -0,0 +1,11 @@ +// +build !windows + +package system + +import ( + "syscall" +) + +func Umask(newmask int) (oldmask int, err error) { + return syscall.Umask(newmask), nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/umask_windows.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/umask_windows.go new file mode 100644 index 0000000..3be563f --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/umask_windows.go @@ -0,0 +1,8 @@ +// +build windows + +package system + +func Umask(newmask int) (oldmask int, err error) { + // should not be called on cli code path + return 0, ErrNotSupportedPlatform +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_darwin.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_darwin.go new file mode 100644 index 0000000..4c6002f --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_darwin.go @@ -0,0 +1,11 @@ +package system + +import "syscall" + +func LUtimesNano(path string, ts []syscall.Timespec) error { + return ErrNotSupportedPlatform +} + +func UtimesNano(path string, ts []syscall.Timespec) error { + return syscall.UtimesNano(path, ts) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_freebsd.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_freebsd.go new file mode 100644 index 0000000..ceaa044 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_freebsd.go @@ -0,0 +1,24 @@ +package system + +import ( + "syscall" + "unsafe" +) + +func LUtimesNano(path string, ts []syscall.Timespec) error { + var _path *byte + _path, err := syscall.BytePtrFromString(path) + if err != nil { + return err + } + + if _, _, err := syscall.Syscall(syscall.SYS_LUTIMES, uintptr(unsafe.Pointer(_path)), uintptr(unsafe.Pointer(&ts[0])), 0); err != 0 && err != syscall.ENOSYS { + return err + } + + return nil +} + +func UtimesNano(path string, ts []syscall.Timespec) error { + return syscall.UtimesNano(path, ts) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_linux.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_linux.go new file mode 100644 index 0000000..8f90298 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_linux.go @@ -0,0 +1,28 @@ +package system + +import ( + "syscall" + "unsafe" +) + +func LUtimesNano(path string, ts []syscall.Timespec) error { + // These are not currently available in syscall + AT_FDCWD := -100 + AT_SYMLINK_NOFOLLOW := 0x100 + + var _path *byte + _path, err := syscall.BytePtrFromString(path) + if err != nil { + return err + } + + if _, _, err := syscall.Syscall6(syscall.SYS_UTIMENSAT, uintptr(AT_FDCWD), uintptr(unsafe.Pointer(_path)), uintptr(unsafe.Pointer(&ts[0])), uintptr(AT_SYMLINK_NOFOLLOW), 0, 0); err != 0 && err != syscall.ENOSYS { + return err + } + + return nil +} + +func UtimesNano(path string, ts []syscall.Timespec) error { + return syscall.UtimesNano(path, ts) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_test.go new file mode 100644 index 0000000..350cce1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_test.go @@ -0,0 +1,66 @@ +package system + +import ( + "io/ioutil" + "os" + "path/filepath" + "syscall" + "testing" +) + +// prepareFiles creates files for testing in the temp directory +func prepareFiles(t *testing.T) (string, string, string, string) { + dir, err := ioutil.TempDir("", "docker-system-test") + if err != nil { + t.Fatal(err) + } + + file := filepath.Join(dir, "exist") + if err := ioutil.WriteFile(file, []byte("hello"), 0644); err != nil { + t.Fatal(err) + } + + invalid := filepath.Join(dir, "doesnt-exist") + + symlink := filepath.Join(dir, "symlink") + if err := os.Symlink(file, symlink); err != nil { + t.Fatal(err) + } + + return file, invalid, symlink, dir +} + +func TestLUtimesNano(t *testing.T) { + file, invalid, symlink, dir := prepareFiles(t) + defer os.RemoveAll(dir) + + before, err := os.Stat(file) + if err != nil { + t.Fatal(err) + } + + ts := []syscall.Timespec{{0, 0}, {0, 0}} + if err := LUtimesNano(symlink, ts); err != nil { + t.Fatal(err) + } + + symlinkInfo, err := os.Lstat(symlink) + if err != nil { + t.Fatal(err) + } + if before.ModTime().Unix() == symlinkInfo.ModTime().Unix() { + t.Fatal("The modification time of the symlink should be different") + } + + fileInfo, err := os.Stat(file) + if err != nil { + t.Fatal(err) + } + if before.ModTime().Unix() != fileInfo.ModTime().Unix() { + t.Fatal("The modification time of the file should be same") + } + + if err := LUtimesNano(invalid, ts); err == nil { + t.Fatal("Doesn't return an error on a non-existing file") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_unsupported.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_unsupported.go new file mode 100644 index 0000000..adf2734 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/utimes_unsupported.go @@ -0,0 +1,13 @@ +// +build !linux,!freebsd,!darwin + +package system + +import "syscall" + +func LUtimesNano(path string, ts []syscall.Timespec) error { + return ErrNotSupportedPlatform +} + +func UtimesNano(path string, ts []syscall.Timespec) error { + return ErrNotSupportedPlatform +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/xattrs_linux.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/xattrs_linux.go new file mode 100644 index 0000000..00edb20 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/xattrs_linux.go @@ -0,0 +1,59 @@ +package system + +import ( + "syscall" + "unsafe" +) + +// Returns a nil slice and nil error if the xattr is not set +func Lgetxattr(path string, attr string) ([]byte, error) { + pathBytes, err := syscall.BytePtrFromString(path) + if err != nil { + return nil, err + } + attrBytes, err := syscall.BytePtrFromString(attr) + if err != nil { + return nil, err + } + + dest := make([]byte, 128) + destBytes := unsafe.Pointer(&dest[0]) + sz, _, errno := syscall.Syscall6(syscall.SYS_LGETXATTR, uintptr(unsafe.Pointer(pathBytes)), uintptr(unsafe.Pointer(attrBytes)), uintptr(destBytes), uintptr(len(dest)), 0, 0) + if errno == syscall.ENODATA { + return nil, nil + } + if errno == syscall.ERANGE { + dest = make([]byte, sz) + destBytes := unsafe.Pointer(&dest[0]) + sz, _, errno = syscall.Syscall6(syscall.SYS_LGETXATTR, uintptr(unsafe.Pointer(pathBytes)), uintptr(unsafe.Pointer(attrBytes)), uintptr(destBytes), uintptr(len(dest)), 0, 0) + } + if errno != 0 { + return nil, errno + } + + return dest[:sz], nil +} + +var _zero uintptr + +func Lsetxattr(path string, attr string, data []byte, flags int) error { + pathBytes, err := syscall.BytePtrFromString(path) + if err != nil { + return err + } + attrBytes, err := syscall.BytePtrFromString(attr) + if err != nil { + return err + } + var dataBytes unsafe.Pointer + if len(data) > 0 { + dataBytes = unsafe.Pointer(&data[0]) + } else { + dataBytes = unsafe.Pointer(&_zero) + } + _, _, errno := syscall.Syscall6(syscall.SYS_LSETXATTR, uintptr(unsafe.Pointer(pathBytes)), uintptr(unsafe.Pointer(attrBytes)), uintptr(dataBytes), uintptr(len(data)), uintptr(flags), 0) + if errno != 0 { + return errno + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/system/xattrs_unsupported.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/xattrs_unsupported.go new file mode 100644 index 0000000..0060c16 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/system/xattrs_unsupported.go @@ -0,0 +1,11 @@ +// +build !linux + +package system + +func Lgetxattr(path string, attr string) ([]byte, error) { + return nil, ErrNotSupportedPlatform +} + +func Lsetxattr(path string, attr string, data []byte, flags int) error { + return ErrNotSupportedPlatform +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/units/duration.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/duration.go new file mode 100644 index 0000000..c219a8a --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/duration.go @@ -0,0 +1,33 @@ +// Package units provides helper function to parse and print size and time units +// in human-readable format. +package units + +import ( + "fmt" + "time" +) + +// HumanDuration returns a human-readable approximation of a duration +// (eg. "About a minute", "4 hours ago", etc.). +func HumanDuration(d time.Duration) string { + if seconds := int(d.Seconds()); seconds < 1 { + return "Less than a second" + } else if seconds < 60 { + return fmt.Sprintf("%d seconds", seconds) + } else if minutes := int(d.Minutes()); minutes == 1 { + return "About a minute" + } else if minutes < 60 { + return fmt.Sprintf("%d minutes", minutes) + } else if hours := int(d.Hours()); hours == 1 { + return "About an hour" + } else if hours < 48 { + return fmt.Sprintf("%d hours", hours) + } else if hours < 24*7*2 { + return fmt.Sprintf("%d days", hours/24) + } else if hours < 24*30*3 { + return fmt.Sprintf("%d weeks", hours/24/7) + } else if hours < 24*365*2 { + return fmt.Sprintf("%d months", hours/24/30) + } + return fmt.Sprintf("%d years", int(d.Hours())/24/365) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/units/duration_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/duration_test.go new file mode 100644 index 0000000..fcfb6b7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/duration_test.go @@ -0,0 +1,46 @@ +package units + +import ( + "testing" + "time" +) + +func TestHumanDuration(t *testing.T) { + // Useful duration abstractions + day := 24 * time.Hour + week := 7 * day + month := 30 * day + year := 365 * day + + assertEquals(t, "Less than a second", HumanDuration(450*time.Millisecond)) + assertEquals(t, "47 seconds", HumanDuration(47*time.Second)) + assertEquals(t, "About a minute", HumanDuration(1*time.Minute)) + assertEquals(t, "3 minutes", HumanDuration(3*time.Minute)) + assertEquals(t, "35 minutes", HumanDuration(35*time.Minute)) + assertEquals(t, "35 minutes", HumanDuration(35*time.Minute+40*time.Second)) + assertEquals(t, "About an hour", HumanDuration(1*time.Hour)) + assertEquals(t, "About an hour", HumanDuration(1*time.Hour+45*time.Minute)) + assertEquals(t, "3 hours", HumanDuration(3*time.Hour)) + assertEquals(t, "3 hours", HumanDuration(3*time.Hour+59*time.Minute)) + assertEquals(t, "4 hours", HumanDuration(3*time.Hour+60*time.Minute)) + assertEquals(t, "24 hours", HumanDuration(24*time.Hour)) + assertEquals(t, "36 hours", HumanDuration(1*day+12*time.Hour)) + assertEquals(t, "2 days", HumanDuration(2*day)) + assertEquals(t, "7 days", HumanDuration(7*day)) + assertEquals(t, "13 days", HumanDuration(13*day+5*time.Hour)) + assertEquals(t, "2 weeks", HumanDuration(2*week)) + assertEquals(t, "2 weeks", HumanDuration(2*week+4*day)) + assertEquals(t, "3 weeks", HumanDuration(3*week)) + assertEquals(t, "4 weeks", HumanDuration(4*week)) + assertEquals(t, "4 weeks", HumanDuration(4*week+3*day)) + assertEquals(t, "4 weeks", HumanDuration(1*month)) + assertEquals(t, "6 weeks", HumanDuration(1*month+2*week)) + assertEquals(t, "8 weeks", HumanDuration(2*month)) + assertEquals(t, "3 months", HumanDuration(3*month+1*week)) + assertEquals(t, "5 months", HumanDuration(5*month+2*week)) + assertEquals(t, "13 months", HumanDuration(13*month)) + assertEquals(t, "23 months", HumanDuration(23*month)) + assertEquals(t, "24 months", HumanDuration(24*month)) + assertEquals(t, "2 years", HumanDuration(24*month+2*week)) + assertEquals(t, "3 years", HumanDuration(3*year+2*month)) +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/units/size.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/size.go new file mode 100644 index 0000000..2fde3b4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/size.go @@ -0,0 +1,95 @@ +package units + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// See: http://en.wikipedia.org/wiki/Binary_prefix +const ( + // Decimal + + KB = 1000 + MB = 1000 * KB + GB = 1000 * MB + TB = 1000 * GB + PB = 1000 * TB + + // Binary + + KiB = 1024 + MiB = 1024 * KiB + GiB = 1024 * MiB + TiB = 1024 * GiB + PiB = 1024 * TiB +) + +type unitMap map[string]int64 + +var ( + decimalMap = unitMap{"k": KB, "m": MB, "g": GB, "t": TB, "p": PB} + binaryMap = unitMap{"k": KiB, "m": MiB, "g": GiB, "t": TiB, "p": PiB} + sizeRegex = regexp.MustCompile(`^(\d+)([kKmMgGtTpP])?[bB]?$`) +) + +var decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"} +var binaryAbbrs = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"} + +// CustomSize returns a human-readable approximation of a size +// using custom format. +func CustomSize(format string, size float64, base float64, _map []string) string { + i := 0 + for size >= base { + size = size / base + i++ + } + return fmt.Sprintf(format, size, _map[i]) +} + +// HumanSize returns a human-readable approximation of a size +// using SI standard (eg. "44kB", "17MB"). +func HumanSize(size float64) string { + return CustomSize("%.4g %s", size, 1000.0, decimapAbbrs) +} + +// BytesSize returns a human-readable size in bytes, kibibytes, +// mebibytes, gibibytes, or tebibytes (eg. "44kiB", "17MiB"). +func BytesSize(size float64) string { + return CustomSize("%.4g %s", size, 1024.0, binaryAbbrs) +} + +// FromHumanSize returns an integer from a human-readable specification of a +// size using SI standard (eg. "44kB", "17MB"). +func FromHumanSize(size string) (int64, error) { + return parseSize(size, decimalMap) +} + +// RAMInBytes parses a human-readable string representing an amount of RAM +// in bytes, kibibytes, mebibytes, gibibytes, or tebibytes and +// returns the number of bytes, or -1 if the string is unparseable. +// Units are case-insensitive, and the 'b' suffix is optional. +func RAMInBytes(size string) (int64, error) { + return parseSize(size, binaryMap) +} + +// Parses the human-readable size string into the amount it represents. +func parseSize(sizeStr string, uMap unitMap) (int64, error) { + matches := sizeRegex.FindStringSubmatch(sizeStr) + if len(matches) != 3 { + return -1, fmt.Errorf("invalid size: '%s'", sizeStr) + } + + size, err := strconv.ParseInt(matches[1], 10, 0) + if err != nil { + return -1, err + } + + unitPrefix := strings.ToLower(matches[2]) + if mul, ok := uMap[unitPrefix]; ok { + size *= mul + } + + return size, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/docker/pkg/units/size_test.go b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/size_test.go new file mode 100644 index 0000000..67c3b81 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/docker/pkg/units/size_test.go @@ -0,0 +1,108 @@ +package units + +import ( + "reflect" + "runtime" + "strings" + "testing" +) + +func TestBytesSize(t *testing.T) { + assertEquals(t, "1 KiB", BytesSize(1024)) + assertEquals(t, "1 MiB", BytesSize(1024*1024)) + assertEquals(t, "1 MiB", BytesSize(1048576)) + assertEquals(t, "2 MiB", BytesSize(2*MiB)) + assertEquals(t, "3.42 GiB", BytesSize(3.42*GiB)) + assertEquals(t, "5.372 TiB", BytesSize(5.372*TiB)) + assertEquals(t, "2.22 PiB", BytesSize(2.22*PiB)) +} + +func TestHumanSize(t *testing.T) { + assertEquals(t, "1 kB", HumanSize(1000)) + assertEquals(t, "1.024 kB", HumanSize(1024)) + assertEquals(t, "1 MB", HumanSize(1000000)) + assertEquals(t, "1.049 MB", HumanSize(1048576)) + assertEquals(t, "2 MB", HumanSize(2*MB)) + assertEquals(t, "3.42 GB", HumanSize(float64(3.42*GB))) + assertEquals(t, "5.372 TB", HumanSize(float64(5.372*TB))) + assertEquals(t, "2.22 PB", HumanSize(float64(2.22*PB))) +} + +func TestFromHumanSize(t *testing.T) { + assertSuccessEquals(t, 32, FromHumanSize, "32") + assertSuccessEquals(t, 32, FromHumanSize, "32b") + assertSuccessEquals(t, 32, FromHumanSize, "32B") + assertSuccessEquals(t, 32*KB, FromHumanSize, "32k") + assertSuccessEquals(t, 32*KB, FromHumanSize, "32K") + assertSuccessEquals(t, 32*KB, FromHumanSize, "32kb") + assertSuccessEquals(t, 32*KB, FromHumanSize, "32Kb") + assertSuccessEquals(t, 32*MB, FromHumanSize, "32Mb") + assertSuccessEquals(t, 32*GB, FromHumanSize, "32Gb") + assertSuccessEquals(t, 32*TB, FromHumanSize, "32Tb") + assertSuccessEquals(t, 32*PB, FromHumanSize, "32Pb") + + assertError(t, FromHumanSize, "") + assertError(t, FromHumanSize, "hello") + assertError(t, FromHumanSize, "-32") + assertError(t, FromHumanSize, "32.3") + assertError(t, FromHumanSize, " 32 ") + assertError(t, FromHumanSize, "32.3Kb") + assertError(t, FromHumanSize, "32 mb") + assertError(t, FromHumanSize, "32m b") + assertError(t, FromHumanSize, "32bm") +} + +func TestRAMInBytes(t *testing.T) { + assertSuccessEquals(t, 32, RAMInBytes, "32") + assertSuccessEquals(t, 32, RAMInBytes, "32b") + assertSuccessEquals(t, 32, RAMInBytes, "32B") + assertSuccessEquals(t, 32*KiB, RAMInBytes, "32k") + assertSuccessEquals(t, 32*KiB, RAMInBytes, "32K") + assertSuccessEquals(t, 32*KiB, RAMInBytes, "32kb") + assertSuccessEquals(t, 32*KiB, RAMInBytes, "32Kb") + assertSuccessEquals(t, 32*MiB, RAMInBytes, "32Mb") + assertSuccessEquals(t, 32*GiB, RAMInBytes, "32Gb") + assertSuccessEquals(t, 32*TiB, RAMInBytes, "32Tb") + assertSuccessEquals(t, 32*PiB, RAMInBytes, "32Pb") + assertSuccessEquals(t, 32*PiB, RAMInBytes, "32PB") + assertSuccessEquals(t, 32*PiB, RAMInBytes, "32P") + + assertError(t, RAMInBytes, "") + assertError(t, RAMInBytes, "hello") + assertError(t, RAMInBytes, "-32") + assertError(t, RAMInBytes, "32.3") + assertError(t, RAMInBytes, " 32 ") + assertError(t, RAMInBytes, "32.3Kb") + assertError(t, RAMInBytes, "32 mb") + assertError(t, RAMInBytes, "32m b") + assertError(t, RAMInBytes, "32bm") +} + +func assertEquals(t *testing.T, expected, actual interface{}) { + if expected != actual { + t.Errorf("Expected '%v' but got '%v'", expected, actual) + } +} + +// func that maps to the parse function signatures as testing abstraction +type parseFn func(string) (int64, error) + +// Define 'String()' for pretty-print +func (fn parseFn) String() string { + fnName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + return fnName[strings.LastIndex(fnName, ".")+1:] +} + +func assertSuccessEquals(t *testing.T, expected int64, fn parseFn, arg string) { + res, err := fn(arg) + if err != nil || res != expected { + t.Errorf("%s(\"%s\") -> expected '%d' but got '%d' with error '%v'", fn, arg, expected, res, err) + } +} + +func assertError(t *testing.T, fn parseFn, arg string) { + res, err := fn(arg) + if err == nil && res != -1 { + t.Errorf("%s(\"%s\") -> expected error but got '%d'", fn, arg, res) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/CONTRIBUTING.md b/Godeps/_workspace/src/github.com/docker/libtrust/CONTRIBUTING.md new file mode 100644 index 0000000..05be0f8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing to libtrust + +Want to hack on libtrust? Awesome! Here are instructions to get you +started. + +libtrust is a part of the [Docker](https://www.docker.com) project, and follows +the same rules and principles. If you're already familiar with the way +Docker does things, you'll feel right at home. + +Otherwise, go read +[Docker's contributions guidelines](https://github.com/docker/docker/blob/master/CONTRIBUTING.md). + +Happy hacking! diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/LICENSE b/Godeps/_workspace/src/github.com/docker/libtrust/LICENSE new file mode 100644 index 0000000..2744858 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/LICENSE @@ -0,0 +1,191 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2014 Docker, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/MAINTAINERS b/Godeps/_workspace/src/github.com/docker/libtrust/MAINTAINERS new file mode 100644 index 0000000..9768175 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/MAINTAINERS @@ -0,0 +1,3 @@ +Solomon Hykes +Josh Hawn (github: jlhawn) +Derek McGowan (github: dmcgowan) diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/README.md b/Godeps/_workspace/src/github.com/docker/libtrust/README.md new file mode 100644 index 0000000..8e7db38 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/README.md @@ -0,0 +1,18 @@ +# libtrust + +Libtrust is library for managing authentication and authorization using public key cryptography. + +Authentication is handled using the identity attached to the public key. +Libtrust provides multiple methods to prove possession of the private key associated with an identity. + - TLS x509 certificates + - Signature verification + - Key Challenge + +Authorization and access control is managed through a distributed trust graph. +Trust servers are used as the authorities of the trust graph and allow caching portions of the graph for faster access. + +## Copyright and license + +Code and documentation copyright 2014 Docker, inc. Code released under the Apache 2.0 license. +Docs released under Creative commons. + diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/certificates.go b/Godeps/_workspace/src/github.com/docker/libtrust/certificates.go new file mode 100644 index 0000000..3dcca33 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/certificates.go @@ -0,0 +1,175 @@ +package libtrust + +import ( + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "time" +) + +type certTemplateInfo struct { + commonName string + domains []string + ipAddresses []net.IP + isCA bool + clientAuth bool + serverAuth bool +} + +func generateCertTemplate(info *certTemplateInfo) *x509.Certificate { + // Generate a certificate template which is valid from the past week to + // 10 years from now. The usage of the certificate depends on the + // specified fields in the given certTempInfo object. + var ( + keyUsage x509.KeyUsage + extKeyUsage []x509.ExtKeyUsage + ) + + if info.isCA { + keyUsage = x509.KeyUsageCertSign + } + + if info.clientAuth { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth) + } + + if info.serverAuth { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageServerAuth) + } + + return &x509.Certificate{ + SerialNumber: big.NewInt(0), + Subject: pkix.Name{ + CommonName: info.commonName, + }, + NotBefore: time.Now().Add(-time.Hour * 24 * 7), + NotAfter: time.Now().Add(time.Hour * 24 * 365 * 10), + DNSNames: info.domains, + IPAddresses: info.ipAddresses, + IsCA: info.isCA, + KeyUsage: keyUsage, + ExtKeyUsage: extKeyUsage, + BasicConstraintsValid: info.isCA, + } +} + +func generateCert(pub PublicKey, priv PrivateKey, subInfo, issInfo *certTemplateInfo) (cert *x509.Certificate, err error) { + pubCertTemplate := generateCertTemplate(subInfo) + privCertTemplate := generateCertTemplate(issInfo) + + certDER, err := x509.CreateCertificate( + rand.Reader, pubCertTemplate, privCertTemplate, + pub.CryptoPublicKey(), priv.CryptoPrivateKey(), + ) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %s", err) + } + + cert, err = x509.ParseCertificate(certDER) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %s", err) + } + + return +} + +// GenerateSelfSignedServerCert creates a self-signed certificate for the +// given key which is to be used for TLS servers with the given domains and +// IP addresses. +func GenerateSelfSignedServerCert(key PrivateKey, domains []string, ipAddresses []net.IP) (*x509.Certificate, error) { + info := &certTemplateInfo{ + commonName: key.KeyID(), + domains: domains, + ipAddresses: ipAddresses, + serverAuth: true, + } + + return generateCert(key.PublicKey(), key, info, info) +} + +// GenerateSelfSignedClientCert creates a self-signed certificate for the +// given key which is to be used for TLS clients. +func GenerateSelfSignedClientCert(key PrivateKey) (*x509.Certificate, error) { + info := &certTemplateInfo{ + commonName: key.KeyID(), + clientAuth: true, + } + + return generateCert(key.PublicKey(), key, info, info) +} + +// GenerateCACert creates a certificate which can be used as a trusted +// certificate authority. +func GenerateCACert(signer PrivateKey, trustedKey PublicKey) (*x509.Certificate, error) { + subjectInfo := &certTemplateInfo{ + commonName: trustedKey.KeyID(), + isCA: true, + } + issuerInfo := &certTemplateInfo{ + commonName: signer.KeyID(), + } + + return generateCert(trustedKey, signer, subjectInfo, issuerInfo) +} + +// GenerateCACertPool creates a certificate authority pool to be used for a +// TLS configuration. Any self-signed certificates issued by the specified +// trusted keys will be verified during a TLS handshake +func GenerateCACertPool(signer PrivateKey, trustedKeys []PublicKey) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + + for _, trustedKey := range trustedKeys { + cert, err := GenerateCACert(signer, trustedKey) + if err != nil { + return nil, fmt.Errorf("failed to generate CA certificate: %s", err) + } + + certPool.AddCert(cert) + } + + return certPool, nil +} + +// LoadCertificateBundle loads certificates from the given file. The file should be pem encoded +// containing one or more certificates. The expected pem type is "CERTIFICATE". +func LoadCertificateBundle(filename string) ([]*x509.Certificate, error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + certificates := []*x509.Certificate{} + var block *pem.Block + block, b = pem.Decode(b) + for ; block != nil; block, b = pem.Decode(b) { + if block.Type == "CERTIFICATE" { + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + certificates = append(certificates, cert) + } else { + return nil, fmt.Errorf("invalid pem block type: %s", block.Type) + } + } + + return certificates, nil +} + +// LoadCertificatePool loads a CA pool from the given file. The file should be pem encoded +// containing one or more certificates. The expected pem type is "CERTIFICATE". +func LoadCertificatePool(filename string) (*x509.CertPool, error) { + certs, err := LoadCertificateBundle(filename) + if err != nil { + return nil, err + } + pool := x509.NewCertPool() + for _, cert := range certs { + pool.AddCert(cert) + } + return pool, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/certificates_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/certificates_test.go new file mode 100644 index 0000000..c111f35 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/certificates_test.go @@ -0,0 +1,111 @@ +package libtrust + +import ( + "encoding/pem" + "io/ioutil" + "net" + "os" + "path" + "testing" +) + +func TestGenerateCertificates(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + _, err = GenerateSelfSignedServerCert(key, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")}) + if err != nil { + t.Fatal(err) + } + + _, err = GenerateSelfSignedClientCert(key) + if err != nil { + t.Fatal(err) + } +} + +func TestGenerateCACertPool(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + caKey1, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + caKey2, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + _, err = GenerateCACertPool(key, []PublicKey{caKey1.PublicKey(), caKey2.PublicKey()}) + if err != nil { + t.Fatal(err) + } +} + +func TestLoadCertificates(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + caKey1, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + caKey2, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + cert1, err := GenerateCACert(caKey1, key) + if err != nil { + t.Fatal(err) + } + cert2, err := GenerateCACert(caKey2, key) + if err != nil { + t.Fatal(err) + } + + d, err := ioutil.TempDir("/tmp", "cert-test") + if err != nil { + t.Fatal(err) + } + caFile := path.Join(d, "ca.pem") + f, err := os.OpenFile(caFile, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + t.Fatal(err) + } + + err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert1.Raw}) + if err != nil { + t.Fatal(err) + } + err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert2.Raw}) + if err != nil { + t.Fatal(err) + } + f.Close() + + certs, err := LoadCertificateBundle(caFile) + if err != nil { + t.Fatal(err) + } + if len(certs) != 2 { + t.Fatalf("Wrong number of certs received, expected: %d, received %d", 2, len(certs)) + } + + pool, err := LoadCertificatePool(caFile) + if err != nil { + t.Fatal(err) + } + + if len(pool.Subjects()) != 2 { + t.Fatalf("Invalid certificate pool") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/doc.go b/Godeps/_workspace/src/github.com/docker/libtrust/doc.go new file mode 100644 index 0000000..ec5d215 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/doc.go @@ -0,0 +1,9 @@ +/* +Package libtrust provides an interface for managing authentication and +authorization using public key cryptography. Authentication is handled +using the identity attached to the public key and verified through TLS +x509 certificates, a key challenge, or signature. Authorization and +access control is managed through a trust graph distributed between +both remote trust servers and locally cached and managed data. +*/ +package libtrust diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/ec_key.go b/Godeps/_workspace/src/github.com/docker/libtrust/ec_key.go new file mode 100644 index 0000000..00bbe4b --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/ec_key.go @@ -0,0 +1,428 @@ +package libtrust + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" +) + +/* + * EC DSA PUBLIC KEY + */ + +// ecPublicKey implements a libtrust.PublicKey using elliptic curve digital +// signature algorithms. +type ecPublicKey struct { + *ecdsa.PublicKey + curveName string + signatureAlgorithm *signatureAlgorithm + extended map[string]interface{} +} + +func fromECPublicKey(cryptoPublicKey *ecdsa.PublicKey) (*ecPublicKey, error) { + curve := cryptoPublicKey.Curve + + switch { + case curve == elliptic.P256(): + return &ecPublicKey{cryptoPublicKey, "P-256", es256, map[string]interface{}{}}, nil + case curve == elliptic.P384(): + return &ecPublicKey{cryptoPublicKey, "P-384", es384, map[string]interface{}{}}, nil + case curve == elliptic.P521(): + return &ecPublicKey{cryptoPublicKey, "P-521", es512, map[string]interface{}{}}, nil + default: + return nil, errors.New("unsupported elliptic curve") + } +} + +// KeyType returns the key type for elliptic curve keys, i.e., "EC". +func (k *ecPublicKey) KeyType() string { + return "EC" +} + +// CurveName returns the elliptic curve identifier. +// Possible values are "P-256", "P-384", and "P-521". +func (k *ecPublicKey) CurveName() string { + return k.curveName +} + +// KeyID returns a distinct identifier which is unique to this Public Key. +func (k *ecPublicKey) KeyID() string { + return keyIDFromCryptoKey(k) +} + +func (k *ecPublicKey) String() string { + return fmt.Sprintf("EC Public Key <%s>", k.KeyID()) +} + +// Verify verifyies the signature of the data in the io.Reader using this +// PublicKey. The alg parameter should identify the digital signature +// algorithm which was used to produce the signature and should be supported +// by this public key. Returns a nil error if the signature is valid. +func (k *ecPublicKey) Verify(data io.Reader, alg string, signature []byte) error { + // For EC keys there is only one supported signature algorithm depending + // on the curve parameters. + if k.signatureAlgorithm.HeaderParam() != alg { + return fmt.Errorf("unable to verify signature: EC Public Key with curve %q does not support signature algorithm %q", k.curveName, alg) + } + + // signature is the concatenation of (r, s), base64Url encoded. + sigLength := len(signature) + expectedOctetLength := 2 * ((k.Params().BitSize + 7) >> 3) + if sigLength != expectedOctetLength { + return fmt.Errorf("signature length is %d octets long, should be %d", sigLength, expectedOctetLength) + } + + rBytes, sBytes := signature[:sigLength/2], signature[sigLength/2:] + r := new(big.Int).SetBytes(rBytes) + s := new(big.Int).SetBytes(sBytes) + + hasher := k.signatureAlgorithm.HashID().New() + _, err := io.Copy(hasher, data) + if err != nil { + return fmt.Errorf("error reading data to sign: %s", err) + } + hash := hasher.Sum(nil) + + if !ecdsa.Verify(k.PublicKey, hash, r, s) { + return errors.New("invalid signature") + } + + return nil +} + +// CryptoPublicKey returns the internal object which can be used as a +// crypto.PublicKey for use with other standard library operations. The type +// is either *rsa.PublicKey or *ecdsa.PublicKey +func (k *ecPublicKey) CryptoPublicKey() crypto.PublicKey { + return k.PublicKey +} + +func (k *ecPublicKey) toMap() map[string]interface{} { + jwk := make(map[string]interface{}) + for k, v := range k.extended { + jwk[k] = v + } + jwk["kty"] = k.KeyType() + jwk["kid"] = k.KeyID() + jwk["crv"] = k.CurveName() + + xBytes := k.X.Bytes() + yBytes := k.Y.Bytes() + octetLength := (k.Params().BitSize + 7) >> 3 + // MUST include leading zeros in the output so that x, y are each + // *octetLength* bytes long. + xBuf := make([]byte, octetLength-len(xBytes), octetLength) + yBuf := make([]byte, octetLength-len(yBytes), octetLength) + xBuf = append(xBuf, xBytes...) + yBuf = append(yBuf, yBytes...) + + jwk["x"] = joseBase64UrlEncode(xBuf) + jwk["y"] = joseBase64UrlEncode(yBuf) + + return jwk +} + +// MarshalJSON serializes this Public Key using the JWK JSON serialization format for +// elliptic curve keys. +func (k *ecPublicKey) MarshalJSON() (data []byte, err error) { + return json.Marshal(k.toMap()) +} + +// PEMBlock serializes this Public Key to DER-encoded PKIX format. +func (k *ecPublicKey) PEMBlock() (*pem.Block, error) { + derBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey) + if err != nil { + return nil, fmt.Errorf("unable to serialize EC PublicKey to DER-encoded PKIX format: %s", err) + } + k.extended["kid"] = k.KeyID() // For display purposes. + return createPemBlock("PUBLIC KEY", derBytes, k.extended) +} + +func (k *ecPublicKey) AddExtendedField(field string, value interface{}) { + k.extended[field] = value +} + +func (k *ecPublicKey) GetExtendedField(field string) interface{} { + v, ok := k.extended[field] + if !ok { + return nil + } + return v +} + +func ecPublicKeyFromMap(jwk map[string]interface{}) (*ecPublicKey, error) { + // JWK key type (kty) has already been determined to be "EC". + // Need to extract 'crv', 'x', 'y', and 'kid' and check for + // consistency. + + // Get the curve identifier value. + crv, err := stringFromMap(jwk, "crv") + if err != nil { + return nil, fmt.Errorf("JWK EC Public Key curve identifier: %s", err) + } + + var ( + curve elliptic.Curve + sigAlg *signatureAlgorithm + ) + + switch { + case crv == "P-256": + curve = elliptic.P256() + sigAlg = es256 + case crv == "P-384": + curve = elliptic.P384() + sigAlg = es384 + case crv == "P-521": + curve = elliptic.P521() + sigAlg = es512 + default: + return nil, fmt.Errorf("JWK EC Public Key curve identifier not supported: %q\n", crv) + } + + // Get the X and Y coordinates for the public key point. + xB64Url, err := stringFromMap(jwk, "x") + if err != nil { + return nil, fmt.Errorf("JWK EC Public Key x-coordinate: %s", err) + } + x, err := parseECCoordinate(xB64Url, curve) + if err != nil { + return nil, fmt.Errorf("JWK EC Public Key x-coordinate: %s", err) + } + + yB64Url, err := stringFromMap(jwk, "y") + if err != nil { + return nil, fmt.Errorf("JWK EC Public Key y-coordinate: %s", err) + } + y, err := parseECCoordinate(yB64Url, curve) + if err != nil { + return nil, fmt.Errorf("JWK EC Public Key y-coordinate: %s", err) + } + + key := &ecPublicKey{ + PublicKey: &ecdsa.PublicKey{Curve: curve, X: x, Y: y}, + curveName: crv, signatureAlgorithm: sigAlg, + } + + // Key ID is optional too, but if it exists, it should match the key. + _, ok := jwk["kid"] + if ok { + kid, err := stringFromMap(jwk, "kid") + if err != nil { + return nil, fmt.Errorf("JWK EC Public Key ID: %s", err) + } + if kid != key.KeyID() { + return nil, fmt.Errorf("JWK EC Public Key ID does not match: %s", kid) + } + } + + key.extended = jwk + + return key, nil +} + +/* + * EC DSA PRIVATE KEY + */ + +// ecPrivateKey implements a JWK Private Key using elliptic curve digital signature +// algorithms. +type ecPrivateKey struct { + ecPublicKey + *ecdsa.PrivateKey +} + +func fromECPrivateKey(cryptoPrivateKey *ecdsa.PrivateKey) (*ecPrivateKey, error) { + publicKey, err := fromECPublicKey(&cryptoPrivateKey.PublicKey) + if err != nil { + return nil, err + } + + return &ecPrivateKey{*publicKey, cryptoPrivateKey}, nil +} + +// PublicKey returns the Public Key data associated with this Private Key. +func (k *ecPrivateKey) PublicKey() PublicKey { + return &k.ecPublicKey +} + +func (k *ecPrivateKey) String() string { + return fmt.Sprintf("EC Private Key <%s>", k.KeyID()) +} + +// Sign signs the data read from the io.Reader using a signature algorithm supported +// by the elliptic curve private key. If the specified hashing algorithm is +// supported by this key, that hash function is used to generate the signature +// otherwise the the default hashing algorithm for this key is used. Returns +// the signature and the name of the JWK signature algorithm used, e.g., +// "ES256", "ES384", "ES512". +func (k *ecPrivateKey) Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) { + // Generate a signature of the data using the internal alg. + // The given hashId is only a suggestion, and since EC keys only support + // on signature/hash algorithm given the curve name, we disregard it for + // the elliptic curve JWK signature implementation. + hasher := k.signatureAlgorithm.HashID().New() + _, err = io.Copy(hasher, data) + if err != nil { + return nil, "", fmt.Errorf("error reading data to sign: %s", err) + } + hash := hasher.Sum(nil) + + r, s, err := ecdsa.Sign(rand.Reader, k.PrivateKey, hash) + if err != nil { + return nil, "", fmt.Errorf("error producing signature: %s", err) + } + rBytes, sBytes := r.Bytes(), s.Bytes() + octetLength := (k.ecPublicKey.Params().BitSize + 7) >> 3 + // MUST include leading zeros in the output + rBuf := make([]byte, octetLength-len(rBytes), octetLength) + sBuf := make([]byte, octetLength-len(sBytes), octetLength) + + rBuf = append(rBuf, rBytes...) + sBuf = append(sBuf, sBytes...) + + signature = append(rBuf, sBuf...) + alg = k.signatureAlgorithm.HeaderParam() + + return +} + +// CryptoPrivateKey returns the internal object which can be used as a +// crypto.PublicKey for use with other standard library operations. The type +// is either *rsa.PublicKey or *ecdsa.PublicKey +func (k *ecPrivateKey) CryptoPrivateKey() crypto.PrivateKey { + return k.PrivateKey +} + +func (k *ecPrivateKey) toMap() map[string]interface{} { + jwk := k.ecPublicKey.toMap() + + dBytes := k.D.Bytes() + // The length of this octet string MUST be ceiling(log-base-2(n)/8) + // octets (where n is the order of the curve). This is because the private + // key d must be in the interval [1, n-1] so the bitlength of d should be + // no larger than the bitlength of n-1. The easiest way to find the octet + // length is to take bitlength(n-1), add 7 to force a carry, and shift this + // bit sequence right by 3, which is essentially dividing by 8 and adding + // 1 if there is any remainder. Thus, the private key value d should be + // output to (bitlength(n-1)+7)>>3 octets. + n := k.ecPublicKey.Params().N + octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3 + // Create a buffer with the necessary zero-padding. + dBuf := make([]byte, octetLength-len(dBytes), octetLength) + dBuf = append(dBuf, dBytes...) + + jwk["d"] = joseBase64UrlEncode(dBuf) + + return jwk +} + +// MarshalJSON serializes this Private Key using the JWK JSON serialization format for +// elliptic curve keys. +func (k *ecPrivateKey) MarshalJSON() (data []byte, err error) { + return json.Marshal(k.toMap()) +} + +// PEMBlock serializes this Private Key to DER-encoded PKIX format. +func (k *ecPrivateKey) PEMBlock() (*pem.Block, error) { + derBytes, err := x509.MarshalECPrivateKey(k.PrivateKey) + if err != nil { + return nil, fmt.Errorf("unable to serialize EC PrivateKey to DER-encoded PKIX format: %s", err) + } + k.extended["keyID"] = k.KeyID() // For display purposes. + return createPemBlock("EC PRIVATE KEY", derBytes, k.extended) +} + +func ecPrivateKeyFromMap(jwk map[string]interface{}) (*ecPrivateKey, error) { + dB64Url, err := stringFromMap(jwk, "d") + if err != nil { + return nil, fmt.Errorf("JWK EC Private Key: %s", err) + } + + // JWK key type (kty) has already been determined to be "EC". + // Need to extract the public key information, then extract the private + // key value 'd'. + publicKey, err := ecPublicKeyFromMap(jwk) + if err != nil { + return nil, err + } + + d, err := parseECPrivateParam(dB64Url, publicKey.Curve) + if err != nil { + return nil, fmt.Errorf("JWK EC Private Key d-param: %s", err) + } + + key := &ecPrivateKey{ + ecPublicKey: *publicKey, + PrivateKey: &ecdsa.PrivateKey{ + PublicKey: *publicKey.PublicKey, + D: d, + }, + } + + return key, nil +} + +/* + * Key Generation Functions. + */ + +func generateECPrivateKey(curve elliptic.Curve) (k *ecPrivateKey, err error) { + k = new(ecPrivateKey) + k.PrivateKey, err = ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, err + } + + k.ecPublicKey.PublicKey = &k.PrivateKey.PublicKey + k.extended = make(map[string]interface{}) + + return +} + +// GenerateECP256PrivateKey generates a key pair using elliptic curve P-256. +func GenerateECP256PrivateKey() (PrivateKey, error) { + k, err := generateECPrivateKey(elliptic.P256()) + if err != nil { + return nil, fmt.Errorf("error generating EC P-256 key: %s", err) + } + + k.curveName = "P-256" + k.signatureAlgorithm = es256 + + return k, nil +} + +// GenerateECP384PrivateKey generates a key pair using elliptic curve P-384. +func GenerateECP384PrivateKey() (PrivateKey, error) { + k, err := generateECPrivateKey(elliptic.P384()) + if err != nil { + return nil, fmt.Errorf("error generating EC P-384 key: %s", err) + } + + k.curveName = "P-384" + k.signatureAlgorithm = es384 + + return k, nil +} + +// GenerateECP521PrivateKey generates aß key pair using elliptic curve P-521. +func GenerateECP521PrivateKey() (PrivateKey, error) { + k, err := generateECPrivateKey(elliptic.P521()) + if err != nil { + return nil, fmt.Errorf("error generating EC P-521 key: %s", err) + } + + k.curveName = "P-521" + k.signatureAlgorithm = es512 + + return k, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/ec_key_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/ec_key_test.go new file mode 100644 index 0000000..26ac381 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/ec_key_test.go @@ -0,0 +1,157 @@ +package libtrust + +import ( + "bytes" + "encoding/json" + "testing" +) + +func generateECTestKeys(t *testing.T) []PrivateKey { + p256Key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + p384Key, err := GenerateECP384PrivateKey() + if err != nil { + t.Fatal(err) + } + + p521Key, err := GenerateECP521PrivateKey() + if err != nil { + t.Fatal(err) + } + + return []PrivateKey{p256Key, p384Key, p521Key} +} + +func TestECKeys(t *testing.T) { + ecKeys := generateECTestKeys(t) + + for _, ecKey := range ecKeys { + if ecKey.KeyType() != "EC" { + t.Fatalf("key type must be %q, instead got %q", "EC", ecKey.KeyType()) + } + } +} + +func TestECSignVerify(t *testing.T) { + ecKeys := generateECTestKeys(t) + + message := "Hello, World!" + data := bytes.NewReader([]byte(message)) + + sigAlgs := []*signatureAlgorithm{es256, es384, es512} + + for i, ecKey := range ecKeys { + sigAlg := sigAlgs[i] + + t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, ecKey.KeyID()) + + data.Seek(0, 0) // Reset the byte reader + + // Sign + sig, alg, err := ecKey.Sign(data, sigAlg.HashID()) + if err != nil { + t.Fatal(err) + } + + data.Seek(0, 0) // Reset the byte reader + + // Verify + err = ecKey.Verify(data, alg, sig) + if err != nil { + t.Fatal(err) + } + } +} + +func TestMarshalUnmarshalECKeys(t *testing.T) { + ecKeys := generateECTestKeys(t) + data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test.")) + sigAlgs := []*signatureAlgorithm{es256, es384, es512} + + for i, ecKey := range ecKeys { + sigAlg := sigAlgs[i] + privateJWKJSON, err := json.MarshalIndent(ecKey, "", " ") + if err != nil { + t.Fatal(err) + } + + publicJWKJSON, err := json.MarshalIndent(ecKey.PublicKey(), "", " ") + if err != nil { + t.Fatal(err) + } + + t.Logf("JWK Private Key: %s", string(privateJWKJSON)) + t.Logf("JWK Public Key: %s", string(publicJWKJSON)) + + privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON) + if err != nil { + t.Fatal(err) + } + + pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON) + if err != nil { + t.Fatal(err) + } + + // Ensure we can sign/verify a message with the unmarshalled keys. + data.Seek(0, 0) // Reset the byte reader + signature, alg, err := privKey2.Sign(data, sigAlg.HashID()) + if err != nil { + t.Fatal(err) + } + + data.Seek(0, 0) // Reset the byte reader + err = pubKey2.Verify(data, alg, signature) + if err != nil { + t.Fatal(err) + } + } +} + +func TestFromCryptoECKeys(t *testing.T) { + ecKeys := generateECTestKeys(t) + + for _, ecKey := range ecKeys { + cryptoPrivateKey := ecKey.CryptoPrivateKey() + cryptoPublicKey := ecKey.CryptoPublicKey() + + pubKey, err := FromCryptoPublicKey(cryptoPublicKey) + if err != nil { + t.Fatal(err) + } + + if pubKey.KeyID() != ecKey.KeyID() { + t.Fatal("public key key ID mismatch") + } + + privKey, err := FromCryptoPrivateKey(cryptoPrivateKey) + if err != nil { + t.Fatal(err) + } + + if privKey.KeyID() != ecKey.KeyID() { + t.Fatal("public key key ID mismatch") + } + } +} + +func TestExtendedFields(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + key.AddExtendedField("test", "foobar") + val := key.GetExtendedField("test") + + gotVal, ok := val.(string) + if !ok { + t.Fatalf("value is not a string") + } else if gotVal != val { + t.Fatalf("value %q is not equal to %q", gotVal, val) + } + +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/filter.go b/Godeps/_workspace/src/github.com/docker/libtrust/filter.go new file mode 100644 index 0000000..5b2b4fc --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/filter.go @@ -0,0 +1,50 @@ +package libtrust + +import ( + "path/filepath" +) + +// FilterByHosts filters the list of PublicKeys to only those which contain a +// 'hosts' pattern which matches the given host. If *includeEmpty* is true, +// then keys which do not specify any hosts are also returned. +func FilterByHosts(keys []PublicKey, host string, includeEmpty bool) ([]PublicKey, error) { + filtered := make([]PublicKey, 0, len(keys)) + + for _, pubKey := range keys { + var hosts []string + switch v := pubKey.GetExtendedField("hosts").(type) { + case []string: + hosts = v + case []interface{}: + for _, value := range v { + h, ok := value.(string) + if !ok { + continue + } + hosts = append(hosts, h) + } + } + + if len(hosts) == 0 { + if includeEmpty { + filtered = append(filtered, pubKey) + } + continue + } + + // Check if any hosts match pattern + for _, hostPattern := range hosts { + match, err := filepath.Match(hostPattern, host) + if err != nil { + return nil, err + } + + if match { + filtered = append(filtered, pubKey) + continue + } + } + } + + return filtered, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/filter_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/filter_test.go new file mode 100644 index 0000000..997e554 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/filter_test.go @@ -0,0 +1,81 @@ +package libtrust + +import ( + "testing" +) + +func compareKeySlices(t *testing.T, sliceA, sliceB []PublicKey) { + if len(sliceA) != len(sliceB) { + t.Fatalf("slice size %d, expected %d", len(sliceA), len(sliceB)) + } + + for i, itemA := range sliceA { + itemB := sliceB[i] + if itemA != itemB { + t.Fatalf("slice index %d not equal: %#v != %#v", i, itemA, itemB) + } + } +} + +func TestFilter(t *testing.T) { + keys := make([]PublicKey, 0, 8) + + // Create 8 keys and add host entries. + for i := 0; i < cap(keys); i++ { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + // we use both []interface{} and []string here because jwt uses + // []interface{} format, while PEM uses []string + switch { + case i == 0: + // Don't add entries for this key, key 0. + break + case i%2 == 0: + // Should catch keys 2, 4, and 6. + key.AddExtendedField("hosts", []interface{}{"*.even.example.com"}) + case i == 7: + // Should catch only the last key, and make it match any hostname. + key.AddExtendedField("hosts", []string{"*"}) + default: + // should catch keys 1, 3, 5. + key.AddExtendedField("hosts", []string{"*.example.com"}) + } + + keys = append(keys, key) + } + + // Should match 2 keys, the empty one, and the one that matches all hosts. + matchedKeys, err := FilterByHosts(keys, "foo.bar.com", true) + if err != nil { + t.Fatal(err) + } + expectedMatch := []PublicKey{keys[0], keys[7]} + compareKeySlices(t, expectedMatch, matchedKeys) + + // Should match 1 key, the one that matches any host. + matchedKeys, err = FilterByHosts(keys, "foo.bar.com", false) + if err != nil { + t.Fatal(err) + } + expectedMatch = []PublicKey{keys[7]} + compareKeySlices(t, expectedMatch, matchedKeys) + + // Should match keys that end in "example.com", and the key that matches anything. + matchedKeys, err = FilterByHosts(keys, "foo.example.com", false) + if err != nil { + t.Fatal(err) + } + expectedMatch = []PublicKey{keys[1], keys[3], keys[5], keys[7]} + compareKeySlices(t, expectedMatch, matchedKeys) + + // Should match all of the keys except the empty key. + matchedKeys, err = FilterByHosts(keys, "foo.even.example.com", false) + if err != nil { + t.Fatal(err) + } + expectedMatch = keys[1:] + compareKeySlices(t, expectedMatch, matchedKeys) +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/hash.go b/Godeps/_workspace/src/github.com/docker/libtrust/hash.go new file mode 100644 index 0000000..a2df787 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/hash.go @@ -0,0 +1,56 @@ +package libtrust + +import ( + "crypto" + _ "crypto/sha256" // Registrer SHA224 and SHA256 + _ "crypto/sha512" // Registrer SHA384 and SHA512 + "fmt" +) + +type signatureAlgorithm struct { + algHeaderParam string + hashID crypto.Hash +} + +func (h *signatureAlgorithm) HeaderParam() string { + return h.algHeaderParam +} + +func (h *signatureAlgorithm) HashID() crypto.Hash { + return h.hashID +} + +var ( + rs256 = &signatureAlgorithm{"RS256", crypto.SHA256} + rs384 = &signatureAlgorithm{"RS384", crypto.SHA384} + rs512 = &signatureAlgorithm{"RS512", crypto.SHA512} + es256 = &signatureAlgorithm{"ES256", crypto.SHA256} + es384 = &signatureAlgorithm{"ES384", crypto.SHA384} + es512 = &signatureAlgorithm{"ES512", crypto.SHA512} +) + +func rsaSignatureAlgorithmByName(alg string) (*signatureAlgorithm, error) { + switch { + case alg == "RS256": + return rs256, nil + case alg == "RS384": + return rs384, nil + case alg == "RS512": + return rs512, nil + default: + return nil, fmt.Errorf("RSA Digital Signature Algorithm %q not supported", alg) + } +} + +func rsaPKCS1v15SignatureAlgorithmForHashID(hashID crypto.Hash) *signatureAlgorithm { + switch { + case hashID == crypto.SHA512: + return rs512 + case hashID == crypto.SHA384: + return rs384 + case hashID == crypto.SHA256: + fallthrough + default: + return rs256 + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/jsonsign.go b/Godeps/_workspace/src/github.com/docker/libtrust/jsonsign.go new file mode 100644 index 0000000..cb2ca9a --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/jsonsign.go @@ -0,0 +1,657 @@ +package libtrust + +import ( + "bytes" + "crypto" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "sort" + "time" + "unicode" +) + +var ( + // ErrInvalidSignContent is used when the content to be signed is invalid. + ErrInvalidSignContent = errors.New("invalid sign content") + + // ErrInvalidJSONContent is used when invalid json is encountered. + ErrInvalidJSONContent = errors.New("invalid json content") + + // ErrMissingSignatureKey is used when the specified signature key + // does not exist in the JSON content. + ErrMissingSignatureKey = errors.New("missing signature key") +) + +type jsHeader struct { + JWK PublicKey `json:"jwk,omitempty"` + Algorithm string `json:"alg"` + Chain []string `json:"x5c,omitempty"` +} + +type jsSignature struct { + Header jsHeader `json:"header"` + Signature string `json:"signature"` + Protected string `json:"protected,omitempty"` +} + +type jsSignaturesSorted []jsSignature + +func (jsbkid jsSignaturesSorted) Swap(i, j int) { jsbkid[i], jsbkid[j] = jsbkid[j], jsbkid[i] } +func (jsbkid jsSignaturesSorted) Len() int { return len(jsbkid) } + +func (jsbkid jsSignaturesSorted) Less(i, j int) bool { + ki, kj := jsbkid[i].Header.JWK.KeyID(), jsbkid[j].Header.JWK.KeyID() + si, sj := jsbkid[i].Signature, jsbkid[j].Signature + + if ki == kj { + return si < sj + } + + return ki < kj +} + +type signKey struct { + PrivateKey + Chain []*x509.Certificate +} + +// JSONSignature represents a signature of a json object. +type JSONSignature struct { + payload string + signatures []jsSignature + indent string + formatLength int + formatTail []byte +} + +func newJSONSignature() *JSONSignature { + return &JSONSignature{ + signatures: make([]jsSignature, 0, 1), + } +} + +// Payload returns the encoded payload of the signature. This +// payload should not be signed directly +func (js *JSONSignature) Payload() ([]byte, error) { + return joseBase64UrlDecode(js.payload) +} + +func (js *JSONSignature) protectedHeader() (string, error) { + protected := map[string]interface{}{ + "formatLength": js.formatLength, + "formatTail": joseBase64UrlEncode(js.formatTail), + "time": time.Now().UTC().Format(time.RFC3339), + } + protectedBytes, err := json.Marshal(protected) + if err != nil { + return "", err + } + + return joseBase64UrlEncode(protectedBytes), nil +} + +func (js *JSONSignature) signBytes(protectedHeader string) ([]byte, error) { + buf := make([]byte, len(js.payload)+len(protectedHeader)+1) + copy(buf, protectedHeader) + buf[len(protectedHeader)] = '.' + copy(buf[len(protectedHeader)+1:], js.payload) + return buf, nil +} + +// Sign adds a signature using the given private key. +func (js *JSONSignature) Sign(key PrivateKey) error { + protected, err := js.protectedHeader() + if err != nil { + return err + } + signBytes, err := js.signBytes(protected) + if err != nil { + return err + } + sigBytes, algorithm, err := key.Sign(bytes.NewReader(signBytes), crypto.SHA256) + if err != nil { + return err + } + + js.signatures = append(js.signatures, jsSignature{ + Header: jsHeader{ + JWK: key.PublicKey(), + Algorithm: algorithm, + }, + Signature: joseBase64UrlEncode(sigBytes), + Protected: protected, + }) + + return nil +} + +// SignWithChain adds a signature using the given private key +// and setting the x509 chain. The public key of the first element +// in the chain must be the public key corresponding with the sign key. +func (js *JSONSignature) SignWithChain(key PrivateKey, chain []*x509.Certificate) error { + // Ensure key.Chain[0] is public key for key + //key.Chain.PublicKey + //key.PublicKey().CryptoPublicKey() + + // Verify chain + protected, err := js.protectedHeader() + if err != nil { + return err + } + signBytes, err := js.signBytes(protected) + if err != nil { + return err + } + sigBytes, algorithm, err := key.Sign(bytes.NewReader(signBytes), crypto.SHA256) + if err != nil { + return err + } + + header := jsHeader{ + Chain: make([]string, len(chain)), + Algorithm: algorithm, + } + + for i, cert := range chain { + header.Chain[i] = base64.StdEncoding.EncodeToString(cert.Raw) + } + + js.signatures = append(js.signatures, jsSignature{ + Header: header, + Signature: joseBase64UrlEncode(sigBytes), + Protected: protected, + }) + + return nil +} + +// Verify verifies all the signatures and returns the list of +// public keys used to sign. Any x509 chains are not checked. +func (js *JSONSignature) Verify() ([]PublicKey, error) { + keys := make([]PublicKey, len(js.signatures)) + for i, signature := range js.signatures { + signBytes, err := js.signBytes(signature.Protected) + if err != nil { + return nil, err + } + var publicKey PublicKey + if len(signature.Header.Chain) > 0 { + certBytes, err := base64.StdEncoding.DecodeString(signature.Header.Chain[0]) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, err + } + publicKey, err = FromCryptoPublicKey(cert.PublicKey) + if err != nil { + return nil, err + } + } else if signature.Header.JWK != nil { + publicKey = signature.Header.JWK + } else { + return nil, errors.New("missing public key") + } + + sigBytes, err := joseBase64UrlDecode(signature.Signature) + if err != nil { + return nil, err + } + + err = publicKey.Verify(bytes.NewReader(signBytes), signature.Header.Algorithm, sigBytes) + if err != nil { + return nil, err + } + + keys[i] = publicKey + } + return keys, nil +} + +// VerifyChains verifies all the signatures and the chains associated +// with each signature and returns the list of verified chains. +// Signatures without an x509 chain are not checked. +func (js *JSONSignature) VerifyChains(ca *x509.CertPool) ([][]*x509.Certificate, error) { + chains := make([][]*x509.Certificate, 0, len(js.signatures)) + for _, signature := range js.signatures { + signBytes, err := js.signBytes(signature.Protected) + if err != nil { + return nil, err + } + var publicKey PublicKey + if len(signature.Header.Chain) > 0 { + certBytes, err := base64.StdEncoding.DecodeString(signature.Header.Chain[0]) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, err + } + publicKey, err = FromCryptoPublicKey(cert.PublicKey) + if err != nil { + return nil, err + } + intermediates := x509.NewCertPool() + if len(signature.Header.Chain) > 1 { + intermediateChain := signature.Header.Chain[1:] + for i := range intermediateChain { + certBytes, err := base64.StdEncoding.DecodeString(intermediateChain[i]) + if err != nil { + return nil, err + } + intermediate, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, err + } + intermediates.AddCert(intermediate) + } + } + + verifyOptions := x509.VerifyOptions{ + Intermediates: intermediates, + Roots: ca, + } + + verifiedChains, err := cert.Verify(verifyOptions) + if err != nil { + return nil, err + } + chains = append(chains, verifiedChains...) + + sigBytes, err := joseBase64UrlDecode(signature.Signature) + if err != nil { + return nil, err + } + + err = publicKey.Verify(bytes.NewReader(signBytes), signature.Header.Algorithm, sigBytes) + if err != nil { + return nil, err + } + } + + } + return chains, nil +} + +// JWS returns JSON serialized JWS according to +// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-7.2 +func (js *JSONSignature) JWS() ([]byte, error) { + if len(js.signatures) == 0 { + return nil, errors.New("missing signature") + } + + sort.Sort(jsSignaturesSorted(js.signatures)) + + jsonMap := map[string]interface{}{ + "payload": js.payload, + "signatures": js.signatures, + } + + return json.MarshalIndent(jsonMap, "", " ") +} + +func notSpace(r rune) bool { + return !unicode.IsSpace(r) +} + +func detectJSONIndent(jsonContent []byte) (indent string) { + if len(jsonContent) > 2 && jsonContent[0] == '{' && jsonContent[1] == '\n' { + quoteIndex := bytes.IndexRune(jsonContent[1:], '"') + if quoteIndex > 0 { + indent = string(jsonContent[2 : quoteIndex+1]) + } + } + return +} + +type jsParsedHeader struct { + JWK json.RawMessage `json:"jwk"` + Algorithm string `json:"alg"` + Chain []string `json:"x5c"` +} + +type jsParsedSignature struct { + Header jsParsedHeader `json:"header"` + Signature string `json:"signature"` + Protected string `json:"protected"` +} + +// ParseJWS parses a JWS serialized JSON object into a Json Signature. +func ParseJWS(content []byte) (*JSONSignature, error) { + type jsParsed struct { + Payload string `json:"payload"` + Signatures []jsParsedSignature `json:"signatures"` + } + parsed := &jsParsed{} + err := json.Unmarshal(content, parsed) + if err != nil { + return nil, err + } + if len(parsed.Signatures) == 0 { + return nil, errors.New("missing signatures") + } + payload, err := joseBase64UrlDecode(parsed.Payload) + if err != nil { + return nil, err + } + + js, err := NewJSONSignature(payload) + if err != nil { + return nil, err + } + js.signatures = make([]jsSignature, len(parsed.Signatures)) + for i, signature := range parsed.Signatures { + header := jsHeader{ + Algorithm: signature.Header.Algorithm, + } + if signature.Header.Chain != nil { + header.Chain = signature.Header.Chain + } + if signature.Header.JWK != nil { + publicKey, err := UnmarshalPublicKeyJWK([]byte(signature.Header.JWK)) + if err != nil { + return nil, err + } + header.JWK = publicKey + } + js.signatures[i] = jsSignature{ + Header: header, + Signature: signature.Signature, + Protected: signature.Protected, + } + } + + return js, nil +} + +// NewJSONSignature returns a new unsigned JWS from a json byte array. +// JSONSignature will need to be signed before serializing or storing. +// Optionally, one or more signatures can be provided as byte buffers, +// containing serialized JWS signatures, to assemble a fully signed JWS +// package. It is the callers responsibility to ensure uniqueness of the +// provided signatures. +func NewJSONSignature(content []byte, signatures ...[]byte) (*JSONSignature, error) { + var dataMap map[string]interface{} + err := json.Unmarshal(content, &dataMap) + if err != nil { + return nil, err + } + + js := newJSONSignature() + js.indent = detectJSONIndent(content) + + js.payload = joseBase64UrlEncode(content) + + // Find trailing } and whitespace, put in protected header + closeIndex := bytes.LastIndexFunc(content, notSpace) + if content[closeIndex] != '}' { + return nil, ErrInvalidJSONContent + } + lastRuneIndex := bytes.LastIndexFunc(content[:closeIndex], notSpace) + if content[lastRuneIndex] == ',' { + return nil, ErrInvalidJSONContent + } + js.formatLength = lastRuneIndex + 1 + js.formatTail = content[js.formatLength:] + + if len(signatures) > 0 { + for _, signature := range signatures { + var parsedJSig jsParsedSignature + + if err := json.Unmarshal(signature, &parsedJSig); err != nil { + return nil, err + } + + // TODO(stevvooe): A lot of the code below is repeated in + // ParseJWS. It will require more refactoring to fix that. + jsig := jsSignature{ + Header: jsHeader{ + Algorithm: parsedJSig.Header.Algorithm, + }, + Signature: parsedJSig.Signature, + Protected: parsedJSig.Protected, + } + + if parsedJSig.Header.Chain != nil { + jsig.Header.Chain = parsedJSig.Header.Chain + } + + if parsedJSig.Header.JWK != nil { + publicKey, err := UnmarshalPublicKeyJWK([]byte(parsedJSig.Header.JWK)) + if err != nil { + return nil, err + } + jsig.Header.JWK = publicKey + } + + js.signatures = append(js.signatures, jsig) + } + } + + return js, nil +} + +// NewJSONSignatureFromMap returns a new unsigned JSONSignature from a map or +// struct. JWS will need to be signed before serializing or storing. +func NewJSONSignatureFromMap(content interface{}) (*JSONSignature, error) { + switch content.(type) { + case map[string]interface{}: + case struct{}: + default: + return nil, errors.New("invalid data type") + } + + js := newJSONSignature() + js.indent = " " + + payload, err := json.MarshalIndent(content, "", js.indent) + if err != nil { + return nil, err + } + js.payload = joseBase64UrlEncode(payload) + + // Remove '\n}' from formatted section, put in protected header + js.formatLength = len(payload) - 2 + js.formatTail = payload[js.formatLength:] + + return js, nil +} + +func readIntFromMap(key string, m map[string]interface{}) (int, bool) { + value, ok := m[key] + if !ok { + return 0, false + } + switch v := value.(type) { + case int: + return v, true + case float64: + return int(v), true + default: + return 0, false + } +} + +func readStringFromMap(key string, m map[string]interface{}) (v string, ok bool) { + value, ok := m[key] + if !ok { + return "", false + } + v, ok = value.(string) + return +} + +// ParsePrettySignature parses a formatted signature into a +// JSON signature. If the signatures are missing the format information +// an error is thrown. The formatted signature must be created by +// the same method as format signature. +func ParsePrettySignature(content []byte, signatureKey string) (*JSONSignature, error) { + var contentMap map[string]json.RawMessage + err := json.Unmarshal(content, &contentMap) + if err != nil { + return nil, fmt.Errorf("error unmarshalling content: %s", err) + } + sigMessage, ok := contentMap[signatureKey] + if !ok { + return nil, ErrMissingSignatureKey + } + + var signatureBlocks []jsParsedSignature + err = json.Unmarshal([]byte(sigMessage), &signatureBlocks) + if err != nil { + return nil, fmt.Errorf("error unmarshalling signatures: %s", err) + } + + js := newJSONSignature() + js.signatures = make([]jsSignature, len(signatureBlocks)) + + for i, signatureBlock := range signatureBlocks { + protectedBytes, err := joseBase64UrlDecode(signatureBlock.Protected) + if err != nil { + return nil, fmt.Errorf("base64 decode error: %s", err) + } + var protectedHeader map[string]interface{} + err = json.Unmarshal(protectedBytes, &protectedHeader) + if err != nil { + return nil, fmt.Errorf("error unmarshalling protected header: %s", err) + } + + formatLength, ok := readIntFromMap("formatLength", protectedHeader) + if !ok { + return nil, errors.New("missing formatted length") + } + encodedTail, ok := readStringFromMap("formatTail", protectedHeader) + if !ok { + return nil, errors.New("missing formatted tail") + } + formatTail, err := joseBase64UrlDecode(encodedTail) + if err != nil { + return nil, fmt.Errorf("base64 decode error on tail: %s", err) + } + if js.formatLength == 0 { + js.formatLength = formatLength + } else if js.formatLength != formatLength { + return nil, errors.New("conflicting format length") + } + if len(js.formatTail) == 0 { + js.formatTail = formatTail + } else if bytes.Compare(js.formatTail, formatTail) != 0 { + return nil, errors.New("conflicting format tail") + } + + header := jsHeader{ + Algorithm: signatureBlock.Header.Algorithm, + Chain: signatureBlock.Header.Chain, + } + if signatureBlock.Header.JWK != nil { + publicKey, err := UnmarshalPublicKeyJWK([]byte(signatureBlock.Header.JWK)) + if err != nil { + return nil, fmt.Errorf("error unmarshalling public key: %s", err) + } + header.JWK = publicKey + } + js.signatures[i] = jsSignature{ + Header: header, + Signature: signatureBlock.Signature, + Protected: signatureBlock.Protected, + } + } + if js.formatLength > len(content) { + return nil, errors.New("invalid format length") + } + formatted := make([]byte, js.formatLength+len(js.formatTail)) + copy(formatted, content[:js.formatLength]) + copy(formatted[js.formatLength:], js.formatTail) + js.indent = detectJSONIndent(formatted) + js.payload = joseBase64UrlEncode(formatted) + + return js, nil +} + +// PrettySignature formats a json signature into an easy to read +// single json serialized object. +func (js *JSONSignature) PrettySignature(signatureKey string) ([]byte, error) { + if len(js.signatures) == 0 { + return nil, errors.New("no signatures") + } + payload, err := joseBase64UrlDecode(js.payload) + if err != nil { + return nil, err + } + payload = payload[:js.formatLength] + + sort.Sort(jsSignaturesSorted(js.signatures)) + + var marshalled []byte + var marshallErr error + if js.indent != "" { + marshalled, marshallErr = json.MarshalIndent(js.signatures, js.indent, js.indent) + } else { + marshalled, marshallErr = json.Marshal(js.signatures) + } + if marshallErr != nil { + return nil, marshallErr + } + + buf := bytes.NewBuffer(make([]byte, 0, len(payload)+len(marshalled)+34)) + buf.Write(payload) + buf.WriteByte(',') + if js.indent != "" { + buf.WriteByte('\n') + buf.WriteString(js.indent) + buf.WriteByte('"') + buf.WriteString(signatureKey) + buf.WriteString("\": ") + buf.Write(marshalled) + buf.WriteByte('\n') + } else { + buf.WriteByte('"') + buf.WriteString(signatureKey) + buf.WriteString("\":") + buf.Write(marshalled) + } + buf.WriteByte('}') + + return buf.Bytes(), nil +} + +// Signatures provides the signatures on this JWS as opaque blobs, sorted by +// keyID. These blobs can be stored and reassembled with payloads. Internally, +// they are simply marshaled json web signatures but implementations should +// not rely on this. +func (js *JSONSignature) Signatures() ([][]byte, error) { + sort.Sort(jsSignaturesSorted(js.signatures)) + + var sb [][]byte + for _, jsig := range js.signatures { + p, err := json.Marshal(jsig) + if err != nil { + return nil, err + } + + sb = append(sb, p) + } + + return sb, nil +} + +// Merge combines the signatures from one or more other signatures into the +// method receiver. If the payloads differ for any argument, an error will be +// returned and the receiver will not be modified. +func (js *JSONSignature) Merge(others ...*JSONSignature) error { + merged := js.signatures + for _, other := range others { + if js.payload != other.payload { + return fmt.Errorf("payloads differ from merge target") + } + merged = append(merged, other.signatures...) + } + + js.signatures = merged + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/jsonsign_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/jsonsign_test.go new file mode 100644 index 0000000..b4f2697 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/jsonsign_test.go @@ -0,0 +1,380 @@ +package libtrust + +import ( + "bytes" + "crypto/rand" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "testing" + + "github.com/docker/libtrust/testutil" +) + +func createTestJSON(sigKey string, indent string) (map[string]interface{}, []byte) { + testMap := map[string]interface{}{ + "name": "dmcgowan/mycontainer", + "config": map[string]interface{}{ + "ports": []int{9101, 9102}, + "run": "/bin/echo \"Hello\"", + }, + "layers": []string{ + "2893c080-27f5-11e4-8c21-0800200c9a66", + "c54bc25b-fbb2-497b-a899-a8bc1b5b9d55", + "4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4", + "0b6da891-7f7f-4abf-9c97-7887549e696c", + "1d960389-ae4f-4011-85fd-18d0f96a67ad", + }, + } + formattedSection := `{"config":{"ports":[9101,9102],"run":"/bin/echo \"Hello\""},"layers":["2893c080-27f5-11e4-8c21-0800200c9a66","c54bc25b-fbb2-497b-a899-a8bc1b5b9d55","4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4","0b6da891-7f7f-4abf-9c97-7887549e696c","1d960389-ae4f-4011-85fd-18d0f96a67ad"],"name":"dmcgowan/mycontainer","%s":[{"header":{` + formattedSection = fmt.Sprintf(formattedSection, sigKey) + if indent != "" { + buf := bytes.NewBuffer(nil) + json.Indent(buf, []byte(formattedSection), "", indent) + return testMap, buf.Bytes() + } + return testMap, []byte(formattedSection) + +} + +func TestSignJSON(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating EC key: %s", err) + } + + testMap, _ := createTestJSON("buildSignatures", " ") + indented, err := json.MarshalIndent(testMap, "", " ") + if err != nil { + t.Fatalf("Marshall error: %s", err) + } + + js, err := NewJSONSignature(indented) + if err != nil { + t.Fatalf("Error creating JSON signature: %s", err) + } + err = js.Sign(key) + if err != nil { + t.Fatalf("Error signing content: %s", err) + } + + keys, err := js.Verify() + if err != nil { + t.Fatalf("Error verifying signature: %s", err) + } + if len(keys) != 1 { + t.Fatalf("Error wrong number of keys returned") + } + if keys[0].KeyID() != key.KeyID() { + t.Fatalf("Unexpected public key returned") + } + +} + +func TestSignMap(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating EC key: %s", err) + } + + testMap, _ := createTestJSON("buildSignatures", " ") + js, err := NewJSONSignatureFromMap(testMap) + if err != nil { + t.Fatalf("Error creating JSON signature: %s", err) + } + err = js.Sign(key) + if err != nil { + t.Fatalf("Error signing JSON signature: %s", err) + } + + keys, err := js.Verify() + if err != nil { + t.Fatalf("Error verifying signature: %s", err) + } + if len(keys) != 1 { + t.Fatalf("Error wrong number of keys returned") + } + if keys[0].KeyID() != key.KeyID() { + t.Fatalf("Unexpected public key returned") + } +} + +func TestFormattedJson(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating EC key: %s", err) + } + + testMap, firstSection := createTestJSON("buildSignatures", " ") + indented, err := json.MarshalIndent(testMap, "", " ") + if err != nil { + t.Fatalf("Marshall error: %s", err) + } + + js, err := NewJSONSignature(indented) + if err != nil { + t.Fatalf("Error creating JSON signature: %s", err) + } + err = js.Sign(key) + if err != nil { + t.Fatalf("Error signing content: %s", err) + } + + b, err := js.PrettySignature("buildSignatures") + if err != nil { + t.Fatalf("Error signing map: %s", err) + } + + if bytes.Compare(b[:len(firstSection)], firstSection) != 0 { + t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)]) + } + + parsed, err := ParsePrettySignature(b, "buildSignatures") + if err != nil { + t.Fatalf("Error parsing formatted signature: %s", err) + } + + keys, err := parsed.Verify() + if err != nil { + t.Fatalf("Error verifying signature: %s", err) + } + if len(keys) != 1 { + t.Fatalf("Error wrong number of keys returned") + } + if keys[0].KeyID() != key.KeyID() { + t.Fatalf("Unexpected public key returned") + } + + var unmarshalled map[string]interface{} + err = json.Unmarshal(b, &unmarshalled) + if err != nil { + t.Fatalf("Could not unmarshall after parse: %s", err) + } + +} + +func TestFormattedFlatJson(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating EC key: %s", err) + } + + testMap, firstSection := createTestJSON("buildSignatures", "") + unindented, err := json.Marshal(testMap) + if err != nil { + t.Fatalf("Marshall error: %s", err) + } + + js, err := NewJSONSignature(unindented) + if err != nil { + t.Fatalf("Error creating JSON signature: %s", err) + } + err = js.Sign(key) + if err != nil { + t.Fatalf("Error signing JSON signature: %s", err) + } + + b, err := js.PrettySignature("buildSignatures") + if err != nil { + t.Fatalf("Error signing map: %s", err) + } + + if bytes.Compare(b[:len(firstSection)], firstSection) != 0 { + t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)]) + } + + parsed, err := ParsePrettySignature(b, "buildSignatures") + if err != nil { + t.Fatalf("Error parsing formatted signature: %s", err) + } + + keys, err := parsed.Verify() + if err != nil { + t.Fatalf("Error verifying signature: %s", err) + } + if len(keys) != 1 { + t.Fatalf("Error wrong number of keys returned") + } + if keys[0].KeyID() != key.KeyID() { + t.Fatalf("Unexpected public key returned") + } +} + +func generateTrustChain(t *testing.T, key PrivateKey, ca *x509.Certificate) (PrivateKey, []*x509.Certificate) { + parent := ca + parentKey := key + chain := make([]*x509.Certificate, 6) + for i := 5; i > 0; i-- { + intermediatekey, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generate key: %s", err) + } + chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent) + if err != nil { + t.Fatalf("Error generating intermdiate certificate: %s", err) + } + parent = chain[i] + parentKey = intermediatekey + } + trustKey, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generate key: %s", err) + } + chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent) + if err != nil { + t.Fatalf("Error generate trust cert: %s", err) + } + + return trustKey, chain +} + +func TestChainVerify(t *testing.T) { + caKey, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating key: %s", err) + } + ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey()) + if err != nil { + t.Fatalf("Error generating ca: %s", err) + } + trustKey, chain := generateTrustChain(t, caKey, ca) + + testMap, _ := createTestJSON("verifySignatures", " ") + js, err := NewJSONSignatureFromMap(testMap) + if err != nil { + t.Fatalf("Error creating JSONSignature from map: %s", err) + } + + err = js.SignWithChain(trustKey, chain) + if err != nil { + t.Fatalf("Error signing with chain: %s", err) + } + + pool := x509.NewCertPool() + pool.AddCert(ca) + chains, err := js.VerifyChains(pool) + if err != nil { + t.Fatalf("Error verifying content: %s", err) + } + if len(chains) != 1 { + t.Fatalf("Unexpected chains length: %d", len(chains)) + } + if len(chains[0]) != 7 { + t.Fatalf("Unexpected chain length: %d", len(chains[0])) + } +} + +func TestInvalidChain(t *testing.T) { + caKey, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating key: %s", err) + } + ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey()) + if err != nil { + t.Fatalf("Error generating ca: %s", err) + } + trustKey, chain := generateTrustChain(t, caKey, ca) + + testMap, _ := createTestJSON("verifySignatures", " ") + js, err := NewJSONSignatureFromMap(testMap) + if err != nil { + t.Fatalf("Error creating JSONSignature from map: %s", err) + } + + err = js.SignWithChain(trustKey, chain[:5]) + if err != nil { + t.Fatalf("Error signing with chain: %s", err) + } + + pool := x509.NewCertPool() + pool.AddCert(ca) + chains, err := js.VerifyChains(pool) + if err == nil { + t.Fatalf("Expected error verifying with bad chain") + } + if len(chains) != 0 { + t.Fatalf("Unexpected chains returned from invalid verify") + } +} + +func TestMergeSignatures(t *testing.T) { + pk1, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("unexpected error generating private key 1: %v", err) + } + + pk2, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("unexpected error generating private key 2: %v", err) + } + + payload := make([]byte, 1<<10) + if _, err = io.ReadFull(rand.Reader, payload); err != nil { + t.Fatalf("error generating payload: %v", err) + } + + payload, _ = json.Marshal(map[string]interface{}{"data": payload}) + + sig1, err := NewJSONSignature(payload) + if err != nil { + t.Fatalf("unexpected error creating signature 1: %v", err) + } + + if err := sig1.Sign(pk1); err != nil { + t.Fatalf("unexpected error signing with pk1: %v", err) + } + + sig2, err := NewJSONSignature(payload) + if err != nil { + t.Fatalf("unexpected error creating signature 2: %v", err) + } + + if err := sig2.Sign(pk2); err != nil { + t.Fatalf("unexpected error signing with pk2: %v", err) + } + + // Now, we actually merge into sig1 + if err := sig1.Merge(sig2); err != nil { + t.Fatalf("unexpected error merging: %v", err) + } + + // Verify the new signature package + pubkeys, err := sig1.Verify() + if err != nil { + t.Fatalf("unexpected error during verify: %v", err) + } + + // Make sure the pubkeys match the two private keys from before + privkeys := map[string]PrivateKey{ + pk1.KeyID(): pk1, + pk2.KeyID(): pk2, + } + + found := map[string]struct{}{} + + for _, pubkey := range pubkeys { + if _, ok := privkeys[pubkey.KeyID()]; !ok { + t.Fatalf("unexpected public key found during verification: %v", pubkey) + } + + found[pubkey.KeyID()] = struct{}{} + } + + // Make sure we've found all the private keys from verification + for keyid, _ := range privkeys { + if _, ok := found[keyid]; !ok { + t.Fatalf("public key %v not found during verification", keyid) + } + } + + // Create another signature, with a different payload, and ensure we get an error. + sig3, err := NewJSONSignature([]byte("{}")) + if err != nil { + t.Fatalf("unexpected error making signature for sig3: %v", err) + } + + if err := sig1.Merge(sig3); err == nil { + t.Fatalf("error expected during invalid merge with different payload") + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/key.go b/Godeps/_workspace/src/github.com/docker/libtrust/key.go new file mode 100644 index 0000000..73642db --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/key.go @@ -0,0 +1,253 @@ +package libtrust + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" +) + +// PublicKey is a generic interface for a Public Key. +type PublicKey interface { + // KeyType returns the key type for this key. For elliptic curve keys, + // this value should be "EC". For RSA keys, this value should be "RSA". + KeyType() string + // KeyID returns a distinct identifier which is unique to this Public Key. + // The format generated by this library is a base32 encoding of a 240 bit + // hash of the public key data divided into 12 groups like so: + // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP + KeyID() string + // Verify verifyies the signature of the data in the io.Reader using this + // Public Key. The alg parameter should identify the digital signature + // algorithm which was used to produce the signature and should be + // supported by this public key. Returns a nil error if the signature + // is valid. + Verify(data io.Reader, alg string, signature []byte) error + // CryptoPublicKey returns the internal object which can be used as a + // crypto.PublicKey for use with other standard library operations. The type + // is either *rsa.PublicKey or *ecdsa.PublicKey + CryptoPublicKey() crypto.PublicKey + // These public keys can be serialized to the standard JSON encoding for + // JSON Web Keys. See section 6 of the IETF draft RFC for JOSE JSON Web + // Algorithms. + MarshalJSON() ([]byte, error) + // These keys can also be serialized to the standard PEM encoding. + PEMBlock() (*pem.Block, error) + // The string representation of a key is its key type and ID. + String() string + AddExtendedField(string, interface{}) + GetExtendedField(string) interface{} +} + +// PrivateKey is a generic interface for a Private Key. +type PrivateKey interface { + // A PrivateKey contains all fields and methods of a PublicKey of the + // same type. The MarshalJSON method also outputs the private key as a + // JSON Web Key, and the PEMBlock method outputs the private key as a + // PEM block. + PublicKey + // PublicKey returns the PublicKey associated with this PrivateKey. + PublicKey() PublicKey + // Sign signs the data read from the io.Reader using a signature algorithm + // supported by the private key. If the specified hashing algorithm is + // supported by this key, that hash function is used to generate the + // signature otherwise the the default hashing algorithm for this key is + // used. Returns the signature and identifier of the algorithm used. + Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) + // CryptoPrivateKey returns the internal object which can be used as a + // crypto.PublicKey for use with other standard library operations. The + // type is either *rsa.PublicKey or *ecdsa.PublicKey + CryptoPrivateKey() crypto.PrivateKey +} + +// FromCryptoPublicKey returns a libtrust PublicKey representation of the given +// *ecdsa.PublicKey or *rsa.PublicKey. Returns a non-nil error when the given +// key is of an unsupported type. +func FromCryptoPublicKey(cryptoPublicKey crypto.PublicKey) (PublicKey, error) { + switch cryptoPublicKey := cryptoPublicKey.(type) { + case *ecdsa.PublicKey: + return fromECPublicKey(cryptoPublicKey) + case *rsa.PublicKey: + return fromRSAPublicKey(cryptoPublicKey), nil + default: + return nil, fmt.Errorf("public key type %T is not supported", cryptoPublicKey) + } +} + +// FromCryptoPrivateKey returns a libtrust PrivateKey representation of the given +// *ecdsa.PrivateKey or *rsa.PrivateKey. Returns a non-nil error when the given +// key is of an unsupported type. +func FromCryptoPrivateKey(cryptoPrivateKey crypto.PrivateKey) (PrivateKey, error) { + switch cryptoPrivateKey := cryptoPrivateKey.(type) { + case *ecdsa.PrivateKey: + return fromECPrivateKey(cryptoPrivateKey) + case *rsa.PrivateKey: + return fromRSAPrivateKey(cryptoPrivateKey), nil + default: + return nil, fmt.Errorf("private key type %T is not supported", cryptoPrivateKey) + } +} + +// UnmarshalPublicKeyPEM parses the PEM encoded data and returns a libtrust +// PublicKey or an error if there is a problem with the encoding. +func UnmarshalPublicKeyPEM(data []byte) (PublicKey, error) { + pemBlock, _ := pem.Decode(data) + if pemBlock == nil { + return nil, errors.New("unable to find PEM encoded data") + } else if pemBlock.Type != "PUBLIC KEY" { + return nil, fmt.Errorf("unable to get PublicKey from PEM type: %s", pemBlock.Type) + } + + return pubKeyFromPEMBlock(pemBlock) +} + +// UnmarshalPublicKeyPEMBundle parses the PEM encoded data as a bundle of +// PEM blocks appended one after the other and returns a slice of PublicKey +// objects that it finds. +func UnmarshalPublicKeyPEMBundle(data []byte) ([]PublicKey, error) { + pubKeys := []PublicKey{} + + for { + var pemBlock *pem.Block + pemBlock, data = pem.Decode(data) + if pemBlock == nil { + break + } else if pemBlock.Type != "PUBLIC KEY" { + return nil, fmt.Errorf("unable to get PublicKey from PEM type: %s", pemBlock.Type) + } + + pubKey, err := pubKeyFromPEMBlock(pemBlock) + if err != nil { + return nil, err + } + + pubKeys = append(pubKeys, pubKey) + } + + return pubKeys, nil +} + +// UnmarshalPrivateKeyPEM parses the PEM encoded data and returns a libtrust +// PrivateKey or an error if there is a problem with the encoding. +func UnmarshalPrivateKeyPEM(data []byte) (PrivateKey, error) { + pemBlock, _ := pem.Decode(data) + if pemBlock == nil { + return nil, errors.New("unable to find PEM encoded data") + } + + var key PrivateKey + + switch { + case pemBlock.Type == "RSA PRIVATE KEY": + rsaPrivateKey, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("unable to decode RSA Private Key PEM data: %s", err) + } + key = fromRSAPrivateKey(rsaPrivateKey) + case pemBlock.Type == "EC PRIVATE KEY": + ecPrivateKey, err := x509.ParseECPrivateKey(pemBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("unable to decode EC Private Key PEM data: %s", err) + } + key, err = fromECPrivateKey(ecPrivateKey) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unable to get PrivateKey from PEM type: %s", pemBlock.Type) + } + + addPEMHeadersToKey(pemBlock, key.PublicKey()) + + return key, nil +} + +// UnmarshalPublicKeyJWK unmarshals the given JSON Web Key into a generic +// Public Key to be used with libtrust. +func UnmarshalPublicKeyJWK(data []byte) (PublicKey, error) { + jwk := make(map[string]interface{}) + + err := json.Unmarshal(data, &jwk) + if err != nil { + return nil, fmt.Errorf( + "decoding JWK Public Key JSON data: %s\n", err, + ) + } + + // Get the Key Type value. + kty, err := stringFromMap(jwk, "kty") + if err != nil { + return nil, fmt.Errorf("JWK Public Key type: %s", err) + } + + switch { + case kty == "EC": + // Call out to unmarshal EC public key. + return ecPublicKeyFromMap(jwk) + case kty == "RSA": + // Call out to unmarshal RSA public key. + return rsaPublicKeyFromMap(jwk) + default: + return nil, fmt.Errorf( + "JWK Public Key type not supported: %q\n", kty, + ) + } +} + +// UnmarshalPublicKeyJWKSet parses the JSON encoded data as a JSON Web Key Set +// and returns a slice of Public Key objects. +func UnmarshalPublicKeyJWKSet(data []byte) ([]PublicKey, error) { + rawKeys, err := loadJSONKeySetRaw(data) + if err != nil { + return nil, err + } + + pubKeys := make([]PublicKey, 0, len(rawKeys)) + + for _, rawKey := range rawKeys { + pubKey, err := UnmarshalPublicKeyJWK(rawKey) + if err != nil { + return nil, err + } + pubKeys = append(pubKeys, pubKey) + } + + return pubKeys, nil +} + +// UnmarshalPrivateKeyJWK unmarshals the given JSON Web Key into a generic +// Private Key to be used with libtrust. +func UnmarshalPrivateKeyJWK(data []byte) (PrivateKey, error) { + jwk := make(map[string]interface{}) + + err := json.Unmarshal(data, &jwk) + if err != nil { + return nil, fmt.Errorf( + "decoding JWK Private Key JSON data: %s\n", err, + ) + } + + // Get the Key Type value. + kty, err := stringFromMap(jwk, "kty") + if err != nil { + return nil, fmt.Errorf("JWK Private Key type: %s", err) + } + + switch { + case kty == "EC": + // Call out to unmarshal EC private key. + return ecPrivateKeyFromMap(jwk) + case kty == "RSA": + // Call out to unmarshal RSA private key. + return rsaPrivateKeyFromMap(jwk) + default: + return nil, fmt.Errorf( + "JWK Private Key type not supported: %q\n", kty, + ) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/key_files.go b/Godeps/_workspace/src/github.com/docker/libtrust/key_files.go new file mode 100644 index 0000000..c526de5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/key_files.go @@ -0,0 +1,255 @@ +package libtrust + +import ( + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "os" + "strings" +) + +var ( + // ErrKeyFileDoesNotExist indicates that the private key file does not exist. + ErrKeyFileDoesNotExist = errors.New("key file does not exist") +) + +func readKeyFileBytes(filename string) ([]byte, error) { + data, err := ioutil.ReadFile(filename) + if err != nil { + if os.IsNotExist(err) { + err = ErrKeyFileDoesNotExist + } else { + err = fmt.Errorf("unable to read key file %s: %s", filename, err) + } + + return nil, err + } + + return data, nil +} + +/* + Loading and Saving of Public and Private Keys in either PEM or JWK format. +*/ + +// LoadKeyFile opens the given filename and attempts to read a Private Key +// encoded in either PEM or JWK format (if .json or .jwk file extension). +func LoadKeyFile(filename string) (PrivateKey, error) { + contents, err := readKeyFileBytes(filename) + if err != nil { + return nil, err + } + + var key PrivateKey + + if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") { + key, err = UnmarshalPrivateKeyJWK(contents) + if err != nil { + return nil, fmt.Errorf("unable to decode private key JWK: %s", err) + } + } else { + key, err = UnmarshalPrivateKeyPEM(contents) + if err != nil { + return nil, fmt.Errorf("unable to decode private key PEM: %s", err) + } + } + + return key, nil +} + +// LoadPublicKeyFile opens the given filename and attempts to read a Public Key +// encoded in either PEM or JWK format (if .json or .jwk file extension). +func LoadPublicKeyFile(filename string) (PublicKey, error) { + contents, err := readKeyFileBytes(filename) + if err != nil { + return nil, err + } + + var key PublicKey + + if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") { + key, err = UnmarshalPublicKeyJWK(contents) + if err != nil { + return nil, fmt.Errorf("unable to decode public key JWK: %s", err) + } + } else { + key, err = UnmarshalPublicKeyPEM(contents) + if err != nil { + return nil, fmt.Errorf("unable to decode public key PEM: %s", err) + } + } + + return key, nil +} + +// SaveKey saves the given key to a file using the provided filename. +// This process will overwrite any existing file at the provided location. +func SaveKey(filename string, key PrivateKey) error { + var encodedKey []byte + var err error + + if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") { + // Encode in JSON Web Key format. + encodedKey, err = json.MarshalIndent(key, "", " ") + if err != nil { + return fmt.Errorf("unable to encode private key JWK: %s", err) + } + } else { + // Encode in PEM format. + pemBlock, err := key.PEMBlock() + if err != nil { + return fmt.Errorf("unable to encode private key PEM: %s", err) + } + encodedKey = pem.EncodeToMemory(pemBlock) + } + + err = ioutil.WriteFile(filename, encodedKey, os.FileMode(0600)) + if err != nil { + return fmt.Errorf("unable to write private key file %s: %s", filename, err) + } + + return nil +} + +// SavePublicKey saves the given public key to the file. +func SavePublicKey(filename string, key PublicKey) error { + var encodedKey []byte + var err error + + if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") { + // Encode in JSON Web Key format. + encodedKey, err = json.MarshalIndent(key, "", " ") + if err != nil { + return fmt.Errorf("unable to encode public key JWK: %s", err) + } + } else { + // Encode in PEM format. + pemBlock, err := key.PEMBlock() + if err != nil { + return fmt.Errorf("unable to encode public key PEM: %s", err) + } + encodedKey = pem.EncodeToMemory(pemBlock) + } + + err = ioutil.WriteFile(filename, encodedKey, os.FileMode(0644)) + if err != nil { + return fmt.Errorf("unable to write public key file %s: %s", filename, err) + } + + return nil +} + +// Public Key Set files + +type jwkSet struct { + Keys []json.RawMessage `json:"keys"` +} + +// LoadKeySetFile loads a key set +func LoadKeySetFile(filename string) ([]PublicKey, error) { + if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") { + return loadJSONKeySetFile(filename) + } + + // Must be a PEM format file + return loadPEMKeySetFile(filename) +} + +func loadJSONKeySetRaw(data []byte) ([]json.RawMessage, error) { + if len(data) == 0 { + // This is okay, just return an empty slice. + return []json.RawMessage{}, nil + } + + keySet := jwkSet{} + + err := json.Unmarshal(data, &keySet) + if err != nil { + return nil, fmt.Errorf("unable to decode JSON Web Key Set: %s", err) + } + + return keySet.Keys, nil +} + +func loadJSONKeySetFile(filename string) ([]PublicKey, error) { + contents, err := readKeyFileBytes(filename) + if err != nil && err != ErrKeyFileDoesNotExist { + return nil, err + } + + return UnmarshalPublicKeyJWKSet(contents) +} + +func loadPEMKeySetFile(filename string) ([]PublicKey, error) { + data, err := readKeyFileBytes(filename) + if err != nil && err != ErrKeyFileDoesNotExist { + return nil, err + } + + return UnmarshalPublicKeyPEMBundle(data) +} + +// AddKeySetFile adds a key to a key set +func AddKeySetFile(filename string, key PublicKey) error { + if strings.HasSuffix(filename, ".json") || strings.HasSuffix(filename, ".jwk") { + return addKeySetJSONFile(filename, key) + } + + // Must be a PEM format file + return addKeySetPEMFile(filename, key) +} + +func addKeySetJSONFile(filename string, key PublicKey) error { + encodedKey, err := json.Marshal(key) + if err != nil { + return fmt.Errorf("unable to encode trusted client key: %s", err) + } + + contents, err := readKeyFileBytes(filename) + if err != nil && err != ErrKeyFileDoesNotExist { + return err + } + + rawEntries, err := loadJSONKeySetRaw(contents) + if err != nil { + return err + } + + rawEntries = append(rawEntries, json.RawMessage(encodedKey)) + entriesWrapper := jwkSet{Keys: rawEntries} + + encodedEntries, err := json.MarshalIndent(entriesWrapper, "", " ") + if err != nil { + return fmt.Errorf("unable to encode trusted client keys: %s", err) + } + + err = ioutil.WriteFile(filename, encodedEntries, os.FileMode(0644)) + if err != nil { + return fmt.Errorf("unable to write trusted client keys file %s: %s", filename, err) + } + + return nil +} + +func addKeySetPEMFile(filename string, key PublicKey) error { + // Encode to PEM, open file for appending, write PEM. + file, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_RDWR, os.FileMode(0644)) + if err != nil { + return fmt.Errorf("unable to open trusted client keys file %s: %s", filename, err) + } + defer file.Close() + + pemBlock, err := key.PEMBlock() + if err != nil { + return fmt.Errorf("unable to encoded trusted key: %s", err) + } + + _, err = file.Write(pem.EncodeToMemory(pemBlock)) + if err != nil { + return fmt.Errorf("unable to write trusted keys file: %s", err) + } + + return nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/key_files_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/key_files_test.go new file mode 100644 index 0000000..57e691f --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/key_files_test.go @@ -0,0 +1,220 @@ +package libtrust + +import ( + "errors" + "io/ioutil" + "os" + "testing" +) + +func makeTempFile(t *testing.T, prefix string) (filename string) { + file, err := ioutil.TempFile("", prefix) + if err != nil { + t.Fatal(err) + } + + filename = file.Name() + file.Close() + + return +} + +func TestKeyFiles(t *testing.T) { + key, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + testKeyFiles(t, key) + + key, err = GenerateRSA2048PrivateKey() + if err != nil { + t.Fatal(err) + } + + testKeyFiles(t, key) +} + +func testKeyFiles(t *testing.T, key PrivateKey) { + var err error + + privateKeyFilename := makeTempFile(t, "private_key") + privateKeyFilenamePEM := privateKeyFilename + ".pem" + privateKeyFilenameJWK := privateKeyFilename + ".jwk" + + publicKeyFilename := makeTempFile(t, "public_key") + publicKeyFilenamePEM := publicKeyFilename + ".pem" + publicKeyFilenameJWK := publicKeyFilename + ".jwk" + + if err = SaveKey(privateKeyFilenamePEM, key); err != nil { + t.Fatal(err) + } + + if err = SaveKey(privateKeyFilenameJWK, key); err != nil { + t.Fatal(err) + } + + if err = SavePublicKey(publicKeyFilenamePEM, key.PublicKey()); err != nil { + t.Fatal(err) + } + + if err = SavePublicKey(publicKeyFilenameJWK, key.PublicKey()); err != nil { + t.Fatal(err) + } + + loadedPEMKey, err := LoadKeyFile(privateKeyFilenamePEM) + if err != nil { + t.Fatal(err) + } + + loadedJWKKey, err := LoadKeyFile(privateKeyFilenameJWK) + if err != nil { + t.Fatal(err) + } + + loadedPEMPublicKey, err := LoadPublicKeyFile(publicKeyFilenamePEM) + if err != nil { + t.Fatal(err) + } + + loadedJWKPublicKey, err := LoadPublicKeyFile(publicKeyFilenameJWK) + if err != nil { + t.Fatal(err) + } + + if key.KeyID() != loadedPEMKey.KeyID() { + t.Fatal(errors.New("key IDs do not match")) + } + + if key.KeyID() != loadedJWKKey.KeyID() { + t.Fatal(errors.New("key IDs do not match")) + } + + if key.KeyID() != loadedPEMPublicKey.KeyID() { + t.Fatal(errors.New("key IDs do not match")) + } + + if key.KeyID() != loadedJWKPublicKey.KeyID() { + t.Fatal(errors.New("key IDs do not match")) + } + + os.Remove(privateKeyFilename) + os.Remove(privateKeyFilenamePEM) + os.Remove(privateKeyFilenameJWK) + os.Remove(publicKeyFilename) + os.Remove(publicKeyFilenamePEM) + os.Remove(publicKeyFilenameJWK) +} + +func TestTrustedHostKeysFile(t *testing.T) { + trustedHostKeysFilename := makeTempFile(t, "trusted_host_keys") + trustedHostKeysFilenamePEM := trustedHostKeysFilename + ".pem" + trustedHostKeysFilenameJWK := trustedHostKeysFilename + ".json" + + testTrustedHostKeysFile(t, trustedHostKeysFilenamePEM) + testTrustedHostKeysFile(t, trustedHostKeysFilenameJWK) + + os.Remove(trustedHostKeysFilename) + os.Remove(trustedHostKeysFilenamePEM) + os.Remove(trustedHostKeysFilenameJWK) +} + +func testTrustedHostKeysFile(t *testing.T, trustedHostKeysFilename string) { + hostAddress1 := "docker.example.com:2376" + hostKey1, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + hostKey1.AddExtendedField("hosts", []string{hostAddress1}) + err = AddKeySetFile(trustedHostKeysFilename, hostKey1.PublicKey()) + if err != nil { + t.Fatal(err) + } + + trustedHostKeysMapping, err := LoadKeySetFile(trustedHostKeysFilename) + if err != nil { + t.Fatal(err) + } + + for addr, hostKey := range trustedHostKeysMapping { + t.Logf("Host Address: %d\n", addr) + t.Logf("Host Key: %s\n\n", hostKey) + } + + hostAddress2 := "192.168.59.103:2376" + hostKey2, err := GenerateRSA2048PrivateKey() + if err != nil { + t.Fatal(err) + } + + hostKey2.AddExtendedField("hosts", hostAddress2) + err = AddKeySetFile(trustedHostKeysFilename, hostKey2.PublicKey()) + if err != nil { + t.Fatal(err) + } + + trustedHostKeysMapping, err = LoadKeySetFile(trustedHostKeysFilename) + if err != nil { + t.Fatal(err) + } + + for addr, hostKey := range trustedHostKeysMapping { + t.Logf("Host Address: %d\n", addr) + t.Logf("Host Key: %s\n\n", hostKey) + } + +} + +func TestTrustedClientKeysFile(t *testing.T) { + trustedClientKeysFilename := makeTempFile(t, "trusted_client_keys") + trustedClientKeysFilenamePEM := trustedClientKeysFilename + ".pem" + trustedClientKeysFilenameJWK := trustedClientKeysFilename + ".json" + + testTrustedClientKeysFile(t, trustedClientKeysFilenamePEM) + testTrustedClientKeysFile(t, trustedClientKeysFilenameJWK) + + os.Remove(trustedClientKeysFilename) + os.Remove(trustedClientKeysFilenamePEM) + os.Remove(trustedClientKeysFilenameJWK) +} + +func testTrustedClientKeysFile(t *testing.T, trustedClientKeysFilename string) { + clientKey1, err := GenerateECP256PrivateKey() + if err != nil { + t.Fatal(err) + } + + err = AddKeySetFile(trustedClientKeysFilename, clientKey1.PublicKey()) + if err != nil { + t.Fatal(err) + } + + trustedClientKeys, err := LoadKeySetFile(trustedClientKeysFilename) + if err != nil { + t.Fatal(err) + } + + for _, clientKey := range trustedClientKeys { + t.Logf("Client Key: %s\n", clientKey) + } + + clientKey2, err := GenerateRSA2048PrivateKey() + if err != nil { + t.Fatal(err) + } + + err = AddKeySetFile(trustedClientKeysFilename, clientKey2.PublicKey()) + if err != nil { + t.Fatal(err) + } + + trustedClientKeys, err = LoadKeySetFile(trustedClientKeysFilename) + if err != nil { + t.Fatal(err) + } + + for _, clientKey := range trustedClientKeys { + t.Logf("Client Key: %s\n", clientKey) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/key_manager.go b/Godeps/_workspace/src/github.com/docker/libtrust/key_manager.go new file mode 100644 index 0000000..9a98ae3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/key_manager.go @@ -0,0 +1,175 @@ +package libtrust + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "os" + "path" + "sync" +) + +// ClientKeyManager manages client keys on the filesystem +type ClientKeyManager struct { + key PrivateKey + clientFile string + clientDir string + + clientLock sync.RWMutex + clients []PublicKey + + configLock sync.Mutex + configs []*tls.Config +} + +// NewClientKeyManager loads a new manager from a set of key files +// and managed by the given private key. +func NewClientKeyManager(trustKey PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) { + m := &ClientKeyManager{ + key: trustKey, + clientFile: clientFile, + clientDir: clientDir, + } + if err := m.loadKeys(); err != nil { + return nil, err + } + // TODO Start watching file and directory + + return m, nil +} + +func (c *ClientKeyManager) loadKeys() (err error) { + // Load authorized keys file + var clients []PublicKey + if c.clientFile != "" { + clients, err = LoadKeySetFile(c.clientFile) + if err != nil { + return fmt.Errorf("unable to load authorized keys: %s", err) + } + } + + // Add clients from authorized keys directory + files, err := ioutil.ReadDir(c.clientDir) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("unable to open authorized keys directory: %s", err) + } + for _, f := range files { + if !f.IsDir() { + publicKey, err := LoadPublicKeyFile(path.Join(c.clientDir, f.Name())) + if err != nil { + return fmt.Errorf("unable to load authorized key file: %s", err) + } + clients = append(clients, publicKey) + } + } + + c.clientLock.Lock() + c.clients = clients + c.clientLock.Unlock() + + return nil +} + +// RegisterTLSConfig registers a tls configuration to manager +// such that any changes to the keys may be reflected in +// the tls client CA pool +func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error { + c.clientLock.RLock() + certPool, err := GenerateCACertPool(c.key, c.clients) + if err != nil { + return fmt.Errorf("CA pool generation error: %s", err) + } + c.clientLock.RUnlock() + + tlsConfig.ClientCAs = certPool + + c.configLock.Lock() + c.configs = append(c.configs, tlsConfig) + c.configLock.Unlock() + + return nil +} + +// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for +// libtrust identity authentication for the domain specified +func NewIdentityAuthTLSConfig(trustKey PrivateKey, clients *ClientKeyManager, addr string, domain string) (*tls.Config, error) { + tlsConfig := newTLSConfig() + + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + if err := clients.RegisterTLSConfig(tlsConfig); err != nil { + return nil, err + } + + // Generate cert + ips, domains, err := parseAddr(addr) + if err != nil { + return nil, err + } + // add domain that it expects clients to use + domains = append(domains, domain) + x509Cert, err := GenerateSelfSignedServerCert(trustKey, domains, ips) + if err != nil { + return nil, fmt.Errorf("certificate generation error: %s", err) + } + tlsConfig.Certificates = []tls.Certificate{{ + Certificate: [][]byte{x509Cert.Raw}, + PrivateKey: trustKey.CryptoPrivateKey(), + Leaf: x509Cert, + }} + + return tlsConfig, nil +} + +// NewCertAuthTLSConfig creates a tls.Config for the server to use for +// certificate authentication +func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) { + tlsConfig := newTLSConfig() + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + + // Verify client certificates against a CA? + if caPath != "" { + certPool := x509.NewCertPool() + file, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("Couldn't read CA certificate: %s", err) + } + certPool.AppendCertsFromPEM(file) + + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = certPool + } + + return tlsConfig, nil +} + +func newTLSConfig() *tls.Config { + return &tls.Config{ + NextProtos: []string{"http/1.1"}, + // Avoid fallback on insecure SSL protocols + MinVersion: tls.VersionTLS10, + } +} + +// parseAddr parses an address into an array of IPs and domains +func parseAddr(addr string) ([]net.IP, []string, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, nil, err + } + var domains []string + var ips []net.IP + ip := net.ParseIP(host) + if ip != nil { + ips = []net.IP{ip} + } else { + domains = []string{host} + } + return ips, domains, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/key_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/key_test.go new file mode 100644 index 0000000..f6c59cc --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/key_test.go @@ -0,0 +1,80 @@ +package libtrust + +import ( + "testing" +) + +type generateFunc func() (PrivateKey, error) + +func runGenerateBench(b *testing.B, f generateFunc, name string) { + for i := 0; i < b.N; i++ { + _, err := f() + if err != nil { + b.Fatalf("Error generating %s: %s", name, err) + } + } +} + +func runFingerprintBench(b *testing.B, f generateFunc, name string) { + b.StopTimer() + // Don't count this relatively slow generation call. + key, err := f() + if err != nil { + b.Fatalf("Error generating %s: %s", name, err) + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + if key.KeyID() == "" { + b.Fatalf("Error generating key ID for %s", name) + } + } +} + +func BenchmarkECP256Generate(b *testing.B) { + runGenerateBench(b, GenerateECP256PrivateKey, "P256") +} + +func BenchmarkECP384Generate(b *testing.B) { + runGenerateBench(b, GenerateECP384PrivateKey, "P384") +} + +func BenchmarkECP521Generate(b *testing.B) { + runGenerateBench(b, GenerateECP521PrivateKey, "P521") +} + +func BenchmarkRSA2048Generate(b *testing.B) { + runGenerateBench(b, GenerateRSA2048PrivateKey, "RSA2048") +} + +func BenchmarkRSA3072Generate(b *testing.B) { + runGenerateBench(b, GenerateRSA3072PrivateKey, "RSA3072") +} + +func BenchmarkRSA4096Generate(b *testing.B) { + runGenerateBench(b, GenerateRSA4096PrivateKey, "RSA4096") +} + +func BenchmarkECP256Fingerprint(b *testing.B) { + runFingerprintBench(b, GenerateECP256PrivateKey, "P256") +} + +func BenchmarkECP384Fingerprint(b *testing.B) { + runFingerprintBench(b, GenerateECP384PrivateKey, "P384") +} + +func BenchmarkECP521Fingerprint(b *testing.B) { + runFingerprintBench(b, GenerateECP521PrivateKey, "P521") +} + +func BenchmarkRSA2048Fingerprint(b *testing.B) { + runFingerprintBench(b, GenerateRSA2048PrivateKey, "RSA2048") +} + +func BenchmarkRSA3072Fingerprint(b *testing.B) { + runFingerprintBench(b, GenerateRSA3072PrivateKey, "RSA3072") +} + +func BenchmarkRSA4096Fingerprint(b *testing.B) { + runFingerprintBench(b, GenerateRSA4096PrivateKey, "RSA4096") +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/rsa_key.go b/Godeps/_workspace/src/github.com/docker/libtrust/rsa_key.go new file mode 100644 index 0000000..dac4cac --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/rsa_key.go @@ -0,0 +1,427 @@ +package libtrust + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" +) + +/* + * RSA DSA PUBLIC KEY + */ + +// rsaPublicKey implements a JWK Public Key using RSA digital signature algorithms. +type rsaPublicKey struct { + *rsa.PublicKey + extended map[string]interface{} +} + +func fromRSAPublicKey(cryptoPublicKey *rsa.PublicKey) *rsaPublicKey { + return &rsaPublicKey{cryptoPublicKey, map[string]interface{}{}} +} + +// KeyType returns the JWK key type for RSA keys, i.e., "RSA". +func (k *rsaPublicKey) KeyType() string { + return "RSA" +} + +// KeyID returns a distinct identifier which is unique to this Public Key. +func (k *rsaPublicKey) KeyID() string { + return keyIDFromCryptoKey(k) +} + +func (k *rsaPublicKey) String() string { + return fmt.Sprintf("RSA Public Key <%s>", k.KeyID()) +} + +// Verify verifyies the signature of the data in the io.Reader using this Public Key. +// The alg parameter should be the name of the JWA digital signature algorithm +// which was used to produce the signature and should be supported by this +// public key. Returns a nil error if the signature is valid. +func (k *rsaPublicKey) Verify(data io.Reader, alg string, signature []byte) error { + // Verify the signature of the given date, return non-nil error if valid. + sigAlg, err := rsaSignatureAlgorithmByName(alg) + if err != nil { + return fmt.Errorf("unable to verify Signature: %s", err) + } + + hasher := sigAlg.HashID().New() + _, err = io.Copy(hasher, data) + if err != nil { + return fmt.Errorf("error reading data to sign: %s", err) + } + hash := hasher.Sum(nil) + + err = rsa.VerifyPKCS1v15(k.PublicKey, sigAlg.HashID(), hash, signature) + if err != nil { + return fmt.Errorf("invalid %s signature: %s", sigAlg.HeaderParam(), err) + } + + return nil +} + +// CryptoPublicKey returns the internal object which can be used as a +// crypto.PublicKey for use with other standard library operations. The type +// is either *rsa.PublicKey or *ecdsa.PublicKey +func (k *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return k.PublicKey +} + +func (k *rsaPublicKey) toMap() map[string]interface{} { + jwk := make(map[string]interface{}) + for k, v := range k.extended { + jwk[k] = v + } + jwk["kty"] = k.KeyType() + jwk["kid"] = k.KeyID() + jwk["n"] = joseBase64UrlEncode(k.N.Bytes()) + jwk["e"] = joseBase64UrlEncode(serializeRSAPublicExponentParam(k.E)) + + return jwk +} + +// MarshalJSON serializes this Public Key using the JWK JSON serialization format for +// RSA keys. +func (k *rsaPublicKey) MarshalJSON() (data []byte, err error) { + return json.Marshal(k.toMap()) +} + +// PEMBlock serializes this Public Key to DER-encoded PKIX format. +func (k *rsaPublicKey) PEMBlock() (*pem.Block, error) { + derBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey) + if err != nil { + return nil, fmt.Errorf("unable to serialize RSA PublicKey to DER-encoded PKIX format: %s", err) + } + k.extended["kid"] = k.KeyID() // For display purposes. + return createPemBlock("PUBLIC KEY", derBytes, k.extended) +} + +func (k *rsaPublicKey) AddExtendedField(field string, value interface{}) { + k.extended[field] = value +} + +func (k *rsaPublicKey) GetExtendedField(field string) interface{} { + v, ok := k.extended[field] + if !ok { + return nil + } + return v +} + +func rsaPublicKeyFromMap(jwk map[string]interface{}) (*rsaPublicKey, error) { + // JWK key type (kty) has already been determined to be "RSA". + // Need to extract 'n', 'e', and 'kid' and check for + // consistency. + + // Get the modulus parameter N. + nB64Url, err := stringFromMap(jwk, "n") + if err != nil { + return nil, fmt.Errorf("JWK RSA Public Key modulus: %s", err) + } + + n, err := parseRSAModulusParam(nB64Url) + if err != nil { + return nil, fmt.Errorf("JWK RSA Public Key modulus: %s", err) + } + + // Get the public exponent E. + eB64Url, err := stringFromMap(jwk, "e") + if err != nil { + return nil, fmt.Errorf("JWK RSA Public Key exponent: %s", err) + } + + e, err := parseRSAPublicExponentParam(eB64Url) + if err != nil { + return nil, fmt.Errorf("JWK RSA Public Key exponent: %s", err) + } + + key := &rsaPublicKey{ + PublicKey: &rsa.PublicKey{N: n, E: e}, + } + + // Key ID is optional, but if it exists, it should match the key. + _, ok := jwk["kid"] + if ok { + kid, err := stringFromMap(jwk, "kid") + if err != nil { + return nil, fmt.Errorf("JWK RSA Public Key ID: %s", err) + } + if kid != key.KeyID() { + return nil, fmt.Errorf("JWK RSA Public Key ID does not match: %s", kid) + } + } + + if _, ok := jwk["d"]; ok { + return nil, fmt.Errorf("JWK RSA Public Key cannot contain private exponent") + } + + key.extended = jwk + + return key, nil +} + +/* + * RSA DSA PRIVATE KEY + */ + +// rsaPrivateKey implements a JWK Private Key using RSA digital signature algorithms. +type rsaPrivateKey struct { + rsaPublicKey + *rsa.PrivateKey +} + +func fromRSAPrivateKey(cryptoPrivateKey *rsa.PrivateKey) *rsaPrivateKey { + return &rsaPrivateKey{ + *fromRSAPublicKey(&cryptoPrivateKey.PublicKey), + cryptoPrivateKey, + } +} + +// PublicKey returns the Public Key data associated with this Private Key. +func (k *rsaPrivateKey) PublicKey() PublicKey { + return &k.rsaPublicKey +} + +func (k *rsaPrivateKey) String() string { + return fmt.Sprintf("RSA Private Key <%s>", k.KeyID()) +} + +// Sign signs the data read from the io.Reader using a signature algorithm supported +// by the RSA private key. If the specified hashing algorithm is supported by +// this key, that hash function is used to generate the signature otherwise the +// the default hashing algorithm for this key is used. Returns the signature +// and the name of the JWK signature algorithm used, e.g., "RS256", "RS384", +// "RS512". +func (k *rsaPrivateKey) Sign(data io.Reader, hashID crypto.Hash) (signature []byte, alg string, err error) { + // Generate a signature of the data using the internal alg. + sigAlg := rsaPKCS1v15SignatureAlgorithmForHashID(hashID) + hasher := sigAlg.HashID().New() + + _, err = io.Copy(hasher, data) + if err != nil { + return nil, "", fmt.Errorf("error reading data to sign: %s", err) + } + hash := hasher.Sum(nil) + + signature, err = rsa.SignPKCS1v15(rand.Reader, k.PrivateKey, sigAlg.HashID(), hash) + if err != nil { + return nil, "", fmt.Errorf("error producing signature: %s", err) + } + + alg = sigAlg.HeaderParam() + + return +} + +// CryptoPrivateKey returns the internal object which can be used as a +// crypto.PublicKey for use with other standard library operations. The type +// is either *rsa.PublicKey or *ecdsa.PublicKey +func (k *rsaPrivateKey) CryptoPrivateKey() crypto.PrivateKey { + return k.PrivateKey +} + +func (k *rsaPrivateKey) toMap() map[string]interface{} { + k.Precompute() // Make sure the precomputed values are stored. + jwk := k.rsaPublicKey.toMap() + + jwk["d"] = joseBase64UrlEncode(k.D.Bytes()) + jwk["p"] = joseBase64UrlEncode(k.Primes[0].Bytes()) + jwk["q"] = joseBase64UrlEncode(k.Primes[1].Bytes()) + jwk["dp"] = joseBase64UrlEncode(k.Precomputed.Dp.Bytes()) + jwk["dq"] = joseBase64UrlEncode(k.Precomputed.Dq.Bytes()) + jwk["qi"] = joseBase64UrlEncode(k.Precomputed.Qinv.Bytes()) + + otherPrimes := k.Primes[2:] + + if len(otherPrimes) > 0 { + otherPrimesInfo := make([]interface{}, len(otherPrimes)) + for i, r := range otherPrimes { + otherPrimeInfo := make(map[string]string, 3) + otherPrimeInfo["r"] = joseBase64UrlEncode(r.Bytes()) + crtVal := k.Precomputed.CRTValues[i] + otherPrimeInfo["d"] = joseBase64UrlEncode(crtVal.Exp.Bytes()) + otherPrimeInfo["t"] = joseBase64UrlEncode(crtVal.Coeff.Bytes()) + otherPrimesInfo[i] = otherPrimeInfo + } + jwk["oth"] = otherPrimesInfo + } + + return jwk +} + +// MarshalJSON serializes this Private Key using the JWK JSON serialization format for +// RSA keys. +func (k *rsaPrivateKey) MarshalJSON() (data []byte, err error) { + return json.Marshal(k.toMap()) +} + +// PEMBlock serializes this Private Key to DER-encoded PKIX format. +func (k *rsaPrivateKey) PEMBlock() (*pem.Block, error) { + derBytes := x509.MarshalPKCS1PrivateKey(k.PrivateKey) + k.extended["keyID"] = k.KeyID() // For display purposes. + return createPemBlock("RSA PRIVATE KEY", derBytes, k.extended) +} + +func rsaPrivateKeyFromMap(jwk map[string]interface{}) (*rsaPrivateKey, error) { + // The JWA spec for RSA Private Keys (draft rfc section 5.3.2) states that + // only the private key exponent 'd' is REQUIRED, the others are just for + // signature/decryption optimizations and SHOULD be included when the JWK + // is produced. We MAY choose to accept a JWK which only includes 'd', but + // we're going to go ahead and not choose to accept it without the extra + // fields. Only the 'oth' field will be optional (for multi-prime keys). + privateExponent, err := parseRSAPrivateKeyParamFromMap(jwk, "d") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key exponent: %s", err) + } + firstPrimeFactor, err := parseRSAPrivateKeyParamFromMap(jwk, "p") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err) + } + secondPrimeFactor, err := parseRSAPrivateKeyParamFromMap(jwk, "q") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err) + } + firstFactorCRT, err := parseRSAPrivateKeyParamFromMap(jwk, "dp") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err) + } + secondFactorCRT, err := parseRSAPrivateKeyParamFromMap(jwk, "dq") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err) + } + crtCoeff, err := parseRSAPrivateKeyParamFromMap(jwk, "qi") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key CRT coefficient: %s", err) + } + + var oth interface{} + if _, ok := jwk["oth"]; ok { + oth = jwk["oth"] + delete(jwk, "oth") + } + + // JWK key type (kty) has already been determined to be "RSA". + // Need to extract the public key information, then extract the private + // key values. + publicKey, err := rsaPublicKeyFromMap(jwk) + if err != nil { + return nil, err + } + + privateKey := &rsa.PrivateKey{ + PublicKey: *publicKey.PublicKey, + D: privateExponent, + Primes: []*big.Int{firstPrimeFactor, secondPrimeFactor}, + Precomputed: rsa.PrecomputedValues{ + Dp: firstFactorCRT, + Dq: secondFactorCRT, + Qinv: crtCoeff, + }, + } + + if oth != nil { + // Should be an array of more JSON objects. + otherPrimesInfo, ok := oth.([]interface{}) + if !ok { + return nil, errors.New("JWK RSA Private Key: Invalid other primes info: must be an array") + } + numOtherPrimeFactors := len(otherPrimesInfo) + if numOtherPrimeFactors == 0 { + return nil, errors.New("JWK RSA Privake Key: Invalid other primes info: must be absent or non-empty") + } + otherPrimeFactors := make([]*big.Int, numOtherPrimeFactors) + productOfPrimes := new(big.Int).Mul(firstPrimeFactor, secondPrimeFactor) + crtValues := make([]rsa.CRTValue, numOtherPrimeFactors) + + for i, val := range otherPrimesInfo { + otherPrimeinfo, ok := val.(map[string]interface{}) + if !ok { + return nil, errors.New("JWK RSA Private Key: Invalid other prime info: must be a JSON object") + } + + otherPrimeFactor, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "r") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key prime factor: %s", err) + } + otherFactorCRT, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "d") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key CRT exponent: %s", err) + } + otherCrtCoeff, err := parseRSAPrivateKeyParamFromMap(otherPrimeinfo, "t") + if err != nil { + return nil, fmt.Errorf("JWK RSA Private Key CRT coefficient: %s", err) + } + + crtValue := crtValues[i] + crtValue.Exp = otherFactorCRT + crtValue.Coeff = otherCrtCoeff + crtValue.R = productOfPrimes + otherPrimeFactors[i] = otherPrimeFactor + productOfPrimes = new(big.Int).Mul(productOfPrimes, otherPrimeFactor) + } + + privateKey.Primes = append(privateKey.Primes, otherPrimeFactors...) + privateKey.Precomputed.CRTValues = crtValues + } + + key := &rsaPrivateKey{ + rsaPublicKey: *publicKey, + PrivateKey: privateKey, + } + + return key, nil +} + +/* + * Key Generation Functions. + */ + +func generateRSAPrivateKey(bits int) (k *rsaPrivateKey, err error) { + k = new(rsaPrivateKey) + k.PrivateKey, err = rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + + k.rsaPublicKey.PublicKey = &k.PrivateKey.PublicKey + k.extended = make(map[string]interface{}) + + return +} + +// GenerateRSA2048PrivateKey generates a key pair using 2048-bit RSA. +func GenerateRSA2048PrivateKey() (PrivateKey, error) { + k, err := generateRSAPrivateKey(2048) + if err != nil { + return nil, fmt.Errorf("error generating RSA 2048-bit key: %s", err) + } + + return k, nil +} + +// GenerateRSA3072PrivateKey generates a key pair using 3072-bit RSA. +func GenerateRSA3072PrivateKey() (PrivateKey, error) { + k, err := generateRSAPrivateKey(3072) + if err != nil { + return nil, fmt.Errorf("error generating RSA 3072-bit key: %s", err) + } + + return k, nil +} + +// GenerateRSA4096PrivateKey generates a key pair using 4096-bit RSA. +func GenerateRSA4096PrivateKey() (PrivateKey, error) { + k, err := generateRSAPrivateKey(4096) + if err != nil { + return nil, fmt.Errorf("error generating RSA 4096-bit key: %s", err) + } + + return k, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/rsa_key_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/rsa_key_test.go new file mode 100644 index 0000000..5ec7707 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/rsa_key_test.go @@ -0,0 +1,157 @@ +package libtrust + +import ( + "bytes" + "encoding/json" + "log" + "testing" +) + +var rsaKeys []PrivateKey + +func init() { + var err error + rsaKeys, err = generateRSATestKeys() + if err != nil { + log.Fatal(err) + } +} + +func generateRSATestKeys() (keys []PrivateKey, err error) { + log.Println("Generating RSA 2048-bit Test Key") + rsa2048Key, err := GenerateRSA2048PrivateKey() + if err != nil { + return + } + + log.Println("Generating RSA 3072-bit Test Key") + rsa3072Key, err := GenerateRSA3072PrivateKey() + if err != nil { + return + } + + log.Println("Generating RSA 4096-bit Test Key") + rsa4096Key, err := GenerateRSA4096PrivateKey() + if err != nil { + return + } + + log.Println("Done generating RSA Test Keys!") + keys = []PrivateKey{rsa2048Key, rsa3072Key, rsa4096Key} + + return +} + +func TestRSAKeys(t *testing.T) { + for _, rsaKey := range rsaKeys { + if rsaKey.KeyType() != "RSA" { + t.Fatalf("key type must be %q, instead got %q", "RSA", rsaKey.KeyType()) + } + } +} + +func TestRSASignVerify(t *testing.T) { + message := "Hello, World!" + data := bytes.NewReader([]byte(message)) + + sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512} + + for i, rsaKey := range rsaKeys { + sigAlg := sigAlgs[i] + + t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, rsaKey.KeyID()) + + data.Seek(0, 0) // Reset the byte reader + + // Sign + sig, alg, err := rsaKey.Sign(data, sigAlg.HashID()) + if err != nil { + t.Fatal(err) + } + + data.Seek(0, 0) // Reset the byte reader + + // Verify + err = rsaKey.Verify(data, alg, sig) + if err != nil { + t.Fatal(err) + } + } +} + +func TestMarshalUnmarshalRSAKeys(t *testing.T) { + data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test.")) + sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512} + + for i, rsaKey := range rsaKeys { + sigAlg := sigAlgs[i] + privateJWKJSON, err := json.MarshalIndent(rsaKey, "", " ") + if err != nil { + t.Fatal(err) + } + + publicJWKJSON, err := json.MarshalIndent(rsaKey.PublicKey(), "", " ") + if err != nil { + t.Fatal(err) + } + + t.Logf("JWK Private Key: %s", string(privateJWKJSON)) + t.Logf("JWK Public Key: %s", string(publicJWKJSON)) + + privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON) + if err != nil { + t.Fatal(err) + } + + pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON) + if err != nil { + t.Fatal(err) + } + + // Ensure we can sign/verify a message with the unmarshalled keys. + data.Seek(0, 0) // Reset the byte reader + signature, alg, err := privKey2.Sign(data, sigAlg.HashID()) + if err != nil { + t.Fatal(err) + } + + data.Seek(0, 0) // Reset the byte reader + err = pubKey2.Verify(data, alg, signature) + if err != nil { + t.Fatal(err) + } + + // It's a good idea to validate the Private Key to make sure our + // (un)marshal process didn't corrupt the extra parameters. + k := privKey2.(*rsaPrivateKey) + err = k.PrivateKey.Validate() + if err != nil { + t.Fatal(err) + } + } +} + +func TestFromCryptoRSAKeys(t *testing.T) { + for _, rsaKey := range rsaKeys { + cryptoPrivateKey := rsaKey.CryptoPrivateKey() + cryptoPublicKey := rsaKey.CryptoPublicKey() + + pubKey, err := FromCryptoPublicKey(cryptoPublicKey) + if err != nil { + t.Fatal(err) + } + + if pubKey.KeyID() != rsaKey.KeyID() { + t.Fatal("public key key ID mismatch") + } + + privKey, err := FromCryptoPrivateKey(cryptoPrivateKey) + if err != nil { + t.Fatal(err) + } + + if privKey.KeyID() != rsaKey.KeyID() { + t.Fatal("public key key ID mismatch") + } + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/testutil/certificates.go b/Godeps/_workspace/src/github.com/docker/libtrust/testutil/certificates.go new file mode 100644 index 0000000..89debf6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/testutil/certificates.go @@ -0,0 +1,94 @@ +package testutil + +import ( + "crypto" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +// GenerateTrustCA generates a new certificate authority for testing. +func GenerateTrustCA(pub crypto.PublicKey, priv crypto.PrivateKey) (*x509.Certificate, error) { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(0), + Subject: pkix.Name{ + CommonName: "CA Root", + }, + NotBefore: time.Now().Add(-time.Second), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, cert, cert, pub, priv) + if err != nil { + return nil, err + } + + cert, err = x509.ParseCertificate(certDER) + if err != nil { + return nil, err + } + + return cert, nil +} + +// GenerateIntermediate generates an intermediate certificate for testing using +// the parent certificate (likely a CA) and the provided keys. +func GenerateIntermediate(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(0), + Subject: pkix.Name{ + CommonName: "Intermediate", + }, + NotBefore: time.Now().Add(-time.Second), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey) + if err != nil { + return nil, err + } + + cert, err = x509.ParseCertificate(certDER) + if err != nil { + return nil, err + } + + return cert, nil +} + +// GenerateTrustCert generates a new trust certificate for testing. Unlike the +// intermediate certificates, this certificate should be used for signature +// only, not creating certificates. +func GenerateTrustCert(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(0), + Subject: pkix.Name{ + CommonName: "Trust Cert", + }, + NotBefore: time.Now().Add(-time.Second), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey) + if err != nil { + return nil, err + } + + cert, err = x509.ParseCertificate(certDER) + if err != nil { + return nil, err + } + + return cert, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/README.md b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/README.md new file mode 100644 index 0000000..24124db --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/README.md @@ -0,0 +1,50 @@ +## Libtrust TLS Config Demo + +This program generates key pairs and trust files for a TLS client and server. + +To generate the keys, run: + +``` +$ go run genkeys.go +``` + +The generated files are: + +``` +$ ls -l client_data/ server_data/ +client_data/: +total 24 +-rw------- 1 jlhawn staff 281 Aug 8 16:21 private_key.json +-rw-r--r-- 1 jlhawn staff 225 Aug 8 16:21 public_key.json +-rw-r--r-- 1 jlhawn staff 275 Aug 8 16:21 trusted_hosts.json + +server_data/: +total 24 +-rw-r--r-- 1 jlhawn staff 348 Aug 8 16:21 trusted_clients.json +-rw------- 1 jlhawn staff 281 Aug 8 16:21 private_key.json +-rw-r--r-- 1 jlhawn staff 225 Aug 8 16:21 public_key.json +``` + +The private key and public key for the client and server are stored in `private_key.json` and `public_key.json`, respectively, and in their respective directories. They are represented as JSON Web Keys: JSON objects which represent either an ECDSA or RSA private key. The host keys trusted by the client are stored in `trusted_hosts.json` and contain a mapping of an internet address, `:`, to a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted server. The client keys trusted by the server are stored in `trusted_clients.json` and contain an array of JSON objects which contain a comment field which can be used describe the key and a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted client. + +To start the server, run: + +``` +$ go run server.go +``` + +This starts an HTTPS server which listens on `localhost:8888`. The server configures itself with a certificate which is valid for both `localhost` and `127.0.0.1` and uses the key from `server_data/private_key.json`. It accepts connections from clients which present a certificate for a key that it is configured to trust from the `trusted_clients.json` file and returns a simple 'hello' message. + +To make a request using the client, run: + +``` +$ go run client.go +``` + +This command creates an HTTPS client which makes a GET request to `https://localhost:8888`. The client configures itself with a certificate using the key from `client_data/private_key.json`. It only connects to a server which presents a certificate signed by the key specified for the `localhost:8888` address from `client_data/trusted_hosts.json` and made to be used for the `localhost` hostname. If the connection succeeds, it prints the response from the server. + +The file `gencert.go` can be used to generate PEM encoded version of the client key and certificate. If you save them to `key.pem` and `cert.pem` respectively, you can use them with `curl` to test out the server (if it is still running). + +``` +curl --cert cert.pem --key key.pem -k https://localhost:8888 +``` diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/client.go b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/client.go new file mode 100644 index 0000000..0a699a0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/client.go @@ -0,0 +1,89 @@ +package main + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + + "github.com/docker/libtrust" +) + +var ( + serverAddress = "localhost:8888" + privateKeyFilename = "client_data/private_key.pem" + trustedHostsFilename = "client_data/trusted_hosts.pem" +) + +func main() { + // Load Client Key. + clientKey, err := libtrust.LoadKeyFile(privateKeyFilename) + if err != nil { + log.Fatal(err) + } + + // Generate Client Certificate. + selfSignedClientCert, err := libtrust.GenerateSelfSignedClientCert(clientKey) + if err != nil { + log.Fatal(err) + } + + // Load trusted host keys. + hostKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename) + if err != nil { + log.Fatal(err) + } + + // Ensure the host we want to connect to is trusted! + host, _, err := net.SplitHostPort(serverAddress) + if err != nil { + log.Fatal(err) + } + serverKeys, err := libtrust.FilterByHosts(hostKeys, host, false) + if err != nil { + log.Fatalf("%q is not a known and trusted host", host) + } + + // Generate a CA pool with the trusted host's key. + caPool, err := libtrust.GenerateCACertPool(clientKey, serverKeys) + if err != nil { + log.Fatal(err) + } + + // Create HTTP Client. + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{ + tls.Certificate{ + Certificate: [][]byte{selfSignedClientCert.Raw}, + PrivateKey: clientKey.CryptoPrivateKey(), + Leaf: selfSignedClientCert, + }, + }, + RootCAs: caPool, + }, + }, + } + + var makeRequest = func(url string) { + resp, err := client.Get(url) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + log.Println(resp.Status) + log.Println(string(body)) + } + + // Make the request to the trusted server! + makeRequest(fmt.Sprintf("https://%s", serverAddress)) +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/gencert.go b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/gencert.go new file mode 100644 index 0000000..c65f3b6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/gencert.go @@ -0,0 +1,62 @@ +package main + +import ( + "encoding/pem" + "fmt" + "log" + "net" + + "github.com/docker/libtrust" +) + +var ( + serverAddress = "localhost:8888" + clientPrivateKeyFilename = "client_data/private_key.pem" + trustedHostsFilename = "client_data/trusted_hosts.pem" +) + +func main() { + key, err := libtrust.LoadKeyFile(clientPrivateKeyFilename) + if err != nil { + log.Fatal(err) + } + + keyPEMBlock, err := key.PEMBlock() + if err != nil { + log.Fatal(err) + } + + encodedPrivKey := pem.EncodeToMemory(keyPEMBlock) + fmt.Printf("Client Key:\n\n%s\n", string(encodedPrivKey)) + + cert, err := libtrust.GenerateSelfSignedClientCert(key) + if err != nil { + log.Fatal(err) + } + + encodedCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + fmt.Printf("Client Cert:\n\n%s\n", string(encodedCert)) + + trustedServerKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename) + if err != nil { + log.Fatal(err) + } + + hostname, _, err := net.SplitHostPort(serverAddress) + if err != nil { + log.Fatal(err) + } + + trustedServerKeys, err = libtrust.FilterByHosts(trustedServerKeys, hostname, false) + if err != nil { + log.Fatal(err) + } + + caCert, err := libtrust.GenerateCACert(key, trustedServerKeys[0]) + if err != nil { + log.Fatal(err) + } + + encodedCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw}) + fmt.Printf("CA Cert:\n\n%s\n", string(encodedCert)) +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/genkeys.go b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/genkeys.go new file mode 100644 index 0000000..9dc8842 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/genkeys.go @@ -0,0 +1,61 @@ +package main + +import ( + "log" + + "github.com/docker/libtrust" +) + +func main() { + // Generate client key. + clientKey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + log.Fatal(err) + } + + // Add a comment for the client key. + clientKey.AddExtendedField("comment", "TLS Demo Client") + + // Save the client key, public and private versions. + err = libtrust.SaveKey("client_data/private_key.pem", clientKey) + if err != nil { + log.Fatal(err) + } + + err = libtrust.SavePublicKey("client_data/public_key.pem", clientKey.PublicKey()) + if err != nil { + log.Fatal(err) + } + + // Generate server key. + serverKey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + log.Fatal(err) + } + + // Set the list of addresses to use for the server. + serverKey.AddExtendedField("hosts", []string{"localhost", "docker.example.com"}) + + // Save the server key, public and private versions. + err = libtrust.SaveKey("server_data/private_key.pem", serverKey) + if err != nil { + log.Fatal(err) + } + + err = libtrust.SavePublicKey("server_data/public_key.pem", serverKey.PublicKey()) + if err != nil { + log.Fatal(err) + } + + // Generate Authorized Keys file for server. + err = libtrust.AddKeySetFile("server_data/trusted_clients.pem", clientKey.PublicKey()) + if err != nil { + log.Fatal(err) + } + + // Generate Known Host Keys file for client. + err = libtrust.AddKeySetFile("client_data/trusted_hosts.pem", serverKey.PublicKey()) + if err != nil { + log.Fatal(err) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/server.go b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/server.go new file mode 100644 index 0000000..d3cb2ea --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/tlsdemo/server.go @@ -0,0 +1,80 @@ +package main + +import ( + "crypto/tls" + "fmt" + "html" + "log" + "net" + "net/http" + + "github.com/docker/libtrust" +) + +var ( + serverAddress = "localhost:8888" + privateKeyFilename = "server_data/private_key.pem" + authorizedClientsFilename = "server_data/trusted_clients.pem" +) + +func requestHandler(w http.ResponseWriter, r *http.Request) { + clientCert := r.TLS.PeerCertificates[0] + keyID := clientCert.Subject.CommonName + log.Printf("Request from keyID: %s\n", keyID) + fmt.Fprintf(w, "Hello, client! I'm a server! And you are %T: %s.\n", clientCert.PublicKey, html.EscapeString(keyID)) +} + +func main() { + // Load server key. + serverKey, err := libtrust.LoadKeyFile(privateKeyFilename) + if err != nil { + log.Fatal(err) + } + + // Generate server certificate. + selfSignedServerCert, err := libtrust.GenerateSelfSignedServerCert( + serverKey, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")}, + ) + if err != nil { + log.Fatal(err) + } + + // Load authorized client keys. + authorizedClients, err := libtrust.LoadKeySetFile(authorizedClientsFilename) + if err != nil { + log.Fatal(err) + } + + // Create CA pool using trusted client keys. + caPool, err := libtrust.GenerateCACertPool(serverKey, authorizedClients) + if err != nil { + log.Fatal(err) + } + + // Create TLS config, requiring client certificates. + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{ + tls.Certificate{ + Certificate: [][]byte{selfSignedServerCert.Raw}, + PrivateKey: serverKey.CryptoPrivateKey(), + Leaf: selfSignedServerCert, + }, + }, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caPool, + } + + // Create HTTP server with simple request handler. + server := &http.Server{ + Addr: serverAddress, + Handler: http.HandlerFunc(requestHandler), + } + + // Listen and server HTTPS using the libtrust TLS config. + listener, err := net.Listen("tcp", server.Addr) + if err != nil { + log.Fatal(err) + } + tlsListener := tls.NewListener(listener, tlsConfig) + server.Serve(tlsListener) +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/graph.go b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/graph.go new file mode 100644 index 0000000..72b0fc3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/graph.go @@ -0,0 +1,50 @@ +package trustgraph + +import "github.com/docker/libtrust" + +// TrustGraph represents a graph of authorization mapping +// public keys to nodes and grants between nodes. +type TrustGraph interface { + // Verifies that the given public key is allowed to perform + // the given action on the given node according to the trust + // graph. + Verify(libtrust.PublicKey, string, uint16) (bool, error) + + // GetGrants returns an array of all grant chains which are used to + // allow the requested permission. + GetGrants(libtrust.PublicKey, string, uint16) ([][]*Grant, error) +} + +// Grant represents a transfer of permission from one part of the +// trust graph to another. This is the only way to delegate +// permission between two different sub trees in the graph. +type Grant struct { + // Subject is the namespace being granted + Subject string + + // Permissions is a bit map of permissions + Permission uint16 + + // Grantee represents the node being granted + // a permission scope. The grantee can be + // either a namespace item or a key id where namespace + // items will always start with a '/'. + Grantee string + + // statement represents the statement used to create + // this object. + statement *Statement +} + +// Permissions +// Read node 0x01 (can read node, no sub nodes) +// Write node 0x02 (can write to node object, cannot create subnodes) +// Read subtree 0x04 (delegates read to each sub node) +// Write subtree 0x08 (delegates write to each sub node, included create on the subject) +// +// Permission shortcuts +// ReadItem = 0x01 +// WriteItem = 0x03 +// ReadAccess = 0x07 +// WriteAccess = 0x0F +// Delegate = 0x0F diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/memory_graph.go b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/memory_graph.go new file mode 100644 index 0000000..247bfa7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/memory_graph.go @@ -0,0 +1,133 @@ +package trustgraph + +import ( + "strings" + + "github.com/docker/libtrust" +) + +type grantNode struct { + grants []*Grant + children map[string]*grantNode +} + +type memoryGraph struct { + roots map[string]*grantNode +} + +func newGrantNode() *grantNode { + return &grantNode{ + grants: []*Grant{}, + children: map[string]*grantNode{}, + } +} + +// NewMemoryGraph returns a new in memory trust graph created from +// a static list of grants. This graph is immutable after creation +// and any alterations should create a new instance. +func NewMemoryGraph(grants []*Grant) TrustGraph { + roots := map[string]*grantNode{} + for _, grant := range grants { + parts := strings.Split(grant.Grantee, "/") + nodes := roots + var node *grantNode + var nodeOk bool + for _, part := range parts { + node, nodeOk = nodes[part] + if !nodeOk { + node = newGrantNode() + nodes[part] = node + } + if part != "" { + node.grants = append(node.grants, grant) + } + nodes = node.children + } + } + return &memoryGraph{roots} +} + +func (g *memoryGraph) getGrants(name string) []*Grant { + nameParts := strings.Split(name, "/") + nodes := g.roots + var node *grantNode + var nodeOk bool + for _, part := range nameParts { + node, nodeOk = nodes[part] + if !nodeOk { + return nil + } + nodes = node.children + } + return node.grants +} + +func isSubName(name, sub string) bool { + if strings.HasPrefix(name, sub) { + if len(name) == len(sub) || name[len(sub)] == '/' { + return true + } + } + return false +} + +type walkFunc func(*Grant, []*Grant) bool + +func foundWalkFunc(*Grant, []*Grant) bool { + return true +} + +func (g *memoryGraph) walkGrants(start, target string, permission uint16, f walkFunc, chain []*Grant, visited map[*Grant]bool, collect bool) bool { + if visited == nil { + visited = map[*Grant]bool{} + } + grants := g.getGrants(start) + subGrants := make([]*Grant, 0, len(grants)) + for _, grant := range grants { + if visited[grant] { + continue + } + visited[grant] = true + if grant.Permission&permission == permission { + if isSubName(target, grant.Subject) { + if f(grant, chain) { + return true + } + } else { + subGrants = append(subGrants, grant) + } + } + } + for _, grant := range subGrants { + var chainCopy []*Grant + if collect { + chainCopy = make([]*Grant, len(chain)+1) + copy(chainCopy, chain) + chainCopy[len(chainCopy)-1] = grant + } else { + chainCopy = nil + } + + if g.walkGrants(grant.Subject, target, permission, f, chainCopy, visited, collect) { + return true + } + } + return false +} + +func (g *memoryGraph) Verify(key libtrust.PublicKey, node string, permission uint16) (bool, error) { + return g.walkGrants(key.KeyID(), node, permission, foundWalkFunc, nil, nil, false), nil +} + +func (g *memoryGraph) GetGrants(key libtrust.PublicKey, node string, permission uint16) ([][]*Grant, error) { + grants := [][]*Grant{} + collect := func(grant *Grant, chain []*Grant) bool { + grantChain := make([]*Grant, len(chain)+1) + copy(grantChain, chain) + grantChain[len(grantChain)-1] = grant + grants = append(grants, grantChain) + return false + } + g.walkGrants(key.KeyID(), node, permission, collect, nil, nil, true) + return grants, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/memory_graph_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/memory_graph_test.go new file mode 100644 index 0000000..49fd0f3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/memory_graph_test.go @@ -0,0 +1,174 @@ +package trustgraph + +import ( + "fmt" + "testing" + + "github.com/docker/libtrust" +) + +func createTestKeysAndGrants(count int) ([]*Grant, []libtrust.PrivateKey) { + grants := make([]*Grant, count) + keys := make([]libtrust.PrivateKey, count) + for i := 0; i < count; i++ { + pk, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + panic(err) + } + grant := &Grant{ + Subject: fmt.Sprintf("/user-%d", i+1), + Permission: 0x0f, + Grantee: pk.KeyID(), + } + keys[i] = pk + grants[i] = grant + } + return grants, keys +} + +func testVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) { + if ok, err := g.Verify(k, target, permission); err != nil { + t.Fatalf("Unexpected error during verification: %s", err) + } else if !ok { + t.Errorf("key failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target) + } +} + +func testNotVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) { + if ok, err := g.Verify(k, target, permission); err != nil { + t.Fatalf("Unexpected error during verification: %s", err) + } else if ok { + t.Errorf("key should have failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target) + } +} + +func TestVerify(t *testing.T) { + grants, keys := createTestKeysAndGrants(4) + extraGrants := make([]*Grant, 3) + extraGrants[0] = &Grant{ + Subject: "/user-3", + Permission: 0x0f, + Grantee: "/user-2", + } + extraGrants[1] = &Grant{ + Subject: "/user-3/sub-project", + Permission: 0x0f, + Grantee: "/user-4", + } + extraGrants[2] = &Grant{ + Subject: "/user-4", + Permission: 0x07, + Grantee: "/user-1", + } + grants = append(grants, extraGrants...) + + g := NewMemoryGraph(grants) + + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1/some-project/sub-value", 0x0f) + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x07) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2/", 0x0f) + testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3/sub-value", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-value", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project/app", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f) + + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f) + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3/sub-value", 0x0f) + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1/", 0x0f) + testNotVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-2", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-4", 0x0f) + testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f) +} + +func TestCircularWalk(t *testing.T) { + grants, keys := createTestKeysAndGrants(3) + user1Grant := &Grant{ + Subject: "/user-2", + Permission: 0x0f, + Grantee: "/user-1", + } + user2Grant := &Grant{ + Subject: "/user-1", + Permission: 0x0f, + Grantee: "/user-2", + } + grants = append(grants, user1Grant, user2Grant) + + g := NewMemoryGraph(grants) + + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1", 0x0f) + testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f) + + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f) +} + +func assertGrantSame(t *testing.T, actual, expected *Grant) { + if actual != expected { + t.Fatalf("Unexpected grant retrieved\n\tExpected: %v\n\tActual: %v", expected, actual) + } +} + +func TestGetGrants(t *testing.T) { + grants, keys := createTestKeysAndGrants(5) + extraGrants := make([]*Grant, 4) + extraGrants[0] = &Grant{ + Subject: "/user-3/friend-project", + Permission: 0x0f, + Grantee: "/user-2/friends", + } + extraGrants[1] = &Grant{ + Subject: "/user-3/sub-project", + Permission: 0x0f, + Grantee: "/user-4", + } + extraGrants[2] = &Grant{ + Subject: "/user-2/friends", + Permission: 0x0f, + Grantee: "/user-5/fun-project", + } + extraGrants[3] = &Grant{ + Subject: "/user-5/fun-project", + Permission: 0x0f, + Grantee: "/user-1", + } + grants = append(grants, extraGrants...) + + g := NewMemoryGraph(grants) + + grantChains, err := g.GetGrants(keys[3], "/user-3/sub-project/specific-app", 0x0f) + if err != nil { + t.Fatalf("Error getting grants: %s", err) + } + if len(grantChains) != 1 { + t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains)) + } + if len(grantChains[0]) != 2 { + t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0])) + } + assertGrantSame(t, grantChains[0][0], grants[3]) + assertGrantSame(t, grantChains[0][1], extraGrants[1]) + + grantChains, err = g.GetGrants(keys[0], "/user-3/friend-project/fun-app", 0x0f) + if err != nil { + t.Fatalf("Error getting grants: %s", err) + } + if len(grantChains) != 1 { + t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains)) + } + if len(grantChains[0]) != 4 { + t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0])) + } + assertGrantSame(t, grantChains[0][0], grants[0]) + assertGrantSame(t, grantChains[0][1], extraGrants[3]) + assertGrantSame(t, grantChains[0][2], extraGrants[2]) + assertGrantSame(t, grantChains[0][3], extraGrants[0]) +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/statement.go b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/statement.go new file mode 100644 index 0000000..7a74b55 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/statement.go @@ -0,0 +1,227 @@ +package trustgraph + +import ( + "crypto/x509" + "encoding/json" + "io" + "io/ioutil" + "sort" + "strings" + "time" + + "github.com/docker/libtrust" +) + +type jsonGrant struct { + Subject string `json:"subject"` + Permission uint16 `json:"permission"` + Grantee string `json:"grantee"` +} + +type jsonRevocation struct { + Subject string `json:"subject"` + Revocation uint16 `json:"revocation"` + Grantee string `json:"grantee"` +} + +type jsonStatement struct { + Revocations []*jsonRevocation `json:"revocations"` + Grants []*jsonGrant `json:"grants"` + Expiration time.Time `json:"expiration"` + IssuedAt time.Time `json:"issuedAt"` +} + +func (g *jsonGrant) Grant(statement *Statement) *Grant { + return &Grant{ + Subject: g.Subject, + Permission: g.Permission, + Grantee: g.Grantee, + statement: statement, + } +} + +// Statement represents a set of grants made from a verifiable +// authority. A statement has an expiration associated with it +// set by the authority. +type Statement struct { + jsonStatement + + signature *libtrust.JSONSignature +} + +// IsExpired returns whether the statement has expired +func (s *Statement) IsExpired() bool { + return s.Expiration.Before(time.Now().Add(-10 * time.Second)) +} + +// Bytes returns an indented json representation of the statement +// in a byte array. This value can be written to a file or stream +// without alteration. +func (s *Statement) Bytes() ([]byte, error) { + return s.signature.PrettySignature("signatures") +} + +// LoadStatement loads and verifies a statement from an input stream. +func LoadStatement(r io.Reader, authority *x509.CertPool) (*Statement, error) { + b, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + js, err := libtrust.ParsePrettySignature(b, "signatures") + if err != nil { + return nil, err + } + payload, err := js.Payload() + if err != nil { + return nil, err + } + var statement Statement + err = json.Unmarshal(payload, &statement.jsonStatement) + if err != nil { + return nil, err + } + + if authority == nil { + _, err = js.Verify() + if err != nil { + return nil, err + } + } else { + _, err = js.VerifyChains(authority) + if err != nil { + return nil, err + } + } + statement.signature = js + + return &statement, nil +} + +// CreateStatements creates and signs a statement from a stream of grants +// and revocations in a JSON array. +func CreateStatement(grants, revocations io.Reader, expiration time.Duration, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) { + var statement Statement + err := json.NewDecoder(grants).Decode(&statement.jsonStatement.Grants) + if err != nil { + return nil, err + } + err = json.NewDecoder(revocations).Decode(&statement.jsonStatement.Revocations) + if err != nil { + return nil, err + } + statement.jsonStatement.Expiration = time.Now().UTC().Add(expiration) + statement.jsonStatement.IssuedAt = time.Now().UTC() + + b, err := json.MarshalIndent(&statement.jsonStatement, "", " ") + if err != nil { + return nil, err + } + + statement.signature, err = libtrust.NewJSONSignature(b) + if err != nil { + return nil, err + } + err = statement.signature.SignWithChain(key, chain) + if err != nil { + return nil, err + } + + return &statement, nil +} + +type statementList []*Statement + +func (s statementList) Len() int { + return len(s) +} + +func (s statementList) Less(i, j int) bool { + return s[i].IssuedAt.Before(s[j].IssuedAt) +} + +func (s statementList) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +// CollapseStatements returns a single list of the valid statements as well as the +// time when the next grant will expire. +func CollapseStatements(statements []*Statement, useExpired bool) ([]*Grant, time.Time, error) { + sorted := make(statementList, 0, len(statements)) + for _, statement := range statements { + if useExpired || !statement.IsExpired() { + sorted = append(sorted, statement) + } + } + sort.Sort(sorted) + + var minExpired time.Time + var grantCount int + roots := map[string]*grantNode{} + for i, statement := range sorted { + if statement.Expiration.Before(minExpired) || i == 0 { + minExpired = statement.Expiration + } + for _, grant := range statement.Grants { + parts := strings.Split(grant.Grantee, "/") + nodes := roots + g := grant.Grant(statement) + grantCount = grantCount + 1 + + for _, part := range parts { + node, nodeOk := nodes[part] + if !nodeOk { + node = newGrantNode() + nodes[part] = node + } + node.grants = append(node.grants, g) + nodes = node.children + } + } + + for _, revocation := range statement.Revocations { + parts := strings.Split(revocation.Grantee, "/") + nodes := roots + + var node *grantNode + var nodeOk bool + for _, part := range parts { + node, nodeOk = nodes[part] + if !nodeOk { + break + } + nodes = node.children + } + if node != nil { + for _, grant := range node.grants { + if isSubName(grant.Subject, revocation.Subject) { + grant.Permission = grant.Permission &^ revocation.Revocation + } + } + } + } + } + + retGrants := make([]*Grant, 0, grantCount) + for _, rootNodes := range roots { + retGrants = append(retGrants, rootNodes.grants...) + } + + return retGrants, minExpired, nil +} + +// FilterStatements filters the statements to statements including the given grants. +func FilterStatements(grants []*Grant) ([]*Statement, error) { + statements := map[*Statement]bool{} + for _, grant := range grants { + if grant.statement != nil { + statements[grant.statement] = true + } + } + retStatements := make([]*Statement, len(statements)) + var i int + for statement := range statements { + retStatements[i] = statement + i++ + } + return retStatements, nil +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/statement_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/statement_test.go new file mode 100644 index 0000000..e509468 --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/trustgraph/statement_test.go @@ -0,0 +1,417 @@ +package trustgraph + +import ( + "bytes" + "crypto/x509" + "encoding/json" + "testing" + "time" + + "github.com/docker/libtrust" + "github.com/docker/libtrust/testutil" +) + +const testStatementExpiration = time.Hour * 5 + +func generateStatement(grants []*Grant, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) { + var statement Statement + + statement.Grants = make([]*jsonGrant, len(grants)) + for i, grant := range grants { + statement.Grants[i] = &jsonGrant{ + Subject: grant.Subject, + Permission: grant.Permission, + Grantee: grant.Grantee, + } + } + statement.IssuedAt = time.Now() + statement.Expiration = time.Now().Add(testStatementExpiration) + statement.Revocations = make([]*jsonRevocation, 0) + + marshalled, err := json.MarshalIndent(statement.jsonStatement, "", " ") + if err != nil { + return nil, err + } + + sig, err := libtrust.NewJSONSignature(marshalled) + if err != nil { + return nil, err + } + err = sig.SignWithChain(key, chain) + if err != nil { + return nil, err + } + statement.signature = sig + + return &statement, nil +} + +func generateTrustChain(t *testing.T, chainLen int) (libtrust.PrivateKey, *x509.CertPool, []*x509.Certificate) { + caKey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating key: %s", err) + } + ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey()) + if err != nil { + t.Fatalf("Error generating ca: %s", err) + } + + parent := ca + parentKey := caKey + chain := make([]*x509.Certificate, chainLen) + for i := chainLen - 1; i > 0; i-- { + intermediatekey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generate key: %s", err) + } + chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent) + if err != nil { + t.Fatalf("Error generating intermdiate certificate: %s", err) + } + parent = chain[i] + parentKey = intermediatekey + } + trustKey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generate key: %s", err) + } + chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent) + if err != nil { + t.Fatalf("Error generate trust cert: %s", err) + } + + caPool := x509.NewCertPool() + caPool.AddCert(ca) + + return trustKey, caPool, chain +} + +func TestLoadStatement(t *testing.T) { + grantCount := 4 + grants, _ := createTestKeysAndGrants(grantCount) + + trustKey, caPool, chain := generateTrustChain(t, 6) + + statement, err := generateStatement(grants, trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + + statementBytes, err := statement.Bytes() + if err != nil { + t.Fatalf("Error getting statement bytes: %s", err) + } + + s2, err := LoadStatement(bytes.NewReader(statementBytes), caPool) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + if len(s2.Grants) != grantCount { + t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants)) + } + + pool := x509.NewCertPool() + _, err = LoadStatement(bytes.NewReader(statementBytes), pool) + if err == nil { + t.Fatalf("No error thrown verifying without an authority") + } else if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("Unexpected error verifying without authority: %s", err) + } + + s2, err = LoadStatement(bytes.NewReader(statementBytes), nil) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + if len(s2.Grants) != grantCount { + t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants)) + } + + badData := make([]byte, len(statementBytes)) + copy(badData, statementBytes) + badData[0] = '[' + _, err = LoadStatement(bytes.NewReader(badData), nil) + if err == nil { + t.Fatalf("No error thrown parsing bad json") + } + + alteredData := make([]byte, len(statementBytes)) + copy(alteredData, statementBytes) + alteredData[30] = '0' + _, err = LoadStatement(bytes.NewReader(alteredData), nil) + if err == nil { + t.Fatalf("No error thrown from bad data") + } +} + +func TestCollapseGrants(t *testing.T) { + grantCount := 8 + grants, keys := createTestKeysAndGrants(grantCount) + linkGrants := make([]*Grant, 4) + linkGrants[0] = &Grant{ + Subject: "/user-3", + Permission: 0x0f, + Grantee: "/user-2", + } + linkGrants[1] = &Grant{ + Subject: "/user-3/sub-project", + Permission: 0x0f, + Grantee: "/user-4", + } + linkGrants[2] = &Grant{ + Subject: "/user-6", + Permission: 0x0f, + Grantee: "/user-7", + } + linkGrants[3] = &Grant{ + Subject: "/user-6/sub-project/specific-app", + Permission: 0x0f, + Grantee: "/user-5", + } + trustKey, pool, chain := generateTrustChain(t, 3) + + statements := make([]*Statement, 3) + var err error + statements[0], err = generateStatement(grants[0:4], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[1], err = generateStatement(grants[4:], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[2], err = generateStatement(linkGrants, trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + + statementsCopy := make([]*Statement, len(statements)) + for i, statement := range statements { + b, err := statement.Bytes() + if err != nil { + t.Fatalf("Error getting statement bytes: %s", err) + } + verifiedStatement, err := LoadStatement(bytes.NewReader(b), pool) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + // Force sort by reversing order + statementsCopy[len(statementsCopy)-i-1] = verifiedStatement + } + statements = statementsCopy + + collapsedGrants, expiration, err := CollapseStatements(statements, false) + if len(collapsedGrants) != 12 { + t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %d", 12, len(collapsedGrants)) + } + if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) { + t.Fatalf("Unexpected expiration time: %s", expiration.String()) + } + g := NewMemoryGraph(collapsedGrants) + + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f) + testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f) + testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-5", 0x0f) + testVerified(t, g, keys[5].PublicKey(), "user-key-6", "/user-6", 0x0f) + testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-7", 0x0f) + testVerified(t, g, keys[7].PublicKey(), "user-key-8", "/user-8", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-project/specific-app", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f) + testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6", 0x0f) + testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f) + testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project/specific-app", 0x0f) + + testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f) + testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-6/sub-project", 0x0f) + testNotVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project", 0x0f) + + // Add revocation grant + statements = append(statements, &Statement{ + jsonStatement{ + IssuedAt: time.Now(), + Expiration: time.Now().Add(testStatementExpiration), + Grants: []*jsonGrant{}, + Revocations: []*jsonRevocation{ + &jsonRevocation{ + Subject: "/user-1", + Revocation: 0x0f, + Grantee: keys[0].KeyID(), + }, + &jsonRevocation{ + Subject: "/user-2", + Revocation: 0x08, + Grantee: keys[1].KeyID(), + }, + &jsonRevocation{ + Subject: "/user-6", + Revocation: 0x0f, + Grantee: "/user-7", + }, + &jsonRevocation{ + Subject: "/user-9", + Revocation: 0x0f, + Grantee: "/user-10", + }, + }, + }, + nil, + }) + + collapsedGrants, expiration, err = CollapseStatements(statements, false) + if len(collapsedGrants) != 12 { + t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %d", 12, len(collapsedGrants)) + } + if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) { + t.Fatalf("Unexpected expiration time: %s", expiration.String()) + } + g = NewMemoryGraph(collapsedGrants) + + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f) + testNotVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f) + + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x07) +} + +func TestFilterStatements(t *testing.T) { + grantCount := 8 + grants, keys := createTestKeysAndGrants(grantCount) + linkGrants := make([]*Grant, 3) + linkGrants[0] = &Grant{ + Subject: "/user-3", + Permission: 0x0f, + Grantee: "/user-2", + } + linkGrants[1] = &Grant{ + Subject: "/user-5", + Permission: 0x0f, + Grantee: "/user-4", + } + linkGrants[2] = &Grant{ + Subject: "/user-7", + Permission: 0x0f, + Grantee: "/user-6", + } + + trustKey, _, chain := generateTrustChain(t, 3) + + statements := make([]*Statement, 5) + var err error + statements[0], err = generateStatement(grants[0:2], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[1], err = generateStatement(grants[2:4], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[2], err = generateStatement(grants[4:6], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[3], err = generateStatement(grants[6:], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[4], err = generateStatement(linkGrants, trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + collapsed, _, err := CollapseStatements(statements, false) + if err != nil { + t.Fatalf("Error collapsing grants: %s", err) + } + + // Filter 1, all 5 statements + filter1, err := FilterStatements(collapsed) + if err != nil { + t.Fatalf("Error filtering statements: %s", err) + } + if len(filter1) != 5 { + t.Fatalf("Wrong number of statements, expected %d, received %d", 5, len(filter1)) + } + + // Filter 2, one statement + filter2, err := FilterStatements([]*Grant{collapsed[0]}) + if err != nil { + t.Fatalf("Error filtering statements: %s", err) + } + if len(filter2) != 1 { + t.Fatalf("Wrong number of statements, expected %d, received %d", 1, len(filter2)) + } + + // Filter 3, 2 statements, from graph lookup + g := NewMemoryGraph(collapsed) + lookupGrants, err := g.GetGrants(keys[1], "/user-3", 0x0f) + if err != nil { + t.Fatalf("Error looking up grants: %s", err) + } + if len(lookupGrants) != 1 { + t.Fatalf("Wrong numberof grant chains returned from lookup, expected %d, received %d", 1, len(lookupGrants)) + } + if len(lookupGrants[0]) != 2 { + t.Fatalf("Wrong number of grants looked up, expected %d, received %d", 2, len(lookupGrants)) + } + filter3, err := FilterStatements(lookupGrants[0]) + if err != nil { + t.Fatalf("Error filtering statements: %s", err) + } + if len(filter3) != 2 { + t.Fatalf("Wrong number of statements, expected %d, received %d", 2, len(filter3)) + } + +} + +func TestCreateStatement(t *testing.T) { + grantJSON := bytes.NewReader([]byte(`[ + { + "subject": "/user-2", + "permission": 15, + "grantee": "/user-1" + }, + { + "subject": "/user-7", + "permission": 1, + "grantee": "/user-9" + }, + { + "subject": "/user-3", + "permission": 15, + "grantee": "/user-2" + } +]`)) + revocationJSON := bytes.NewReader([]byte(`[ + { + "subject": "user-8", + "revocation": 12, + "grantee": "user-9" + } +]`)) + + trustKey, pool, chain := generateTrustChain(t, 3) + + statement, err := CreateStatement(grantJSON, revocationJSON, testStatementExpiration, trustKey, chain) + if err != nil { + t.Fatalf("Error creating statement: %s", err) + } + + b, err := statement.Bytes() + if err != nil { + t.Fatalf("Error retrieving bytes: %s", err) + } + + verified, err := LoadStatement(bytes.NewReader(b), pool) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + + if len(verified.Grants) != 3 { + t.Errorf("Unexpected number of grants, expected %d, received %d", 3, len(verified.Grants)) + } + + if len(verified.Revocations) != 1 { + t.Errorf("Unexpected number of revocations, expected %d, received %d", 1, len(verified.Revocations)) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/util.go b/Godeps/_workspace/src/github.com/docker/libtrust/util.go new file mode 100644 index 0000000..d88176c --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/util.go @@ -0,0 +1,363 @@ +package libtrust + +import ( + "bytes" + "crypto" + "crypto/elliptic" + "crypto/tls" + "crypto/x509" + "encoding/base32" + "encoding/base64" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "math/big" + "net/url" + "os" + "path/filepath" + "strings" + "time" +) + +// LoadOrCreateTrustKey will load a PrivateKey from the specified path +func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) { + if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil { + return nil, err + } + + trustKey, err := LoadKeyFile(trustKeyPath) + if err == ErrKeyFileDoesNotExist { + trustKey, err = GenerateECP256PrivateKey() + if err != nil { + return nil, fmt.Errorf("error generating key: %s", err) + } + + if err := SaveKey(trustKeyPath, trustKey); err != nil { + return nil, fmt.Errorf("error saving key file: %s", err) + } + + dir, file := filepath.Split(trustKeyPath) + if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil { + return nil, fmt.Errorf("error saving public key file: %s", err) + } + } else if err != nil { + return nil, fmt.Errorf("error loading key file: %s", err) + } + return trustKey, nil +} + +// NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity +// based authentication from the specified dockerUrl, the rootConfigPath and +// the server name to which it is connecting. +// If trustUnknownHosts is true it will automatically add the host to the +// known-hosts.json in rootConfigPath. +func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) { + tlsConfig := newTLSConfig() + + trustKeyPath := filepath.Join(rootConfigPath, "key.json") + knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json") + + u, err := url.Parse(dockerUrl) + if err != nil { + return nil, fmt.Errorf("unable to parse machine url") + } + + if u.Scheme == "unix" { + return nil, nil + } + + addr := u.Host + proto := "tcp" + + trustKey, err := LoadOrCreateTrustKey(trustKeyPath) + if err != nil { + return nil, fmt.Errorf("unable to load trust key: %s", err) + } + + knownHosts, err := LoadKeySetFile(knownHostsPath) + if err != nil { + return nil, fmt.Errorf("could not load trusted hosts file: %s", err) + } + + allowedHosts, err := FilterByHosts(knownHosts, addr, false) + if err != nil { + return nil, fmt.Errorf("error filtering hosts: %s", err) + } + + certPool, err := GenerateCACertPool(trustKey, allowedHosts) + if err != nil { + return nil, fmt.Errorf("Could not create CA pool: %s", err) + } + + tlsConfig.ServerName = serverName + tlsConfig.RootCAs = certPool + + x509Cert, err := GenerateSelfSignedClientCert(trustKey) + if err != nil { + return nil, fmt.Errorf("certificate generation error: %s", err) + } + + tlsConfig.Certificates = []tls.Certificate{{ + Certificate: [][]byte{x509Cert.Raw}, + PrivateKey: trustKey.CryptoPrivateKey(), + Leaf: x509Cert, + }} + + tlsConfig.InsecureSkipVerify = true + + testConn, err := tls.Dial(proto, addr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("tls Handshake error: %s", err) + } + + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + CurrentTime: time.Now(), + DNSName: tlsConfig.ServerName, + Intermediates: x509.NewCertPool(), + } + + certs := testConn.ConnectionState().PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + if _, err := certs[0].Verify(opts); err != nil { + if _, ok := err.(x509.UnknownAuthorityError); ok { + if trustUnknownHosts { + pubKey, err := FromCryptoPublicKey(certs[0].PublicKey) + if err != nil { + return nil, fmt.Errorf("error extracting public key from cert: %s", err) + } + + pubKey.AddExtendedField("hosts", []string{addr}) + + if err := AddKeySetFile(knownHostsPath, pubKey); err != nil { + return nil, fmt.Errorf("error adding machine to known hosts: %s", err) + } + } else { + return nil, fmt.Errorf("unable to connect. unknown host: %s", addr) + } + } + } + + testConn.Close() + tlsConfig.InsecureSkipVerify = false + + return tlsConfig, nil +} + +// joseBase64UrlEncode encodes the given data using the standard base64 url +// encoding format but with all trailing '=' characters ommitted in accordance +// with the jose specification. +// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2 +func joseBase64UrlEncode(b []byte) string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=") +} + +// joseBase64UrlDecode decodes the given string using the standard base64 url +// decoder but first adds the appropriate number of trailing '=' characters in +// accordance with the jose specification. +// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2 +func joseBase64UrlDecode(s string) ([]byte, error) { + s = strings.Replace(s, "\n", "", -1) + s = strings.Replace(s, " ", "", -1) + switch len(s) % 4 { + case 0: + case 2: + s += "==" + case 3: + s += "=" + default: + return nil, errors.New("illegal base64url string") + } + return base64.URLEncoding.DecodeString(s) +} + +func keyIDEncode(b []byte) string { + s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=") + var buf bytes.Buffer + var i int + for i = 0; i < len(s)/4-1; i++ { + start := i * 4 + end := start + 4 + buf.WriteString(s[start:end] + ":") + } + buf.WriteString(s[i*4:]) + return buf.String() +} + +func keyIDFromCryptoKey(pubKey PublicKey) string { + // Generate and return a 'libtrust' fingerprint of the public key. + // For an RSA key this should be: + // SHA256(DER encoded ASN1) + // Then truncated to 240 bits and encoded into 12 base32 groups like so: + // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP + derBytes, err := x509.MarshalPKIXPublicKey(pubKey.CryptoPublicKey()) + if err != nil { + return "" + } + hasher := crypto.SHA256.New() + hasher.Write(derBytes) + return keyIDEncode(hasher.Sum(nil)[:30]) +} + +func stringFromMap(m map[string]interface{}, key string) (string, error) { + val, ok := m[key] + if !ok { + return "", fmt.Errorf("%q value not specified", key) + } + + str, ok := val.(string) + if !ok { + return "", fmt.Errorf("%q value must be a string", key) + } + delete(m, key) + + return str, nil +} + +func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) { + curveByteLen := (curve.Params().BitSize + 7) >> 3 + + cBytes, err := joseBase64UrlDecode(cB64Url) + if err != nil { + return nil, fmt.Errorf("invalid base64 URL encoding: %s", err) + } + cByteLength := len(cBytes) + if cByteLength != curveByteLen { + return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen) + } + return new(big.Int).SetBytes(cBytes), nil +} + +func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) { + dBytes, err := joseBase64UrlDecode(dB64Url) + if err != nil { + return nil, fmt.Errorf("invalid base64 URL encoding: %s", err) + } + + // The length of this octet string MUST be ceiling(log-base-2(n)/8) + // octets (where n is the order of the curve). This is because the private + // key d must be in the interval [1, n-1] so the bitlength of d should be + // no larger than the bitlength of n-1. The easiest way to find the octet + // length is to take bitlength(n-1), add 7 to force a carry, and shift this + // bit sequence right by 3, which is essentially dividing by 8 and adding + // 1 if there is any remainder. Thus, the private key value d should be + // output to (bitlength(n-1)+7)>>3 octets. + n := curve.Params().N + octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3 + dByteLength := len(dBytes) + + if dByteLength != octetLength { + return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength) + } + + return new(big.Int).SetBytes(dBytes), nil +} + +func parseRSAModulusParam(nB64Url string) (*big.Int, error) { + nBytes, err := joseBase64UrlDecode(nB64Url) + if err != nil { + return nil, fmt.Errorf("invalid base64 URL encoding: %s", err) + } + + return new(big.Int).SetBytes(nBytes), nil +} + +func serializeRSAPublicExponentParam(e int) []byte { + // We MUST use the minimum number of octets to represent E. + // E is supposed to be 65537 for performance and security reasons + // and is what golang's rsa package generates, but it might be + // different if imported from some other generator. + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(e)) + var i int + for i = 0; i < 8; i++ { + if buf[i] != 0 { + break + } + } + return buf[i:] +} + +func parseRSAPublicExponentParam(eB64Url string) (int, error) { + eBytes, err := joseBase64UrlDecode(eB64Url) + if err != nil { + return 0, fmt.Errorf("invalid base64 URL encoding: %s", err) + } + // Only the minimum number of bytes were used to represent E, but + // binary.BigEndian.Uint32 expects at least 4 bytes, so we need + // to add zero padding if necassary. + byteLen := len(eBytes) + buf := make([]byte, 4-byteLen, 4) + eBytes = append(buf, eBytes...) + + return int(binary.BigEndian.Uint32(eBytes)), nil +} + +func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) { + b64Url, err := stringFromMap(m, key) + if err != nil { + return nil, err + } + + paramBytes, err := joseBase64UrlDecode(b64Url) + if err != nil { + return nil, fmt.Errorf("invaled base64 URL encoding: %s", err) + } + + return new(big.Int).SetBytes(paramBytes), nil +} + +func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) { + pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}} + for k, v := range headers { + switch val := v.(type) { + case string: + pemBlock.Headers[k] = val + case []string: + if k == "hosts" { + pemBlock.Headers[k] = strings.Join(val, ",") + } else { + // Return error, non-encodable type + } + default: + // Return error, non-encodable type + } + } + + return pemBlock, nil +} + +func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) { + cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err) + } + + pubKey, err := FromCryptoPublicKey(cryptoPublicKey) + if err != nil { + return nil, err + } + + addPEMHeadersToKey(pemBlock, pubKey) + + return pubKey, nil +} + +func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) { + for key, value := range pemBlock.Headers { + var safeVal interface{} + if key == "hosts" { + safeVal = strings.Split(value, ",") + } else { + safeVal = value + } + pubKey.AddExtendedField(key, safeVal) + } +} diff --git a/Godeps/_workspace/src/github.com/docker/libtrust/util_test.go b/Godeps/_workspace/src/github.com/docker/libtrust/util_test.go new file mode 100644 index 0000000..83b7cfb --- /dev/null +++ b/Godeps/_workspace/src/github.com/docker/libtrust/util_test.go @@ -0,0 +1,45 @@ +package libtrust + +import ( + "encoding/pem" + "reflect" + "testing" +) + +func TestAddPEMHeadersToKey(t *testing.T) { + pk := &rsaPublicKey{nil, map[string]interface{}{}} + blk := &pem.Block{Headers: map[string]string{"hosts": "localhost,127.0.0.1"}} + addPEMHeadersToKey(blk, pk) + + val := pk.GetExtendedField("hosts") + hosts, ok := val.([]string) + if !ok { + t.Fatalf("hosts type(%v), expected []string", reflect.TypeOf(val)) + } + expected := []string{"localhost", "127.0.0.1"} + if !reflect.DeepEqual(hosts, expected) { + t.Errorf("hosts(%v), expected %v", hosts, expected) + } +} + +func TestBase64URL(t *testing.T) { + clean := "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJwMnMiOiIyV0NUY0paMVJ2ZF9DSnVKcmlwUTF3IiwicDJjIjo0MDk2LCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwiY3R5IjoiandrK2pzb24ifQ" + + tests := []string{ + clean, // clean roundtrip + "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJwMnMiOiIyV0NUY0paMVJ2\nZF9DSnVKcmlwUTF3IiwicDJjIjo0MDk2LCJlbmMiOiJBMTI4Q0JDLUhTMjU2\nIiwiY3R5IjoiandrK2pzb24ifQ", // with newlines + "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJwMnMiOiIyV0NUY0paMVJ2 \n ZF9DSnVKcmlwUTF3IiwicDJjIjo0MDk2LCJlbmMiOiJBMTI4Q0JDLUhTMjU2 \n IiwiY3R5IjoiandrK2pzb24ifQ", // with newlines and spaces + } + + for i, test := range tests { + b, err := joseBase64UrlDecode(test) + if err != nil { + t.Fatalf("on test %d: %s", i, err) + } + got := joseBase64UrlEncode(b) + + if got != clean { + t.Errorf("expected %q, got %q", clean, got) + } + } +} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/.gitignore b/Godeps/_workspace/src/github.com/macaron-contrib/session/.gitignore deleted file mode 100644 index 9297dbc..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -ledis/tmp.db -nodb/tmp.db \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/LICENSE b/Godeps/_workspace/src/github.com/macaron-contrib/session/LICENSE deleted file mode 100644 index 8405e89..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/LICENSE +++ /dev/null @@ -1,191 +0,0 @@ -Apache License -Version 2.0, January 2004 -http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - -"License" shall mean the terms and conditions for use, reproduction, and -distribution as defined by Sections 1 through 9 of this document. - -"Licensor" shall mean the copyright owner or entity authorized by the copyright -owner that is granting the License. - -"Legal Entity" shall mean the union of the acting entity and all other entities -that control, are controlled by, or are under common control with that entity. -For the purposes of this definition, "control" means (i) the power, direct or -indirect, to cause the direction or management of such entity, whether by -contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the -outstanding shares, or (iii) beneficial ownership of such entity. - -"You" (or "Your") shall mean an individual or Legal Entity exercising -permissions granted by this License. - -"Source" form shall mean the preferred form for making modifications, including -but not limited to software source code, documentation source, and configuration -files. - -"Object" form shall mean any form resulting from mechanical transformation or -translation of a Source form, including but not limited to compiled object code, -generated documentation, and conversions to other media types. - -"Work" shall mean the work of authorship, whether in Source or Object form, made -available under the License, as indicated by a copyright notice that is included -in or attached to the work (an example is provided in the Appendix below). - -"Derivative Works" shall mean any work, whether in Source or Object form, that -is based on (or derived from) the Work and for which the editorial revisions, -annotations, elaborations, or other modifications represent, as a whole, an -original work of authorship. For the purposes of this License, Derivative Works -shall not include works that remain separable from, or merely link (or bind by -name) to the interfaces of, the Work and Derivative Works thereof. - -"Contribution" shall mean any work of authorship, including the original version -of the Work and any modifications or additions to that Work or Derivative Works -thereof, that is intentionally submitted to Licensor for inclusion in the Work -by the copyright owner or by an individual or Legal Entity authorized to submit -on behalf of the copyright owner. For the purposes of this definition, -"submitted" means any form of electronic, verbal, or written communication sent -to the Licensor or its representatives, including but not limited to -communication on electronic mailing lists, source code control systems, and -issue tracking systems that are managed by, or on behalf of, the Licensor for -the purpose of discussing and improving the Work, but excluding communication -that is conspicuously marked or otherwise designated in writing by the copyright -owner as "Not a Contribution." - -"Contributor" shall mean Licensor and any individual or Legal Entity on behalf -of whom a Contribution has been received by Licensor and subsequently -incorporated within the Work. - -2. Grant of Copyright License. - -Subject to the terms and conditions of this License, each Contributor hereby -grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, -irrevocable copyright license to reproduce, prepare Derivative Works of, -publicly display, publicly perform, sublicense, and distribute the Work and such -Derivative Works in Source or Object form. - -3. Grant of Patent License. - -Subject to the terms and conditions of this License, each Contributor hereby -grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, -irrevocable (except as stated in this section) patent license to make, have -made, use, offer to sell, sell, import, and otherwise transfer the Work, where -such license applies only to those patent claims licensable by such Contributor -that are necessarily infringed by their Contribution(s) alone or by combination -of their Contribution(s) with the Work to which such Contribution(s) was -submitted. If You institute patent litigation against any entity (including a -cross-claim or counterclaim in a lawsuit) alleging that the Work or a -Contribution incorporated within the Work constitutes direct or contributory -patent infringement, then any patent licenses granted to You under this License -for that Work shall terminate as of the date such litigation is filed. - -4. Redistribution. - -You may reproduce and distribute copies of the Work or Derivative Works thereof -in any medium, with or without modifications, and in Source or Object form, -provided that You meet the following conditions: - -You must give any other recipients of the Work or Derivative Works a copy of -this License; and -You must cause any modified files to carry prominent notices stating that You -changed the files; and -You must retain, in the Source form of any Derivative Works that You distribute, -all copyright, patent, trademark, and attribution notices from the Source form -of the Work, excluding those notices that do not pertain to any part of the -Derivative Works; and -If the Work includes a "NOTICE" text file as part of its distribution, then any -Derivative Works that You distribute must include a readable copy of the -attribution notices contained within such NOTICE file, excluding those notices -that do not pertain to any part of the Derivative Works, in at least one of the -following places: within a NOTICE text file distributed as part of the -Derivative Works; within the Source form or documentation, if provided along -with the Derivative Works; or, within a display generated by the Derivative -Works, if and wherever such third-party notices normally appear. The contents of -the NOTICE file are for informational purposes only and do not modify the -License. You may add Your own attribution notices within Derivative Works that -You distribute, alongside or as an addendum to the NOTICE text from the Work, -provided that such additional attribution notices cannot be construed as -modifying the License. -You may add Your own copyright statement to Your modifications and may provide -additional or different license terms and conditions for use, reproduction, or -distribution of Your modifications, or for any such Derivative Works as a whole, -provided Your use, reproduction, and distribution of the Work otherwise complies -with the conditions stated in this License. - -5. Submission of Contributions. - -Unless You explicitly state otherwise, any Contribution intentionally submitted -for inclusion in the Work by You to the Licensor shall be under the terms and -conditions of this License, without any additional terms or conditions. -Notwithstanding the above, nothing herein shall supersede or modify the terms of -any separate license agreement you may have executed with Licensor regarding -such Contributions. - -6. Trademarks. - -This License does not grant permission to use the trade names, trademarks, -service marks, or product names of the Licensor, except as required for -reasonable and customary use in describing the origin of the Work and -reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. - -Unless required by applicable law or agreed to in writing, Licensor provides the -Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, -including, without limitation, any warranties or conditions of TITLE, -NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are -solely responsible for determining the appropriateness of using or -redistributing the Work and assume any risks associated with Your exercise of -permissions under this License. - -8. Limitation of Liability. - -In no event and under no legal theory, whether in tort (including negligence), -contract, or otherwise, unless required by applicable law (such as deliberate -and grossly negligent acts) or agreed to in writing, shall any Contributor be -liable to You for damages, including any direct, indirect, special, incidental, -or consequential damages of any character arising as a result of this License or -out of the use or inability to use the Work (including but not limited to -damages for loss of goodwill, work stoppage, computer failure or malfunction, or -any and all other commercial damages or losses), even if such Contributor has -been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. - -While redistributing the Work or Derivative Works thereof, You may choose to -offer, and charge a fee for, acceptance of support, warranty, indemnity, or -other liability obligations and/or rights consistent with this License. However, -in accepting such obligations, You may act only on Your own behalf and on Your -sole responsibility, not on behalf of any other Contributor, and only if You -agree to indemnify, defend, and hold each Contributor harmless for any liability -incurred by, or claims asserted against, such Contributor by reason of your -accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work - -To apply the Apache License to your work, attach the following boilerplate -notice, with the fields enclosed by brackets "[]" replaced with your own -identifying information. (Don't include the brackets!) The text should be -enclosed in the appropriate comment syntax for the file format. We also -recommend that a file or class name and description of purpose be included on -the same "printed page" as the copyright notice for easier identification within -third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/README.md b/Godeps/_workspace/src/github.com/macaron-contrib/session/README.md deleted file mode 100644 index 01de811..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/README.md +++ /dev/null @@ -1,21 +0,0 @@ -session [![Build Status](https://drone.io/github.com/macaron-contrib/session/status.png)](https://drone.io/github.com/macaron-contrib/session/latest) [![](http://gocover.io/_badge/github.com/macaron-contrib/session)](http://gocover.io/github.com/macaron-contrib/session) -======= - -Middleware session provides session management for [Macaron](https://github.com/Unknwon/macaron). It can use many session providers, including memory, file, Redis, Memcache, PostgreSQL, MySQL, Couchbase, Ledis and Nodb. - -### Installation - - go get github.com/macaron-contrib/session - -## Getting Help - -- [API Reference](https://gowalker.org/github.com/macaron-contrib/session) -- [Documentation](http://macaron.gogs.io/docs/middlewares/session) - -## Credits - -This package is forked from [beego/session](https://github.com/astaxie/beego/tree/master/session) with reconstruction(over 80%). - -## License - -This project is under Apache v2 License. See the [LICENSE](LICENSE) file for the full license text. diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/couchbase/couchbase.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/couchbase/couchbase.go deleted file mode 100644 index 93953c6..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/couchbase/couchbase.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "strings" - "sync" - - "github.com/couchbaselabs/go-couchbase" - - "github.com/macaron-contrib/session" -) - -// CouchbaseSessionStore represents a couchbase session store implementation. -type CouchbaseSessionStore struct { - b *couchbase.Bucket - sid string - lock sync.RWMutex - data map[interface{}]interface{} - maxlifetime int64 -} - -// Set sets value to given key in session. -func (s *CouchbaseSessionStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *CouchbaseSessionStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *CouchbaseSessionStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *CouchbaseSessionStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *CouchbaseSessionStore) Release() error { - defer s.b.Close() - - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - return s.b.Set(s.sid, int(s.maxlifetime), data) -} - -// Flush deletes all session data. -func (s *CouchbaseSessionStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// CouchbaseProvider represents a couchbase session provider implementation. -type CouchbaseProvider struct { - maxlifetime int64 - connStr string - pool string - bucket string - b *couchbase.Bucket -} - -func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket { - c, err := couchbase.Connect(cp.connStr) - if err != nil { - return nil - } - - pool, err := c.GetPool(cp.pool) - if err != nil { - return nil - } - - bucket, err := pool.GetBucket(cp.bucket) - if err != nil { - return nil - } - - return bucket -} - -// Init initializes memory session provider. -// connStr is couchbase server REST/JSON URL -// e.g. http://host:port/, Pool, Bucket -func (p *CouchbaseProvider) Init(maxlifetime int64, connStr string) error { - p.maxlifetime = maxlifetime - configs := strings.Split(connStr, ",") - if len(configs) > 0 { - p.connStr = configs[0] - } - if len(configs) > 1 { - p.pool = configs[1] - } - if len(configs) > 2 { - p.bucket = configs[2] - } - - return nil -} - -// Read returns raw session store by session ID. -func (p *CouchbaseProvider) Read(sid string) (session.RawStore, error) { - p.b = p.getBucket() - - var doc []byte - - err := p.b.Get(sid, &doc) - var kv map[interface{}]interface{} - if doc == nil { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(doc) - if err != nil { - return nil, err - } - } - - cs := &CouchbaseSessionStore{b: p.b, sid: sid, data: kv, maxlifetime: p.maxlifetime} - return cs, nil -} - -// Exist returns true if session with given ID exists. -func (p *CouchbaseProvider) Exist(sid string) bool { - p.b = p.getBucket() - defer p.b.Close() - - var doc []byte - - if err := p.b.Get(sid, &doc); err != nil || doc == nil { - return false - } else { - return true - } -} - -// Destory deletes a session by session ID. -func (p *CouchbaseProvider) Destory(sid string) error { - p.b = p.getBucket() - defer p.b.Close() - - p.b.Delete(sid) - return nil -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *CouchbaseProvider) Regenerate(oldsid, sid string) (session.RawStore, error) { - p.b = p.getBucket() - - var doc []byte - if err := p.b.Get(oldsid, &doc); err != nil || doc == nil { - p.b.Set(sid, int(p.maxlifetime), "") - } else { - err := p.b.Delete(oldsid) - if err != nil { - return nil, err - } - _, _ = p.b.Add(sid, int(p.maxlifetime), doc) - } - - err := p.b.Get(sid, &doc) - if err != nil { - return nil, err - } - var kv map[interface{}]interface{} - if doc == nil { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(doc) - if err != nil { - return nil, err - } - } - - cs := &CouchbaseSessionStore{b: p.b, sid: sid, data: kv, maxlifetime: p.maxlifetime} - return cs, nil -} - -// Count counts and returns number of sessions. -func (p *CouchbaseProvider) Count() int { - // FIXME - return 0 -} - -// GC calls GC to clean expired sessions. -func (p *CouchbaseProvider) GC() {} - -func init() { - session.Register("couchbase", &CouchbaseProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/file.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/file.go deleted file mode 100644 index cab807d..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/file.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "fmt" - "io/ioutil" - "log" - "os" - "path" - "path/filepath" - "sync" - "time" - - "github.com/Unknwon/com" -) - -// FileStore represents a file session store implementation. -type FileStore struct { - p *FileProvider - sid string - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewFileStore creates and returns a file session store. -func NewFileStore(p *FileProvider, sid string, kv map[interface{}]interface{}) *FileStore { - return &FileStore{ - p: p, - sid: sid, - data: kv, - } -} - -// Set sets value to given key in session. -func (s *FileStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *FileStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *FileStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *FileStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *FileStore) Release() error { - data, err := EncodeGob(s.data) - if err != nil { - return err - } - - return ioutil.WriteFile(s.p.filepath(s.sid), data, os.ModePerm) -} - -// Flush deletes all session data. -func (s *FileStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// FileProvider represents a file session provider implementation. -type FileProvider struct { - maxlifetime int64 - rootPath string -} - -// Init initializes file session provider with given root path. -func (p *FileProvider) Init(maxlifetime int64, rootPath string) error { - p.maxlifetime = maxlifetime - p.rootPath = rootPath - return nil -} - -func (p *FileProvider) filepath(sid string) string { - return path.Join(p.rootPath, string(sid[0]), string(sid[1]), sid) -} - -// Read returns raw session store by session ID. -func (p *FileProvider) Read(sid string) (_ RawStore, err error) { - filename := p.filepath(sid) - if err = os.MkdirAll(path.Dir(filename), os.ModePerm); err != nil { - return nil, err - } - - var f *os.File - if com.IsFile(filename) { - f, err = os.OpenFile(filename, os.O_RDWR, os.ModePerm) - } else { - f, err = os.Create(filename) - } - if err != nil { - return nil, err - } - defer f.Close() - - if err = os.Chtimes(filename, time.Now(), time.Now()); err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - data, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - if len(data) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(data) - if err != nil { - return nil, err - } - } - return NewFileStore(p, sid, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *FileProvider) Exist(sid string) bool { - return com.IsFile(p.filepath(sid)) -} - -// Destory deletes a session by session ID. -func (p *FileProvider) Destory(sid string) error { - return os.Remove(p.filepath(sid)) -} - -func (p *FileProvider) regenerate(oldsid, sid string) (err error) { - filename := p.filepath(sid) - if com.IsExist(filename) { - return fmt.Errorf("new sid '%s' already exists", sid) - } - - oldname := p.filepath(oldsid) - if !com.IsFile(oldname) { - data, err := EncodeGob(make(map[interface{}]interface{})) - if err != nil { - return err - } - if err = os.MkdirAll(path.Dir(oldname), os.ModePerm); err != nil { - return err - } - if err = ioutil.WriteFile(oldname, data, os.ModePerm); err != nil { - return err - } - } - - if err = os.MkdirAll(path.Dir(filename), os.ModePerm); err != nil { - return err - } - if err = os.Rename(oldname, filename); err != nil { - return err - } - return nil -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *FileProvider) Regenerate(oldsid, sid string) (_ RawStore, err error) { - if err := p.regenerate(oldsid, sid); err != nil { - return nil, err - } - - return p.Read(sid) -} - -// Count counts and returns number of sessions. -func (p *FileProvider) Count() int { - count := 0 - if err := filepath.Walk(p.rootPath, func(path string, fi os.FileInfo, err error) error { - if err != nil { - return err - } - - if !fi.IsDir() { - count++ - } - return nil - }); err != nil { - log.Printf("error counting session files: %v", err) - return 0 - } - return count -} - -// GC calls GC to clean expired sessions. -func (p *FileProvider) GC() { - if !com.IsExist(p.rootPath) { - return - } - - if err := filepath.Walk(p.rootPath, func(path string, fi os.FileInfo, err error) error { - if err != nil { - return err - } - - if !fi.IsDir() && - (fi.ModTime().Unix()+p.maxlifetime) < time.Now().Unix() { - return os.Remove(path) - } - return nil - }); err != nil { - log.Printf("error garbage collecting session files: %v", err) - } -} - -func init() { - Register("file", &FileProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/file_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/file_test.go deleted file mode 100644 index 9c83555..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/file_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "os" - "path" - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -func Test_FileProvider(t *testing.T) { - Convey("Test file session provider", t, func() { - dir := path.Join(os.TempDir(), "data/sessions") - os.RemoveAll(dir) - testProvider(Options{ - Provider: "file", - ProviderConfig: dir, - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis.go deleted file mode 100644 index afde713..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "fmt" - "strings" - "sync" - - "github.com/Unknwon/com" - "github.com/siddontang/ledisdb/config" - "github.com/siddontang/ledisdb/ledis" - "gopkg.in/ini.v1" - - "github.com/macaron-contrib/session" -) - -// LedisStore represents a ledis session store implementation. -type LedisStore struct { - c *ledis.DB - sid string - expire int64 - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewLedisStore creates and returns a ledis session store. -func NewLedisStore(c *ledis.DB, sid string, expire int64, kv map[interface{}]interface{}) *LedisStore { - return &LedisStore{ - c: c, - expire: expire, - sid: sid, - data: kv, - } -} - -// Set sets value to given key in session. -func (s *LedisStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *LedisStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *LedisStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *LedisStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *LedisStore) Release() error { - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - if err = s.c.Set([]byte(s.sid), data); err != nil { - return err - } - _, err = s.c.Expire([]byte(s.sid), s.expire) - return err -} - -// Flush deletes all session data. -func (s *LedisStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// LedisProvider represents a ledis session provider implementation. -type LedisProvider struct { - c *ledis.DB - expire int64 -} - -// Init initializes ledis session provider. -// configs: data_dir=./app.db,db=0 -func (p *LedisProvider) Init(expire int64, configs string) error { - p.expire = expire - - cfg, err := ini.Load([]byte(strings.Replace(configs, ",", "\n", -1))) - if err != nil { - return err - } - - db := 0 - opt := new(config.Config) - for k, v := range cfg.Section("").KeysHash() { - switch k { - case "data_dir": - opt.DataDir = v - case "db": - db = com.StrTo(v).MustInt() - default: - return fmt.Errorf("session/ledis: unsupported option '%s'", k) - } - } - - l, err := ledis.Open(opt) - if err != nil { - return fmt.Errorf("session/ledis: error opening db: %v", err) - } - p.c, err = l.Select(db) - return err -} - -// Read returns raw session store by session ID. -func (p *LedisProvider) Read(sid string) (session.RawStore, error) { - if !p.Exist(sid) { - if err := p.c.Set([]byte(sid), []byte("")); err != nil { - return nil, err - } - } - - var kv map[interface{}]interface{} - kvs, err := p.c.Get([]byte(sid)) - if err != nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(kvs) - if err != nil { - return nil, err - } - } - - return NewLedisStore(p.c, sid, p.expire, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *LedisProvider) Exist(sid string) bool { - count, err := p.c.Exists([]byte(sid)) - return err == nil && count > 0 -} - -// Destory deletes a session by session ID. -func (p *LedisProvider) Destory(sid string) error { - _, err := p.c.Del([]byte(sid)) - return err -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *LedisProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } - - kvs := make([]byte, 0) - if p.Exist(oldsid) { - if kvs, err = p.c.Get([]byte(oldsid)); err != nil { - return nil, err - } else if _, err = p.c.Del([]byte(oldsid)); err != nil { - return nil, err - } - } - if err = p.c.SetEX([]byte(sid), p.expire, kvs); err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { - return nil, err - } - } - - return NewLedisStore(p.c, sid, p.expire, kv), nil -} - -// Count counts and returns number of sessions. -func (p *LedisProvider) Count() int { - // FIXME: how come this library does not have DbSize() method? - return -1 -} - -// GC calls GC to clean expired sessions. -func (p *LedisProvider) GC() { - // FIXME: wtf??? -} - -func init() { - session.Register("ledis", &LedisProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis.goconvey b/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis.goconvey deleted file mode 100644 index 8485e98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis.goconvey +++ /dev/null @@ -1 +0,0 @@ -ignore \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis_test.go deleted file mode 100644 index dac42a3..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/ledis/ledis_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" - - "github.com/macaron-contrib/session" -) - -func Test_LedisProvider(t *testing.T) { - Convey("Test ledis session provider", t, func() { - opt := session.Options{ - Provider: "ledis", - ProviderConfig: "data_dir=./tmp.db", - } - - Convey("Basic operation", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess session.Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - Convey("Regenrate empty session", func() { - m.Get("/empty", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - }) - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/empty", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;") - m.ServeHTTP(resp, req) - }) - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache.go deleted file mode 100644 index b4fcdde..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "fmt" - "strings" - "sync" - - "github.com/bradfitz/gomemcache/memcache" - - "github.com/macaron-contrib/session" -) - -// MemcacheStore represents a memcache session store implementation. -type MemcacheStore struct { - c *memcache.Client - sid string - expire int32 - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewMemcacheStore creates and returns a memcache session store. -func NewMemcacheStore(c *memcache.Client, sid string, expire int32, kv map[interface{}]interface{}) *MemcacheStore { - return &MemcacheStore{ - c: c, - sid: sid, - expire: expire, - data: kv, - } -} - -func NewItem(sid string, data []byte, expire int32) *memcache.Item { - return &memcache.Item{ - Key: sid, - Value: data, - Expiration: expire, - } -} - -// Set sets value to given key in session. -func (s *MemcacheStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *MemcacheStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *MemcacheStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *MemcacheStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *MemcacheStore) Release() error { - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - return s.c.Set(NewItem(s.sid, data, s.expire)) -} - -// Flush deletes all session data. -func (s *MemcacheStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// MemcacheProvider represents a memcache session provider implementation. -type MemcacheProvider struct { - c *memcache.Client - expire int32 -} - -// Init initializes memcache session provider. -// connStrs: 127.0.0.1:9090;127.0.0.1:9091 -func (p *MemcacheProvider) Init(expire int64, connStrs string) error { - p.expire = int32(expire) - p.c = memcache.New(strings.Split(connStrs, ";")...) - return nil -} - -// Read returns raw session store by session ID. -func (p *MemcacheProvider) Read(sid string) (session.RawStore, error) { - if !p.Exist(sid) { - if err := p.c.Set(NewItem(sid, []byte(""), p.expire)); err != nil { - return nil, err - } - } - - var kv map[interface{}]interface{} - item, err := p.c.Get(sid) - if err != nil { - return nil, err - } - if len(item.Value) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(item.Value) - if err != nil { - return nil, err - } - } - - return NewMemcacheStore(p.c, sid, p.expire, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *MemcacheProvider) Exist(sid string) bool { - _, err := p.c.Get(sid) - return err == nil -} - -// Destory deletes a session by session ID. -func (p *MemcacheProvider) Destory(sid string) error { - return p.c.Delete(sid) -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *MemcacheProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } - - item := NewItem(sid, []byte(""), p.expire) - if p.Exist(oldsid) { - item, err = p.c.Get(oldsid) - if err != nil { - return nil, err - } else if err = p.c.Delete(oldsid); err != nil { - return nil, err - } - item.Key = sid - } - if err = p.c.Set(item); err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(item.Value) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(item.Value) - if err != nil { - return nil, err - } - } - - return NewMemcacheStore(p.c, sid, p.expire, kv), nil -} - -// Count counts and returns number of sessions. -func (p *MemcacheProvider) Count() int { - // FIXME: how come this library does not have Stats method? - return -1 -} - -// GC calls GC to clean expired sessions. -func (p *MemcacheProvider) GC() {} - -func init() { - session.Register("memcache", &MemcacheProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache.goconvey b/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache.goconvey deleted file mode 100644 index 8485e98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache.goconvey +++ /dev/null @@ -1 +0,0 @@ -ignore \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache_test.go deleted file mode 100644 index beb272d..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/memcache/memcache_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" - - "github.com/macaron-contrib/session" -) - -func Test_MemcacheProvider(t *testing.T) { - Convey("Test memcache session provider", t, func() { - opt := session.Options{ - Provider: "memcache", - ProviderConfig: "127.0.0.1:9090", - } - - Convey("Basic operation", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess session.Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - }) - - Convey("Regenrate empty session", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;") - m.ServeHTTP(resp, req) - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/memory.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/memory.go deleted file mode 100644 index e717635..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/memory.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "container/list" - "fmt" - "sync" - "time" -) - -// MemStore represents a in-memory session store implementation. -type MemStore struct { - sid string - lock sync.RWMutex - data map[interface{}]interface{} - lastAccess time.Time -} - -// NewMemStore creates and returns a memory session store. -func NewMemStore(sid string) *MemStore { - return &MemStore{ - sid: sid, - data: make(map[interface{}]interface{}), - lastAccess: time.Now(), - } -} - -// Set sets value to given key in session. -func (s *MemStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *MemStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete deletes a key from session. -func (s *MemStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *MemStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (_ *MemStore) Release() error { - return nil -} - -// Flush deletes all session data. -func (s *MemStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// MemProvider represents a in-memory session provider implementation. -type MemProvider struct { - lock sync.RWMutex - maxLifetime int64 - data map[string]*list.Element - // A priority list whose lastAccess newer gets higer priority. - list *list.List -} - -// Init initializes memory session provider. -func (p *MemProvider) Init(maxLifetime int64, _ string) error { - p.maxLifetime = maxLifetime - return nil -} - -// update expands time of session store by given ID. -func (p *MemProvider) update(sid string) error { - p.lock.Lock() - defer p.lock.Unlock() - - if e, ok := p.data[sid]; ok { - e.Value.(*MemStore).lastAccess = time.Now() - p.list.MoveToFront(e) - return nil - } - return nil -} - -// Read returns raw session store by session ID. -func (p *MemProvider) Read(sid string) (_ RawStore, err error) { - p.lock.RLock() - e, ok := p.data[sid] - p.lock.RUnlock() - - if ok { - if err = p.update(sid); err != nil { - return nil, err - } - return e.Value.(*MemStore), nil - } - - // Create a new session. - p.lock.Lock() - defer p.lock.Unlock() - - s := NewMemStore(sid) - p.data[sid] = p.list.PushBack(s) - return s, nil -} - -// Exist returns true if session with given ID exists. -func (p *MemProvider) Exist(sid string) bool { - p.lock.RLock() - defer p.lock.RUnlock() - - _, ok := p.data[sid] - return ok -} - -// Destory deletes a session by session ID. -func (p *MemProvider) Destory(sid string) error { - p.lock.Lock() - defer p.lock.Unlock() - - e, ok := p.data[sid] - if !ok { - return nil - } - - p.list.Remove(e) - delete(p.data, sid) - return nil -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *MemProvider) Regenerate(oldsid, sid string) (RawStore, error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } - - s, err := p.Read(oldsid) - if err != nil { - return nil, err - } - - if err = p.Destory(oldsid); err != nil { - return nil, err - } - - s.(*MemStore).sid = sid - p.data[sid] = p.list.PushBack(s) - return s, nil -} - -// Count counts and returns number of sessions. -func (p *MemProvider) Count() int { - return p.list.Len() -} - -// GC calls GC to clean expired sessions. -func (p *MemProvider) GC() { - p.lock.RLock() - for { - // No session in the list. - e := p.list.Back() - if e == nil { - break - } - - if (e.Value.(*MemStore).lastAccess.Unix() + p.maxLifetime) < time.Now().Unix() { - p.lock.RUnlock() - p.lock.Lock() - p.list.Remove(e) - delete(p.data, e.Value.(*MemStore).sid) - p.lock.Unlock() - p.lock.RLock() - } else { - break - } - } - p.lock.RUnlock() -} - -func init() { - Register("memory", &MemProvider{list: list.New(), data: make(map[string]*list.Element)}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/memory_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/memory_test.go deleted file mode 100644 index 41659bb..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/memory_test.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -func Test_MemProvider(t *testing.T) { - Convey("Test memory session provider", t, func() { - testProvider(Options{}) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql.go deleted file mode 100644 index 7997e03..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql.go +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "database/sql" - "fmt" - "log" - "sync" - "time" - - _ "github.com/go-sql-driver/mysql" - - "github.com/macaron-contrib/session" -) - -// MysqlStore represents a mysql session store implementation. -type MysqlStore struct { - c *sql.DB - sid string - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewMysqlStore creates and returns a mysql session store. -func NewMysqlStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *MysqlStore { - return &MysqlStore{ - c: c, - sid: sid, - data: kv, - } -} - -// Set sets value to given key in session. -func (s *MysqlStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *MysqlStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *MysqlStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *MysqlStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *MysqlStore) Release() error { - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - _, err = s.c.Exec("UPDATE session SET data=?, expiry=? WHERE `key`=?", - data, time.Now().Unix(), s.sid) - return err -} - -// Flush deletes all session data. -func (s *MysqlStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// MysqlProvider represents a mysql session provider implementation. -type MysqlProvider struct { - c *sql.DB - expire int64 -} - -// Init initializes mysql session provider. -// connStr: username:password@protocol(address)/dbname?param=value -func (p *MysqlProvider) Init(expire int64, connStr string) (err error) { - p.expire = expire - - p.c, err = sql.Open("mysql", connStr) - if err != nil { - return err - } - return p.c.Ping() -} - -// Read returns raw session store by session ID. -func (p *MysqlProvider) Read(sid string) (session.RawStore, error) { - var data []byte - err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data) - if err == sql.ErrNoRows { - _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)", - sid, "", time.Now().Unix()) - } - if err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(data) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(data) - if err != nil { - return nil, err - } - } - - return NewMysqlStore(p.c, sid, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *MysqlProvider) Exist(sid string) bool { - var data []byte - err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data) - if err != nil && err != sql.ErrNoRows { - panic("session/mysql: error checking existence: " + err.Error()) - } - return err != sql.ErrNoRows -} - -// Destory deletes a session by session ID. -func (p *MysqlProvider) Destory(sid string) error { - _, err := p.c.Exec("DELETE FROM session WHERE `key`=?", sid) - return err -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *MysqlProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } - - if !p.Exist(oldsid) { - if _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)", - oldsid, "", time.Now().Unix()); err != nil { - return nil, err - } - } - - if _, err = p.c.Exec("UPDATE session SET `key`=? WHERE `key`=?", sid, oldsid); err != nil { - return nil, err - } - - return p.Read(sid) -} - -// Count counts and returns number of sessions. -func (p *MysqlProvider) Count() (total int) { - if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil { - panic("session/mysql: error counting records: " + err.Error()) - } - return total -} - -// GC calls GC to clean expired sessions. -func (p *MysqlProvider) GC() { - if _, err := p.c.Exec("DELETE FROM session WHERE UNIX_TIMESTAMP(NOW()) - expiry > ?", p.expire); err != nil { - log.Printf("session/mysql: error garbage collecting: %v", err) - } -} - -func init() { - session.Register("mysql", &MysqlProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql.goconvey b/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql.goconvey deleted file mode 100644 index 8485e98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql.goconvey +++ /dev/null @@ -1 +0,0 @@ -ignore \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql_test.go deleted file mode 100644 index 15b3996..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/mysql/mysql_test.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" - - "github.com/macaron-contrib/session" -) - -func Test_MysqlProvider(t *testing.T) { - Convey("Test mysql session provider", t, func() { - opt := session.Options{ - Provider: "mysql", - ProviderConfig: "root:@tcp(localhost:3306)/macaron?charset=utf8", - } - - Convey("Basic operation", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess session.Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - So(raw.Release(), ShouldBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - }) - - Convey("Regenrate empty session", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;") - m.ServeHTTP(resp, req) - }) - - Convey("GC session", func() { - m := macaron.New() - opt2 := opt - opt2.Gclifetime = 1 - m.Use(session.Sessioner(opt2)) - - m.Get("/", func(sess session.Store) { - sess.Set("uname", "unknwon") - So(sess.ID(), ShouldNotBeEmpty) - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Flush(), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - time.Sleep(2 * time.Second) - sess.GC() - So(sess.Count(), ShouldEqual, 0) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb.go deleted file mode 100644 index 7f017bf..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2015 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "fmt" - "sync" - - "github.com/lunny/nodb" - "github.com/lunny/nodb/config" - - "github.com/macaron-contrib/session" -) - -// NodbStore represents a nodb session store implementation. -type NodbStore struct { - c *nodb.DB - sid string - expire int64 - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewNodbStore creates and returns a ledis session store. -func NewNodbStore(c *nodb.DB, sid string, expire int64, kv map[interface{}]interface{}) *NodbStore { - return &NodbStore{ - c: c, - expire: expire, - sid: sid, - data: kv, - } -} - -// Set sets value to given key in session. -func (s *NodbStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *NodbStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *NodbStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *NodbStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *NodbStore) Release() error { - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - if err = s.c.Set([]byte(s.sid), data); err != nil { - return err - } - _, err = s.c.Expire([]byte(s.sid), s.expire) - return err -} - -// Flush deletes all session data. -func (s *NodbStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// NodbProvider represents a ledis session provider implementation. -type NodbProvider struct { - c *nodb.DB - expire int64 -} - -// Init initializes nodb session provider. -func (p *NodbProvider) Init(expire int64, configs string) error { - p.expire = expire - - cfg := new(config.Config) - cfg.DataDir = configs - dbs, err := nodb.Open(cfg) - if err != nil { - return fmt.Errorf("session/nodb: error opening db: %v", err) - } - - p.c, err = dbs.Select(0) - return err -} - -// Read returns raw session store by session ID. -func (p *NodbProvider) Read(sid string) (session.RawStore, error) { - if !p.Exist(sid) { - if err := p.c.Set([]byte(sid), []byte("")); err != nil { - return nil, err - } - } - - var kv map[interface{}]interface{} - kvs, err := p.c.Get([]byte(sid)) - if err != nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(kvs) - if err != nil { - return nil, err - } - } - - return NewNodbStore(p.c, sid, p.expire, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *NodbProvider) Exist(sid string) bool { - count, err := p.c.Exists([]byte(sid)) - return err == nil && count > 0 -} - -// Destory deletes a session by session ID. -func (p *NodbProvider) Destory(sid string) error { - _, err := p.c.Del([]byte(sid)) - return err -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *NodbProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } - - kvs := make([]byte, 0) - if p.Exist(oldsid) { - if kvs, err = p.c.Get([]byte(oldsid)); err != nil { - return nil, err - } else if _, err = p.c.Del([]byte(oldsid)); err != nil { - return nil, err - } - } - - if err = p.c.Set([]byte(sid), kvs); err != nil { - return nil, err - } else if _, err = p.c.Expire([]byte(sid), p.expire); err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { - return nil, err - } - } - - return NewNodbStore(p.c, sid, p.expire, kv), nil -} - -// Count counts and returns number of sessions. -func (p *NodbProvider) Count() int { - // FIXME: how come this library does not have DbSize() method? - return -1 -} - -// GC calls GC to clean expired sessions. -func (p *NodbProvider) GC() {} - -func init() { - session.Register("nodb", &NodbProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb.goconvey b/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb.goconvey deleted file mode 100644 index 8485e98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb.goconvey +++ /dev/null @@ -1 +0,0 @@ -ignore \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb_test.go deleted file mode 100644 index c86ba98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/nodb/nodb_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2015 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" - - "github.com/macaron-contrib/session" -) - -func Test_LedisProvider(t *testing.T) { - Convey("Test nodb session provider", t, func() { - opt := session.Options{ - Provider: "nodb", - ProviderConfig: "./tmp.db", - } - - Convey("Basic operation", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess session.Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - Convey("Regenrate empty session", func() { - m.Get("/empty", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - }) - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/empty", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;") - m.ServeHTTP(resp, req) - }) - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres.go deleted file mode 100644 index 5cb4c82..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "database/sql" - "fmt" - "log" - "sync" - "time" - - _ "github.com/lib/pq" - - "github.com/macaron-contrib/session" -) - -// PostgresStore represents a postgres session store implementation. -type PostgresStore struct { - c *sql.DB - sid string - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewPostgresStore creates and returns a postgres session store. -func NewPostgresStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *PostgresStore { - return &PostgresStore{ - c: c, - sid: sid, - data: kv, - } -} - -// Set sets value to given key in session. -func (s *PostgresStore) Set(key, value interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = value - return nil -} - -// Get gets value by given key in session. -func (s *PostgresStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *PostgresStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *PostgresStore) ID() string { - return s.sid -} - -// save postgres session values to database. -// must call this method to save values to database. -func (s *PostgresStore) Release() error { - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - _, err = s.c.Exec("UPDATE session SET data=$1, expiry=$2 WHERE key=$3", - data, time.Now().Unix(), s.sid) - return err -} - -// Flush deletes all session data. -func (s *PostgresStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// PostgresProvider represents a postgres session provider implementation. -type PostgresProvider struct { - c *sql.DB - maxlifetime int64 -} - -// Init initializes postgres session provider. -// connStr: user=a password=b host=localhost port=5432 dbname=c sslmode=disable -func (p *PostgresProvider) Init(maxlifetime int64, connStr string) (err error) { - p.maxlifetime = maxlifetime - - p.c, err = sql.Open("postgres", connStr) - if err != nil { - return err - } - return p.c.Ping() -} - -// Read returns raw session store by session ID. -func (p *PostgresProvider) Read(sid string) (session.RawStore, error) { - var data []byte - err := p.c.QueryRow("SELECT data FROM session WHERE key=$1", sid).Scan(&data) - if err == sql.ErrNoRows { - _, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)", - sid, "", time.Now().Unix()) - } - if err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(data) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(data) - if err != nil { - return nil, err - } - } - - return NewPostgresStore(p.c, sid, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *PostgresProvider) Exist(sid string) bool { - var data []byte - err := p.c.QueryRow("SELECT data FROM session WHERE key=$1", sid).Scan(&data) - if err != nil && err != sql.ErrNoRows { - panic("session/postgres: error checking existence: " + err.Error()) - } - return err != sql.ErrNoRows -} - -// Destory deletes a session by session ID. -func (p *PostgresProvider) Destory(sid string) error { - _, err := p.c.Exec("DELETE FROM session WHERE key=$1", sid) - return err -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *PostgresProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } - - if !p.Exist(oldsid) { - if _, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)", - oldsid, "", time.Now().Unix()); err != nil { - return nil, err - } - } - - if _, err = p.c.Exec("UPDATE session SET key=$1 WHERE key=$2", sid, oldsid); err != nil { - return nil, err - } - - return p.Read(sid) -} - -// Count counts and returns number of sessions. -func (p *PostgresProvider) Count() (total int) { - if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil { - panic("session/postgres: error counting records: " + err.Error()) - } - return total -} - -// GC calls GC to clean expired sessions. -func (p *PostgresProvider) GC() { - if _, err := p.c.Exec("DELETE FROM session WHERE EXTRACT(EPOCH FROM NOW()) - expiry > $1", p.maxlifetime); err != nil { - log.Printf("session/postgres: error garbage collecting: %v", err) - } -} - -func init() { - session.Register("postgres", &PostgresProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres.goconvey b/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres.goconvey deleted file mode 100644 index 8485e98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres.goconvey +++ /dev/null @@ -1 +0,0 @@ -ignore \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres_test.go deleted file mode 100644 index ea212c7..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/postgres/postgres_test.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" - - "github.com/macaron-contrib/session" -) - -func Test_PostgresProvider(t *testing.T) { - Convey("Test postgres session provider", t, func() { - opt := session.Options{ - Provider: "postgres", - ProviderConfig: "user=jiahuachen dbname=macaron port=5432 sslmode=disable", - } - - Convey("Basic operation", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess session.Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - So(raw.Release(), ShouldBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - }) - - Convey("Regenrate empty session", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;") - m.ServeHTTP(resp, req) - }) - - Convey("GC session", func() { - m := macaron.New() - opt2 := opt - opt2.Gclifetime = 1 - m.Use(session.Sessioner(opt2)) - - m.Get("/", func(sess session.Store) { - sess.Set("uname", "unknwon") - So(sess.ID(), ShouldNotBeEmpty) - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Flush(), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - time.Sleep(2 * time.Second) - sess.GC() - So(sess.Count(), ShouldEqual, 0) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis.go deleted file mode 100644 index 6d6a2c4..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "fmt" - "strings" - "sync" - "time" - - "github.com/Unknwon/com" - "gopkg.in/ini.v1" - "gopkg.in/redis.v2" - - "github.com/macaron-contrib/session" -) - -// RedisStore represents a redis session store implementation. -type RedisStore struct { - c *redis.Client - sid string - duration time.Duration - lock sync.RWMutex - data map[interface{}]interface{} -} - -// NewRedisStore creates and returns a redis session store. -func NewRedisStore(c *redis.Client, sid string, dur time.Duration, kv map[interface{}]interface{}) *RedisStore { - return &RedisStore{ - c: c, - sid: sid, - duration: dur, - data: kv, - } -} - -// Set sets value to given key in session. -func (s *RedisStore) Set(key, val interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data[key] = val - return nil -} - -// Get gets value by given key in session. -func (s *RedisStore) Get(key interface{}) interface{} { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.data[key] -} - -// Delete delete a key from session. -func (s *RedisStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - - delete(s.data, key) - return nil -} - -// ID returns current session ID. -func (s *RedisStore) ID() string { - return s.sid -} - -// Release releases resource and save data to provider. -func (s *RedisStore) Release() error { - data, err := session.EncodeGob(s.data) - if err != nil { - return err - } - - return s.c.SetEx(s.sid, s.duration, string(data)).Err() -} - -// Flush deletes all session data. -func (s *RedisStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - - s.data = make(map[interface{}]interface{}) - return nil -} - -// RedisProvider represents a redis session provider implementation. -type RedisProvider struct { - c *redis.Client - duration time.Duration -} - -// Init initializes redis session provider. -// configs: network=tcp,addr=:6379,password=macaron,db=0,pool_size=100,idle_timeout=180 -func (p *RedisProvider) Init(maxlifetime int64, configs string) (err error) { - p.duration, err = time.ParseDuration(fmt.Sprintf("%ds", maxlifetime)) - if err != nil { - return err - } - - cfg, err := ini.Load([]byte(strings.Replace(configs, ",", "\n", -1))) - if err != nil { - return err - } - - opt := &redis.Options{ - Network: "tcp", - } - for k, v := range cfg.Section("").KeysHash() { - switch k { - case "network": - opt.Network = v - case "addr": - opt.Addr = v - case "password": - opt.Password = v - case "db": - opt.DB = com.StrTo(v).MustInt64() - case "pool_size": - opt.PoolSize = com.StrTo(v).MustInt() - case "idle_timeout": - opt.IdleTimeout, err = time.ParseDuration(v + "s") - if err != nil { - return fmt.Errorf("error parsing idle timeout: %v", err) - } - default: - return fmt.Errorf("session/redis: unsupported option '%s'", k) - } - } - - p.c = redis.NewClient(opt) - return p.c.Ping().Err() -} - -// Read returns raw session store by session ID. -func (p *RedisProvider) Read(sid string) (session.RawStore, error) { - if !p.Exist(sid) { - if err := p.c.Set(sid, "").Err(); err != nil { - return nil, err - } - } - - var kv map[interface{}]interface{} - kvs, err := p.c.Get(sid).Result() - if err != nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { - return nil, err - } - } - - return NewRedisStore(p.c, sid, p.duration, kv), nil -} - -// Exist returns true if session with given ID exists. -func (p *RedisProvider) Exist(sid string) bool { - has, err := p.c.Exists(sid).Result() - return err == nil && has -} - -// Destory deletes a session by session ID. -func (p *RedisProvider) Destory(sid string) error { - return p.c.Del(sid).Err() -} - -// Regenerate regenerates a session store from old session ID to new one. -func (p *RedisProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - if p.Exist(sid) { - return nil, fmt.Errorf("new sid '%s' already exists", sid) - } else if !p.Exist(oldsid) { - // Make a fake old session. - if err = p.c.SetEx(oldsid, p.duration, "").Err(); err != nil { - return nil, err - } - } - - if err = p.c.Rename(oldsid, sid).Err(); err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - kvs, err := p.c.Get(sid).Result() - if err != nil { - return nil, err - } - - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { - return nil, err - } - } - - return NewRedisStore(p.c, sid, p.duration, kv), nil -} - -// Count counts and returns number of sessions. -func (p *RedisProvider) Count() int { - return int(p.c.DbSize().Val()) -} - -// GC calls GC to clean expired sessions. -func (_ *RedisProvider) GC() {} - -func init() { - session.Register("redis", &RedisProvider{}) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis.goconvey b/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis.goconvey deleted file mode 100644 index 8485e98..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis.goconvey +++ /dev/null @@ -1 +0,0 @@ -ignore \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis_test.go deleted file mode 100644 index 9fd8e65..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/redis/redis_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" - - "github.com/macaron-contrib/session" -) - -func Test_RedisProvider(t *testing.T) { - Convey("Test redis session provider", t, func() { - opt := session.Options{ - Provider: "redis", - ProviderConfig: "addr=:6379", - } - - Convey("Basic operation", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess session.Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - }) - - Convey("Regenrate empty session", func() { - m := macaron.New() - m.Use(session.Sessioner(opt)) - m.Get("/", func(ctx *macaron.Context, sess session.Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;") - m.ServeHTTP(resp, req) - }) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/session.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/session.go deleted file mode 100644 index 9cc1d52..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/session.go +++ /dev/null @@ -1,401 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -// Package session a middleware that provides the session management of Macaron. -package session - -// NOTE: last sync 000033e on Nov 4, 2014. - -import ( - "encoding/hex" - "fmt" - "net/http" - "net/url" - "time" - - "github.com/Unknwon/macaron" -) - -const _VERSION = "0.1.6" - -func Version() string { - return _VERSION -} - -// RawStore is the interface that operates the session data. -type RawStore interface { - // Set sets value to given key in session. - Set(interface{}, interface{}) error - // Get gets value by given key in session. - Get(interface{}) interface{} - // Delete deletes a key from session. - Delete(interface{}) error - // ID returns current session ID. - ID() string - // Release releases session resource and save data to provider. - Release() error - // Flush deletes all session data. - Flush() error -} - -// Store is the interface that contains all data for one session process with specific ID. -type Store interface { - RawStore - // Read returns raw session store by session ID. - Read(string) (RawStore, error) - // Destory deletes a session. - Destory(*macaron.Context) error - // RegenerateId regenerates a session store from old session ID to new one. - RegenerateId(*macaron.Context) (RawStore, error) - // Count counts and returns number of sessions. - Count() int - // GC calls GC to clean expired sessions. - GC() -} - -type store struct { - RawStore - *Manager -} - -var _ Store = &store{} - -// Options represents a struct for specifying configuration options for the session middleware. -type Options struct { - // Name of provider. Default is "memory". - Provider string - // Provider configuration, it's corresponding to provider. - ProviderConfig string - // Cookie name to save session ID. Default is "MacaronSession". - CookieName string - // Cookie path to store. Default is "/". - CookiePath string - // GC interval time in seconds. Default is 3600. - Gclifetime int64 - // Max life time in seconds. Default is whatever GC interval time is. - Maxlifetime int64 - // Use HTTPS only. Default is false. - Secure bool - // Cookie life time. Default is 0. - CookieLifeTime int - // Cookie domain name. Default is empty. - Domain string - // Session ID length. Default is 16. - IDLength int - // Configuration section name. Default is "session". - Section string -} - -func prepareOptions(options []Options) Options { - var opt Options - if len(options) > 0 { - opt = options[0] - } - if len(opt.Section) == 0 { - opt.Section = "session" - } - sec := macaron.Config().Section(opt.Section) - - if len(opt.Provider) == 0 { - opt.Provider = sec.Key("PROVIDER").MustString("memory") - } - if len(opt.ProviderConfig) == 0 { - opt.ProviderConfig = sec.Key("PROVIDER_CONFIG").MustString("data/sessions") - } - if len(opt.CookieName) == 0 { - opt.CookieName = sec.Key("COOKIE_NAME").MustString("MacaronSession") - } - if len(opt.CookiePath) == 0 { - opt.CookiePath = sec.Key("COOKIE_PATH").MustString("/") - } - if opt.Gclifetime == 0 { - opt.Gclifetime = sec.Key("GC_INTERVAL_TIME").MustInt64(3600) - } - if opt.Maxlifetime == 0 { - opt.Maxlifetime = sec.Key("MAX_LIFE_TIME").MustInt64(opt.Gclifetime) - } - if !opt.Secure { - opt.Secure = sec.Key("SECURE").MustBool() - } - if opt.CookieLifeTime == 0 { - opt.CookieLifeTime = sec.Key("COOKIE_LIFE_TIME").MustInt() - } - if len(opt.Domain) == 0 { - opt.Domain = sec.Key("DOMAIN").String() - } - if opt.IDLength == 0 { - opt.IDLength = sec.Key("ID_LENGTH").MustInt(16) - } - - return opt -} - -// Sessioner is a middleware that maps a session.SessionStore service into the Macaron handler chain. -// An single variadic session.Options struct can be optionally provided to configure. -func Sessioner(options ...Options) macaron.Handler { - opt := prepareOptions(options) - manager, err := NewManager(opt.Provider, opt) - if err != nil { - panic(err) - } - go manager.startGC() - - return func(ctx *macaron.Context) { - sess, err := manager.Start(ctx) - if err != nil { - panic("session(start): " + err.Error()) - } - - // Get flash. - vals, _ := url.ParseQuery(ctx.GetCookie("macaron_flash")) - if len(vals) > 0 { - f := &Flash{Values: vals} - f.ErrorMsg = f.Get("error") - f.SuccessMsg = f.Get("success") - f.InfoMsg = f.Get("info") - f.WarningMsg = f.Get("warning") - ctx.Data["Flash"] = f - ctx.SetCookie("macaron_flash", "", -1, opt.CookiePath) - } - - f := &Flash{ctx, url.Values{}, "", "", "", ""} - ctx.Resp.Before(func(macaron.ResponseWriter) { - if flash := f.Encode(); len(flash) > 0 { - ctx.SetCookie("macaron_flash", flash, 0, opt.CookiePath) - } - }) - - ctx.Map(f) - s := store{ - RawStore: sess, - Manager: manager, - } - - ctx.MapTo(s, (*Store)(nil)) - - ctx.Next() - - if err = sess.Release(); err != nil { - panic("session(release): " + err.Error()) - } - } -} - -// Provider is the interface that provides session manipulations. -type Provider interface { - // Init initializes session provider. - Init(gclifetime int64, config string) error - // Read returns raw session store by session ID. - Read(sid string) (RawStore, error) - // Exist returns true if session with given ID exists. - Exist(sid string) bool - // Destory deletes a session by session ID. - Destory(sid string) error - // Regenerate regenerates a session store from old session ID to new one. - Regenerate(oldsid, sid string) (RawStore, error) - // Count counts and returns number of sessions. - Count() int - // GC calls GC to clean expired sessions. - GC() -} - -var providers = make(map[string]Provider) - -// Register registers a provider. -func Register(name string, provider Provider) { - if provider == nil { - panic("session: cannot register provider with nil value") - } - if _, dup := providers[name]; dup { - panic(fmt.Errorf("session: cannot register provider '%s' twice", name)) - } - providers[name] = provider -} - -// _____ -// / \ _____ ____ _____ ____ ___________ -// / \ / \\__ \ / \\__ \ / ___\_/ __ \_ __ \ -// / Y \/ __ \| | \/ __ \_/ /_/ > ___/| | \/ -// \____|__ (____ /___| (____ /\___ / \___ >__| -// \/ \/ \/ \//_____/ \/ - -// Manager represents a struct that contains session provider and its configuration. -type Manager struct { - provider Provider - opt Options -} - -// NewManager creates and returns a new session manager by given provider name and configuration. -// It panics when given provider isn't registered. -func NewManager(name string, opt Options) (*Manager, error) { - p, ok := providers[name] - if !ok { - return nil, fmt.Errorf("session: unknown provider '%s'(forgotten import?)", name) - } - return &Manager{p, opt}, p.Init(opt.Maxlifetime, opt.ProviderConfig) -} - -// sessionId generates a new session ID with rand string, unix nano time, remote addr by hash function. -func (m *Manager) sessionId() string { - return hex.EncodeToString(generateRandomKey(m.opt.IDLength / 2)) -} - -// Start starts a session by generating new one -// or retrieve existence one by reading session ID from HTTP request if it's valid. -func (m *Manager) Start(ctx *macaron.Context) (RawStore, error) { - sid := ctx.GetCookie(m.opt.CookieName) - if len(sid) > 0 && m.provider.Exist(sid) { - return m.provider.Read(sid) - } - - sid = m.sessionId() - sess, err := m.provider.Read(sid) - if err != nil { - return nil, err - } - - cookie := &http.Cookie{ - Name: m.opt.CookieName, - Value: sid, - Path: m.opt.CookiePath, - HttpOnly: true, - Secure: m.opt.Secure, - Domain: m.opt.Domain, - } - if m.opt.CookieLifeTime >= 0 { - cookie.MaxAge = m.opt.CookieLifeTime - } - http.SetCookie(ctx.Resp, cookie) - ctx.Req.AddCookie(cookie) - return sess, nil -} - -// Read returns raw session store by session ID. -func (m *Manager) Read(sid string) (RawStore, error) { - return m.provider.Read(sid) -} - -// Destory deletes a session by given ID. -func (m *Manager) Destory(ctx *macaron.Context) error { - sid := ctx.GetCookie(m.opt.CookieName) - if len(sid) == 0 { - return nil - } - - if err := m.provider.Destory(sid); err != nil { - return err - } - cookie := &http.Cookie{ - Name: m.opt.CookieName, - Path: m.opt.CookiePath, - HttpOnly: true, - Expires: time.Now(), - MaxAge: -1, - } - http.SetCookie(ctx.Resp, cookie) - return nil -} - -// RegenerateId regenerates a session store from old session ID to new one. -func (m *Manager) RegenerateId(ctx *macaron.Context) (sess RawStore, err error) { - sid := m.sessionId() - oldsid := ctx.GetCookie(m.opt.CookieName) - sess, err = m.provider.Regenerate(oldsid, sid) - if err != nil { - return nil, err - } - ck := &http.Cookie{ - Name: m.opt.CookieName, - Value: sid, - Path: m.opt.CookiePath, - HttpOnly: true, - Secure: m.opt.Secure, - Domain: m.opt.Domain, - } - if m.opt.CookieLifeTime >= 0 { - ck.MaxAge = m.opt.CookieLifeTime - } - http.SetCookie(ctx.Resp, ck) - ctx.Req.AddCookie(ck) - return sess, nil -} - -// Count counts and returns number of sessions. -func (m *Manager) Count() int { - return m.provider.Count() -} - -// GC starts GC job in a certain period. -func (m *Manager) GC() { - m.provider.GC() -} - -// startGC starts GC job in a certain period. -func (m *Manager) startGC() { - m.GC() - time.AfterFunc(time.Duration(m.opt.Gclifetime)*time.Second, func() { m.startGC() }) -} - -// SetSecure indicates whether to set cookie with HTTPS or not. -func (m *Manager) SetSecure(secure bool) { - m.opt.Secure = secure -} - -// ___________.____ _____ _________ ___ ___ -// \_ _____/| | / _ \ / _____// | \ -// | __) | | / /_\ \ \_____ \/ ~ \ -// | \ | |___/ | \/ \ Y / -// \___ / |_______ \____|__ /_______ /\___|_ / -// \/ \/ \/ \/ \/ - -type Flash struct { - ctx *macaron.Context - url.Values - ErrorMsg, WarningMsg, InfoMsg, SuccessMsg string -} - -func (f *Flash) set(name, msg string, current ...bool) { - isShow := false - if (len(current) == 0 && macaron.FlashNow) || - (len(current) > 0 && current[0]) { - isShow = true - } - - if isShow { - f.ctx.Data["Flash"] = f - } else { - f.Set(name, msg) - } -} - -func (f *Flash) Error(msg string, current ...bool) { - f.ErrorMsg = msg - f.set("error", msg, current...) -} - -func (f *Flash) Warning(msg string, current ...bool) { - f.WarningMsg = msg - f.set("warning", msg, current...) -} - -func (f *Flash) Info(msg string, current ...bool) { - f.InfoMsg = msg - f.set("info", msg, current...) -} - -func (f *Flash) Success(msg string, current ...bool) { - f.SuccessMsg = msg - f.set("success", msg, current...) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/session_test.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/session_test.go deleted file mode 100644 index 82efc27..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/session_test.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/Unknwon/macaron" - . "github.com/smartystreets/goconvey/convey" -) - -func Test_Version(t *testing.T) { - Convey("Check package version", t, func() { - So(Version(), ShouldEqual, _VERSION) - }) -} - -func Test_Sessioner(t *testing.T) { - Convey("Use session middleware", t, func() { - m := macaron.New() - m.Use(Sessioner()) - m.Get("/", func() {}) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - }) - - Convey("Register invalid provider", t, func() { - Convey("Provider not exists", func() { - defer func() { - So(recover(), ShouldNotBeNil) - }() - - m := macaron.New() - m.Use(Sessioner(Options{ - Provider: "fake", - })) - }) - - Convey("Provider value is nil", func() { - defer func() { - So(recover(), ShouldNotBeNil) - }() - - Register("fake", nil) - }) - - Convey("Register twice", func() { - defer func() { - So(recover(), ShouldNotBeNil) - }() - - Register("memory", &MemProvider{}) - }) - }) -} - -func testProvider(opt Options) { - Convey("Basic operation", func() { - m := macaron.New() - m.Use(Sessioner(opt)) - - m.Get("/", func(ctx *macaron.Context, sess Store) { - sess.Set("uname", "unknwon") - }) - m.Get("/reg", func(ctx *macaron.Context, sess Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := raw.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - }) - m.Get("/get", func(ctx *macaron.Context, sess Store) { - sid := sess.ID() - So(sid, ShouldNotBeEmpty) - - raw, err := sess.Read(sid) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Delete("uname"), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - So(sess.Destory(ctx), ShouldBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - cookie := resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/reg", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - - cookie = resp.Header().Get("Set-Cookie") - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", cookie) - m.ServeHTTP(resp, req) - }) - - Convey("Regenrate empty session", func() { - m := macaron.New() - m.Use(Sessioner(opt)) - m.Get("/", func(ctx *macaron.Context, sess Store) { - raw, err := sess.RegenerateId(ctx) - So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf486; Path=/;") - m.ServeHTTP(resp, req) - }) - - Convey("GC session", func() { - m := macaron.New() - opt2 := opt - opt2.Gclifetime = 1 - m.Use(Sessioner(opt2)) - - m.Get("/", func(sess Store) { - sess.Set("uname", "unknwon") - So(sess.ID(), ShouldNotBeEmpty) - uname := sess.Get("uname") - So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") - - So(sess.Flush(), ShouldBeNil) - So(sess.Get("uname"), ShouldBeNil) - - time.Sleep(2 * time.Second) - sess.GC() - So(sess.Count(), ShouldEqual, 0) - }) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - }) -} - -func Test_Flash(t *testing.T) { - Convey("Test flash", t, func() { - m := macaron.New() - m.Use(Sessioner()) - m.Get("/set", func(f *Flash) string { - f.Success("success") - f.Error("error") - f.Warning("warning") - f.Info("info") - return "" - }) - m.Get("/get", func() {}) - - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/set", nil) - So(err, ShouldBeNil) - m.ServeHTTP(resp, req) - - resp = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/get", nil) - So(err, ShouldBeNil) - req.Header.Set("Cookie", "macaron_flash=error%3Derror%26info%3Dinfo%26success%3Dsuccess%26warning%3Dwarning; Path=/") - m.ServeHTTP(resp, req) - }) -} diff --git a/Godeps/_workspace/src/github.com/macaron-contrib/session/utils.go b/Godeps/_workspace/src/github.com/macaron-contrib/session/utils.go deleted file mode 100644 index 6c9ea49..0000000 --- a/Godeps/_workspace/src/github.com/macaron-contrib/session/utils.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2013 Beego Authors -// Copyright 2014 Unknwon -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package session - -import ( - "bytes" - "crypto/rand" - "encoding/gob" - "io" - - "github.com/Unknwon/com" -) - -func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { - for _, v := range obj { - gob.Register(v) - } - buf := bytes.NewBuffer(nil) - err := gob.NewEncoder(buf).Encode(obj) - return buf.Bytes(), err -} - -func DecodeGob(encoded []byte) (out map[interface{}]interface{}, err error) { - buf := bytes.NewBuffer(encoded) - err = gob.NewDecoder(buf).Decode(&out) - return out, err -} - -// generateRandomKey creates a random key with the given strength. -func generateRandomKey(strength int) []byte { - k := make([]byte, strength) - if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil { - return com.RandomCreateBytes(strength) - } - return k -} diff --git a/Godeps/_workspace/src/github.com/satori/go.uuid/.travis.yml b/Godeps/_workspace/src/github.com/satori/go.uuid/.travis.yml new file mode 100644 index 0000000..0bbdc41 --- /dev/null +++ b/Godeps/_workspace/src/github.com/satori/go.uuid/.travis.yml @@ -0,0 +1,10 @@ +language: go +go: + - 1.0 + - 1.1 + - 1.2 + - 1.3 + - 1.4 +sudo: false +notifications: + email: false diff --git a/Godeps/_workspace/src/github.com/satori/go.uuid/LICENSE b/Godeps/_workspace/src/github.com/satori/go.uuid/LICENSE new file mode 100644 index 0000000..6a1fb91 --- /dev/null +++ b/Godeps/_workspace/src/github.com/satori/go.uuid/LICENSE @@ -0,0 +1,20 @@ +Copyright (C) 2013-2015 by Maxim Bublis + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Godeps/_workspace/src/github.com/satori/go.uuid/README.md b/Godeps/_workspace/src/github.com/satori/go.uuid/README.md new file mode 100644 index 0000000..759f77c --- /dev/null +++ b/Godeps/_workspace/src/github.com/satori/go.uuid/README.md @@ -0,0 +1,66 @@ +# UUID package for Go language + +[![Build Status](https://travis-ci.org/satori/go.uuid.png?branch=master)](https://travis-ci.org/satori/go.uuid) +[![GoDoc](http://godoc.org/github.com/satori/go.uuid?status.png)](http://godoc.org/github.com/satori/go.uuid) + +This package provides pure Go implementation of Universally Unique Identifier (UUID). Supported both creation and parsing of UUIDs. + +With 100% test coverage and benchmarks out of box. + +Supported versions: +* Version 1, based on timestamp and MAC address (RFC 4122) +* Version 2, based on timestamp, MAC address and POSIX UID/GID (DCE 1.1) +* Version 3, based on MD5 hashing (RFC 4122) +* Version 4, based on random numbers (RFC 4122) +* Version 5, based on SHA-1 hashing (RFC 4122) + +## Installation + +Use the `go` command: + + $ go get github.com/satori/go.uuid + +## Requirements + +UUID package requires any stable version of Go Programming Language. + +It is tested against following versions of Go: 1.0-1.4 + +## Example + +```go +package main + +import ( + "fmt" + "github.com/satori/go.uuid" +) + +func main() { + // Creating UUID Version 4 + u1 := uuid.NewV4() + fmt.Printf("UUIDv4: %s\n", u1) + + // Parsing UUID from string input + u2, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + if err != nil { + fmt.Printf("Something gone wrong: %s", err) + } + fmt.Printf("Successfully parsed: %s", u2) +} +``` + +## Documentation + +[Documentation](http://godoc.org/github.com/satori/go.uuid) is hosted at GoDoc project. + +## Links +* [RFC 4122](http://tools.ietf.org/html/rfc4122) +* [DCE 1.1: Authentication and Security Services](http://pubs.opengroup.org/onlinepubs/9696989899/chap5.htm#tagcjh_08_02_01_01) + +## Copyright + +Copyright (C) 2013-2015 by Maxim Bublis . + +UUID package released under MIT License. +See [LICENSE](https://github.com/satori/go.uuid/blob/master/LICENSE) for details. diff --git a/Godeps/_workspace/src/github.com/satori/go.uuid/benchmarks_test.go b/Godeps/_workspace/src/github.com/satori/go.uuid/benchmarks_test.go new file mode 100644 index 0000000..9a85f7c --- /dev/null +++ b/Godeps/_workspace/src/github.com/satori/go.uuid/benchmarks_test.go @@ -0,0 +1,121 @@ +// Copyright (C) 2013-2014 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package uuid + +import ( + "testing" +) + +func BenchmarkFromBytes(b *testing.B) { + bytes := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + for i := 0; i < b.N; i++ { + FromBytes(bytes) + } +} + +func BenchmarkFromString(b *testing.B) { + s := "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + for i := 0; i < b.N; i++ { + FromString(s) + } +} + +func BenchmarkFromStringUrn(b *testing.B) { + s := "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" + for i := 0; i < b.N; i++ { + FromString(s) + } +} + +func BenchmarkFromStringWithBrackets(b *testing.B) { + s := "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}" + for i := 0; i < b.N; i++ { + FromString(s) + } +} + +func BenchmarkNewV1(b *testing.B) { + for i := 0; i < b.N; i++ { + NewV1() + } +} + +func BenchmarkNewV2(b *testing.B) { + for i := 0; i < b.N; i++ { + NewV2(DomainPerson) + } +} + +func BenchmarkNewV3(b *testing.B) { + for i := 0; i < b.N; i++ { + NewV3(NamespaceDNS, "www.example.com") + } +} + +func BenchmarkNewV4(b *testing.B) { + for i := 0; i < b.N; i++ { + NewV4() + } +} + +func BenchmarkNewV5(b *testing.B) { + for i := 0; i < b.N; i++ { + NewV5(NamespaceDNS, "www.example.com") + } +} + +func BenchmarkMarshalBinary(b *testing.B) { + u := NewV4() + for i := 0; i < b.N; i++ { + u.MarshalBinary() + } +} + +func BenchmarkMarshalText(b *testing.B) { + u := NewV4() + for i := 0; i < b.N; i++ { + u.MarshalText() + } +} + +func BenchmarkUnmarshalBinary(b *testing.B) { + bytes := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + u := UUID{} + for i := 0; i < b.N; i++ { + u.UnmarshalBinary(bytes) + } +} + +func BenchmarkUnmarshalText(b *testing.B) { + bytes := []byte("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + u := UUID{} + for i := 0; i < b.N; i++ { + u.UnmarshalText(bytes) + } +} + +func BenchmarkMarshalToString(b *testing.B) { + u := NewV4() + for i := 0; i < b.N; i++ { + u.String() + } +} diff --git a/Godeps/_workspace/src/github.com/satori/go.uuid/uuid.go b/Godeps/_workspace/src/github.com/satori/go.uuid/uuid.go new file mode 100644 index 0000000..b4dc4ea --- /dev/null +++ b/Godeps/_workspace/src/github.com/satori/go.uuid/uuid.go @@ -0,0 +1,429 @@ +// Copyright (C) 2013-2015 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Package uuid provides implementation of Universally Unique Identifier (UUID). +// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and +// version 2 (as specified in DCE 1.1). +package uuid + +import ( + "bytes" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "encoding/binary" + "encoding/hex" + "fmt" + "hash" + "net" + "os" + "sync" + "time" +) + +// UUID layout variants. +const ( + VariantNCS = iota + VariantRFC4122 + VariantMicrosoft + VariantFuture +) + +// UUID DCE domains. +const ( + DomainPerson = iota + DomainGroup + DomainOrg +) + +// Difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970). +const epochStart = 122192928000000000 + +// Used in string method conversion +const dash byte = '-' + +// UUID v1/v2 storage. +var ( + storageMutex sync.Mutex + storageOnce sync.Once + epochFunc = unixTimeFunc + clockSequence uint16 + lastTime uint64 + hardwareAddr [6]byte + posixUID = uint32(os.Getuid()) + posixGID = uint32(os.Getgid()) +) + +// String parse helpers. +var ( + urnPrefix = []byte("urn:uuid:") + byteGroups = []int{8, 4, 4, 4, 12} +) + +func initClockSequence() { + buf := make([]byte, 2) + safeRandom(buf) + clockSequence = binary.BigEndian.Uint16(buf) +} + +func initHardwareAddr() { + interfaces, err := net.Interfaces() + if err == nil { + for _, iface := range interfaces { + if len(iface.HardwareAddr) >= 6 { + copy(hardwareAddr[:], iface.HardwareAddr) + return + } + } + } + + // Initialize hardwareAddr randomly in case + // of real network interfaces absence + safeRandom(hardwareAddr[:]) + + // Set multicast bit as recommended in RFC 4122 + hardwareAddr[0] |= 0x01 +} + +func initStorage() { + initClockSequence() + initHardwareAddr() +} + +func safeRandom(dest []byte) { + if _, err := rand.Read(dest); err != nil { + panic(err) + } +} + +// Returns difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and current time. +// This is default epoch calculation function. +func unixTimeFunc() uint64 { + return epochStart + uint64(time.Now().UnixNano()/100) +} + +// UUID representation compliant with specification +// described in RFC 4122. +type UUID [16]byte + +// The nil UUID is special form of UUID that is specified to have all +// 128 bits set to zero. +var Nil = UUID{} + +// Predefined namespace UUIDs. +var ( + NamespaceDNS, _ = FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + NamespaceURL, _ = FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8") + NamespaceOID, _ = FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8") + NamespaceX500, _ = FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8") +) + +// And returns result of binary AND of two UUIDs. +func And(u1 UUID, u2 UUID) UUID { + u := UUID{} + for i := 0; i < 16; i++ { + u[i] = u1[i] & u2[i] + } + return u +} + +// Or returns result of binary OR of two UUIDs. +func Or(u1 UUID, u2 UUID) UUID { + u := UUID{} + for i := 0; i < 16; i++ { + u[i] = u1[i] | u2[i] + } + return u +} + +// Equal returns true if u1 and u2 equals, otherwise returns false. +func Equal(u1 UUID, u2 UUID) bool { + return bytes.Equal(u1[:], u2[:]) +} + +// Version returns algorithm version used to generate UUID. +func (u UUID) Version() uint { + return uint(u[6] >> 4) +} + +// Variant returns UUID layout variant. +func (u UUID) Variant() uint { + switch { + case (u[8] & 0x80) == 0x00: + return VariantNCS + case (u[8]&0xc0)|0x80 == 0x80: + return VariantRFC4122 + case (u[8]&0xe0)|0xc0 == 0xc0: + return VariantMicrosoft + } + return VariantFuture +} + +// Bytes returns bytes slice representation of UUID. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Returns canonical string representation of UUID: +// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. +func (u UUID) String() string { + buf := make([]byte, 36) + + hex.Encode(buf[0:8], u[0:4]) + buf[8] = dash + hex.Encode(buf[9:13], u[4:6]) + buf[13] = dash + hex.Encode(buf[14:18], u[6:8]) + buf[18] = dash + hex.Encode(buf[19:23], u[8:10]) + buf[23] = dash + hex.Encode(buf[24:], u[10:]) + + return string(buf) +} + +// SetVersion sets version bits. +func (u *UUID) SetVersion(v byte) { + u[6] = (u[6] & 0x0f) | (v << 4) +} + +// SetVariant sets variant bits as described in RFC 4122. +func (u *UUID) SetVariant() { + u[8] = (u[8] & 0xbf) | 0x80 +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The encoding is the same as returned by String. +func (u UUID) MarshalText() (text []byte, err error) { + text = []byte(u.String()) + return +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// Following formats are supported: +// "6ba7b810-9dad-11d1-80b4-00c04fd430c8", +// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}", +// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" +func (u *UUID) UnmarshalText(text []byte) (err error) { + if len(text) < 32 { + err = fmt.Errorf("uuid: invalid UUID string: %s", text) + return + } + + if bytes.Equal(text[:9], urnPrefix) { + text = text[9:] + } else if text[0] == '{' { + text = text[1:] + } + + b := u[:] + + for _, byteGroup := range byteGroups { + if text[0] == '-' { + text = text[1:] + } + + _, err = hex.Decode(b[:byteGroup/2], text[:byteGroup]) + + if err != nil { + return + } + + text = text[byteGroup:] + b = b[byteGroup/2:] + } + + return +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (u UUID) MarshalBinary() (data []byte, err error) { + data = u.Bytes() + return +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It will return error if the slice isn't 16 bytes long. +func (u *UUID) UnmarshalBinary(data []byte) (err error) { + if len(data) != 16 { + err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data)) + return + } + copy(u[:], data) + + return +} + +// Scan implements the sql.Scanner interface. +// A 16-byte slice is handled by UnmarshalBinary, while +// a longer byte slice or a string is handled by UnmarshalText. +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + if len(src) == 16 { + return u.UnmarshalBinary(src) + } + return u.UnmarshalText(src) + + case string: + return u.UnmarshalText([]byte(src)) + } + + return fmt.Errorf("uuid: cannot convert %T to UUID", src) +} + +// FromBytes returns UUID converted from raw byte slice input. +// It will return error if the slice isn't 16 bytes long. +func FromBytes(input []byte) (u UUID, err error) { + err = u.UnmarshalBinary(input) + return +} + +// FromBytesOrNil returns UUID converted from raw byte slice input. +// Same behavior as FromBytes, but returns a Nil UUID on error. +func FromBytesOrNil(input []byte) UUID { + uuid, err := FromBytes(input) + if err != nil { + return Nil + } + return uuid +} + +// FromString returns UUID parsed from string input. +// Input is expected in a form accepted by UnmarshalText. +func FromString(input string) (u UUID, err error) { + err = u.UnmarshalText([]byte(input)) + return +} + +// FromStringOrNil returns UUID parsed from string input. +// Same behavior as FromString, but returns a Nil UUID on error. +func FromStringOrNil(input string) UUID { + uuid, err := FromString(input) + if err != nil { + return Nil + } + return uuid +} + +// Returns UUID v1/v2 storage state. +// Returns epoch timestamp, clock sequence, and hardware address. +func getStorage() (uint64, uint16, []byte) { + storageOnce.Do(initStorage) + + storageMutex.Lock() + defer storageMutex.Unlock() + + timeNow := epochFunc() + // Clock changed backwards since last UUID generation. + // Should increase clock sequence. + if timeNow <= lastTime { + clockSequence++ + } + lastTime = timeNow + + return timeNow, clockSequence, hardwareAddr[:] +} + +// NewV1 returns UUID based on current timestamp and MAC address. +func NewV1() UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := getStorage() + + binary.BigEndian.PutUint32(u[0:], uint32(timeNow)) + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + + copy(u[10:], hardwareAddr) + + u.SetVersion(1) + u.SetVariant() + + return u +} + +// NewV2 returns DCE Security UUID based on POSIX UID/GID. +func NewV2(domain byte) UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := getStorage() + + switch domain { + case DomainPerson: + binary.BigEndian.PutUint32(u[0:], posixUID) + case DomainGroup: + binary.BigEndian.PutUint32(u[0:], posixGID) + } + + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + u[9] = domain + + copy(u[10:], hardwareAddr) + + u.SetVersion(2) + u.SetVariant() + + return u +} + +// NewV3 returns UUID based on MD5 hash of namespace UUID and name. +func NewV3(ns UUID, name string) UUID { + u := newFromHash(md5.New(), ns, name) + u.SetVersion(3) + u.SetVariant() + + return u +} + +// NewV4 returns random generated UUID. +func NewV4() UUID { + u := UUID{} + safeRandom(u[:]) + u.SetVersion(4) + u.SetVariant() + + return u +} + +// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name. +func NewV5(ns UUID, name string) UUID { + u := newFromHash(sha1.New(), ns, name) + u.SetVersion(5) + u.SetVariant() + + return u +} + +// Returns UUID based on hashing of namespace UUID and name. +func newFromHash(h hash.Hash, ns UUID, name string) UUID { + u := UUID{} + h.Write(ns[:]) + h.Write([]byte(name)) + copy(u[:], h.Sum(nil)) + + return u +} diff --git a/Godeps/_workspace/src/github.com/satori/go.uuid/uuid_test.go b/Godeps/_workspace/src/github.com/satori/go.uuid/uuid_test.go new file mode 100644 index 0000000..c77d2d3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/satori/go.uuid/uuid_test.go @@ -0,0 +1,492 @@ +// Copyright (C) 2013, 2015 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package uuid + +import ( + "bytes" + "testing" +) + +func TestBytes(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + bytes1 := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + if !bytes.Equal(u.Bytes(), bytes1) { + t.Errorf("Incorrect bytes representation for UUID: %s", u) + } +} + +func TestString(t *testing.T) { + if NamespaceDNS.String() != "6ba7b810-9dad-11d1-80b4-00c04fd430c8" { + t.Errorf("Incorrect string representation for UUID: %s", NamespaceDNS.String()) + } +} + +func TestEqual(t *testing.T) { + if !Equal(NamespaceDNS, NamespaceDNS) { + t.Errorf("Incorrect comparison of %s and %s", NamespaceDNS, NamespaceDNS) + } + + if Equal(NamespaceDNS, NamespaceURL) { + t.Errorf("Incorrect comparison of %s and %s", NamespaceDNS, NamespaceURL) + } +} + +func TestOr(t *testing.T) { + u1 := UUID{0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff} + u2 := UUID{0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00} + + u := UUID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + + if !Equal(u, Or(u1, u2)) { + t.Errorf("Incorrect bitwise OR result %s", Or(u1, u2)) + } +} + +func TestAnd(t *testing.T) { + u1 := UUID{0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff} + u2 := UUID{0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00, 0xff, 0x00} + + u := UUID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + if !Equal(u, And(u1, u2)) { + t.Errorf("Incorrect bitwise AND result %s", And(u1, u2)) + } +} + +func TestVersion(t *testing.T) { + u := UUID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + if u.Version() != 1 { + t.Errorf("Incorrect version for UUID: %d", u.Version()) + } +} + +func TestSetVersion(t *testing.T) { + u := UUID{} + u.SetVersion(4) + + if u.Version() != 4 { + t.Errorf("Incorrect version for UUID after u.setVersion(4): %d", u.Version()) + } +} + +func TestVariant(t *testing.T) { + u1 := UUID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + if u1.Variant() != VariantNCS { + t.Errorf("Incorrect variant for UUID variant %d: %d", VariantNCS, u1.Variant()) + } + + u2 := UUID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + if u2.Variant() != VariantRFC4122 { + t.Errorf("Incorrect variant for UUID variant %d: %d", VariantRFC4122, u2.Variant()) + } + + u3 := UUID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + if u3.Variant() != VariantMicrosoft { + t.Errorf("Incorrect variant for UUID variant %d: %d", VariantMicrosoft, u3.Variant()) + } + + u4 := UUID{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + if u4.Variant() != VariantFuture { + t.Errorf("Incorrect variant for UUID variant %d: %d", VariantFuture, u4.Variant()) + } +} + +func TestSetVariant(t *testing.T) { + u := new(UUID) + u.SetVariant() + + if u.Variant() != VariantRFC4122 { + t.Errorf("Incorrect variant for UUID after u.setVariant(): %d", u.Variant()) + } +} + +func TestFromBytes(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + u1, err := FromBytes(b1) + if err != nil { + t.Errorf("Error parsing UUID from bytes: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + b2 := []byte{} + + _, err = FromBytes(b2) + if err == nil { + t.Errorf("Should return error parsing from empty byte slice, got %s", err) + } +} + +func TestMarshalBinary(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + b2, err := u.MarshalBinary() + if err != nil { + t.Errorf("Error marshaling UUID: %s", err) + } + + if !bytes.Equal(b1, b2) { + t.Errorf("Marshaled UUID should be %s, got %s", b1, b2) + } +} + +func TestUnmarshalBinary(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + u1 := UUID{} + err := u1.UnmarshalBinary(b1) + if err != nil { + t.Errorf("Error unmarshaling UUID: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + b2 := []byte{} + u2 := UUID{} + + err = u2.UnmarshalBinary(b2) + if err == nil { + t.Errorf("Should return error unmarshalling from empty byte slice, got %s", err) + } +} + +func TestFromString(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + s1 := "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + s2 := "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}" + s3 := "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" + + _, err := FromString("") + if err == nil { + t.Errorf("Should return error trying to parse empty string, got %s", err) + } + + u1, err := FromString(s1) + if err != nil { + t.Errorf("Error parsing UUID from string: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + u2, err := FromString(s2) + if err != nil { + t.Errorf("Error parsing UUID from string: %s", err) + } + + if !Equal(u, u2) { + t.Errorf("UUIDs should be equal: %s and %s", u, u2) + } + + u3, err := FromString(s3) + if err != nil { + t.Errorf("Error parsing UUID from string: %s", err) + } + + if !Equal(u, u3) { + t.Errorf("UUIDs should be equal: %s and %s", u, u3) + } +} + +func TestFromStringOrNil(t *testing.T) { + u := FromStringOrNil("") + if u != Nil { + t.Errorf("Should return Nil UUID on parse failure, got %s", u) + } +} + +func TestFromBytesOrNil(t *testing.T) { + b := []byte{} + u := FromBytesOrNil(b) + if u != Nil { + t.Errorf("Should return Nil UUID on parse failure, got %s", u) + } +} + +func TestMarshalText(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + + b2, err := u.MarshalText() + if err != nil { + t.Errorf("Error marshaling UUID: %s", err) + } + + if !bytes.Equal(b1, b2) { + t.Errorf("Marshaled UUID should be %s, got %s", b1, b2) + } +} + +func TestUnmarshalText(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + + u1 := UUID{} + err := u1.UnmarshalText(b1) + if err != nil { + t.Errorf("Error unmarshaling UUID: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + b2 := []byte("") + u2 := UUID{} + + err = u2.UnmarshalText(b2) + if err == nil { + t.Errorf("Should return error trying to unmarshal from empty string") + } +} + +func TestScanBinary(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + + u1 := UUID{} + err := u1.Scan(b1) + if err != nil { + t.Errorf("Error unmarshaling UUID: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + b2 := []byte{} + u2 := UUID{} + + err = u2.Scan(b2) + if err == nil { + t.Errorf("Should return error unmarshalling from empty byte slice, got %s", err) + } +} + +func TestScanString(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + s1 := "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + + u1 := UUID{} + err := u1.Scan(s1) + if err != nil { + t.Errorf("Error unmarshaling UUID: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + s2 := "" + u2 := UUID{} + + err = u2.Scan(s2) + if err == nil { + t.Errorf("Should return error trying to unmarshal from empty string") + } +} + +func TestScanText(t *testing.T) { + u := UUID{0x6b, 0xa7, 0xb8, 0x10, 0x9d, 0xad, 0x11, 0xd1, 0x80, 0xb4, 0x00, 0xc0, 0x4f, 0xd4, 0x30, 0xc8} + b1 := []byte("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + + u1 := UUID{} + err := u1.Scan(b1) + if err != nil { + t.Errorf("Error unmarshaling UUID: %s", err) + } + + if !Equal(u, u1) { + t.Errorf("UUIDs should be equal: %s and %s", u, u1) + } + + b2 := []byte("") + u2 := UUID{} + + err = u2.Scan(b2) + if err == nil { + t.Errorf("Should return error trying to unmarshal from empty string") + } +} + +func TestScanUnsupported(t *testing.T) { + u := UUID{} + + err := u.Scan(true) + if err == nil { + t.Errorf("Should return error trying to unmarshal from bool") + } +} + +func TestNewV1(t *testing.T) { + u := NewV1() + + if u.Version() != 1 { + t.Errorf("UUIDv1 generated with incorrect version: %d", u.Version()) + } + + if u.Variant() != VariantRFC4122 { + t.Errorf("UUIDv1 generated with incorrect variant: %d", u.Variant()) + } + + u1 := NewV1() + u2 := NewV1() + + if Equal(u1, u2) { + t.Errorf("UUIDv1 generated two equal UUIDs: %s and %s", u1, u2) + } + + oldFunc := epochFunc + epochFunc = func() uint64 { return 0 } + + u3 := NewV1() + u4 := NewV1() + + if Equal(u3, u4) { + t.Errorf("UUIDv1 generated two equal UUIDs: %s and %s", u3, u4) + } + + epochFunc = oldFunc +} + +func TestNewV2(t *testing.T) { + u1 := NewV2(DomainPerson) + + if u1.Version() != 2 { + t.Errorf("UUIDv2 generated with incorrect version: %d", u1.Version()) + } + + if u1.Variant() != VariantRFC4122 { + t.Errorf("UUIDv2 generated with incorrect variant: %d", u1.Variant()) + } + + u2 := NewV2(DomainGroup) + + if u2.Version() != 2 { + t.Errorf("UUIDv2 generated with incorrect version: %d", u2.Version()) + } + + if u2.Variant() != VariantRFC4122 { + t.Errorf("UUIDv2 generated with incorrect variant: %d", u2.Variant()) + } +} + +func TestNewV3(t *testing.T) { + u := NewV3(NamespaceDNS, "www.example.com") + + if u.Version() != 3 { + t.Errorf("UUIDv3 generated with incorrect version: %d", u.Version()) + } + + if u.Variant() != VariantRFC4122 { + t.Errorf("UUIDv3 generated with incorrect variant: %d", u.Variant()) + } + + if u.String() != "5df41881-3aed-3515-88a7-2f4a814cf09e" { + t.Errorf("UUIDv3 generated incorrectly: %s", u.String()) + } + + u = NewV3(NamespaceDNS, "python.org") + + if u.String() != "6fa459ea-ee8a-3ca4-894e-db77e160355e" { + t.Errorf("UUIDv3 generated incorrectly: %s", u.String()) + } + + u1 := NewV3(NamespaceDNS, "golang.org") + u2 := NewV3(NamespaceDNS, "golang.org") + if !Equal(u1, u2) { + t.Errorf("UUIDv3 generated different UUIDs for same namespace and name: %s and %s", u1, u2) + } + + u3 := NewV3(NamespaceDNS, "example.com") + if Equal(u1, u3) { + t.Errorf("UUIDv3 generated same UUIDs for different names in same namespace: %s and %s", u1, u2) + } + + u4 := NewV3(NamespaceURL, "golang.org") + if Equal(u1, u4) { + t.Errorf("UUIDv3 generated same UUIDs for sane names in different namespaces: %s and %s", u1, u4) + } +} + +func TestNewV4(t *testing.T) { + u := NewV4() + + if u.Version() != 4 { + t.Errorf("UUIDv4 generated with incorrect version: %d", u.Version()) + } + + if u.Variant() != VariantRFC4122 { + t.Errorf("UUIDv4 generated with incorrect variant: %d", u.Variant()) + } +} + +func TestNewV5(t *testing.T) { + u := NewV5(NamespaceDNS, "www.example.com") + + if u.Version() != 5 { + t.Errorf("UUIDv5 generated with incorrect version: %d", u.Version()) + } + + if u.Variant() != VariantRFC4122 { + t.Errorf("UUIDv5 generated with incorrect variant: %d", u.Variant()) + } + + u = NewV5(NamespaceDNS, "python.org") + + if u.String() != "886313e1-3b8a-5372-9b90-0c9aee199e5d" { + t.Errorf("UUIDv5 generated incorrectly: %s", u.String()) + } + + u1 := NewV5(NamespaceDNS, "golang.org") + u2 := NewV5(NamespaceDNS, "golang.org") + if !Equal(u1, u2) { + t.Errorf("UUIDv5 generated different UUIDs for same namespace and name: %s and %s", u1, u2) + } + + u3 := NewV5(NamespaceDNS, "example.com") + if Equal(u1, u3) { + t.Errorf("UUIDv5 generated same UUIDs for different names in same namespace: %s and %s", u1, u2) + } + + u4 := NewV5(NamespaceURL, "golang.org") + if Equal(u1, u4) { + t.Errorf("UUIDv3 generated same UUIDs for sane names in different namespaces: %s and %s", u1, u4) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/.travis.yml b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/.travis.yml new file mode 100644 index 0000000..14543fd --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/.travis.yml @@ -0,0 +1,7 @@ +language: go +script: make testall +go: + - 1.4 + - 1.3 + - 1.2 + - tip diff --git a/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/Makefile b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/Makefile new file mode 100644 index 0000000..fb960c6 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/Makefile @@ -0,0 +1,13 @@ +default: test + +testdeps: + @go get github.com/onsi/ginkgo + @go get github.com/onsi/gomega + +test: testdeps + @go test ./... + +testrace: testdeps + @go test ./... -race + +testall: test testrace diff --git a/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/README.md b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/README.md new file mode 100644 index 0000000..538eec8 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/README.md @@ -0,0 +1,54 @@ +# RateLimit [![Build Status](https://travis-ci.org/bsm/ratelimit.png?branch=master)](https://travis-ci.org/bsm/ratelimit) + +Simple, thread-safe Go rate-limiter. +Inspired by Antti Huima's algorithm on http://stackoverflow.com/a/668327 + +### Example + +```go +package main + +import ( + "github.com/bsm/redeo" + "log" +) + +func main() { + // Create a new rate-limiter, allowing up-to 10 calls + // per second + rl := ratelimit.New(10, time.Second) + + for i:=0; i<20; i++ { + if rl.Limit() { + fmt.Println("DOH! Over limit!") + } else { + fmt.Println("OK") + } + } +} +``` + +### Licence + +``` +Copyright (c) 2015 Black Square Media + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +``` diff --git a/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/ratelimit.go b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/ratelimit.go new file mode 100644 index 0000000..5808cdb --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/ratelimit.go @@ -0,0 +1,83 @@ +/* +Simple, thread-safe Go rate-limiter. +Inspired by Antti Huima's algorithm on http://stackoverflow.com/a/668327 + +Example: + + // Create a new rate-limiter, allowing up-to 10 calls + // per second + rl := ratelimit.New(10, time.Second) + + for i:=0; i<20; i++ { + if rl.Limit() { + fmt.Println("DOH! Over limit!") + } else { + fmt.Println("OK") + } + } +*/ +package ratelimit + +import ( + "sync/atomic" + "time" +) + +// RateLimit instances are thread-safe. +type RateLimiter struct { + allowance, max, unit, lastCheck uint64 +} + +// New creates a new rate limiter instance +func New(rate int, per time.Duration) *RateLimiter { + nano := uint64(per) + if nano < 1 { + nano = uint64(time.Second) + } + if rate < 1 { + rate = 1 + } + + return &RateLimiter{ + allowance: uint64(rate) * nano, // store our allowance, in ns units + max: uint64(rate) * nano, // remember our maximum allowance + unit: nano, // remember our unit size + + lastCheck: uint64(time.Now().UnixNano()), + } +} + +// Limit returns true if rate was exceeded +func (rl *RateLimiter) Limit() bool { + // Calculate the number of ns that have passed since our last call + now := uint64(time.Now().UnixNano()) + passed := now - atomic.SwapUint64(&rl.lastCheck, now) + + // Add them to our allowance + current := atomic.AddUint64(&rl.allowance, passed) + + // Ensure our allowance is not over maximum + if current > rl.max { + atomic.AddUint64(&rl.allowance, rl.max-current) + current = rl.max + } + + // If our allowance is less than one unit, rate-limit! + if current < rl.unit { + return true + } + + // Not limited, subtract a unit + atomic.AddUint64(&rl.allowance, -rl.unit) + return false +} + +// Undo reverts the last Limit() call, returning consumed allowance +func (rl *RateLimiter) Undo() { + current := atomic.AddUint64(&rl.allowance, rl.unit) + + // Ensure our allowance is not over maximum + if current > rl.max { + atomic.AddUint64(&rl.allowance, rl.max-current) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/ratelimit_test.go b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/ratelimit_test.go new file mode 100644 index 0000000..234e289 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/bsm/ratelimit.v1/ratelimit_test.go @@ -0,0 +1,81 @@ +package ratelimit + +import ( + "sync" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RateLimiter", func() { + + It("should accurately rate-limit at small rates", func() { + n := 10 + rl := New(n, time.Minute) + for i := 0; i < n; i++ { + Expect(rl.Limit()).To(BeFalse(), "on cycle %d", i) + } + Expect(rl.Limit()).To(BeTrue()) + }) + + It("should accurately rate-limit at large rates", func() { + n := 100000 + rl := New(n, time.Hour) + for i := 0; i < n; i++ { + Expect(rl.Limit()).To(BeFalse(), "on cycle %d", i) + } + Expect(rl.Limit()).To(BeTrue()) + }) + + It("should correctly increase allowance", func() { + n := 25 + rl := New(n, 50*time.Millisecond) + for i := 0; i < n; i++ { + Expect(rl.Limit()).To(BeFalse(), "on cycle %d", i) + } + Expect(rl.Limit()).To(BeTrue()) + Eventually(rl.Limit, "60ms", "10ms").Should(BeFalse()) + }) + + It("should undo", func() { + rl := New(1, time.Minute) + Expect(rl.Limit()).To(BeFalse()) + Expect(rl.Limit()).To(BeTrue()) + Expect(rl.Limit()).To(BeTrue()) + rl.Undo() + rl.Undo() + Expect(rl.Limit()).To(BeFalse()) + Expect(rl.Limit()).To(BeTrue()) + }) + + It("should be thread-safe", func() { + c := 10 + n := 10000 + wg := sync.WaitGroup{} + rl := New(c*n, time.Minute) + for i := 0; i < c; i++ { + wg.Add(1) + + go func(thread int) { + defer GinkgoRecover() + defer wg.Done() + + for j := 0; j < n; j++ { + Expect(rl.Limit()).To(BeFalse(), "thread %d, cycle %d", thread, j) + } + }(i) + } + wg.Wait() + Expect(rl.Limit()).To(BeTrue()) + }) + +}) + +// -------------------------------------------------------------------- + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "github.com/bsm/ratelimit") +} diff --git a/Godeps/_workspace/src/gopkg.in/ini.v1/README.md b/Godeps/_workspace/src/gopkg.in/ini.v1/README.md index 6d77181..d006003 100644 --- a/Godeps/_workspace/src/gopkg.in/ini.v1/README.md +++ b/Godeps/_workspace/src/gopkg.in/ini.v1/README.md @@ -144,7 +144,7 @@ v = cfg.Section("").Key("TIME").MustTime() // RFC3339 // when key not found or fail to parse value to given type. // Except method MustString, which you have to pass a default value. -v = cfg.Seciont("").Key("String").MustString("default") +v = cfg.Section("").Key("String").MustString("default") v = cfg.Section("").Key("BOOL").MustBool(true) v = cfg.Section("").Key("FLOAT64").MustFloat64(1.25) v = cfg.Section("").Key("INT").MustInt(10) @@ -174,6 +174,32 @@ Earth ------ end --- */ ``` +That's cool, how about continuation lines? + +```ini +[advance] +two_lines = how about \ + continuation lines? +lots_of_lines = 1 \ + 2 \ + 3 \ + 4 +``` + +Piece of cake! + +```go +cfg.Section("advance").Key("two_lines").String() // how about continuation lines? +cfg.Section("advance").Key("lots_of_lines").String() // 1 2 3 4 +``` + +Note that single quotes around values will be stripped: + +```ini +foo = "some value" // foo: some value +bar = 'some value' // bar: some value +``` + That's all? Hmm, no. #### Helper methods of working with values @@ -329,7 +355,7 @@ p := &Person{ #### Name Mapper -To save your time and make your code cleaner, this library supports [`NameMapper`](https://gowalker.org/gopkg.in/ini.v1#NameMapper) between struct field and actual secion and key name. +To save your time and make your code cleaner, this library supports [`NameMapper`](https://gowalker.org/gopkg.in/ini.v1#NameMapper) between struct field and actual section and key name. There are 2 built-in name mappers: @@ -347,7 +373,7 @@ func main() { err = ini.MapToWithMapper(&Info{}, ini.TitleUnderscore, []byte("packag_name=ini")) // ... - cfg, err := ini.Load("PACKAGE_NAME=ini") + cfg, err := ini.Load([]byte("PACKAGE_NAME=ini")) // ... info := new(Info) cfg.NameMapper = ini.AllCapsUnderscore diff --git a/Godeps/_workspace/src/gopkg.in/ini.v1/README_ZH.md b/Godeps/_workspace/src/gopkg.in/ini.v1/README_ZH.md index c455cb6..2ef3c4d 100644 --- a/Godeps/_workspace/src/gopkg.in/ini.v1/README_ZH.md +++ b/Godeps/_workspace/src/gopkg.in/ini.v1/README_ZH.md @@ -169,6 +169,32 @@ Earth ------ end --- */ ``` +赞爆了!那要是我属于一行的内容写不下想要写到第二行怎么办? + +```ini +[advance] +two_lines = how about \ + continuation lines? +lots_of_lines = 1 \ + 2 \ + 3 \ + 4 +``` + +简直是小菜一碟! + +```go +cfg.Section("advance").Key("two_lines").String() // how about continuation lines? +cfg.Section("advance").Key("lots_of_lines").String() // 1 2 3 4 +``` + +需要注意的是,值两侧的单引号会被自动剔除: + +```ini +foo = "some value" // foo: some value +bar = 'some value' // bar: some value +``` + 这就是全部了?哈哈,当然不是。 #### 操作键值的辅助方法 @@ -340,7 +366,7 @@ func main() { err = ini.MapToWithMapper(&Info{}, ini.TitleUnderscore, []byte("packag_name=ini")) // ... - cfg, err := ini.Load("PACKAGE_NAME=ini") + cfg, err := ini.Load([]byte("PACKAGE_NAME=ini")) // ... info := new(Info) cfg.NameMapper = ini.AllCapsUnderscore diff --git a/Godeps/_workspace/src/gopkg.in/ini.v1/ini.go b/Godeps/_workspace/src/gopkg.in/ini.v1/ini.go index 6674baf..4b74eed 100644 --- a/Godeps/_workspace/src/gopkg.in/ini.v1/ini.go +++ b/Godeps/_workspace/src/gopkg.in/ini.v1/ini.go @@ -35,7 +35,7 @@ const ( // Maximum allowed depth when recursively substituing variable names. _DEPTH_VALUES = 99 - _VERSION = "1.2.6" + _VERSION = "1.3.4" ) func Version() string { @@ -179,6 +179,11 @@ func (k *Key) Int64() (int64, error) { return strconv.ParseInt(k.String(), 10, 64) } +// Duration returns time.Duration type value. +func (k *Key) Duration() (time.Duration, error) { + return time.ParseDuration(k.String()) +} + // TimeFormat parses with given format and returns time.Time type value. func (k *Key) TimeFormat(format string) (time.Time, error) { return time.Parse(format, k.String()) @@ -238,6 +243,16 @@ func (k *Key) MustInt64(defaultVal ...int64) int64 { return val } +// MustDuration always returns value without error, +// it returns zero value if error occurs. +func (k *Key) MustDuration(defaultVal ...time.Duration) time.Duration { + val, err := k.Duration() + if len(defaultVal) > 0 && err != nil { + return defaultVal[0] + } + return val +} + // MustTimeFormat always parses with given format and returns value without error, // it returns zero value if error occurs. func (k *Key) MustTimeFormat(format string, defaultVal ...time.Time) time.Time { @@ -483,10 +498,12 @@ func (s *Section) GetKey(name string) (*Key, error) { // FIXME: change to section level lock? if s.f.BlockMode { s.f.lock.RLock() - defer s.f.lock.RUnlock() } - key := s.keys[name] + if s.f.BlockMode { + s.f.lock.RUnlock() + } + if key == nil { // Check if it is a child-section. if i := strings.LastIndex(s.name, "."); i > -1 { @@ -730,6 +747,55 @@ func cutComment(str string) string { return str[:i] } +func checkMultipleLines(buf *bufio.Reader, line, val, valQuote string) (string, error) { + isEnd := false + for { + next, err := buf.ReadString('\n') + if err != nil { + if err != io.EOF { + return "", err + } + isEnd = true + } + pos := strings.LastIndex(next, valQuote) + if pos > -1 { + val += next[:pos] + break + } + val += next + if isEnd { + return "", fmt.Errorf("error parsing line: missing closing key quote from '%s' to '%s'", line, next) + } + } + return val, nil +} + +func checkContinuationLines(buf *bufio.Reader, val string) (string, bool, error) { + isEnd := false + for { + valLen := len(val) + if valLen == 0 || val[valLen-1] != '\\' { + break + } + val = val[:valLen-1] + + next, err := buf.ReadString('\n') + if err != nil { + if err != io.EOF { + return "", isEnd, err + } + isEnd = true + } + + next = strings.TrimSpace(next) + if len(next) == 0 { + break + } + val += next + } + return val, isEnd, nil +} + // parse parses data through an io.Reader. func (f *File) parse(reader io.Reader) error { buf := bufio.NewReader(reader) @@ -781,8 +847,7 @@ func (f *File) parse(reader io.Reader) error { } continue case line[0] == '[' && line[length-1] == ']': // New sction. - name := strings.TrimSpace(line[1 : length-1]) - section, err = f.NewSection(name) + section, err = f.NewSection(strings.TrimSpace(line[1 : length-1])) if err != nil { return err } @@ -856,39 +921,39 @@ func (f *File) parse(reader io.Reader) error { } if firstChar == "`" { valQuote = "`" - } else if lineRightLength >= 6 && lineRight[0:3] == `"""` { - valQuote = `"""` + } else if firstChar == `"` { + if lineRightLength >= 3 && lineRight[0:3] == `"""` { + valQuote = `"""` + } else { + valQuote = `"` + } + } else if firstChar == `'` { + valQuote = `'` } + if len(valQuote) > 0 { qLen := len(valQuote) pos := strings.LastIndex(lineRight[qLen:], valQuote) - // For multiple lines value. + // For multiple-line value check. if pos == -1 { - isEnd := false + if valQuote == `"` || valQuote == `'` { + return fmt.Errorf("error parsing line: single quote does not allow multiple-line value: %s", line) + } + val = lineRight[qLen:] + "\n" - for { - next, err := buf.ReadString('\n') - if err != nil { - if err != io.EOF { - return err - } - isEnd = true - } - pos = strings.LastIndex(next, valQuote) - if pos > -1 { - val += next[:pos] - break - } - val += next - if isEnd { - return fmt.Errorf("error parsing line: missing closing key quote from '%s' to '%s'", line, next) - } + val, err = checkMultipleLines(buf, line, val, valQuote) + if err != nil { + return err } } else { val = lineRight[qLen : pos+qLen] } } else { val = strings.TrimSpace(cutComment(lineRight[0:])) + val, isEnd, err = checkContinuationLines(buf, val) + if err != nil { + return err + } } k, err := section.NewKey(kname, val) @@ -939,8 +1004,8 @@ func (f *File) Append(source interface{}, others ...interface{}) error { return f.Reload() } -// SaveTo writes content to filesystem. -func (f *File) SaveTo(filename string) (err error) { +// WriteTo writes file content into io.Writer. +func (f *File) WriteTo(w io.Writer) (n int64, err error) { equalSign := "=" if PrettyFormat { equalSign = " = " @@ -955,13 +1020,13 @@ func (f *File) SaveTo(filename string) (err error) { sec.Comment = "; " + sec.Comment } if _, err = buf.WriteString(sec.Comment + LineBreak); err != nil { - return err + return 0, err } } if i > 0 { if _, err = buf.WriteString("[" + sname + "]" + LineBreak); err != nil { - return err + return 0, err } } else { // Write nothing if default section is empty. @@ -977,7 +1042,7 @@ func (f *File) SaveTo(filename string) (err error) { key.Comment = "; " + key.Comment } if _, err = buf.WriteString(key.Comment + LineBreak); err != nil { - return err + return 0, err } } @@ -992,26 +1057,32 @@ func (f *File) SaveTo(filename string) (err error) { val := key.value // In case key value contains "\n", "`" or "\"". - if strings.Contains(val, "\n") || strings.Contains(val, "`") || strings.Contains(val, `"`) { + if strings.Contains(val, "\n") || strings.Contains(val, "`") || strings.Contains(val, `"`) || + strings.Contains(val, "#") { val = `"""` + val + `"""` } if _, err = buf.WriteString(kname + equalSign + val + LineBreak); err != nil { - return err + return 0, err } } // Put a line between sections. if _, err = buf.WriteString(LineBreak); err != nil { - return err + return 0, err } } + return buf.WriteTo(w) +} + +// SaveTo writes content to file system. +func (f *File) SaveTo(filename string) error { fw, err := os.Create(filename) if err != nil { return err } - if _, err = buf.WriteTo(fw); err != nil { - return err - } - return fw.Close() + defer fw.Close() + + _, err = f.WriteTo(fw) + return err } diff --git a/Godeps/_workspace/src/gopkg.in/ini.v1/ini_test.go b/Godeps/_workspace/src/gopkg.in/ini.v1/ini_test.go index c6daf81..72a09ee 100644 --- a/Godeps/_workspace/src/gopkg.in/ini.v1/ini_test.go +++ b/Godeps/_workspace/src/gopkg.in/ini.v1/ini_test.go @@ -66,6 +66,7 @@ BOOL_FALSE = false FLOAT64 = 1.25 INT = 10 TIME = 2015-01-01T20:17:05Z +DURATION = 2h45m [array] STRINGS = en, zh, de @@ -74,8 +75,11 @@ INTS = 1, 2, 3 TIMES = 2015-01-01T20:17:05Z,2015-01-01T20:17:05Z,2015-01-01T20:17:05Z [note] +empty_lines = next line is empty\ [advance] +value with quotes = "some value" +value quote2 again = 'some value' true = """"2+3=5"""" "1+1=2" = true """6+1=7""" = true @@ -83,7 +87,15 @@ true = """"2+3=5"""" """"6+6"""" = 12 ` + "`" + `7-2=4` + "`" + ` = false ADDRESS = ` + "`" + `404 road, -NotFound, State, 50000` + "`" +NotFound, State, 50000` + "`" + ` + +two_lines = how about \ + continuation lines? +lots_of_lines = 1 \ + 2 \ + 3 \ + 4 \ +` func Test_Load(t *testing.T) { Convey("Load from data sources", t, func() { @@ -140,6 +152,9 @@ func Test_Load(t *testing.T) { Convey("Load with bad values", func() { _, err := Load([]byte(`name="""Unknwon`)) So(err, ShouldNotBeNil) + + _, err = Load([]byte(`key = "value`)) + So(err, ShouldNotBeNil) }) }) } @@ -232,6 +247,10 @@ func Test_Values(t *testing.T) { So(sec.Key("INT").MustInt64(), ShouldEqual, 10) So(sec.Key("TIME").MustTime().String(), ShouldEqual, t.String()) + dur, err := time.ParseDuration("2h45m") + So(err, ShouldBeNil) + So(sec.Key("DURATION").MustDuration().Seconds(), ShouldEqual, dur.Seconds()) + Convey("Must get values with default value", func() { So(sec.Key("STRING_404").MustString("404"), ShouldEqual, "404") So(sec.Key("BOOL_404").MustBool(true), ShouldBeTrue) @@ -242,6 +261,8 @@ func Test_Values(t *testing.T) { t, err := time.Parse(time.RFC3339, "2014-01-01T20:17:05Z") So(err, ShouldBeNil) So(sec.Key("TIME_404").MustTime(t).String(), ShouldEqual, t.String()) + + So(sec.Key("DURATION_404").MustDuration(dur).Seconds(), ShouldEqual, dur.Seconds()) }) }) }) @@ -331,7 +352,7 @@ func Test_Values(t *testing.T) { }) Convey("Get key strings", func() { - So(strings.Join(cfg.Section("types").KeyStrings(), ","), ShouldEqual, "STRING,BOOL,BOOL_FALSE,FLOAT64,INT,TIME") + So(strings.Join(cfg.Section("types").KeyStrings(), ","), ShouldEqual, "STRING,BOOL,BOOL_FALSE,FLOAT64,INT,TIME,DURATION") }) Convey("Delete a key", func() { @@ -414,6 +435,7 @@ func Test_File_SaveTo(t *testing.T) { cfg.Section("").Key("NAME").Comment = "Package name" cfg.Section("author").Comment = `Information about package author # Bio can be written in multiple lines.` + cfg.Section("advanced").Key("val w/ pound").SetValue("my#password") So(cfg.SaveTo("testdata/conf_out.ini"), ShouldBeNil) }) } diff --git a/Godeps/_workspace/src/gopkg.in/ini.v1/struct.go b/Godeps/_workspace/src/gopkg.in/ini.v1/struct.go index 09ea816..d9e010e 100644 --- a/Godeps/_workspace/src/gopkg.in/ini.v1/struct.go +++ b/Godeps/_workspace/src/gopkg.in/ini.v1/struct.go @@ -78,8 +78,8 @@ var reflectTime = reflect.TypeOf(time.Now()).Kind() // setWithProperType sets proper value to field based on its type, // but it does not return error for failing parsing, // because we want to use default value that is already assigned to strcut. -func setWithProperType(kind reflect.Kind, key *Key, field reflect.Value, delim string) error { - switch kind { +func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim string) error { + switch t.Kind() { case reflect.String: if len(key.String()) == 0 { return nil @@ -92,6 +92,12 @@ func setWithProperType(kind reflect.Kind, key *Key, field reflect.Value, delim s } field.SetBool(boolVal) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + durationVal, err := key.Duration() + if err == nil { + field.Set(reflect.ValueOf(durationVal)) + return nil + } + intVal, err := key.Int64() if err != nil { return nil @@ -134,7 +140,7 @@ func setWithProperType(kind reflect.Kind, key *Key, field reflect.Value, delim s } field.Set(slice) default: - return fmt.Errorf("unsupported type '%s'", kind) + return fmt.Errorf("unsupported type '%s'", t) } return nil } @@ -177,7 +183,7 @@ func (s *Section) mapTo(val reflect.Value) error { } if key, err := s.GetKey(fieldName); err == nil { - if err = setWithProperType(tpField.Type.Kind(), key, field, parseDelim(tpField.Tag.Get("delim"))); err != nil { + if err = setWithProperType(tpField.Type, key, field, parseDelim(tpField.Tag.Get("delim"))); err != nil { return fmt.Errorf("error mapping field(%s): %v", fieldName, err) } } diff --git a/Godeps/_workspace/src/gopkg.in/ini.v1/struct_test.go b/Godeps/_workspace/src/gopkg.in/ini.v1/struct_test.go index f6fad19..8938b62 100644 --- a/Godeps/_workspace/src/gopkg.in/ini.v1/struct_test.go +++ b/Godeps/_workspace/src/gopkg.in/ini.v1/struct_test.go @@ -39,6 +39,7 @@ type testStruct struct { Male bool Money float64 Born time.Time + Time time.Duration `ini:"Duration"` Others testNested *testEmbeded `ini:"grade"` Unused int `ini:"-"` @@ -50,6 +51,7 @@ Age = 21 Male = true Money = 1.25 Born = 1993-10-07T20:17:05Z +Duration = 2h45m [Others] Cities = HangZhou|Boston @@ -58,6 +60,10 @@ Note = Hello world! [grade] GPA = 2.8 + +[foo.bar] +Here = there +When = then ` type unsupport struct { @@ -87,6 +93,10 @@ type defaultValue struct { Cities []string } +type fooBar struct { + Here, When string +} + const _INVALID_DATA_CONF_STRUCT = ` Name = Age = age @@ -110,12 +120,26 @@ func Test_Struct(t *testing.T) { So(err, ShouldBeNil) So(ts.Born.String(), ShouldEqual, t.String()) + dur, err := time.ParseDuration("2h45m") + So(err, ShouldBeNil) + So(ts.Time.Seconds(), ShouldEqual, dur.Seconds()) + So(strings.Join(ts.Others.Cities, ","), ShouldEqual, "HangZhou,Boston") So(ts.Others.Visits[0].String(), ShouldEqual, t.String()) So(ts.Others.Note, ShouldEqual, "Hello world!") So(ts.testEmbeded.GPA, ShouldEqual, 2.8) }) + Convey("Map section to struct", t, func() { + foobar := new(fooBar) + f, err := Load([]byte(_CONF_DATA_STRUCT)) + So(err, ShouldBeNil) + + So(f.Section("foo.bar").MapTo(foobar), ShouldBeNil) + So(foobar.Here, ShouldEqual, "there") + So(foobar.When, ShouldEqual, "then") + }) + Convey("Map to non-pointer struct", t, func() { cfg, err := Load([]byte(_CONF_DATA_STRUCT)) So(err, ShouldBeNil) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/.travis.yml b/Godeps/_workspace/src/gopkg.in/redis.v2/.travis.yml deleted file mode 100644 index c3cf4b8..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: go - -services: -- redis-server - -go: - - 1.1 - - 1.2 - - 1.3 - - tip - -install: - - go get gopkg.in/bufio.v1 - - go get gopkg.in/check.v1 - - mkdir -p $HOME/gopath/src/gopkg.in - - ln -s `pwd` $HOME/gopath/src/gopkg.in/redis.v2 - -before_script: - - redis-server testdata/sentinel.conf --sentinel & diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/Makefile b/Godeps/_workspace/src/gopkg.in/redis.v2/Makefile deleted file mode 100644 index b250d9b..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/Makefile +++ /dev/null @@ -1,3 +0,0 @@ -all: - go test gopkg.in/redis.v2 -cpu=1,2,4 - go test gopkg.in/redis.v2 -short -race diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/README.md b/Godeps/_workspace/src/gopkg.in/redis.v2/README.md deleted file mode 100644 index ddf875f..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/README.md +++ /dev/null @@ -1,46 +0,0 @@ -Redis client for Golang [![Build Status](https://travis-ci.org/go-redis/redis.png?branch=master)](https://travis-ci.org/go-redis/redis) -======================= - -Supports: - -- Redis 2.8 commands except QUIT, MONITOR, SLOWLOG and SYNC. -- Pub/sub. -- Transactions. -- Pipelining. -- Connection pool. -- TLS connections. -- Thread safety. -- Timeouts. -- Redis Sentinel. - -API docs: http://godoc.org/gopkg.in/redis.v2. -Examples: http://godoc.org/gopkg.in/redis.v2#pkg-examples. - -Installation ------------- - -Install: - - go get gopkg.in/redis.v2 - -Look and feel -------------- - -Some corner cases: - - SORT list LIMIT 0 2 ASC - vals, err := client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() - - ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 - vals, err := client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ - Min: "-inf", - Max: "+inf", - Offset: 0, - Count: 2, - }).Result() - - ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM - vals, err := client.ZInterStore("out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2").Result() - - EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" - vals, err := client.Eval("return {KEYS[1],ARGV[1]}", []string{"key"}, []string{"hello"}).Result() diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/commands.go b/Godeps/_workspace/src/gopkg.in/redis.v2/commands.go deleted file mode 100644 index 6068bab..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/commands.go +++ /dev/null @@ -1,1246 +0,0 @@ -package redis - -import ( - "io" - "strconv" - "time" -) - -func formatFloat(f float64) string { - return strconv.FormatFloat(f, 'f', -1, 64) -} - -func readTimeout(sec int64) time.Duration { - if sec == 0 { - return 0 - } - return time.Duration(sec+1) * time.Second -} - -//------------------------------------------------------------------------------ - -func (c *Client) Auth(password string) *StatusCmd { - cmd := NewStatusCmd("AUTH", password) - c.Process(cmd) - return cmd -} - -func (c *Client) Echo(message string) *StringCmd { - cmd := NewStringCmd("ECHO", message) - c.Process(cmd) - return cmd -} - -func (c *Client) Ping() *StatusCmd { - cmd := NewStatusCmd("PING") - c.Process(cmd) - return cmd -} - -func (c *Client) Quit() *StatusCmd { - panic("not implemented") -} - -func (c *Client) Select(index int64) *StatusCmd { - cmd := NewStatusCmd("SELECT", strconv.FormatInt(index, 10)) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) Del(keys ...string) *IntCmd { - args := append([]string{"DEL"}, keys...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) Dump(key string) *StringCmd { - cmd := NewStringCmd("DUMP", key) - c.Process(cmd) - return cmd -} - -func (c *Client) Exists(key string) *BoolCmd { - cmd := NewBoolCmd("EXISTS", key) - c.Process(cmd) - return cmd -} - -func (c *Client) Expire(key string, dur time.Duration) *BoolCmd { - cmd := NewBoolCmd("EXPIRE", key, strconv.FormatInt(int64(dur/time.Second), 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) ExpireAt(key string, tm time.Time) *BoolCmd { - cmd := NewBoolCmd("EXPIREAT", key, strconv.FormatInt(tm.Unix(), 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) Keys(pattern string) *StringSliceCmd { - cmd := NewStringSliceCmd("KEYS", pattern) - c.Process(cmd) - return cmd -} - -func (c *Client) Migrate(host, port, key string, db, timeout int64) *StatusCmd { - cmd := NewStatusCmd( - "MIGRATE", - host, - port, - key, - strconv.FormatInt(db, 10), - strconv.FormatInt(timeout, 10), - ) - cmd.setReadTimeout(readTimeout(timeout)) - c.Process(cmd) - return cmd -} - -func (c *Client) Move(key string, db int64) *BoolCmd { - cmd := NewBoolCmd("MOVE", key, strconv.FormatInt(db, 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) ObjectRefCount(keys ...string) *IntCmd { - args := append([]string{"OBJECT", "REFCOUNT"}, keys...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ObjectEncoding(keys ...string) *StringCmd { - args := append([]string{"OBJECT", "ENCODING"}, keys...) - cmd := NewStringCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ObjectIdleTime(keys ...string) *DurationCmd { - args := append([]string{"OBJECT", "IDLETIME"}, keys...) - cmd := NewDurationCmd(time.Second, args...) - c.Process(cmd) - return cmd -} - -func (c *Client) Persist(key string) *BoolCmd { - cmd := NewBoolCmd("PERSIST", key) - c.Process(cmd) - return cmd -} - -func (c *Client) PExpire(key string, dur time.Duration) *BoolCmd { - cmd := NewBoolCmd("PEXPIRE", key, strconv.FormatInt(int64(dur/time.Millisecond), 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) PExpireAt(key string, tm time.Time) *BoolCmd { - cmd := NewBoolCmd( - "PEXPIREAT", - key, - strconv.FormatInt(tm.UnixNano()/int64(time.Millisecond), 10), - ) - c.Process(cmd) - return cmd -} - -func (c *Client) PTTL(key string) *DurationCmd { - cmd := NewDurationCmd(time.Millisecond, "PTTL", key) - c.Process(cmd) - return cmd -} - -func (c *Client) RandomKey() *StringCmd { - cmd := NewStringCmd("RANDOMKEY") - c.Process(cmd) - return cmd -} - -func (c *Client) Rename(key, newkey string) *StatusCmd { - cmd := NewStatusCmd("RENAME", key, newkey) - c.Process(cmd) - return cmd -} - -func (c *Client) RenameNX(key, newkey string) *BoolCmd { - cmd := NewBoolCmd("RENAMENX", key, newkey) - c.Process(cmd) - return cmd -} - -func (c *Client) Restore(key string, ttl int64, value string) *StatusCmd { - cmd := NewStatusCmd( - "RESTORE", - key, - strconv.FormatInt(ttl, 10), - value, - ) - c.Process(cmd) - return cmd -} - -type Sort struct { - By string - Offset, Count float64 - Get []string - Order string - IsAlpha bool - Store string -} - -func (c *Client) Sort(key string, sort Sort) *StringSliceCmd { - args := []string{"SORT", key} - if sort.By != "" { - args = append(args, "BY", sort.By) - } - if sort.Offset != 0 || sort.Count != 0 { - args = append(args, "LIMIT", formatFloat(sort.Offset), formatFloat(sort.Count)) - } - for _, get := range sort.Get { - args = append(args, "GET", get) - } - if sort.Order != "" { - args = append(args, sort.Order) - } - if sort.IsAlpha { - args = append(args, "ALPHA") - } - if sort.Store != "" { - args = append(args, "STORE", sort.Store) - } - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) TTL(key string) *DurationCmd { - cmd := NewDurationCmd(time.Second, "TTL", key) - c.Process(cmd) - return cmd -} - -func (c *Client) Type(key string) *StatusCmd { - cmd := NewStatusCmd("TYPE", key) - c.Process(cmd) - return cmd -} - -func (c *Client) Scan(cursor int64, match string, count int64) *ScanCmd { - args := []string{"SCAN", strconv.FormatInt(cursor, 10)} - if match != "" { - args = append(args, "MATCH", match) - } - if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) - } - cmd := NewScanCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SScan(key string, cursor int64, match string, count int64) *ScanCmd { - args := []string{"SSCAN", key, strconv.FormatInt(cursor, 10)} - if match != "" { - args = append(args, "MATCH", match) - } - if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) - } - cmd := NewScanCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) HScan(key string, cursor int64, match string, count int64) *ScanCmd { - args := []string{"HSCAN", key, strconv.FormatInt(cursor, 10)} - if match != "" { - args = append(args, "MATCH", match) - } - if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) - } - cmd := NewScanCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZScan(key string, cursor int64, match string, count int64) *ScanCmd { - args := []string{"ZSCAN", key, strconv.FormatInt(cursor, 10)} - if match != "" { - args = append(args, "MATCH", match) - } - if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) - } - cmd := NewScanCmd(args...) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) Append(key, value string) *IntCmd { - cmd := NewIntCmd("APPEND", key, value) - c.Process(cmd) - return cmd -} - -type BitCount struct { - Start, End int64 -} - -func (c *Client) BitCount(key string, bitCount *BitCount) *IntCmd { - args := []string{"BITCOUNT", key} - if bitCount != nil { - args = append( - args, - strconv.FormatInt(bitCount.Start, 10), - strconv.FormatInt(bitCount.End, 10), - ) - } - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) bitOp(op, destKey string, keys ...string) *IntCmd { - args := []string{"BITOP", op, destKey} - args = append(args, keys...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) BitOpAnd(destKey string, keys ...string) *IntCmd { - return c.bitOp("AND", destKey, keys...) -} - -func (c *Client) BitOpOr(destKey string, keys ...string) *IntCmd { - return c.bitOp("OR", destKey, keys...) -} - -func (c *Client) BitOpXor(destKey string, keys ...string) *IntCmd { - return c.bitOp("XOR", destKey, keys...) -} - -func (c *Client) BitOpNot(destKey string, key string) *IntCmd { - return c.bitOp("NOT", destKey, key) -} - -func (c *Client) Decr(key string) *IntCmd { - cmd := NewIntCmd("DECR", key) - c.Process(cmd) - return cmd -} - -func (c *Client) DecrBy(key string, decrement int64) *IntCmd { - cmd := NewIntCmd("DECRBY", key, strconv.FormatInt(decrement, 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) Get(key string) *StringCmd { - cmd := NewStringCmd("GET", key) - c.Process(cmd) - return cmd -} - -func (c *Client) GetBit(key string, offset int64) *IntCmd { - cmd := NewIntCmd("GETBIT", key, strconv.FormatInt(offset, 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) GetRange(key string, start, end int64) *StringCmd { - cmd := NewStringCmd( - "GETRANGE", - key, - strconv.FormatInt(start, 10), - strconv.FormatInt(end, 10), - ) - c.Process(cmd) - return cmd -} - -func (c *Client) GetSet(key, value string) *StringCmd { - cmd := NewStringCmd("GETSET", key, value) - c.Process(cmd) - return cmd -} - -func (c *Client) Incr(key string) *IntCmd { - cmd := NewIntCmd("INCR", key) - c.Process(cmd) - return cmd -} - -func (c *Client) IncrBy(key string, value int64) *IntCmd { - cmd := NewIntCmd("INCRBY", key, strconv.FormatInt(value, 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) IncrByFloat(key string, value float64) *FloatCmd { - cmd := NewFloatCmd("INCRBYFLOAT", key, formatFloat(value)) - c.Process(cmd) - return cmd -} - -func (c *Client) MGet(keys ...string) *SliceCmd { - args := append([]string{"MGET"}, keys...) - cmd := NewSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) MSet(pairs ...string) *StatusCmd { - args := append([]string{"MSET"}, pairs...) - cmd := NewStatusCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) MSetNX(pairs ...string) *BoolCmd { - args := append([]string{"MSETNX"}, pairs...) - cmd := NewBoolCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) PSetEx(key string, dur time.Duration, value string) *StatusCmd { - cmd := NewStatusCmd( - "PSETEX", - key, - strconv.FormatInt(int64(dur/time.Millisecond), 10), - value, - ) - c.Process(cmd) - return cmd -} - -func (c *Client) Set(key, value string) *StatusCmd { - cmd := NewStatusCmd("SET", key, value) - c.Process(cmd) - return cmd -} - -func (c *Client) SetBit(key string, offset int64, value int) *IntCmd { - cmd := NewIntCmd( - "SETBIT", - key, - strconv.FormatInt(offset, 10), - strconv.FormatInt(int64(value), 10), - ) - c.Process(cmd) - return cmd -} - -func (c *Client) SetEx(key string, dur time.Duration, value string) *StatusCmd { - cmd := NewStatusCmd("SETEX", key, strconv.FormatInt(int64(dur/time.Second), 10), value) - c.Process(cmd) - return cmd -} - -func (c *Client) SetNX(key, value string) *BoolCmd { - cmd := NewBoolCmd("SETNX", key, value) - c.Process(cmd) - return cmd -} - -func (c *Client) SetRange(key string, offset int64, value string) *IntCmd { - cmd := NewIntCmd("SETRANGE", key, strconv.FormatInt(offset, 10), value) - c.Process(cmd) - return cmd -} - -func (c *Client) StrLen(key string) *IntCmd { - cmd := NewIntCmd("STRLEN", key) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) HDel(key string, fields ...string) *IntCmd { - args := append([]string{"HDEL", key}, fields...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) HExists(key, field string) *BoolCmd { - cmd := NewBoolCmd("HEXISTS", key, field) - c.Process(cmd) - return cmd -} - -func (c *Client) HGet(key, field string) *StringCmd { - cmd := NewStringCmd("HGET", key, field) - c.Process(cmd) - return cmd -} - -func (c *Client) HGetAll(key string) *StringSliceCmd { - cmd := NewStringSliceCmd("HGETALL", key) - c.Process(cmd) - return cmd -} - -func (c *Client) HGetAllMap(key string) *StringStringMapCmd { - cmd := NewStringStringMapCmd("HGETALL", key) - c.Process(cmd) - return cmd -} - -func (c *Client) HIncrBy(key, field string, incr int64) *IntCmd { - cmd := NewIntCmd("HINCRBY", key, field, strconv.FormatInt(incr, 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) HIncrByFloat(key, field string, incr float64) *FloatCmd { - cmd := NewFloatCmd("HINCRBYFLOAT", key, field, formatFloat(incr)) - c.Process(cmd) - return cmd -} - -func (c *Client) HKeys(key string) *StringSliceCmd { - cmd := NewStringSliceCmd("HKEYS", key) - c.Process(cmd) - return cmd -} - -func (c *Client) HLen(key string) *IntCmd { - cmd := NewIntCmd("HLEN", key) - c.Process(cmd) - return cmd -} - -func (c *Client) HMGet(key string, fields ...string) *SliceCmd { - args := append([]string{"HMGET", key}, fields...) - cmd := NewSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) HMSet(key, field, value string, pairs ...string) *StatusCmd { - args := append([]string{"HMSET", key, field, value}, pairs...) - cmd := NewStatusCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) HSet(key, field, value string) *BoolCmd { - cmd := NewBoolCmd("HSET", key, field, value) - c.Process(cmd) - return cmd -} - -func (c *Client) HSetNX(key, field, value string) *BoolCmd { - cmd := NewBoolCmd("HSETNX", key, field, value) - c.Process(cmd) - return cmd -} - -func (c *Client) HVals(key string) *StringSliceCmd { - cmd := NewStringSliceCmd("HVALS", key) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) BLPop(timeout int64, keys ...string) *StringSliceCmd { - args := append([]string{"BLPOP"}, keys...) - args = append(args, strconv.FormatInt(timeout, 10)) - cmd := NewStringSliceCmd(args...) - cmd.setReadTimeout(readTimeout(timeout)) - c.Process(cmd) - return cmd -} - -func (c *Client) BRPop(timeout int64, keys ...string) *StringSliceCmd { - args := append([]string{"BRPOP"}, keys...) - args = append(args, strconv.FormatInt(timeout, 10)) - cmd := NewStringSliceCmd(args...) - cmd.setReadTimeout(readTimeout(timeout)) - c.Process(cmd) - return cmd -} - -func (c *Client) BRPopLPush(source, destination string, timeout int64) *StringCmd { - cmd := NewStringCmd( - "BRPOPLPUSH", - source, - destination, - strconv.FormatInt(timeout, 10), - ) - cmd.setReadTimeout(readTimeout(timeout)) - c.Process(cmd) - return cmd -} - -func (c *Client) LIndex(key string, index int64) *StringCmd { - cmd := NewStringCmd("LINDEX", key, strconv.FormatInt(index, 10)) - c.Process(cmd) - return cmd -} - -func (c *Client) LInsert(key, op, pivot, value string) *IntCmd { - cmd := NewIntCmd("LINSERT", key, op, pivot, value) - c.Process(cmd) - return cmd -} - -func (c *Client) LLen(key string) *IntCmd { - cmd := NewIntCmd("LLEN", key) - c.Process(cmd) - return cmd -} - -func (c *Client) LPop(key string) *StringCmd { - cmd := NewStringCmd("LPOP", key) - c.Process(cmd) - return cmd -} - -func (c *Client) LPush(key string, values ...string) *IntCmd { - args := append([]string{"LPUSH", key}, values...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) LPushX(key, value string) *IntCmd { - cmd := NewIntCmd("LPUSHX", key, value) - c.Process(cmd) - return cmd -} - -func (c *Client) LRange(key string, start, stop int64) *StringSliceCmd { - cmd := NewStringSliceCmd( - "LRANGE", - key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), - ) - c.Process(cmd) - return cmd -} - -func (c *Client) LRem(key string, count int64, value string) *IntCmd { - cmd := NewIntCmd("LREM", key, strconv.FormatInt(count, 10), value) - c.Process(cmd) - return cmd -} - -func (c *Client) LSet(key string, index int64, value string) *StatusCmd { - cmd := NewStatusCmd("LSET", key, strconv.FormatInt(index, 10), value) - c.Process(cmd) - return cmd -} - -func (c *Client) LTrim(key string, start, stop int64) *StatusCmd { - cmd := NewStatusCmd( - "LTRIM", - key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), - ) - c.Process(cmd) - return cmd -} - -func (c *Client) RPop(key string) *StringCmd { - cmd := NewStringCmd("RPOP", key) - c.Process(cmd) - return cmd -} - -func (c *Client) RPopLPush(source, destination string) *StringCmd { - cmd := NewStringCmd("RPOPLPUSH", source, destination) - c.Process(cmd) - return cmd -} - -func (c *Client) RPush(key string, values ...string) *IntCmd { - args := append([]string{"RPUSH", key}, values...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) RPushX(key string, value string) *IntCmd { - cmd := NewIntCmd("RPUSHX", key, value) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) SAdd(key string, members ...string) *IntCmd { - args := append([]string{"SADD", key}, members...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SCard(key string) *IntCmd { - cmd := NewIntCmd("SCARD", key) - c.Process(cmd) - return cmd -} - -func (c *Client) SDiff(keys ...string) *StringSliceCmd { - args := append([]string{"SDIFF"}, keys...) - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SDiffStore(destination string, keys ...string) *IntCmd { - args := append([]string{"SDIFFSTORE", destination}, keys...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SInter(keys ...string) *StringSliceCmd { - args := append([]string{"SINTER"}, keys...) - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SInterStore(destination string, keys ...string) *IntCmd { - args := append([]string{"SINTERSTORE", destination}, keys...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SIsMember(key, member string) *BoolCmd { - cmd := NewBoolCmd("SISMEMBER", key, member) - c.Process(cmd) - return cmd -} - -func (c *Client) SMembers(key string) *StringSliceCmd { - cmd := NewStringSliceCmd("SMEMBERS", key) - c.Process(cmd) - return cmd -} - -func (c *Client) SMove(source, destination, member string) *BoolCmd { - cmd := NewBoolCmd("SMOVE", source, destination, member) - c.Process(cmd) - return cmd -} - -func (c *Client) SPop(key string) *StringCmd { - cmd := NewStringCmd("SPOP", key) - c.Process(cmd) - return cmd -} - -func (c *Client) SRandMember(key string) *StringCmd { - cmd := NewStringCmd("SRANDMEMBER", key) - c.Process(cmd) - return cmd -} - -func (c *Client) SRem(key string, members ...string) *IntCmd { - args := append([]string{"SREM", key}, members...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SUnion(keys ...string) *StringSliceCmd { - args := append([]string{"SUNION"}, keys...) - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) SUnionStore(destination string, keys ...string) *IntCmd { - args := append([]string{"SUNIONSTORE", destination}, keys...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -type Z struct { - Score float64 - Member string -} - -type ZStore struct { - Weights []int64 - Aggregate string -} - -func (c *Client) ZAdd(key string, members ...Z) *IntCmd { - args := []string{"ZADD", key} - for _, m := range members { - args = append(args, formatFloat(m.Score), m.Member) - } - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZCard(key string) *IntCmd { - cmd := NewIntCmd("ZCARD", key) - c.Process(cmd) - return cmd -} - -func (c *Client) ZCount(key, min, max string) *IntCmd { - cmd := NewIntCmd("ZCOUNT", key, min, max) - c.Process(cmd) - return cmd -} - -func (c *Client) ZIncrBy(key string, increment float64, member string) *FloatCmd { - cmd := NewFloatCmd("ZINCRBY", key, formatFloat(increment), member) - c.Process(cmd) - return cmd -} - -func (c *Client) ZInterStore( - destination string, - store ZStore, - keys ...string, -) *IntCmd { - args := []string{"ZINTERSTORE", destination, strconv.FormatInt(int64(len(keys)), 10)} - args = append(args, keys...) - if len(store.Weights) > 0 { - args = append(args, "WEIGHTS") - for _, weight := range store.Weights { - args = append(args, strconv.FormatInt(weight, 10)) - } - } - if store.Aggregate != "" { - args = append(args, "AGGREGATE", store.Aggregate) - } - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) zRange(key string, start, stop int64, withScores bool) *StringSliceCmd { - args := []string{ - "ZRANGE", - key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), - } - if withScores { - args = append(args, "WITHSCORES") - } - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRange(key string, start, stop int64) *StringSliceCmd { - return c.zRange(key, start, stop, false) -} - -func (c *Client) ZRangeWithScores(key string, start, stop int64) *ZSliceCmd { - args := []string{ - "ZRANGE", - key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), - "WITHSCORES", - } - cmd := NewZSliceCmd(args...) - c.Process(cmd) - return cmd -} - -type ZRangeByScore struct { - Min, Max string - - Offset, Count int64 -} - -func (c *Client) zRangeByScore(key string, opt ZRangeByScore, withScores bool) *StringSliceCmd { - args := []string{"ZRANGEBYSCORE", key, opt.Min, opt.Max} - if withScores { - args = append(args, "WITHSCORES") - } - if opt.Offset != 0 || opt.Count != 0 { - args = append( - args, - "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), - ) - } - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRangeByScore(key string, opt ZRangeByScore) *StringSliceCmd { - return c.zRangeByScore(key, opt, false) -} - -func (c *Client) ZRangeByScoreWithScores(key string, opt ZRangeByScore) *ZSliceCmd { - args := []string{"ZRANGEBYSCORE", key, opt.Min, opt.Max, "WITHSCORES"} - if opt.Offset != 0 || opt.Count != 0 { - args = append( - args, - "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), - ) - } - cmd := NewZSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRank(key, member string) *IntCmd { - cmd := NewIntCmd("ZRANK", key, member) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRem(key string, members ...string) *IntCmd { - args := append([]string{"ZREM", key}, members...) - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRemRangeByRank(key string, start, stop int64) *IntCmd { - cmd := NewIntCmd( - "ZREMRANGEBYRANK", - key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), - ) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRemRangeByScore(key, min, max string) *IntCmd { - cmd := NewIntCmd("ZREMRANGEBYSCORE", key, min, max) - c.Process(cmd) - return cmd -} - -func (c *Client) zRevRange(key, start, stop string, withScores bool) *StringSliceCmd { - args := []string{"ZREVRANGE", key, start, stop} - if withScores { - args = append(args, "WITHSCORES") - } - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRevRange(key, start, stop string) *StringSliceCmd { - return c.zRevRange(key, start, stop, false) -} - -func (c *Client) ZRevRangeWithScores(key, start, stop string) *ZSliceCmd { - args := []string{"ZREVRANGE", key, start, stop, "WITHSCORES"} - cmd := NewZSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) zRevRangeByScore(key string, opt ZRangeByScore, withScores bool) *StringSliceCmd { - args := []string{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min} - if withScores { - args = append(args, "WITHSCORES") - } - if opt.Offset != 0 || opt.Count != 0 { - args = append( - args, - "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), - ) - } - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRevRangeByScore(key string, opt ZRangeByScore) *StringSliceCmd { - return c.zRevRangeByScore(key, opt, false) -} - -func (c *Client) ZRevRangeByScoreWithScores(key string, opt ZRangeByScore) *ZSliceCmd { - args := []string{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min, "WITHSCORES"} - if opt.Offset != 0 || opt.Count != 0 { - args = append( - args, - "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), - ) - } - cmd := NewZSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ZRevRank(key, member string) *IntCmd { - cmd := NewIntCmd("ZREVRANK", key, member) - c.Process(cmd) - return cmd -} - -func (c *Client) ZScore(key, member string) *FloatCmd { - cmd := NewFloatCmd("ZSCORE", key, member) - c.Process(cmd) - return cmd -} - -func (c *Client) ZUnionStore( - destination string, - store ZStore, - keys ...string, -) *IntCmd { - args := []string{"ZUNIONSTORE", destination, strconv.FormatInt(int64(len(keys)), 10)} - args = append(args, keys...) - if len(store.Weights) > 0 { - args = append(args, "WEIGHTS") - for _, weight := range store.Weights { - args = append(args, strconv.FormatInt(weight, 10)) - } - } - if store.Aggregate != "" { - args = append(args, "AGGREGATE", store.Aggregate) - } - cmd := NewIntCmd(args...) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) BgRewriteAOF() *StatusCmd { - cmd := NewStatusCmd("BGREWRITEAOF") - c.Process(cmd) - return cmd -} - -func (c *Client) BgSave() *StatusCmd { - cmd := NewStatusCmd("BGSAVE") - c.Process(cmd) - return cmd -} - -func (c *Client) ClientKill(ipPort string) *StatusCmd { - cmd := NewStatusCmd("CLIENT", "KILL", ipPort) - c.Process(cmd) - return cmd -} - -func (c *Client) ClientList() *StringCmd { - cmd := NewStringCmd("CLIENT", "LIST") - c.Process(cmd) - return cmd -} - -func (c *Client) ConfigGet(parameter string) *SliceCmd { - cmd := NewSliceCmd("CONFIG", "GET", parameter) - c.Process(cmd) - return cmd -} - -func (c *Client) ConfigResetStat() *StatusCmd { - cmd := NewStatusCmd("CONFIG", "RESETSTAT") - c.Process(cmd) - return cmd -} - -func (c *Client) ConfigSet(parameter, value string) *StatusCmd { - cmd := NewStatusCmd("CONFIG", "SET", parameter, value) - c.Process(cmd) - return cmd -} - -func (c *Client) DbSize() *IntCmd { - cmd := NewIntCmd("DBSIZE") - c.Process(cmd) - return cmd -} - -func (c *Client) FlushAll() *StatusCmd { - cmd := NewStatusCmd("FLUSHALL") - c.Process(cmd) - return cmd -} - -func (c *Client) FlushDb() *StatusCmd { - cmd := NewStatusCmd("FLUSHDB") - c.Process(cmd) - return cmd -} - -func (c *Client) Info() *StringCmd { - cmd := NewStringCmd("INFO") - c.Process(cmd) - return cmd -} - -func (c *Client) LastSave() *IntCmd { - cmd := NewIntCmd("LASTSAVE") - c.Process(cmd) - return cmd -} - -func (c *Client) Save() *StatusCmd { - cmd := NewStatusCmd("SAVE") - c.Process(cmd) - return cmd -} - -func (c *Client) shutdown(modifier string) *StatusCmd { - var args []string - if modifier == "" { - args = []string{"SHUTDOWN"} - } else { - args = []string{"SHUTDOWN", modifier} - } - cmd := NewStatusCmd(args...) - c.Process(cmd) - if err := cmd.Err(); err != nil { - if err == io.EOF { - // Server quit as expected. - cmd.err = nil - } - } else { - // Server did not quit. String reply contains the reason. - cmd.err = errorf(cmd.val) - cmd.val = "" - } - return cmd -} - -func (c *Client) Shutdown() *StatusCmd { - return c.shutdown("") -} - -func (c *Client) ShutdownSave() *StatusCmd { - return c.shutdown("SAVE") -} - -func (c *Client) ShutdownNoSave() *StatusCmd { - return c.shutdown("NOSAVE") -} - -func (c *Client) SlaveOf(host, port string) *StatusCmd { - cmd := NewStatusCmd("SLAVEOF", host, port) - c.Process(cmd) - return cmd -} - -func (c *Client) SlowLog() { - panic("not implemented") -} - -func (c *Client) Sync() { - panic("not implemented") -} - -func (c *Client) Time() *StringSliceCmd { - cmd := NewStringSliceCmd("TIME") - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) Eval(script string, keys []string, args []string) *Cmd { - cmdArgs := []string{"EVAL", script, strconv.FormatInt(int64(len(keys)), 10)} - cmdArgs = append(cmdArgs, keys...) - cmdArgs = append(cmdArgs, args...) - cmd := NewCmd(cmdArgs...) - c.Process(cmd) - return cmd -} - -func (c *Client) EvalSha(sha1 string, keys []string, args []string) *Cmd { - cmdArgs := []string{"EVALSHA", sha1, strconv.FormatInt(int64(len(keys)), 10)} - cmdArgs = append(cmdArgs, keys...) - cmdArgs = append(cmdArgs, args...) - cmd := NewCmd(cmdArgs...) - c.Process(cmd) - return cmd -} - -func (c *Client) ScriptExists(scripts ...string) *BoolSliceCmd { - args := append([]string{"SCRIPT", "EXISTS"}, scripts...) - cmd := NewBoolSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) ScriptFlush() *StatusCmd { - cmd := NewStatusCmd("SCRIPT", "FLUSH") - c.Process(cmd) - return cmd -} - -func (c *Client) ScriptKill() *StatusCmd { - cmd := NewStatusCmd("SCRIPT", "KILL") - c.Process(cmd) - return cmd -} - -func (c *Client) ScriptLoad(script string) *StringCmd { - cmd := NewStringCmd("SCRIPT", "LOAD", script) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) DebugObject(key string) *StringCmd { - cmd := NewStringCmd("DEBUG", "OBJECT", key) - c.Process(cmd) - return cmd -} - -//------------------------------------------------------------------------------ - -func (c *Client) PubSubChannels(pattern string) *StringSliceCmd { - args := []string{"PUBSUB", "CHANNELS"} - if pattern != "*" { - args = append(args, pattern) - } - cmd := NewStringSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) PubSubNumSub(channels ...string) *SliceCmd { - args := []string{"PUBSUB", "NUMSUB"} - args = append(args, channels...) - cmd := NewSliceCmd(args...) - c.Process(cmd) - return cmd -} - -func (c *Client) PubSubNumPat() *IntCmd { - cmd := NewIntCmd("PUBSUB", "NUMPAT") - c.Process(cmd) - return cmd -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/error.go b/Godeps/_workspace/src/gopkg.in/redis.v2/error.go deleted file mode 100644 index 667fffd..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/error.go +++ /dev/null @@ -1,23 +0,0 @@ -package redis - -import ( - "fmt" -) - -// Redis nil reply. -var Nil = errorf("redis: nil") - -// Redis transaction failed. -var TxFailedErr = errorf("redis: transaction failed") - -type redisError struct { - s string -} - -func errorf(s string, args ...interface{}) redisError { - return redisError{s: fmt.Sprintf(s, args...)} -} - -func (err redisError) Error() string { - return err.s -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/example_test.go b/Godeps/_workspace/src/gopkg.in/redis.v2/example_test.go deleted file mode 100644 index dbc9513..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/example_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package redis_test - -import ( - "fmt" - "strconv" - - "gopkg.in/redis.v2" -) - -var client *redis.Client - -func init() { - client = redis.NewTCPClient(&redis.Options{ - Addr: ":6379", - }) - client.FlushDb() -} - -func ExampleNewTCPClient() { - client := redis.NewTCPClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password set - DB: 0, // use default DB - }) - - pong, err := client.Ping().Result() - fmt.Println(pong, err) - // Output: PONG -} - -func ExampleNewFailoverClient() { - client := redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: "master", - SentinelAddrs: []string{":26379"}, - }) - - pong, err := client.Ping().Result() - fmt.Println(pong, err) - // Output: PONG -} - -func ExampleClient() { - if err := client.Set("foo", "bar").Err(); err != nil { - panic(err) - } - - v, err := client.Get("hello").Result() - fmt.Printf("%q %q %v", v, err, err == redis.Nil) - // Output: "" "redis: nil" true -} - -func ExampleClient_Incr() { - if err := client.Incr("counter").Err(); err != nil { - panic(err) - } - - n, err := client.Get("counter").Int64() - fmt.Println(n, err) - // Output: 1 -} - -func ExampleClient_Pipelined() { - cmds, err := client.Pipelined(func(c *redis.Pipeline) error { - c.Set("key1", "hello1") - c.Get("key1") - return nil - }) - fmt.Println(err) - set := cmds[0].(*redis.StatusCmd) - fmt.Println(set) - get := cmds[1].(*redis.StringCmd) - fmt.Println(get) - // Output: - // SET key1 hello1: OK - // GET key1: hello1 -} - -func ExamplePipeline() { - pipeline := client.Pipeline() - set := pipeline.Set("key1", "hello1") - get := pipeline.Get("key1") - cmds, err := pipeline.Exec() - fmt.Println(cmds, err) - fmt.Println(set) - fmt.Println(get) - // Output: [SET key1 hello1: OK GET key1: hello1] - // SET key1 hello1: OK - // GET key1: hello1 -} - -func ExampleMulti() { - incr := func(tx *redis.Multi) ([]redis.Cmder, error) { - s, err := tx.Get("key").Result() - if err != nil && err != redis.Nil { - return nil, err - } - n, _ := strconv.ParseInt(s, 10, 64) - - return tx.Exec(func() error { - tx.Set("key", strconv.FormatInt(n+1, 10)) - return nil - }) - } - - client.Del("key") - - tx := client.Multi() - defer tx.Close() - - watch := tx.Watch("key") - _ = watch.Err() - - for { - cmds, err := incr(tx) - if err == redis.TxFailedErr { - continue - } else if err != nil { - panic(err) - } - fmt.Println(cmds, err) - break - } - - // Output: [SET key 1: OK] -} - -func ExamplePubSub() { - pubsub := client.PubSub() - defer pubsub.Close() - - err := pubsub.Subscribe("mychannel") - _ = err - - msg, err := pubsub.Receive() - fmt.Println(msg, err) - - pub := client.Publish("mychannel", "hello") - _ = pub.Err() - - msg, err = pubsub.Receive() - fmt.Println(msg, err) - - // Output: subscribe: mychannel - // Message -} - -func ExampleScript() { - setnx := redis.NewScript(` - if redis.call("get", KEYS[1]) == false then - redis.call("set", KEYS[1], ARGV[1]) - return 1 - end - return 0 - `) - - v1, err := setnx.Run(client, []string{"keynx"}, []string{"foo"}).Result() - fmt.Println(v1.(int64), err) - - v2, err := setnx.Run(client, []string{"keynx"}, []string{"bar"}).Result() - fmt.Println(v2.(int64), err) - - get := client.Get("keynx") - fmt.Println(get) - - // Output: 1 - // 0 - // GET keynx: foo -} - -func Example_customCommand() { - Get := func(client *redis.Client, key string) *redis.StringCmd { - cmd := redis.NewStringCmd("GET", key) - client.Process(cmd) - return cmd - } - - v, err := Get(client, "key_does_not_exist").Result() - fmt.Printf("%q %s", v, err) - // Output: "" redis: nil -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/export_test.go b/Godeps/_workspace/src/gopkg.in/redis.v2/export_test.go deleted file mode 100644 index 7f7fa67..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/export_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package redis - -func (c *baseClient) Pool() pool { - return c.connPool -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/parser.go b/Godeps/_workspace/src/gopkg.in/redis.v2/parser.go deleted file mode 100644 index b4c380c..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/parser.go +++ /dev/null @@ -1,262 +0,0 @@ -package redis - -import ( - "errors" - "fmt" - "strconv" - - "gopkg.in/bufio.v1" -) - -type multiBulkParser func(rd *bufio.Reader, n int64) (interface{}, error) - -var ( - errReaderTooSmall = errors.New("redis: reader is too small") -) - -//------------------------------------------------------------------------------ - -func appendArgs(buf []byte, args []string) []byte { - buf = append(buf, '*') - buf = strconv.AppendUint(buf, uint64(len(args)), 10) - buf = append(buf, '\r', '\n') - for _, arg := range args { - buf = append(buf, '$') - buf = strconv.AppendUint(buf, uint64(len(arg)), 10) - buf = append(buf, '\r', '\n') - buf = append(buf, arg...) - buf = append(buf, '\r', '\n') - } - return buf -} - -//------------------------------------------------------------------------------ - -func readLine(rd *bufio.Reader) ([]byte, error) { - line, isPrefix, err := rd.ReadLine() - if err != nil { - return line, err - } - if isPrefix { - return line, errReaderTooSmall - } - return line, nil -} - -func readN(rd *bufio.Reader, n int) ([]byte, error) { - b, err := rd.ReadN(n) - if err == bufio.ErrBufferFull { - tmp := make([]byte, n) - r := copy(tmp, b) - b = tmp - - for { - nn, err := rd.Read(b[r:]) - r += nn - if r >= n { - // Ignore error if we read enough. - break - } - if err != nil { - return nil, err - } - } - } else if err != nil { - return nil, err - } - return b, nil -} - -//------------------------------------------------------------------------------ - -func parseReq(rd *bufio.Reader) ([]string, error) { - line, err := readLine(rd) - if err != nil { - return nil, err - } - - if line[0] != '*' { - return []string{string(line)}, nil - } - numReplies, err := strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - return nil, err - } - - args := make([]string, 0, numReplies) - for i := int64(0); i < numReplies; i++ { - line, err = readLine(rd) - if err != nil { - return nil, err - } - if line[0] != '$' { - return nil, fmt.Errorf("redis: expected '$', but got %q", line) - } - - argLen, err := strconv.ParseInt(string(line[1:]), 10, 32) - if err != nil { - return nil, err - } - - arg, err := readN(rd, int(argLen)+2) - if err != nil { - return nil, err - } - args = append(args, string(arg[:argLen])) - } - return args, nil -} - -//------------------------------------------------------------------------------ - -func parseReply(rd *bufio.Reader, p multiBulkParser) (interface{}, error) { - line, err := readLine(rd) - if err != nil { - return nil, err - } - - switch line[0] { - case '-': - return nil, errorf(string(line[1:])) - case '+': - return string(line[1:]), nil - case ':': - v, err := strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - return nil, err - } - return v, nil - case '$': - if len(line) == 3 && line[1] == '-' && line[2] == '1' { - return nil, Nil - } - - replyLen, err := strconv.Atoi(string(line[1:])) - if err != nil { - return nil, err - } - - b, err := readN(rd, replyLen+2) - if err != nil { - return nil, err - } - return string(b[:replyLen]), nil - case '*': - if len(line) == 3 && line[1] == '-' && line[2] == '1' { - return nil, Nil - } - - repliesNum, err := strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - return nil, err - } - - return p(rd, repliesNum) - } - return nil, fmt.Errorf("redis: can't parse %q", line) -} - -func parseSlice(rd *bufio.Reader, n int64) (interface{}, error) { - vals := make([]interface{}, 0, n) - for i := int64(0); i < n; i++ { - v, err := parseReply(rd, parseSlice) - if err == Nil { - vals = append(vals, nil) - } else if err != nil { - return nil, err - } else { - vals = append(vals, v) - } - } - return vals, nil -} - -func parseStringSlice(rd *bufio.Reader, n int64) (interface{}, error) { - vals := make([]string, 0, n) - for i := int64(0); i < n; i++ { - viface, err := parseReply(rd, nil) - if err != nil { - return nil, err - } - v, ok := viface.(string) - if !ok { - return nil, fmt.Errorf("got %T, expected string", viface) - } - vals = append(vals, v) - } - return vals, nil -} - -func parseBoolSlice(rd *bufio.Reader, n int64) (interface{}, error) { - vals := make([]bool, 0, n) - for i := int64(0); i < n; i++ { - viface, err := parseReply(rd, nil) - if err != nil { - return nil, err - } - v, ok := viface.(int64) - if !ok { - return nil, fmt.Errorf("got %T, expected int64", viface) - } - vals = append(vals, v == 1) - } - return vals, nil -} - -func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) { - m := make(map[string]string, n/2) - for i := int64(0); i < n; i += 2 { - keyiface, err := parseReply(rd, nil) - if err != nil { - return nil, err - } - key, ok := keyiface.(string) - if !ok { - return nil, fmt.Errorf("got %T, expected string", keyiface) - } - - valueiface, err := parseReply(rd, nil) - if err != nil { - return nil, err - } - value, ok := valueiface.(string) - if !ok { - return nil, fmt.Errorf("got %T, expected string", valueiface) - } - - m[key] = value - } - return m, nil -} - -func parseZSlice(rd *bufio.Reader, n int64) (interface{}, error) { - zz := make([]Z, n/2) - for i := int64(0); i < n; i += 2 { - z := &zz[i/2] - - memberiface, err := parseReply(rd, nil) - if err != nil { - return nil, err - } - member, ok := memberiface.(string) - if !ok { - return nil, fmt.Errorf("got %T, expected string", memberiface) - } - z.Member = member - - scoreiface, err := parseReply(rd, nil) - if err != nil { - return nil, err - } - scorestr, ok := scoreiface.(string) - if !ok { - return nil, fmt.Errorf("got %T, expected string", scoreiface) - } - score, err := strconv.ParseFloat(scorestr, 64) - if err != nil { - return nil, err - } - z.Score = score - } - return zz, nil -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/pipeline.go b/Godeps/_workspace/src/gopkg.in/redis.v2/pipeline.go deleted file mode 100644 index 540d6c5..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/pipeline.go +++ /dev/null @@ -1,91 +0,0 @@ -package redis - -// Not thread-safe. -type Pipeline struct { - *Client - - closed bool -} - -func (c *Client) Pipeline() *Pipeline { - return &Pipeline{ - Client: &Client{ - baseClient: &baseClient{ - opt: c.opt, - connPool: c.connPool, - - cmds: make([]Cmder, 0), - }, - }, - } -} - -func (c *Client) Pipelined(f func(*Pipeline) error) ([]Cmder, error) { - pc := c.Pipeline() - if err := f(pc); err != nil { - return nil, err - } - cmds, err := pc.Exec() - pc.Close() - return cmds, err -} - -func (c *Pipeline) Close() error { - c.closed = true - return nil -} - -func (c *Pipeline) Discard() error { - if c.closed { - return errClosed - } - c.cmds = c.cmds[:0] - return nil -} - -// Exec always returns list of commands and error of the first failed -// command if any. -func (c *Pipeline) Exec() ([]Cmder, error) { - if c.closed { - return nil, errClosed - } - - cmds := c.cmds - c.cmds = make([]Cmder, 0) - - if len(cmds) == 0 { - return []Cmder{}, nil - } - - cn, err := c.conn() - if err != nil { - setCmdsErr(cmds, err) - return cmds, err - } - - if err := c.execCmds(cn, cmds); err != nil { - c.freeConn(cn, err) - return cmds, err - } - - c.putConn(cn) - return cmds, nil -} - -func (c *Pipeline) execCmds(cn *conn, cmds []Cmder) error { - if err := c.writeCmd(cn, cmds...); err != nil { - setCmdsErr(cmds, err) - return err - } - - var firstCmdErr error - for _, cmd := range cmds { - if err := cmd.parseReply(cn.rd); err != nil { - if firstCmdErr == nil { - firstCmdErr = err - } - } - } - - return firstCmdErr -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/pool.go b/Godeps/_workspace/src/gopkg.in/redis.v2/pool.go deleted file mode 100644 index bca4d19..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/pool.go +++ /dev/null @@ -1,405 +0,0 @@ -package redis - -import ( - "container/list" - "errors" - "log" - "net" - "sync" - "time" - - "gopkg.in/bufio.v1" -) - -var ( - errClosed = errors.New("redis: client is closed") - errRateLimited = errors.New("redis: you open connections too fast") -) - -var ( - zeroTime = time.Time{} -) - -type pool interface { - Get() (*conn, bool, error) - Put(*conn) error - Remove(*conn) error - Len() int - Size() int - Close() error - Filter(func(*conn) bool) -} - -//------------------------------------------------------------------------------ - -type conn struct { - netcn net.Conn - rd *bufio.Reader - buf []byte - - inUse bool - usedAt time.Time - - readTimeout time.Duration - writeTimeout time.Duration - - elem *list.Element -} - -func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { - return func() (*conn, error) { - netcn, err := dial() - if err != nil { - return nil, err - } - cn := &conn{ - netcn: netcn, - buf: make([]byte, 0, 64), - } - cn.rd = bufio.NewReader(cn) - return cn, nil - } -} - -func (cn *conn) Read(b []byte) (int, error) { - if cn.readTimeout != 0 { - cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout)) - } else { - cn.netcn.SetReadDeadline(zeroTime) - } - return cn.netcn.Read(b) -} - -func (cn *conn) Write(b []byte) (int, error) { - if cn.writeTimeout != 0 { - cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) - } else { - cn.netcn.SetWriteDeadline(zeroTime) - } - return cn.netcn.Write(b) -} - -func (cn *conn) RemoteAddr() net.Addr { - return cn.netcn.RemoteAddr() -} - -func (cn *conn) Close() error { - return cn.netcn.Close() -} - -//------------------------------------------------------------------------------ - -type connPool struct { - dial func() (*conn, error) - rl *rateLimiter - - opt *options - - cond *sync.Cond - conns *list.List - - idleNum int - closed bool -} - -func newConnPool(dial func() (*conn, error), opt *options) *connPool { - return &connPool{ - dial: dial, - rl: newRateLimiter(time.Second, 2*opt.PoolSize), - - opt: opt, - - cond: sync.NewCond(&sync.Mutex{}), - conns: list.New(), - } -} - -func (p *connPool) new() (*conn, error) { - if !p.rl.Check() { - return nil, errRateLimited - } - return p.dial() -} - -func (p *connPool) Get() (*conn, bool, error) { - p.cond.L.Lock() - - if p.closed { - p.cond.L.Unlock() - return nil, false, errClosed - } - - if p.opt.IdleTimeout > 0 { - for el := p.conns.Front(); el != nil; el = el.Next() { - cn := el.Value.(*conn) - if cn.inUse { - break - } - if time.Since(cn.usedAt) > p.opt.IdleTimeout { - if err := p.remove(cn); err != nil { - log.Printf("remove failed: %s", err) - } - } - } - } - - for p.conns.Len() >= p.opt.PoolSize && p.idleNum == 0 { - p.cond.Wait() - } - - if p.idleNum > 0 { - elem := p.conns.Front() - cn := elem.Value.(*conn) - if cn.inUse { - panic("pool: precondition failed") - } - cn.inUse = true - p.conns.MoveToBack(elem) - p.idleNum-- - - p.cond.L.Unlock() - return cn, false, nil - } - - if p.conns.Len() < p.opt.PoolSize { - cn, err := p.new() - if err != nil { - p.cond.L.Unlock() - return nil, false, err - } - - cn.inUse = true - cn.elem = p.conns.PushBack(cn) - - p.cond.L.Unlock() - return cn, true, nil - } - - panic("not reached") -} - -func (p *connPool) Put(cn *conn) error { - if cn.rd.Buffered() != 0 { - b, _ := cn.rd.ReadN(cn.rd.Buffered()) - log.Printf("redis: connection has unread data: %q", b) - return p.Remove(cn) - } - - if p.opt.IdleTimeout > 0 { - cn.usedAt = time.Now() - } - - p.cond.L.Lock() - if p.closed { - p.cond.L.Unlock() - return errClosed - } - cn.inUse = false - p.conns.MoveToFront(cn.elem) - p.idleNum++ - p.cond.Signal() - p.cond.L.Unlock() - - return nil -} - -func (p *connPool) Remove(cn *conn) error { - p.cond.L.Lock() - if p.closed { - // Noop, connection is already closed. - p.cond.L.Unlock() - return nil - } - err := p.remove(cn) - p.cond.Signal() - p.cond.L.Unlock() - return err -} - -func (p *connPool) remove(cn *conn) error { - p.conns.Remove(cn.elem) - cn.elem = nil - if !cn.inUse { - p.idleNum-- - } - return cn.Close() -} - -// Len returns number of idle connections. -func (p *connPool) Len() int { - defer p.cond.L.Unlock() - p.cond.L.Lock() - return p.idleNum -} - -// Size returns number of connections in the pool. -func (p *connPool) Size() int { - defer p.cond.L.Unlock() - p.cond.L.Lock() - return p.conns.Len() -} - -func (p *connPool) Filter(f func(*conn) bool) { - p.cond.L.Lock() - for el, next := p.conns.Front(), p.conns.Front(); el != nil; el = next { - next = el.Next() - cn := el.Value.(*conn) - if !f(cn) { - p.remove(cn) - } - } - p.cond.L.Unlock() -} - -func (p *connPool) Close() error { - defer p.cond.L.Unlock() - p.cond.L.Lock() - if p.closed { - return nil - } - p.closed = true - p.rl.Close() - var retErr error - for { - e := p.conns.Front() - if e == nil { - break - } - if err := p.remove(e.Value.(*conn)); err != nil { - log.Printf("cn.Close failed: %s", err) - retErr = err - } - } - return retErr -} - -//------------------------------------------------------------------------------ - -type singleConnPool struct { - pool pool - - cnMtx sync.Mutex - cn *conn - - reusable bool - - closed bool -} - -func newSingleConnPool(pool pool, reusable bool) *singleConnPool { - return &singleConnPool{ - pool: pool, - reusable: reusable, - } -} - -func (p *singleConnPool) SetConn(cn *conn) { - p.cnMtx.Lock() - p.cn = cn - p.cnMtx.Unlock() -} - -func (p *singleConnPool) Get() (*conn, bool, error) { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - - if p.closed { - return nil, false, errClosed - } - if p.cn != nil { - return p.cn, false, nil - } - - cn, isNew, err := p.pool.Get() - if err != nil { - return nil, false, err - } - p.cn = cn - - return p.cn, isNew, nil -} - -func (p *singleConnPool) Put(cn *conn) error { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - if p.cn != cn { - panic("p.cn != cn") - } - if p.closed { - return errClosed - } - return nil -} - -func (p *singleConnPool) put() error { - err := p.pool.Put(p.cn) - p.cn = nil - return err -} - -func (p *singleConnPool) Remove(cn *conn) error { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - if p.cn == nil { - panic("p.cn == nil") - } - if p.cn != cn { - panic("p.cn != cn") - } - if p.closed { - return errClosed - } - return p.remove() -} - -func (p *singleConnPool) remove() error { - err := p.pool.Remove(p.cn) - p.cn = nil - return err -} - -func (p *singleConnPool) Len() int { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - if p.cn == nil { - return 0 - } - return 1 -} - -func (p *singleConnPool) Size() int { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - if p.cn == nil { - return 0 - } - return 1 -} - -func (p *singleConnPool) Filter(f func(*conn) bool) { - p.cnMtx.Lock() - if p.cn != nil { - if !f(p.cn) { - p.remove() - } - } - p.cnMtx.Unlock() -} - -func (p *singleConnPool) Close() error { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - if p.closed { - return nil - } - p.closed = true - var err error - if p.cn != nil { - if p.reusable { - err = p.put() - } else { - err = p.remove() - } - } - return err -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/pubsub.go b/Godeps/_workspace/src/gopkg.in/redis.v2/pubsub.go deleted file mode 100644 index 6ac130b..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/pubsub.go +++ /dev/null @@ -1,134 +0,0 @@ -package redis - -import ( - "fmt" - "time" -) - -// Not thread-safe. -type PubSub struct { - *baseClient -} - -func (c *Client) PubSub() *PubSub { - return &PubSub{ - baseClient: &baseClient{ - opt: c.opt, - connPool: newSingleConnPool(c.connPool, false), - }, - } -} - -func (c *Client) Publish(channel, message string) *IntCmd { - req := NewIntCmd("PUBLISH", channel, message) - c.Process(req) - return req -} - -type Message struct { - Channel string - Payload string -} - -func (m *Message) String() string { - return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) -} - -type PMessage struct { - Channel string - Pattern string - Payload string -} - -func (m *PMessage) String() string { - return fmt.Sprintf("PMessage<%s: %s>", m.Channel, m.Payload) -} - -type Subscription struct { - Kind string - Channel string - Count int -} - -func (m *Subscription) String() string { - return fmt.Sprintf("%s: %s", m.Kind, m.Channel) -} - -func (c *PubSub) Receive() (interface{}, error) { - return c.ReceiveTimeout(0) -} - -func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { - cn, err := c.conn() - if err != nil { - return nil, err - } - cn.readTimeout = timeout - - cmd := NewSliceCmd() - if err := cmd.parseReply(cn.rd); err != nil { - return nil, err - } - - reply := cmd.Val() - - msgName := reply[0].(string) - switch msgName { - case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": - return &Subscription{ - Kind: msgName, - Channel: reply[1].(string), - Count: int(reply[2].(int64)), - }, nil - case "message": - return &Message{ - Channel: reply[1].(string), - Payload: reply[2].(string), - }, nil - case "pmessage": - return &PMessage{ - Pattern: reply[1].(string), - Channel: reply[2].(string), - Payload: reply[3].(string), - }, nil - } - return nil, fmt.Errorf("redis: unsupported message name: %q", msgName) -} - -func (c *PubSub) subscribe(cmd string, channels ...string) error { - cn, err := c.conn() - if err != nil { - return err - } - - args := append([]string{cmd}, channels...) - req := NewSliceCmd(args...) - return c.writeCmd(cn, req) -} - -func (c *PubSub) Subscribe(channels ...string) error { - return c.subscribe("SUBSCRIBE", channels...) -} - -func (c *PubSub) PSubscribe(patterns ...string) error { - return c.subscribe("PSUBSCRIBE", patterns...) -} - -func (c *PubSub) unsubscribe(cmd string, channels ...string) error { - cn, err := c.conn() - if err != nil { - return err - } - - args := append([]string{cmd}, channels...) - req := NewSliceCmd(args...) - return c.writeCmd(cn, req) -} - -func (c *PubSub) Unsubscribe(channels ...string) error { - return c.unsubscribe("UNSUBSCRIBE", channels...) -} - -func (c *PubSub) PUnsubscribe(patterns ...string) error { - return c.unsubscribe("PUNSUBSCRIBE", patterns...) -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/rate_limit.go b/Godeps/_workspace/src/gopkg.in/redis.v2/rate_limit.go deleted file mode 100644 index 20d8512..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/rate_limit.go +++ /dev/null @@ -1,53 +0,0 @@ -package redis - -import ( - "sync/atomic" - "time" -) - -type rateLimiter struct { - v int64 - - _closed int64 -} - -func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter { - rl := &rateLimiter{ - v: int64(bucketSize), - } - go rl.loop(limit, int64(bucketSize)) - return rl -} - -func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) { - for { - if rl.closed() { - break - } - if v := atomic.LoadInt64(&rl.v); v < bucketSize { - atomic.AddInt64(&rl.v, 1) - } - time.Sleep(limit) - } -} - -func (rl *rateLimiter) Check() bool { - for { - if v := atomic.LoadInt64(&rl.v); v > 0 { - if atomic.CompareAndSwapInt64(&rl.v, v, v-1) { - return true - } - } else { - return false - } - } -} - -func (rl *rateLimiter) Close() error { - atomic.StoreInt64(&rl._closed, 1) - return nil -} - -func (rl *rateLimiter) closed() bool { - return atomic.LoadInt64(&rl._closed) == 1 -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/rate_limit_test.go b/Godeps/_workspace/src/gopkg.in/redis.v2/rate_limit_test.go deleted file mode 100644 index 2f0d41a..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/rate_limit_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package redis - -import ( - "sync" - "testing" - "time" -) - -func TestRateLimiter(t *testing.T) { - var n = 100000 - if testing.Short() { - n = 1000 - } - rl := newRateLimiter(time.Minute, n) - - wg := &sync.WaitGroup{} - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - if !rl.Check() { - panic("check failed") - } - wg.Done() - }() - } - wg.Wait() - - if rl.Check() && rl.Check() { - t.Fatal("check passed") - } -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/redis.go b/Godeps/_workspace/src/gopkg.in/redis.v2/redis.go deleted file mode 100644 index 0d15dc8..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/redis.go +++ /dev/null @@ -1,231 +0,0 @@ -package redis - -import ( - "log" - "net" - "time" -) - -type baseClient struct { - connPool pool - opt *options - cmds []Cmder -} - -func (c *baseClient) writeCmd(cn *conn, cmds ...Cmder) error { - buf := cn.buf[:0] - for _, cmd := range cmds { - buf = appendArgs(buf, cmd.args()) - } - - _, err := cn.Write(buf) - return err -} - -func (c *baseClient) conn() (*conn, error) { - cn, isNew, err := c.connPool.Get() - if err != nil { - return nil, err - } - - if isNew { - if err := c.initConn(cn); err != nil { - c.removeConn(cn) - return nil, err - } - } - - return cn, nil -} - -func (c *baseClient) initConn(cn *conn) error { - if c.opt.Password == "" && c.opt.DB == 0 { - return nil - } - - pool := newSingleConnPool(c.connPool, false) - pool.SetConn(cn) - - // Client is not closed because we want to reuse underlying connection. - client := &Client{ - baseClient: &baseClient{ - opt: c.opt, - connPool: pool, - }, - } - - if c.opt.Password != "" { - if err := client.Auth(c.opt.Password).Err(); err != nil { - return err - } - } - - if c.opt.DB > 0 { - if err := client.Select(c.opt.DB).Err(); err != nil { - return err - } - } - - return nil -} - -func (c *baseClient) freeConn(cn *conn, ei error) error { - if cn.rd.Buffered() > 0 { - return c.connPool.Remove(cn) - } - if _, ok := ei.(redisError); ok { - return c.connPool.Put(cn) - } - return c.connPool.Remove(cn) -} - -func (c *baseClient) removeConn(cn *conn) { - if err := c.connPool.Remove(cn); err != nil { - log.Printf("pool.Remove failed: %s", err) - } -} - -func (c *baseClient) putConn(cn *conn) { - if err := c.connPool.Put(cn); err != nil { - log.Printf("pool.Put failed: %s", err) - } -} - -func (c *baseClient) Process(cmd Cmder) { - if c.cmds == nil { - c.run(cmd) - } else { - c.cmds = append(c.cmds, cmd) - } -} - -func (c *baseClient) run(cmd Cmder) { - cn, err := c.conn() - if err != nil { - cmd.setErr(err) - return - } - - if timeout := cmd.writeTimeout(); timeout != nil { - cn.writeTimeout = *timeout - } else { - cn.writeTimeout = c.opt.WriteTimeout - } - - if timeout := cmd.readTimeout(); timeout != nil { - cn.readTimeout = *timeout - } else { - cn.readTimeout = c.opt.ReadTimeout - } - - if err := c.writeCmd(cn, cmd); err != nil { - c.freeConn(cn, err) - cmd.setErr(err) - return - } - - if err := cmd.parseReply(cn.rd); err != nil { - c.freeConn(cn, err) - return - } - - c.putConn(cn) -} - -// Close closes the client, releasing any open resources. -func (c *baseClient) Close() error { - return c.connPool.Close() -} - -//------------------------------------------------------------------------------ - -type options struct { - Password string - DB int64 - - DialTimeout time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration - - PoolSize int - IdleTimeout time.Duration -} - -type Options struct { - Network string - Addr string - - // Dialer creates new network connection and has priority over - // Network and Addr options. - Dialer func() (net.Conn, error) - - Password string - DB int64 - - DialTimeout time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration - - PoolSize int - IdleTimeout time.Duration -} - -func (opt *Options) getPoolSize() int { - if opt.PoolSize == 0 { - return 10 - } - return opt.PoolSize -} - -func (opt *Options) getDialTimeout() time.Duration { - if opt.DialTimeout == 0 { - return 5 * time.Second - } - return opt.DialTimeout -} - -func (opt *Options) options() *options { - return &options{ - DB: opt.DB, - Password: opt.Password, - - DialTimeout: opt.getDialTimeout(), - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, - - PoolSize: opt.getPoolSize(), - IdleTimeout: opt.IdleTimeout, - } -} - -type Client struct { - *baseClient -} - -func NewClient(clOpt *Options) *Client { - opt := clOpt.options() - dialer := clOpt.Dialer - if dialer == nil { - dialer = func() (net.Conn, error) { - return net.DialTimeout(clOpt.Network, clOpt.Addr, opt.DialTimeout) - } - } - return &Client{ - baseClient: &baseClient{ - opt: opt, - connPool: newConnPool(newConnFunc(dialer), opt), - }, - } -} - -// Deprecated. Use NewClient instead. -func NewTCPClient(opt *Options) *Client { - opt.Network = "tcp" - return NewClient(opt) -} - -// Deprecated. Use NewClient instead. -func NewUnixClient(opt *Options) *Client { - opt.Network = "unix" - return NewClient(opt) -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/redis_test.go b/Godeps/_workspace/src/gopkg.in/redis.v2/redis_test.go deleted file mode 100644 index 49f84d0..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/redis_test.go +++ /dev/null @@ -1,3333 +0,0 @@ -package redis_test - -import ( - "bytes" - "fmt" - "io" - "net" - "sort" - "strconv" - "sync" - "testing" - "time" - - "gopkg.in/redis.v2" - - . "gopkg.in/check.v1" -) - -const redisAddr = ":6379" - -//------------------------------------------------------------------------------ - -func sortStrings(slice []string) []string { - sort.Strings(slice) - return slice -} - -//------------------------------------------------------------------------------ - -type RedisConnectorTest struct{} - -var _ = Suite(&RedisConnectorTest{}) - -func (t *RedisConnectorTest) TestShutdown(c *C) { - c.Skip("shutdowns server") - - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - - shutdown := client.Shutdown() - c.Check(shutdown.Err(), Equals, io.EOF) - c.Check(shutdown.Val(), Equals, "") - - ping := client.Ping() - c.Check(ping.Err(), ErrorMatches, "dial tcp :[0-9]+: connection refused") - c.Check(ping.Val(), Equals, "") -} - -func (t *RedisConnectorTest) TestNewTCPClient(c *C) { - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - ping := client.Ping() - c.Check(ping.Err(), IsNil) - c.Check(ping.Val(), Equals, "PONG") - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestNewUnixClient(c *C) { - c.Skip("not available on Travis CI") - - client := redis.NewUnixClient(&redis.Options{ - Addr: "/tmp/redis.sock", - }) - ping := client.Ping() - c.Check(ping.Err(), IsNil) - c.Check(ping.Val(), Equals, "PONG") - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestDialer(c *C) { - client := redis.NewClient(&redis.Options{ - Dialer: func() (net.Conn, error) { - return net.Dial("tcp", redisAddr) - }, - }) - ping := client.Ping() - c.Check(ping.Err(), IsNil) - c.Check(ping.Val(), Equals, "PONG") - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestClose(c *C) { - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - c.Assert(client.Close(), IsNil) - - ping := client.Ping() - c.Assert(ping.Err(), Not(IsNil)) - c.Assert(ping.Err().Error(), Equals, "redis: client is closed") - - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestPubSubClose(c *C) { - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - - pubsub := client.PubSub() - c.Assert(pubsub.Close(), IsNil) - - _, err := pubsub.Receive() - c.Assert(err, Not(IsNil)) - c.Assert(err.Error(), Equals, "redis: client is closed") - - ping := client.Ping() - c.Assert(ping.Err(), IsNil) - - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestMultiClose(c *C) { - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - - multi := client.Multi() - c.Assert(multi.Close(), IsNil) - - _, err := multi.Exec(func() error { - multi.Ping() - return nil - }) - c.Assert(err, Not(IsNil)) - c.Assert(err.Error(), Equals, "redis: client is closed") - - ping := client.Ping() - c.Assert(ping.Err(), IsNil) - - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestPipelineClose(c *C) { - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - - _, err := client.Pipelined(func(pipeline *redis.Pipeline) error { - c.Assert(pipeline.Close(), IsNil) - pipeline.Ping() - return nil - }) - c.Assert(err, Not(IsNil)) - c.Assert(err.Error(), Equals, "redis: client is closed") - - ping := client.Ping() - c.Assert(ping.Err(), IsNil) - - c.Assert(client.Close(), IsNil) -} - -func (t *RedisConnectorTest) TestIdleTimeout(c *C) { - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - IdleTimeout: time.Nanosecond, - }) - for i := 0; i < 10; i++ { - c.Assert(client.Ping().Err(), IsNil) - } -} - -func (t *RedisConnectorTest) TestSelectDb(c *C) { - client1 := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - DB: 1, - }) - c.Assert(client1.Set("key", "db1").Err(), IsNil) - - client2 := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - DB: 2, - }) - c.Assert(client2.Get("key").Err(), Equals, redis.Nil) -} - -//------------------------------------------------------------------------------ - -type RedisConnPoolTest struct { - client *redis.Client -} - -var _ = Suite(&RedisConnPoolTest{}) - -func (t *RedisConnPoolTest) SetUpTest(c *C) { - t.client = redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) -} - -func (t *RedisConnPoolTest) TearDownTest(c *C) { - c.Assert(t.client.FlushDb().Err(), IsNil) - c.Assert(t.client.Close(), IsNil) -} - -func (t *RedisConnPoolTest) TestConnPoolMaxSize(c *C) { - wg := &sync.WaitGroup{} - for i := 0; i < 1000; i++ { - wg.Add(1) - go func() { - ping := t.client.Ping() - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") - wg.Done() - }() - } - wg.Wait() - - c.Assert(t.client.Pool().Size(), Equals, 10) - c.Assert(t.client.Pool().Len(), Equals, 10) -} - -func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPipelineClient(c *C) { - const N = 1000 - - wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { - go func() { - pipeline := t.client.Pipeline() - ping := pipeline.Ping() - cmds, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 1) - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") - - c.Assert(pipeline.Close(), IsNil) - - wg.Done() - }() - } - wg.Wait() - - c.Assert(t.client.Pool().Size(), Equals, 10) - c.Assert(t.client.Pool().Len(), Equals, 10) -} - -func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnMultiClient(c *C) { - const N = 1000 - - wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { - go func() { - multi := t.client.Multi() - var ping *redis.StatusCmd - cmds, err := multi.Exec(func() error { - ping = multi.Ping() - return nil - }) - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 1) - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") - - c.Assert(multi.Close(), IsNil) - - wg.Done() - }() - } - wg.Wait() - - c.Assert(t.client.Pool().Size(), Equals, 10) - c.Assert(t.client.Pool().Len(), Equals, 10) -} - -func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPubSub(c *C) { - const N = 10 - - wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { - go func() { - defer wg.Done() - pubsub := t.client.PubSub() - c.Assert(pubsub.Subscribe(), IsNil) - c.Assert(pubsub.Close(), IsNil) - }() - } - wg.Wait() - - c.Assert(t.client.Pool().Size(), Equals, 0) - c.Assert(t.client.Pool().Len(), Equals, 0) -} - -func (t *RedisConnPoolTest) TestConnPoolRemovesBrokenConn(c *C) { - cn, _, err := t.client.Pool().Get() - c.Assert(err, IsNil) - c.Assert(cn.Close(), IsNil) - c.Assert(t.client.Pool().Put(cn), IsNil) - - ping := t.client.Ping() - c.Assert(ping.Err().Error(), Equals, "use of closed network connection") - c.Assert(ping.Val(), Equals, "") - - ping = t.client.Ping() - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") - - c.Assert(t.client.Pool().Size(), Equals, 1) - c.Assert(t.client.Pool().Len(), Equals, 1) -} - -func (t *RedisConnPoolTest) TestConnPoolReusesConn(c *C) { - for i := 0; i < 1000; i++ { - ping := t.client.Ping() - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") - } - - c.Assert(t.client.Pool().Size(), Equals, 1) - c.Assert(t.client.Pool().Len(), Equals, 1) -} - -//------------------------------------------------------------------------------ - -type RedisTest struct { - client *redis.Client -} - -var _ = Suite(&RedisTest{}) - -func Test(t *testing.T) { TestingT(t) } - -func (t *RedisTest) SetUpTest(c *C) { - t.client = redis.NewTCPClient(&redis.Options{ - Addr: ":6379", - }) - - // This is much faster than Flushall. - c.Assert(t.client.Select(1).Err(), IsNil) - c.Assert(t.client.FlushDb().Err(), IsNil) - c.Assert(t.client.Select(0).Err(), IsNil) - c.Assert(t.client.FlushDb().Err(), IsNil) -} - -func (t *RedisTest) TearDownTest(c *C) { - c.Assert(t.client.Close(), IsNil) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestCmdStringMethod(c *C) { - set := t.client.Set("foo", "bar") - c.Assert(set.String(), Equals, "SET foo bar: OK") - - get := t.client.Get("foo") - c.Assert(get.String(), Equals, "GET foo: bar") -} - -func (t *RedisTest) TestCmdStringMethodError(c *C) { - get2 := t.client.Get("key_does_not_exists") - c.Assert(get2.String(), Equals, "GET key_does_not_exists: redis: nil") -} - -func (t *RedisTest) TestRunWithouthCheckingErrVal(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") - - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") -} - -func (t *RedisTest) TestGetSpecChars(c *C) { - set := t.client.Set("key", "hello1\r\nhello2\r\n") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello1\r\nhello2\r\n") -} - -func (t *RedisTest) TestGetBigVal(c *C) { - val := string(bytes.Repeat([]byte{'*'}, 1<<16)) - - set := t.client.Set("key", val) - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, val) -} - -func (t *RedisTest) TestManyKeys(c *C) { - var n = 100000 - - for i := 0; i < n; i++ { - t.client.Set("keys.key"+strconv.Itoa(i), "hello"+strconv.Itoa(i)) - } - keys := t.client.Keys("keys.*") - c.Assert(keys.Err(), IsNil) - c.Assert(len(keys.Val()), Equals, n) -} - -func (t *RedisTest) TestManyKeys2(c *C) { - var n = 100000 - - keys := []string{"non-existent-key"} - for i := 0; i < n; i++ { - key := "keys.key" + strconv.Itoa(i) - t.client.Set(key, "hello"+strconv.Itoa(i)) - keys = append(keys, key) - } - keys = append(keys, "non-existent-key") - - mget := t.client.MGet(keys...) - c.Assert(mget.Err(), IsNil) - c.Assert(len(mget.Val()), Equals, n+2) - vals := mget.Val() - for i := 0; i < n; i++ { - c.Assert(vals[i+1], Equals, "hello"+strconv.Itoa(i)) - } - c.Assert(vals[0], Equals, nil) - c.Assert(vals[n+1], Equals, nil) -} - -func (t *RedisTest) TestStringCmdHelpers(c *C) { - set := t.client.Set("key", "10") - c.Assert(set.Err(), IsNil) - - n, err := t.client.Get("key").Int64() - c.Assert(err, IsNil) - c.Assert(n, Equals, int64(10)) - - un, err := t.client.Get("key").Uint64() - c.Assert(err, IsNil) - c.Assert(un, Equals, uint64(10)) - - f, err := t.client.Get("key").Float64() - c.Assert(err, IsNil) - c.Assert(f, Equals, float64(10)) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestAuth(c *C) { - auth := t.client.Auth("password") - c.Assert(auth.Err(), ErrorMatches, "ERR Client sent AUTH, but no password is set") - c.Assert(auth.Val(), Equals, "") -} - -func (t *RedisTest) TestEcho(c *C) { - echo := t.client.Echo("hello") - c.Assert(echo.Err(), IsNil) - c.Assert(echo.Val(), Equals, "hello") -} - -func (t *RedisTest) TestPing(c *C) { - ping := t.client.Ping() - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") -} - -func (t *RedisTest) TestSelect(c *C) { - sel := t.client.Select(1) - c.Assert(sel.Err(), IsNil) - c.Assert(sel.Val(), Equals, "OK") -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestCmdKeysDel(c *C) { - set := t.client.Set("key1", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - set = t.client.Set("key2", "World") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - del := t.client.Del("key1", "key2", "key3") - c.Assert(del.Err(), IsNil) - c.Assert(del.Val(), Equals, int64(2)) -} - -func (t *RedisTest) TestCmdKeysDump(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - dump := t.client.Dump("key") - c.Assert(dump.Err(), IsNil) - c.Assert(dump.Val(), Equals, "\x00\x05hello\x06\x00\xf5\x9f\xb7\xf6\x90a\x1c\x99") -} - -func (t *RedisTest) TestCmdKeysExists(c *C) { - set := t.client.Set("key1", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - exists := t.client.Exists("key1") - c.Assert(exists.Err(), IsNil) - c.Assert(exists.Val(), Equals, true) - - exists = t.client.Exists("key2") - c.Assert(exists.Err(), IsNil) - c.Assert(exists.Val(), Equals, false) -} - -func (t *RedisTest) TestCmdKeysExpire(c *C) { - set := t.client.Set("key", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - expire := t.client.Expire("key", 10*time.Second) - c.Assert(expire.Err(), IsNil) - c.Assert(expire.Val(), Equals, true) - - ttl := t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val(), Equals, 10*time.Second) - - set = t.client.Set("key", "Hello World") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - ttl = t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val() < 0, Equals, true) -} - -func (t *RedisTest) TestCmdKeysExpireAt(c *C) { - set := t.client.Set("key", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - exists := t.client.Exists("key") - c.Assert(exists.Err(), IsNil) - c.Assert(exists.Val(), Equals, true) - - expireAt := t.client.ExpireAt("key", time.Now().Add(-time.Hour)) - c.Assert(expireAt.Err(), IsNil) - c.Assert(expireAt.Val(), Equals, true) - - exists = t.client.Exists("key") - c.Assert(exists.Err(), IsNil) - c.Assert(exists.Val(), Equals, false) -} - -func (t *RedisTest) TestCmdKeysKeys(c *C) { - mset := t.client.MSet("one", "1", "two", "2", "three", "3", "four", "4") - c.Assert(mset.Err(), IsNil) - c.Assert(mset.Val(), Equals, "OK") - - keys := t.client.Keys("*o*") - c.Assert(keys.Err(), IsNil) - c.Assert(sortStrings(keys.Val()), DeepEquals, []string{"four", "one", "two"}) - - keys = t.client.Keys("t??") - c.Assert(keys.Err(), IsNil) - c.Assert(keys.Val(), DeepEquals, []string{"two"}) - - keys = t.client.Keys("*") - c.Assert(keys.Err(), IsNil) - c.Assert( - sortStrings(keys.Val()), - DeepEquals, - []string{"four", "one", "three", "two"}, - ) -} - -func (t *RedisTest) TestCmdKeysMigrate(c *C) { - migrate := t.client.Migrate("localhost", "6380", "key", 0, 0) - c.Assert(migrate.Err(), IsNil) - c.Assert(migrate.Val(), Equals, "NOKEY") - - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - migrate = t.client.Migrate("localhost", "6380", "key", 0, 0) - c.Assert(migrate.Err(), ErrorMatches, "IOERR error or timeout writing to target instance") - c.Assert(migrate.Val(), Equals, "") -} - -func (t *RedisTest) TestCmdKeysMove(c *C) { - move := t.client.Move("key", 1) - c.Assert(move.Err(), IsNil) - c.Assert(move.Val(), Equals, false) - - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - move = t.client.Move("key", 1) - c.Assert(move.Err(), IsNil) - c.Assert(move.Val(), Equals, true) - - get := t.client.Get("key") - c.Assert(get.Err(), Equals, redis.Nil) - c.Assert(get.Val(), Equals, "") - - sel := t.client.Select(1) - c.Assert(sel.Err(), IsNil) - c.Assert(sel.Val(), Equals, "OK") - - get = t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestCmdKeysObject(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - refCount := t.client.ObjectRefCount("key") - c.Assert(refCount.Err(), IsNil) - c.Assert(refCount.Val(), Equals, int64(1)) - - enc := t.client.ObjectEncoding("key") - c.Assert(enc.Err(), IsNil) - c.Assert(enc.Val(), Equals, "raw") - - idleTime := t.client.ObjectIdleTime("key") - c.Assert(idleTime.Err(), IsNil) - c.Assert(idleTime.Val(), Equals, time.Duration(0)) -} - -func (t *RedisTest) TestCmdKeysPersist(c *C) { - set := t.client.Set("key", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - expire := t.client.Expire("key", 10*time.Second) - c.Assert(expire.Err(), IsNil) - c.Assert(expire.Val(), Equals, true) - - ttl := t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val(), Equals, 10*time.Second) - - persist := t.client.Persist("key") - c.Assert(persist.Err(), IsNil) - c.Assert(persist.Val(), Equals, true) - - ttl = t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val() < 0, Equals, true) -} - -func (t *RedisTest) TestCmdKeysPExpire(c *C) { - set := t.client.Set("key", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - expiration := 900 * time.Millisecond - pexpire := t.client.PExpire("key", expiration) - c.Assert(pexpire.Err(), IsNil) - c.Assert(pexpire.Val(), Equals, true) - - ttl := t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val(), Equals, time.Second) - - pttl := t.client.PTTL("key") - c.Assert(pttl.Err(), IsNil) - c.Assert(pttl.Val() <= expiration, Equals, true) - c.Assert(pttl.Val() >= expiration-time.Millisecond, Equals, true) -} - -func (t *RedisTest) TestCmdKeysPExpireAt(c *C) { - set := t.client.Set("key", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - expiration := 900 * time.Millisecond - pexpireat := t.client.PExpireAt("key", time.Now().Add(expiration)) - c.Assert(pexpireat.Err(), IsNil) - c.Assert(pexpireat.Val(), Equals, true) - - ttl := t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val(), Equals, time.Second) - - pttl := t.client.PTTL("key") - c.Assert(pttl.Err(), IsNil) - c.Assert(pttl.Val() <= expiration, Equals, true) - c.Assert(pttl.Val() >= expiration-time.Millisecond, Equals, true) -} - -func (t *RedisTest) TestCmdKeysPTTL(c *C) { - set := t.client.Set("key", "Hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - expiration := time.Second - expire := t.client.Expire("key", expiration) - c.Assert(expire.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - pttl := t.client.PTTL("key") - c.Assert(pttl.Err(), IsNil) - c.Assert(pttl.Val() <= expiration, Equals, true) - c.Assert(pttl.Val() >= expiration-time.Millisecond, Equals, true) -} - -func (t *RedisTest) TestCmdKeysRandomKey(c *C) { - randomKey := t.client.RandomKey() - c.Assert(randomKey.Err(), Equals, redis.Nil) - c.Assert(randomKey.Val(), Equals, "") - - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - randomKey = t.client.RandomKey() - c.Assert(randomKey.Err(), IsNil) - c.Assert(randomKey.Val(), Equals, "key") -} - -func (t *RedisTest) TestCmdKeysRename(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - status := t.client.Rename("key", "key1") - c.Assert(status.Err(), IsNil) - c.Assert(status.Val(), Equals, "OK") - - get := t.client.Get("key1") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestCmdKeysRenameNX(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - renameNX := t.client.RenameNX("key", "key1") - c.Assert(renameNX.Err(), IsNil) - c.Assert(renameNX.Val(), Equals, true) - - get := t.client.Get("key1") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestCmdKeysRestore(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - dump := t.client.Dump("key") - c.Assert(dump.Err(), IsNil) - - del := t.client.Del("key") - c.Assert(del.Err(), IsNil) - - restore := t.client.Restore("key", 0, dump.Val()) - c.Assert(restore.Err(), IsNil) - c.Assert(restore.Val(), Equals, "OK") - - type_ := t.client.Type("key") - c.Assert(type_.Err(), IsNil) - c.Assert(type_.Val(), Equals, "string") - - lRange := t.client.Get("key") - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), Equals, "hello") -} - -func (t *RedisTest) TestCmdKeysSort(c *C) { - lPush := t.client.LPush("list", "1") - c.Assert(lPush.Err(), IsNil) - c.Assert(lPush.Val(), Equals, int64(1)) - lPush = t.client.LPush("list", "3") - c.Assert(lPush.Err(), IsNil) - c.Assert(lPush.Val(), Equals, int64(2)) - lPush = t.client.LPush("list", "2") - c.Assert(lPush.Err(), IsNil) - c.Assert(lPush.Val(), Equals, int64(3)) - - sort := t.client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}) - c.Assert(sort.Err(), IsNil) - c.Assert(sort.Val(), DeepEquals, []string{"1", "2"}) -} - -func (t *RedisTest) TestCmdKeysSortBy(c *C) { - lPush := t.client.LPush("list", "1") - c.Assert(lPush.Err(), IsNil) - c.Assert(lPush.Val(), Equals, int64(1)) - lPush = t.client.LPush("list", "3") - c.Assert(lPush.Err(), IsNil) - c.Assert(lPush.Val(), Equals, int64(2)) - lPush = t.client.LPush("list", "2") - c.Assert(lPush.Err(), IsNil) - c.Assert(lPush.Val(), Equals, int64(3)) - - set := t.client.Set("weight_1", "5") - c.Assert(set.Err(), IsNil) - set = t.client.Set("weight_2", "2") - c.Assert(set.Err(), IsNil) - set = t.client.Set("weight_3", "8") - c.Assert(set.Err(), IsNil) - - sort := t.client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC", By: "weight_*"}) - c.Assert(sort.Err(), IsNil) - c.Assert(sort.Val(), DeepEquals, []string{"2", "1"}) -} - -func (t *RedisTest) TestCmdKeysTTL(c *C) { - ttl := t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val() < 0, Equals, true) - - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - expire := t.client.Expire("key", 60*time.Second) - c.Assert(expire.Err(), IsNil) - c.Assert(expire.Val(), Equals, true) - - ttl = t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val(), Equals, 60*time.Second) -} - -func (t *RedisTest) TestCmdKeysType(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - type_ := t.client.Type("key") - c.Assert(type_.Err(), IsNil) - c.Assert(type_.Val(), Equals, "string") -} - -func (t *RedisTest) TestCmdScan(c *C) { - for i := 0; i < 1000; i++ { - set := t.client.Set(fmt.Sprintf("key%d", i), "hello") - c.Assert(set.Err(), IsNil) - } - - cursor, keys, err := t.client.Scan(0, "", 0).Result() - c.Assert(err, IsNil) - c.Assert(cursor > 0, Equals, true) - c.Assert(len(keys) > 0, Equals, true) -} - -func (t *RedisTest) TestCmdSScan(c *C) { - for i := 0; i < 1000; i++ { - sadd := t.client.SAdd("myset", fmt.Sprintf("member%d", i)) - c.Assert(sadd.Err(), IsNil) - } - - cursor, keys, err := t.client.SScan("myset", 0, "", 0).Result() - c.Assert(err, IsNil) - c.Assert(cursor > 0, Equals, true) - c.Assert(len(keys) > 0, Equals, true) -} - -func (t *RedisTest) TestCmdHScan(c *C) { - for i := 0; i < 1000; i++ { - sadd := t.client.HSet("myhash", fmt.Sprintf("key%d", i), "hello") - c.Assert(sadd.Err(), IsNil) - } - - cursor, keys, err := t.client.HScan("myhash", 0, "", 0).Result() - c.Assert(err, IsNil) - c.Assert(cursor > 0, Equals, true) - c.Assert(len(keys) > 0, Equals, true) -} - -func (t *RedisTest) TestCmdZScan(c *C) { - for i := 0; i < 1000; i++ { - sadd := t.client.ZAdd("myset", redis.Z{float64(i), fmt.Sprintf("member%d", i)}) - c.Assert(sadd.Err(), IsNil) - } - - cursor, keys, err := t.client.ZScan("myset", 0, "", 0).Result() - c.Assert(err, IsNil) - c.Assert(cursor > 0, Equals, true) - c.Assert(len(keys) > 0, Equals, true) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestStringsAppend(c *C) { - exists := t.client.Exists("key") - c.Assert(exists.Err(), IsNil) - c.Assert(exists.Val(), Equals, false) - - append := t.client.Append("key", "Hello") - c.Assert(append.Err(), IsNil) - c.Assert(append.Val(), Equals, int64(5)) - - append = t.client.Append("key", " World") - c.Assert(append.Err(), IsNil) - c.Assert(append.Val(), Equals, int64(11)) - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "Hello World") -} - -func (t *RedisTest) TestStringsBitCount(c *C) { - set := t.client.Set("key", "foobar") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - bitCount := t.client.BitCount("key", nil) - c.Assert(bitCount.Err(), IsNil) - c.Assert(bitCount.Val(), Equals, int64(26)) - - bitCount = t.client.BitCount("key", &redis.BitCount{0, 0}) - c.Assert(bitCount.Err(), IsNil) - c.Assert(bitCount.Val(), Equals, int64(4)) - - bitCount = t.client.BitCount("key", &redis.BitCount{1, 1}) - c.Assert(bitCount.Err(), IsNil) - c.Assert(bitCount.Val(), Equals, int64(6)) -} - -func (t *RedisTest) TestStringsBitOpAnd(c *C) { - set := t.client.Set("key1", "1") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - set = t.client.Set("key2", "0") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - bitOpAnd := t.client.BitOpAnd("dest", "key1", "key2") - c.Assert(bitOpAnd.Err(), IsNil) - c.Assert(bitOpAnd.Val(), Equals, int64(1)) - - get := t.client.Get("dest") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "0") -} - -func (t *RedisTest) TestStringsBitOpOr(c *C) { - set := t.client.Set("key1", "1") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - set = t.client.Set("key2", "0") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - bitOpOr := t.client.BitOpOr("dest", "key1", "key2") - c.Assert(bitOpOr.Err(), IsNil) - c.Assert(bitOpOr.Val(), Equals, int64(1)) - - get := t.client.Get("dest") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "1") -} - -func (t *RedisTest) TestStringsBitOpXor(c *C) { - set := t.client.Set("key1", "\xff") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - set = t.client.Set("key2", "\x0f") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - bitOpXor := t.client.BitOpXor("dest", "key1", "key2") - c.Assert(bitOpXor.Err(), IsNil) - c.Assert(bitOpXor.Val(), Equals, int64(1)) - - get := t.client.Get("dest") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "\xf0") -} - -func (t *RedisTest) TestStringsBitOpNot(c *C) { - set := t.client.Set("key1", "\x00") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - bitOpNot := t.client.BitOpNot("dest", "key1") - c.Assert(bitOpNot.Err(), IsNil) - c.Assert(bitOpNot.Val(), Equals, int64(1)) - - get := t.client.Get("dest") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "\xff") -} - -func (t *RedisTest) TestStringsDecr(c *C) { - set := t.client.Set("key", "10") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - decr := t.client.Decr("key") - c.Assert(decr.Err(), IsNil) - c.Assert(decr.Val(), Equals, int64(9)) - - set = t.client.Set("key", "234293482390480948029348230948") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - decr = t.client.Decr("key") - c.Assert(decr.Err(), ErrorMatches, "ERR value is not an integer or out of range") - c.Assert(decr.Val(), Equals, int64(0)) -} - -func (t *RedisTest) TestStringsDecrBy(c *C) { - set := t.client.Set("key", "10") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - decrBy := t.client.DecrBy("key", 5) - c.Assert(decrBy.Err(), IsNil) - c.Assert(decrBy.Val(), Equals, int64(5)) -} - -func (t *RedisTest) TestStringsGet(c *C) { - get := t.client.Get("_") - c.Assert(get.Err(), Equals, redis.Nil) - c.Assert(get.Val(), Equals, "") - - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - get = t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestStringsGetBit(c *C) { - setBit := t.client.SetBit("key", 7, 1) - c.Assert(setBit.Err(), IsNil) - c.Assert(setBit.Val(), Equals, int64(0)) - - getBit := t.client.GetBit("key", 0) - c.Assert(getBit.Err(), IsNil) - c.Assert(getBit.Val(), Equals, int64(0)) - - getBit = t.client.GetBit("key", 7) - c.Assert(getBit.Err(), IsNil) - c.Assert(getBit.Val(), Equals, int64(1)) - - getBit = t.client.GetBit("key", 100) - c.Assert(getBit.Err(), IsNil) - c.Assert(getBit.Val(), Equals, int64(0)) -} - -func (t *RedisTest) TestStringsGetRange(c *C) { - set := t.client.Set("key", "This is a string") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - getRange := t.client.GetRange("key", 0, 3) - c.Assert(getRange.Err(), IsNil) - c.Assert(getRange.Val(), Equals, "This") - - getRange = t.client.GetRange("key", -3, -1) - c.Assert(getRange.Err(), IsNil) - c.Assert(getRange.Val(), Equals, "ing") - - getRange = t.client.GetRange("key", 0, -1) - c.Assert(getRange.Err(), IsNil) - c.Assert(getRange.Val(), Equals, "This is a string") - - getRange = t.client.GetRange("key", 10, 100) - c.Assert(getRange.Err(), IsNil) - c.Assert(getRange.Val(), Equals, "string") -} - -func (t *RedisTest) TestStringsGetSet(c *C) { - incr := t.client.Incr("key") - c.Assert(incr.Err(), IsNil) - c.Assert(incr.Val(), Equals, int64(1)) - - getSet := t.client.GetSet("key", "0") - c.Assert(getSet.Err(), IsNil) - c.Assert(getSet.Val(), Equals, "1") - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "0") -} - -func (t *RedisTest) TestStringsIncr(c *C) { - set := t.client.Set("key", "10") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - incr := t.client.Incr("key") - c.Assert(incr.Err(), IsNil) - c.Assert(incr.Val(), Equals, int64(11)) - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "11") -} - -func (t *RedisTest) TestStringsIncrBy(c *C) { - set := t.client.Set("key", "10") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - incrBy := t.client.IncrBy("key", 5) - c.Assert(incrBy.Err(), IsNil) - c.Assert(incrBy.Val(), Equals, int64(15)) -} - -func (t *RedisTest) TestIncrByFloat(c *C) { - set := t.client.Set("key", "10.50") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - incrByFloat := t.client.IncrByFloat("key", 0.1) - c.Assert(incrByFloat.Err(), IsNil) - c.Assert(incrByFloat.Val(), Equals, 10.6) - - set = t.client.Set("key", "5.0e3") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - incrByFloat = t.client.IncrByFloat("key", 2.0e2) - c.Assert(incrByFloat.Err(), IsNil) - c.Assert(incrByFloat.Val(), Equals, float64(5200)) -} - -func (t *RedisTest) TestIncrByFloatOverflow(c *C) { - incrByFloat := t.client.IncrByFloat("key", 996945661) - c.Assert(incrByFloat.Err(), IsNil) - c.Assert(incrByFloat.Val(), Equals, float64(996945661)) -} - -func (t *RedisTest) TestStringsMSetMGet(c *C) { - mSet := t.client.MSet("key1", "hello1", "key2", "hello2") - c.Assert(mSet.Err(), IsNil) - c.Assert(mSet.Val(), Equals, "OK") - - mGet := t.client.MGet("key1", "key2", "_") - c.Assert(mGet.Err(), IsNil) - c.Assert(mGet.Val(), DeepEquals, []interface{}{"hello1", "hello2", nil}) -} - -func (t *RedisTest) TestStringsMSetNX(c *C) { - mSetNX := t.client.MSetNX("key1", "hello1", "key2", "hello2") - c.Assert(mSetNX.Err(), IsNil) - c.Assert(mSetNX.Val(), Equals, true) - - mSetNX = t.client.MSetNX("key2", "hello1", "key3", "hello2") - c.Assert(mSetNX.Err(), IsNil) - c.Assert(mSetNX.Val(), Equals, false) -} - -func (t *RedisTest) TestStringsPSetEx(c *C) { - expiration := 50 * time.Millisecond - psetex := t.client.PSetEx("key", expiration, "hello") - c.Assert(psetex.Err(), IsNil) - c.Assert(psetex.Val(), Equals, "OK") - - pttl := t.client.PTTL("key") - c.Assert(pttl.Err(), IsNil) - c.Assert(pttl.Val() <= expiration, Equals, true) - c.Assert(pttl.Val() >= expiration-time.Millisecond, Equals, true) - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestStringsSetGet(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestStringsSetEx(c *C) { - setEx := t.client.SetEx("key", 10*time.Second, "hello") - c.Assert(setEx.Err(), IsNil) - c.Assert(setEx.Val(), Equals, "OK") - - ttl := t.client.TTL("key") - c.Assert(ttl.Err(), IsNil) - c.Assert(ttl.Val(), Equals, 10*time.Second) -} - -func (t *RedisTest) TestStringsSetNX(c *C) { - setNX := t.client.SetNX("key", "hello") - c.Assert(setNX.Err(), IsNil) - c.Assert(setNX.Val(), Equals, true) - - setNX = t.client.SetNX("key", "hello2") - c.Assert(setNX.Err(), IsNil) - c.Assert(setNX.Val(), Equals, false) - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestStringsSetRange(c *C) { - set := t.client.Set("key", "Hello World") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - range_ := t.client.SetRange("key", 6, "Redis") - c.Assert(range_.Err(), IsNil) - c.Assert(range_.Val(), Equals, int64(11)) - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "Hello Redis") -} - -func (t *RedisTest) TestStringsStrLen(c *C) { - set := t.client.Set("key", "hello") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - strLen := t.client.StrLen("key") - c.Assert(strLen.Err(), IsNil) - c.Assert(strLen.Val(), Equals, int64(5)) - - strLen = t.client.StrLen("_") - c.Assert(strLen.Err(), IsNil) - c.Assert(strLen.Val(), Equals, int64(0)) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestCmdHDel(c *C) { - hSet := t.client.HSet("hash", "key", "hello") - c.Assert(hSet.Err(), IsNil) - - hDel := t.client.HDel("hash", "key") - c.Assert(hDel.Err(), IsNil) - c.Assert(hDel.Val(), Equals, int64(1)) - - hDel = t.client.HDel("hash", "key") - c.Assert(hDel.Err(), IsNil) - c.Assert(hDel.Val(), Equals, int64(0)) -} - -func (t *RedisTest) TestCmdHExists(c *C) { - hSet := t.client.HSet("hash", "key", "hello") - c.Assert(hSet.Err(), IsNil) - - hExists := t.client.HExists("hash", "key") - c.Assert(hExists.Err(), IsNil) - c.Assert(hExists.Val(), Equals, true) - - hExists = t.client.HExists("hash", "key1") - c.Assert(hExists.Err(), IsNil) - c.Assert(hExists.Val(), Equals, false) -} - -func (t *RedisTest) TestCmdHGet(c *C) { - hSet := t.client.HSet("hash", "key", "hello") - c.Assert(hSet.Err(), IsNil) - - hGet := t.client.HGet("hash", "key") - c.Assert(hGet.Err(), IsNil) - c.Assert(hGet.Val(), Equals, "hello") - - hGet = t.client.HGet("hash", "key1") - c.Assert(hGet.Err(), Equals, redis.Nil) - c.Assert(hGet.Val(), Equals, "") -} - -func (t *RedisTest) TestCmdHGetAll(c *C) { - hSet := t.client.HSet("hash", "key1", "hello1") - c.Assert(hSet.Err(), IsNil) - hSet = t.client.HSet("hash", "key2", "hello2") - c.Assert(hSet.Err(), IsNil) - - hGetAll := t.client.HGetAll("hash") - c.Assert(hGetAll.Err(), IsNil) - c.Assert(hGetAll.Val(), DeepEquals, []string{"key1", "hello1", "key2", "hello2"}) -} - -func (t *RedisTest) TestCmdHGetAllMap(c *C) { - hSet := t.client.HSet("hash", "key1", "hello1") - c.Assert(hSet.Err(), IsNil) - hSet = t.client.HSet("hash", "key2", "hello2") - c.Assert(hSet.Err(), IsNil) - - hGetAll := t.client.HGetAllMap("hash") - c.Assert(hGetAll.Err(), IsNil) - c.Assert(hGetAll.Val(), DeepEquals, map[string]string{"key1": "hello1", "key2": "hello2"}) -} - -func (t *RedisTest) TestCmdHIncrBy(c *C) { - hSet := t.client.HSet("hash", "key", "5") - c.Assert(hSet.Err(), IsNil) - - hIncrBy := t.client.HIncrBy("hash", "key", 1) - c.Assert(hIncrBy.Err(), IsNil) - c.Assert(hIncrBy.Val(), Equals, int64(6)) - - hIncrBy = t.client.HIncrBy("hash", "key", -1) - c.Assert(hIncrBy.Err(), IsNil) - c.Assert(hIncrBy.Val(), Equals, int64(5)) - - hIncrBy = t.client.HIncrBy("hash", "key", -10) - c.Assert(hIncrBy.Err(), IsNil) - c.Assert(hIncrBy.Val(), Equals, int64(-5)) -} - -func (t *RedisTest) TestCmdHIncrByFloat(c *C) { - hSet := t.client.HSet("hash", "field", "10.50") - c.Assert(hSet.Err(), IsNil) - c.Assert(hSet.Val(), Equals, true) - - hIncrByFloat := t.client.HIncrByFloat("hash", "field", 0.1) - c.Assert(hIncrByFloat.Err(), IsNil) - c.Assert(hIncrByFloat.Val(), Equals, 10.6) - - hSet = t.client.HSet("hash", "field", "5.0e3") - c.Assert(hSet.Err(), IsNil) - c.Assert(hSet.Val(), Equals, false) - - hIncrByFloat = t.client.HIncrByFloat("hash", "field", 2.0e2) - c.Assert(hIncrByFloat.Err(), IsNil) - c.Assert(hIncrByFloat.Val(), Equals, float64(5200)) -} - -func (t *RedisTest) TestCmdHKeys(c *C) { - hkeys := t.client.HKeys("hash") - c.Assert(hkeys.Err(), IsNil) - c.Assert(hkeys.Val(), DeepEquals, []string{}) - - hset := t.client.HSet("hash", "key1", "hello1") - c.Assert(hset.Err(), IsNil) - hset = t.client.HSet("hash", "key2", "hello2") - c.Assert(hset.Err(), IsNil) - - hkeys = t.client.HKeys("hash") - c.Assert(hkeys.Err(), IsNil) - c.Assert(hkeys.Val(), DeepEquals, []string{"key1", "key2"}) -} - -func (t *RedisTest) TestCmdHLen(c *C) { - hSet := t.client.HSet("hash", "key1", "hello1") - c.Assert(hSet.Err(), IsNil) - hSet = t.client.HSet("hash", "key2", "hello2") - c.Assert(hSet.Err(), IsNil) - - hLen := t.client.HLen("hash") - c.Assert(hLen.Err(), IsNil) - c.Assert(hLen.Val(), Equals, int64(2)) -} - -func (t *RedisTest) TestCmdHMGet(c *C) { - hSet := t.client.HSet("hash", "key1", "hello1") - c.Assert(hSet.Err(), IsNil) - hSet = t.client.HSet("hash", "key2", "hello2") - c.Assert(hSet.Err(), IsNil) - - hMGet := t.client.HMGet("hash", "key1", "key2", "_") - c.Assert(hMGet.Err(), IsNil) - c.Assert(hMGet.Val(), DeepEquals, []interface{}{"hello1", "hello2", nil}) -} - -func (t *RedisTest) TestCmdHMSet(c *C) { - hMSet := t.client.HMSet("hash", "key1", "hello1", "key2", "hello2") - c.Assert(hMSet.Err(), IsNil) - c.Assert(hMSet.Val(), Equals, "OK") - - hGet := t.client.HGet("hash", "key1") - c.Assert(hGet.Err(), IsNil) - c.Assert(hGet.Val(), Equals, "hello1") - - hGet = t.client.HGet("hash", "key2") - c.Assert(hGet.Err(), IsNil) - c.Assert(hGet.Val(), Equals, "hello2") -} - -func (t *RedisTest) TestCmdHSet(c *C) { - hSet := t.client.HSet("hash", "key", "hello") - c.Assert(hSet.Err(), IsNil) - c.Assert(hSet.Val(), Equals, true) - - hGet := t.client.HGet("hash", "key") - c.Assert(hGet.Err(), IsNil) - c.Assert(hGet.Val(), Equals, "hello") -} - -func (t *RedisTest) TestCmdHSetNX(c *C) { - hSetNX := t.client.HSetNX("hash", "key", "hello") - c.Assert(hSetNX.Err(), IsNil) - c.Assert(hSetNX.Val(), Equals, true) - - hSetNX = t.client.HSetNX("hash", "key", "hello") - c.Assert(hSetNX.Err(), IsNil) - c.Assert(hSetNX.Val(), Equals, false) - - hGet := t.client.HGet("hash", "key") - c.Assert(hGet.Err(), IsNil) - c.Assert(hGet.Val(), Equals, "hello") -} - -func (t *RedisTest) TestCmdHVals(c *C) { - hSet := t.client.HSet("hash", "key1", "hello1") - c.Assert(hSet.Err(), IsNil) - hSet = t.client.HSet("hash", "key2", "hello2") - c.Assert(hSet.Err(), IsNil) - - hVals := t.client.HVals("hash") - c.Assert(hVals.Err(), IsNil) - c.Assert(hVals.Val(), DeepEquals, []string{"hello1", "hello2"}) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestCmdListsBLPop(c *C) { - rPush := t.client.RPush("list1", "a", "b", "c") - c.Assert(rPush.Err(), IsNil) - - bLPop := t.client.BLPop(0, "list1", "list2") - c.Assert(bLPop.Err(), IsNil) - c.Assert(bLPop.Val(), DeepEquals, []string{"list1", "a"}) -} - -func (t *RedisTest) TestCmdListsBLPopBlocks(c *C) { - started := make(chan bool) - done := make(chan bool) - go func() { - started <- true - bLPop := t.client.BLPop(0, "list") - c.Assert(bLPop.Err(), IsNil) - c.Assert(bLPop.Val(), DeepEquals, []string{"list", "a"}) - done <- true - }() - <-started - - select { - case <-done: - c.Error("BLPop is not blocked") - case <-time.After(time.Second): - // ok - } - - rPush := t.client.RPush("list", "a") - c.Assert(rPush.Err(), IsNil) - - select { - case <-done: - // ok - case <-time.After(time.Second): - c.Error("BLPop is still blocked") - // ok - } -} - -func (t *RedisTest) TestCmdListsBLPopTimeout(c *C) { - bLPop := t.client.BLPop(1, "list1") - c.Assert(bLPop.Err(), Equals, redis.Nil) - c.Assert(bLPop.Val(), IsNil) -} - -func (t *RedisTest) TestCmdListsBRPop(c *C) { - rPush := t.client.RPush("list1", "a", "b", "c") - c.Assert(rPush.Err(), IsNil) - - bRPop := t.client.BRPop(0, "list1", "list2") - c.Assert(bRPop.Err(), IsNil) - c.Assert(bRPop.Val(), DeepEquals, []string{"list1", "c"}) -} - -func (t *RedisTest) TestCmdListsBRPopBlocks(c *C) { - started := make(chan bool) - done := make(chan bool) - go func() { - started <- true - brpop := t.client.BRPop(0, "list") - c.Assert(brpop.Err(), IsNil) - c.Assert(brpop.Val(), DeepEquals, []string{"list", "a"}) - done <- true - }() - <-started - - select { - case <-done: - c.Error("BRPop is not blocked") - case <-time.After(time.Second): - // ok - } - - rPush := t.client.RPush("list", "a") - c.Assert(rPush.Err(), IsNil) - - select { - case <-done: - // ok - case <-time.After(time.Second): - c.Error("BRPop is still blocked") - // ok - } -} - -func (t *RedisTest) TestCmdListsBRPopLPush(c *C) { - rPush := t.client.RPush("list1", "a", "b", "c") - c.Assert(rPush.Err(), IsNil) - - bRPopLPush := t.client.BRPopLPush("list1", "list2", 0) - c.Assert(bRPopLPush.Err(), IsNil) - c.Assert(bRPopLPush.Val(), Equals, "c") -} - -func (t *RedisTest) TestCmdListsLIndex(c *C) { - lPush := t.client.LPush("list", "World") - c.Assert(lPush.Err(), IsNil) - lPush = t.client.LPush("list", "Hello") - c.Assert(lPush.Err(), IsNil) - - lIndex := t.client.LIndex("list", 0) - c.Assert(lIndex.Err(), IsNil) - c.Assert(lIndex.Val(), Equals, "Hello") - - lIndex = t.client.LIndex("list", -1) - c.Assert(lIndex.Err(), IsNil) - c.Assert(lIndex.Val(), Equals, "World") - - lIndex = t.client.LIndex("list", 3) - c.Assert(lIndex.Err(), Equals, redis.Nil) - c.Assert(lIndex.Val(), Equals, "") -} - -func (t *RedisTest) TestCmdListsLInsert(c *C) { - rPush := t.client.RPush("list", "Hello") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "World") - c.Assert(rPush.Err(), IsNil) - - lInsert := t.client.LInsert("list", "BEFORE", "World", "There") - c.Assert(lInsert.Err(), IsNil) - c.Assert(lInsert.Val(), Equals, int64(3)) - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"Hello", "There", "World"}) -} - -func (t *RedisTest) TestCmdListsLLen(c *C) { - lPush := t.client.LPush("list", "World") - c.Assert(lPush.Err(), IsNil) - lPush = t.client.LPush("list", "Hello") - c.Assert(lPush.Err(), IsNil) - - lLen := t.client.LLen("list") - c.Assert(lLen.Err(), IsNil) - c.Assert(lLen.Val(), Equals, int64(2)) -} - -func (t *RedisTest) TestCmdListsLPop(c *C) { - rPush := t.client.RPush("list", "one") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "two") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "three") - c.Assert(rPush.Err(), IsNil) - - lPop := t.client.LPop("list") - c.Assert(lPop.Err(), IsNil) - c.Assert(lPop.Val(), Equals, "one") - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"two", "three"}) -} - -func (t *RedisTest) TestCmdListsLPush(c *C) { - lPush := t.client.LPush("list", "World") - c.Assert(lPush.Err(), IsNil) - lPush = t.client.LPush("list", "Hello") - c.Assert(lPush.Err(), IsNil) - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"Hello", "World"}) -} - -func (t *RedisTest) TestCmdListsLPushX(c *C) { - lPush := t.client.LPush("list", "World") - c.Assert(lPush.Err(), IsNil) - - lPushX := t.client.LPushX("list", "Hello") - c.Assert(lPushX.Err(), IsNil) - c.Assert(lPushX.Val(), Equals, int64(2)) - - lPushX = t.client.LPushX("list2", "Hello") - c.Assert(lPushX.Err(), IsNil) - c.Assert(lPushX.Val(), Equals, int64(0)) - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"Hello", "World"}) - - lRange = t.client.LRange("list2", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{}) -} - -func (t *RedisTest) TestCmdListsLRange(c *C) { - rPush := t.client.RPush("list", "one") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "two") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "three") - c.Assert(rPush.Err(), IsNil) - - lRange := t.client.LRange("list", 0, 0) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"one"}) - - lRange = t.client.LRange("list", -3, 2) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"one", "two", "three"}) - - lRange = t.client.LRange("list", -100, 100) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"one", "two", "three"}) - - lRange = t.client.LRange("list", 5, 10) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{}) -} - -func (t *RedisTest) TestCmdListsLRem(c *C) { - rPush := t.client.RPush("list", "hello") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "hello") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "key") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "hello") - c.Assert(rPush.Err(), IsNil) - - lRem := t.client.LRem("list", -2, "hello") - c.Assert(lRem.Err(), IsNil) - c.Assert(lRem.Val(), Equals, int64(2)) - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"hello", "key"}) -} - -func (t *RedisTest) TestCmdListsLSet(c *C) { - rPush := t.client.RPush("list", "one") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "two") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "three") - c.Assert(rPush.Err(), IsNil) - - lSet := t.client.LSet("list", 0, "four") - c.Assert(lSet.Err(), IsNil) - c.Assert(lSet.Val(), Equals, "OK") - - lSet = t.client.LSet("list", -2, "five") - c.Assert(lSet.Err(), IsNil) - c.Assert(lSet.Val(), Equals, "OK") - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"four", "five", "three"}) -} - -func (t *RedisTest) TestCmdListsLTrim(c *C) { - rPush := t.client.RPush("list", "one") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "two") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "three") - c.Assert(rPush.Err(), IsNil) - - lTrim := t.client.LTrim("list", 1, -1) - c.Assert(lTrim.Err(), IsNil) - c.Assert(lTrim.Val(), Equals, "OK") - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"two", "three"}) -} - -func (t *RedisTest) TestCmdListsRPop(c *C) { - rPush := t.client.RPush("list", "one") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "two") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "three") - c.Assert(rPush.Err(), IsNil) - - rPop := t.client.RPop("list") - c.Assert(rPop.Err(), IsNil) - c.Assert(rPop.Val(), Equals, "three") - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"one", "two"}) -} - -func (t *RedisTest) TestCmdListsRPopLPush(c *C) { - rPush := t.client.RPush("list", "one") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "two") - c.Assert(rPush.Err(), IsNil) - rPush = t.client.RPush("list", "three") - c.Assert(rPush.Err(), IsNil) - - rPopLPush := t.client.RPopLPush("list", "list2") - c.Assert(rPopLPush.Err(), IsNil) - c.Assert(rPopLPush.Val(), Equals, "three") - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"one", "two"}) - - lRange = t.client.LRange("list2", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"three"}) -} - -func (t *RedisTest) TestCmdListsRPush(c *C) { - rPush := t.client.RPush("list", "Hello") - c.Assert(rPush.Err(), IsNil) - c.Assert(rPush.Val(), Equals, int64(1)) - - rPush = t.client.RPush("list", "World") - c.Assert(rPush.Err(), IsNil) - c.Assert(rPush.Val(), Equals, int64(2)) - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"Hello", "World"}) -} - -func (t *RedisTest) TestCmdListsRPushX(c *C) { - rPush := t.client.RPush("list", "Hello") - c.Assert(rPush.Err(), IsNil) - c.Assert(rPush.Val(), Equals, int64(1)) - - rPushX := t.client.RPushX("list", "World") - c.Assert(rPushX.Err(), IsNil) - c.Assert(rPushX.Val(), Equals, int64(2)) - - rPushX = t.client.RPushX("list2", "World") - c.Assert(rPushX.Err(), IsNil) - c.Assert(rPushX.Val(), Equals, int64(0)) - - lRange := t.client.LRange("list", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{"Hello", "World"}) - - lRange = t.client.LRange("list2", 0, -1) - c.Assert(lRange.Err(), IsNil) - c.Assert(lRange.Val(), DeepEquals, []string{}) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestSAdd(c *C) { - sAdd := t.client.SAdd("set", "Hello") - c.Assert(sAdd.Err(), IsNil) - c.Assert(sAdd.Val(), Equals, int64(1)) - - sAdd = t.client.SAdd("set", "World") - c.Assert(sAdd.Err(), IsNil) - c.Assert(sAdd.Val(), Equals, int64(1)) - - sAdd = t.client.SAdd("set", "World") - c.Assert(sAdd.Err(), IsNil) - c.Assert(sAdd.Val(), Equals, int64(0)) - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sortStrings(sMembers.Val()), DeepEquals, []string{"Hello", "World"}) -} - -func (t *RedisTest) TestSCard(c *C) { - sAdd := t.client.SAdd("set", "Hello") - c.Assert(sAdd.Err(), IsNil) - c.Assert(sAdd.Val(), Equals, int64(1)) - - sAdd = t.client.SAdd("set", "World") - c.Assert(sAdd.Err(), IsNil) - c.Assert(sAdd.Val(), Equals, int64(1)) - - sCard := t.client.SCard("set") - c.Assert(sCard.Err(), IsNil) - c.Assert(sCard.Val(), Equals, int64(2)) -} - -func (t *RedisTest) TestSDiff(c *C) { - sAdd := t.client.SAdd("set1", "a") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "b") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "c") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "c") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "d") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "e") - c.Assert(sAdd.Err(), IsNil) - - sDiff := t.client.SDiff("set1", "set2") - c.Assert(sDiff.Err(), IsNil) - c.Assert(sortStrings(sDiff.Val()), DeepEquals, []string{"a", "b"}) -} - -func (t *RedisTest) TestSDiffStore(c *C) { - sAdd := t.client.SAdd("set1", "a") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "b") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "c") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "c") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "d") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "e") - c.Assert(sAdd.Err(), IsNil) - - sDiffStore := t.client.SDiffStore("set", "set1", "set2") - c.Assert(sDiffStore.Err(), IsNil) - c.Assert(sDiffStore.Val(), Equals, int64(2)) - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sortStrings(sMembers.Val()), DeepEquals, []string{"a", "b"}) -} - -func (t *RedisTest) TestSInter(c *C) { - sAdd := t.client.SAdd("set1", "a") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "b") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "c") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "c") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "d") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "e") - c.Assert(sAdd.Err(), IsNil) - - sInter := t.client.SInter("set1", "set2") - c.Assert(sInter.Err(), IsNil) - c.Assert(sInter.Val(), DeepEquals, []string{"c"}) -} - -func (t *RedisTest) TestSInterStore(c *C) { - sAdd := t.client.SAdd("set1", "a") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "b") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "c") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "c") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "d") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "e") - c.Assert(sAdd.Err(), IsNil) - - sInterStore := t.client.SInterStore("set", "set1", "set2") - c.Assert(sInterStore.Err(), IsNil) - c.Assert(sInterStore.Val(), Equals, int64(1)) - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sMembers.Val(), DeepEquals, []string{"c"}) -} - -func (t *RedisTest) TestIsMember(c *C) { - sAdd := t.client.SAdd("set", "one") - c.Assert(sAdd.Err(), IsNil) - - sIsMember := t.client.SIsMember("set", "one") - c.Assert(sIsMember.Err(), IsNil) - c.Assert(sIsMember.Val(), Equals, true) - - sIsMember = t.client.SIsMember("set", "two") - c.Assert(sIsMember.Err(), IsNil) - c.Assert(sIsMember.Val(), Equals, false) -} - -func (t *RedisTest) TestSMembers(c *C) { - sAdd := t.client.SAdd("set", "Hello") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "World") - c.Assert(sAdd.Err(), IsNil) - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sortStrings(sMembers.Val()), DeepEquals, []string{"Hello", "World"}) -} - -func (t *RedisTest) TestSMove(c *C) { - sAdd := t.client.SAdd("set1", "one") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "two") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "three") - c.Assert(sAdd.Err(), IsNil) - - sMove := t.client.SMove("set1", "set2", "two") - c.Assert(sMove.Err(), IsNil) - c.Assert(sMove.Val(), Equals, true) - - sMembers := t.client.SMembers("set1") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sMembers.Val(), DeepEquals, []string{"one"}) - - sMembers = t.client.SMembers("set2") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sortStrings(sMembers.Val()), DeepEquals, []string{"three", "two"}) -} - -func (t *RedisTest) TestSPop(c *C) { - sAdd := t.client.SAdd("set", "one") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "two") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "three") - c.Assert(sAdd.Err(), IsNil) - - sPop := t.client.SPop("set") - c.Assert(sPop.Err(), IsNil) - c.Assert(sPop.Val(), Not(Equals), "") - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sMembers.Val(), HasLen, 2) -} - -func (t *RedisTest) TestSRandMember(c *C) { - sAdd := t.client.SAdd("set", "one") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "two") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "three") - c.Assert(sAdd.Err(), IsNil) - - sRandMember := t.client.SRandMember("set") - c.Assert(sRandMember.Err(), IsNil) - c.Assert(sRandMember.Val(), Not(Equals), "") - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sMembers.Val(), HasLen, 3) -} - -func (t *RedisTest) TestSRem(c *C) { - sAdd := t.client.SAdd("set", "one") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "two") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set", "three") - c.Assert(sAdd.Err(), IsNil) - - sRem := t.client.SRem("set", "one") - c.Assert(sRem.Err(), IsNil) - c.Assert(sRem.Val(), Equals, int64(1)) - - sRem = t.client.SRem("set", "four") - c.Assert(sRem.Err(), IsNil) - c.Assert(sRem.Val(), Equals, int64(0)) - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert( - sortStrings(sMembers.Val()), - DeepEquals, - []string{"three", "two"}, - ) -} - -func (t *RedisTest) TestSUnion(c *C) { - sAdd := t.client.SAdd("set1", "a") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "b") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "c") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "c") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "d") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "e") - c.Assert(sAdd.Err(), IsNil) - - sUnion := t.client.SUnion("set1", "set2") - c.Assert(sUnion.Err(), IsNil) - c.Assert(sUnion.Val(), HasLen, 5) -} - -func (t *RedisTest) TestSUnionStore(c *C) { - sAdd := t.client.SAdd("set1", "a") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "b") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set1", "c") - c.Assert(sAdd.Err(), IsNil) - - sAdd = t.client.SAdd("set2", "c") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "d") - c.Assert(sAdd.Err(), IsNil) - sAdd = t.client.SAdd("set2", "e") - c.Assert(sAdd.Err(), IsNil) - - sUnionStore := t.client.SUnionStore("set", "set1", "set2") - c.Assert(sUnionStore.Err(), IsNil) - c.Assert(sUnionStore.Val(), Equals, int64(5)) - - sMembers := t.client.SMembers("set") - c.Assert(sMembers.Err(), IsNil) - c.Assert(sMembers.Val(), HasLen, 5) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestZAdd(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - c.Assert(zAdd.Val(), Equals, int64(1)) - - zAdd = t.client.ZAdd("zset", redis.Z{1, "uno"}) - c.Assert(zAdd.Err(), IsNil) - c.Assert(zAdd.Val(), Equals, int64(1)) - - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - c.Assert(zAdd.Val(), Equals, int64(1)) - - zAdd = t.client.ZAdd("zset", redis.Z{3, "two"}) - c.Assert(zAdd.Err(), IsNil) - c.Assert(zAdd.Val(), Equals, int64(0)) - - val, err := t.client.ZRangeWithScores("zset", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{1, "one"}, {1, "uno"}, {3, "two"}}) -} - -func (t *RedisTest) TestZCard(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - - zCard := t.client.ZCard("zset") - c.Assert(zCard.Err(), IsNil) - c.Assert(zCard.Val(), Equals, int64(2)) -} - -func (t *RedisTest) TestZCount(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zCount := t.client.ZCount("zset", "-inf", "+inf") - c.Assert(zCount.Err(), IsNil) - c.Assert(zCount.Val(), Equals, int64(3)) - - zCount = t.client.ZCount("zset", "(1", "3") - c.Assert(zCount.Err(), IsNil) - c.Assert(zCount.Val(), Equals, int64(2)) -} - -func (t *RedisTest) TestZIncrBy(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - - zIncrBy := t.client.ZIncrBy("zset", 2, "one") - c.Assert(zIncrBy.Err(), IsNil) - c.Assert(zIncrBy.Val(), Equals, float64(3)) - - val, err := t.client.ZRangeWithScores("zset", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{2, "two"}, {3, "one"}}) -} - -func (t *RedisTest) TestZInterStore(c *C) { - zAdd := t.client.ZAdd("zset1", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset1", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - - zAdd = t.client.ZAdd("zset2", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset2", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset3", redis.Z{3, "two"}) - c.Assert(zAdd.Err(), IsNil) - - zInterStore := t.client.ZInterStore( - "out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2") - c.Assert(zInterStore.Err(), IsNil) - c.Assert(zInterStore.Val(), Equals, int64(2)) - - val, err := t.client.ZRangeWithScores("out", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{5, "one"}, {10, "two"}}) -} - -func (t *RedisTest) TestZRange(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRange := t.client.ZRange("zset", 0, -1) - c.Assert(zRange.Err(), IsNil) - c.Assert(zRange.Val(), DeepEquals, []string{"one", "two", "three"}) - - zRange = t.client.ZRange("zset", 2, 3) - c.Assert(zRange.Err(), IsNil) - c.Assert(zRange.Val(), DeepEquals, []string{"three"}) - - zRange = t.client.ZRange("zset", -2, -1) - c.Assert(zRange.Err(), IsNil) - c.Assert(zRange.Val(), DeepEquals, []string{"two", "three"}) -} - -func (t *RedisTest) TestZRangeWithScores(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - val, err := t.client.ZRangeWithScores("zset", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{1, "one"}, {2, "two"}, {3, "three"}}) - - val, err = t.client.ZRangeWithScores("zset", 2, 3).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{3, "three"}}) - - val, err = t.client.ZRangeWithScores("zset", -2, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{2, "two"}, {3, "three"}}) -} - -func (t *RedisTest) TestZRangeByScore(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRangeByScore := t.client.ZRangeByScore("zset", redis.ZRangeByScore{ - Min: "-inf", - Max: "+inf", - }) - c.Assert(zRangeByScore.Err(), IsNil) - c.Assert(zRangeByScore.Val(), DeepEquals, []string{"one", "two", "three"}) - - zRangeByScore = t.client.ZRangeByScore("zset", redis.ZRangeByScore{ - Min: "1", - Max: "2", - }) - c.Assert(zRangeByScore.Err(), IsNil) - c.Assert(zRangeByScore.Val(), DeepEquals, []string{"one", "two"}) - - zRangeByScore = t.client.ZRangeByScore("zset", redis.ZRangeByScore{ - Min: "(1", - Max: "2", - }) - c.Assert(zRangeByScore.Err(), IsNil) - c.Assert(zRangeByScore.Val(), DeepEquals, []string{"two"}) - - zRangeByScore = t.client.ZRangeByScore("zset", redis.ZRangeByScore{ - Min: "(1", - Max: "(2", - }) - c.Assert(zRangeByScore.Err(), IsNil) - c.Assert(zRangeByScore.Val(), DeepEquals, []string{}) -} - -func (t *RedisTest) TestZRangeByScoreWithScoresMap(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - val, err := t.client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ - Min: "-inf", - Max: "+inf", - }).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{1, "one"}, {2, "two"}, {3, "three"}}) - - val, err = t.client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ - Min: "1", - Max: "2", - }).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{1, "one"}, {2, "two"}}) - - val, err = t.client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ - Min: "(1", - Max: "2", - }).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{2, "two"}}) - - val, err = t.client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ - Min: "(1", - Max: "(2", - }).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{}) -} - -func (t *RedisTest) TestZRank(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRank := t.client.ZRank("zset", "three") - c.Assert(zRank.Err(), IsNil) - c.Assert(zRank.Val(), Equals, int64(2)) - - zRank = t.client.ZRank("zset", "four") - c.Assert(zRank.Err(), Equals, redis.Nil) - c.Assert(zRank.Val(), Equals, int64(0)) -} - -func (t *RedisTest) TestZRem(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRem := t.client.ZRem("zset", "two") - c.Assert(zRem.Err(), IsNil) - c.Assert(zRem.Val(), Equals, int64(1)) - - val, err := t.client.ZRangeWithScores("zset", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{1, "one"}, {3, "three"}}) -} - -func (t *RedisTest) TestZRemRangeByRank(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRemRangeByRank := t.client.ZRemRangeByRank("zset", 0, 1) - c.Assert(zRemRangeByRank.Err(), IsNil) - c.Assert(zRemRangeByRank.Val(), Equals, int64(2)) - - val, err := t.client.ZRangeWithScores("zset", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{3, "three"}}) -} - -func (t *RedisTest) TestZRemRangeByScore(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRemRangeByScore := t.client.ZRemRangeByScore("zset", "-inf", "(2") - c.Assert(zRemRangeByScore.Err(), IsNil) - c.Assert(zRemRangeByScore.Val(), Equals, int64(1)) - - val, err := t.client.ZRangeWithScores("zset", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{2, "two"}, {3, "three"}}) -} - -func (t *RedisTest) TestZRevRange(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRevRange := t.client.ZRevRange("zset", "0", "-1") - c.Assert(zRevRange.Err(), IsNil) - c.Assert(zRevRange.Val(), DeepEquals, []string{"three", "two", "one"}) - - zRevRange = t.client.ZRevRange("zset", "2", "3") - c.Assert(zRevRange.Err(), IsNil) - c.Assert(zRevRange.Val(), DeepEquals, []string{"one"}) - - zRevRange = t.client.ZRevRange("zset", "-2", "-1") - c.Assert(zRevRange.Err(), IsNil) - c.Assert(zRevRange.Val(), DeepEquals, []string{"two", "one"}) -} - -func (t *RedisTest) TestZRevRangeWithScoresMap(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - val, err := t.client.ZRevRangeWithScores("zset", "0", "-1").Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{3, "three"}, {2, "two"}, {1, "one"}}) - - val, err = t.client.ZRevRangeWithScores("zset", "2", "3").Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{1, "one"}}) - - val, err = t.client.ZRevRangeWithScores("zset", "-2", "-1").Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{2, "two"}, {1, "one"}}) -} - -func (t *RedisTest) TestZRevRangeByScore(c *C) { - zadd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zadd.Err(), IsNil) - zadd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zadd.Err(), IsNil) - zadd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zadd.Err(), IsNil) - - vals, err := t.client.ZRevRangeByScore( - "zset", redis.ZRangeByScore{Max: "+inf", Min: "-inf"}).Result() - c.Assert(err, IsNil) - c.Assert(vals, DeepEquals, []string{"three", "two", "one"}) - - vals, err = t.client.ZRevRangeByScore( - "zset", redis.ZRangeByScore{Max: "2", Min: "(1"}).Result() - c.Assert(err, IsNil) - c.Assert(vals, DeepEquals, []string{"two"}) - - vals, err = t.client.ZRevRangeByScore( - "zset", redis.ZRangeByScore{Max: "(2", Min: "(1"}).Result() - c.Assert(err, IsNil) - c.Assert(vals, DeepEquals, []string{}) -} - -func (t *RedisTest) TestZRevRangeByScoreWithScores(c *C) { - zadd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zadd.Err(), IsNil) - zadd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zadd.Err(), IsNil) - zadd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zadd.Err(), IsNil) - - vals, err := t.client.ZRevRangeByScoreWithScores( - "zset", redis.ZRangeByScore{Max: "+inf", Min: "-inf"}).Result() - c.Assert(err, IsNil) - c.Assert(vals, DeepEquals, []redis.Z{{3, "three"}, {2, "two"}, {1, "one"}}) -} - -func (t *RedisTest) TestZRevRangeByScoreWithScoresMap(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - val, err := t.client.ZRevRangeByScoreWithScores( - "zset", redis.ZRangeByScore{Max: "+inf", Min: "-inf"}).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{3, "three"}, {2, "two"}, {1, "one"}}) - - val, err = t.client.ZRevRangeByScoreWithScores( - "zset", redis.ZRangeByScore{Max: "2", Min: "(1"}).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{2, "two"}}) - - val, err = t.client.ZRevRangeByScoreWithScores( - "zset", redis.ZRangeByScore{Max: "(2", Min: "(1"}).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{}) -} - -func (t *RedisTest) TestZRevRank(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zRevRank := t.client.ZRevRank("zset", "one") - c.Assert(zRevRank.Err(), IsNil) - c.Assert(zRevRank.Val(), Equals, int64(2)) - - zRevRank = t.client.ZRevRank("zset", "four") - c.Assert(zRevRank.Err(), Equals, redis.Nil) - c.Assert(zRevRank.Val(), Equals, int64(0)) -} - -func (t *RedisTest) TestZScore(c *C) { - zAdd := t.client.ZAdd("zset", redis.Z{1.001, "one"}) - c.Assert(zAdd.Err(), IsNil) - - zScore := t.client.ZScore("zset", "one") - c.Assert(zScore.Err(), IsNil) - c.Assert(zScore.Val(), Equals, float64(1.001)) -} - -func (t *RedisTest) TestZUnionStore(c *C) { - zAdd := t.client.ZAdd("zset1", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset1", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - - zAdd = t.client.ZAdd("zset2", redis.Z{1, "one"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset2", redis.Z{2, "two"}) - c.Assert(zAdd.Err(), IsNil) - zAdd = t.client.ZAdd("zset2", redis.Z{3, "three"}) - c.Assert(zAdd.Err(), IsNil) - - zUnionStore := t.client.ZUnionStore( - "out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2") - c.Assert(zUnionStore.Err(), IsNil) - c.Assert(zUnionStore.Val(), Equals, int64(3)) - - val, err := t.client.ZRangeWithScores("out", 0, -1).Result() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, []redis.Z{{5, "one"}, {9, "three"}, {10, "two"}}) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestPatternPubSub(c *C) { - pubsub := t.client.PubSub() - defer func() { - c.Assert(pubsub.Close(), IsNil) - }() - - c.Assert(pubsub.PSubscribe("mychannel*"), IsNil) - - pub := t.client.Publish("mychannel1", "hello") - c.Assert(pub.Err(), IsNil) - c.Assert(pub.Val(), Equals, int64(1)) - - c.Assert(pubsub.PUnsubscribe("mychannel*"), IsNil) - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Subscription) - c.Assert(subscr.Kind, Equals, "psubscribe") - c.Assert(subscr.Channel, Equals, "mychannel*") - c.Assert(subscr.Count, Equals, 1) - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.PMessage) - c.Assert(subscr.Channel, Equals, "mychannel1") - c.Assert(subscr.Pattern, Equals, "mychannel*") - c.Assert(subscr.Payload, Equals, "hello") - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Subscription) - c.Assert(subscr.Kind, Equals, "punsubscribe") - c.Assert(subscr.Channel, Equals, "mychannel*") - c.Assert(subscr.Count, Equals, 0) - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err.(net.Error).Timeout(), Equals, true) - c.Assert(msgi, IsNil) - } -} - -func (t *RedisTest) TestPubSub(c *C) { - pubsub := t.client.PubSub() - defer func() { - c.Assert(pubsub.Close(), IsNil) - }() - - c.Assert(pubsub.Subscribe("mychannel", "mychannel2"), IsNil) - - pub := t.client.Publish("mychannel", "hello") - c.Assert(pub.Err(), IsNil) - c.Assert(pub.Val(), Equals, int64(1)) - - pub = t.client.Publish("mychannel2", "hello2") - c.Assert(pub.Err(), IsNil) - c.Assert(pub.Val(), Equals, int64(1)) - - c.Assert(pubsub.Unsubscribe("mychannel", "mychannel2"), IsNil) - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Subscription) - c.Assert(subscr.Kind, Equals, "subscribe") - c.Assert(subscr.Channel, Equals, "mychannel") - c.Assert(subscr.Count, Equals, 1) - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Subscription) - c.Assert(subscr.Kind, Equals, "subscribe") - c.Assert(subscr.Channel, Equals, "mychannel2") - c.Assert(subscr.Count, Equals, 2) - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Message) - c.Assert(subscr.Channel, Equals, "mychannel") - c.Assert(subscr.Payload, Equals, "hello") - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - msg := msgi.(*redis.Message) - c.Assert(msg.Channel, Equals, "mychannel2") - c.Assert(msg.Payload, Equals, "hello2") - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Subscription) - c.Assert(subscr.Kind, Equals, "unsubscribe") - c.Assert(subscr.Channel, Equals, "mychannel") - c.Assert(subscr.Count, Equals, 1) - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err, IsNil) - subscr := msgi.(*redis.Subscription) - c.Assert(subscr.Kind, Equals, "unsubscribe") - c.Assert(subscr.Channel, Equals, "mychannel2") - c.Assert(subscr.Count, Equals, 0) - } - - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - c.Assert(err.(net.Error).Timeout(), Equals, true) - c.Assert(msgi, IsNil) - } -} - -func (t *RedisTest) TestPubSubChannels(c *C) { - channels, err := t.client.PubSubChannels("mychannel*").Result() - c.Assert(err, IsNil) - c.Assert(channels, HasLen, 0) - c.Assert(channels, Not(IsNil)) - - pubsub := t.client.PubSub() - defer pubsub.Close() - - c.Assert(pubsub.Subscribe("mychannel", "mychannel2"), IsNil) - - channels, err = t.client.PubSubChannels("mychannel*").Result() - c.Assert(err, IsNil) - c.Assert(sortStrings(channels), DeepEquals, []string{"mychannel", "mychannel2"}) - - channels, err = t.client.PubSubChannels("").Result() - c.Assert(err, IsNil) - c.Assert(channels, HasLen, 0) - - channels, err = t.client.PubSubChannels("*").Result() - c.Assert(err, IsNil) - c.Assert(len(channels) >= 2, Equals, true) -} - -func (t *RedisTest) TestPubSubNumSub(c *C) { - pubsub := t.client.PubSub() - defer pubsub.Close() - - c.Assert(pubsub.Subscribe("mychannel", "mychannel2"), IsNil) - - channels, err := t.client.PubSubNumSub("mychannel", "mychannel2", "mychannel3").Result() - c.Assert(err, IsNil) - c.Assert( - channels, - DeepEquals, - []interface{}{"mychannel", int64(1), "mychannel2", int64(1), "mychannel3", int64(0)}, - ) -} - -func (t *RedisTest) TestPubSubNumPat(c *C) { - num, err := t.client.PubSubNumPat().Result() - c.Assert(err, IsNil) - c.Assert(num, Equals, int64(0)) - - pubsub := t.client.PubSub() - defer pubsub.Close() - - c.Assert(pubsub.PSubscribe("mychannel*"), IsNil) - - num, err = t.client.PubSubNumPat().Result() - c.Assert(err, IsNil) - c.Assert(num, Equals, int64(1)) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestPipeline(c *C) { - set := t.client.Set("key2", "hello2") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - pipeline := t.client.Pipeline() - defer func() { - c.Assert(pipeline.Close(), IsNil) - }() - - set = pipeline.Set("key1", "hello1") - get := pipeline.Get("key2") - incr := pipeline.Incr("key3") - getNil := pipeline.Get("key4") - - cmds, err := pipeline.Exec() - c.Assert(err, Equals, redis.Nil) - c.Assert(cmds, HasLen, 4) - - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello2") - - c.Assert(incr.Err(), IsNil) - c.Assert(incr.Val(), Equals, int64(1)) - - c.Assert(getNil.Err(), Equals, redis.Nil) - c.Assert(getNil.Val(), Equals, "") -} - -func (t *RedisTest) TestPipelineDiscardQueued(c *C) { - pipeline := t.client.Pipeline() - - pipeline.Get("key") - pipeline.Discard() - cmds, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 0) - - c.Assert(pipeline.Close(), IsNil) -} - -func (t *RedisTest) TestPipelined(c *C) { - var get *redis.StringCmd - cmds, err := t.client.Pipelined(func(pipe *redis.Pipeline) error { - get = pipe.Get("foo") - return nil - }) - c.Assert(err, Equals, redis.Nil) - c.Assert(cmds, HasLen, 1) - c.Assert(cmds[0], Equals, get) - c.Assert(get.Err(), Equals, redis.Nil) - c.Assert(get.Val(), Equals, "") -} - -func (t *RedisTest) TestPipelineErrValNotSet(c *C) { - pipeline := t.client.Pipeline() - defer func() { - c.Assert(pipeline.Close(), IsNil) - }() - - get := pipeline.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "") -} - -func (t *RedisTest) TestPipelineRunQueuedOnEmptyQueue(c *C) { - pipeline := t.client.Pipeline() - defer func() { - c.Assert(pipeline.Close(), IsNil) - }() - - cmds, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 0) -} - -// TODO: make thread safe? -func (t *RedisTest) TestPipelineIncr(c *C) { - const N = 20000 - key := "TestPipelineIncr" - - pipeline := t.client.Pipeline() - - wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { - pipeline.Incr(key) - wg.Done() - } - wg.Wait() - - cmds, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(len(cmds), Equals, 20000) - for _, cmd := range cmds { - if cmd.Err() != nil { - c.Errorf("got %v, expected nil", cmd.Err()) - } - } - - get := t.client.Get(key) - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, strconv.Itoa(N)) - - c.Assert(pipeline.Close(), IsNil) -} - -func (t *RedisTest) TestPipelineEcho(c *C) { - const N = 1000 - - wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { - go func(i int) { - pipeline := t.client.Pipeline() - - msg1 := "echo" + strconv.Itoa(i) - msg2 := "echo" + strconv.Itoa(i+1) - - echo1 := pipeline.Echo(msg1) - echo2 := pipeline.Echo(msg2) - - cmds, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 2) - - c.Assert(echo1.Err(), IsNil) - c.Assert(echo1.Val(), Equals, msg1) - - c.Assert(echo2.Err(), IsNil) - c.Assert(echo2.Val(), Equals, msg2) - - c.Assert(pipeline.Close(), IsNil) - - wg.Done() - }(i) - } - wg.Wait() -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestMultiExec(c *C) { - multi := t.client.Multi() - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - var ( - set *redis.StatusCmd - get *redis.StringCmd - ) - cmds, err := multi.Exec(func() error { - set = multi.Set("key", "hello") - get = multi.Get("key") - return nil - }) - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 2) - - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello") -} - -func (t *RedisTest) TestMultiExecDiscard(c *C) { - multi := t.client.Multi() - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - cmds, err := multi.Exec(func() error { - multi.Set("key1", "hello1") - multi.Discard() - multi.Set("key2", "hello2") - return nil - }) - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 1) - - get := t.client.Get("key1") - c.Assert(get.Err(), Equals, redis.Nil) - c.Assert(get.Val(), Equals, "") - - get = t.client.Get("key2") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "hello2") -} - -func (t *RedisTest) TestMultiExecEmpty(c *C) { - multi := t.client.Multi() - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - cmds, err := multi.Exec(func() error { return nil }) - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 0) - - ping := multi.Ping() - c.Check(ping.Err(), IsNil) - c.Check(ping.Val(), Equals, "PONG") -} - -func (t *RedisTest) TestMultiExecOnEmptyQueue(c *C) { - multi := t.client.Multi() - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - cmds, err := multi.Exec(func() error { return nil }) - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 0) -} - -func (t *RedisTest) TestMultiExecIncr(c *C) { - multi := t.client.Multi() - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - cmds, err := multi.Exec(func() error { - for i := int64(0); i < 20000; i++ { - multi.Incr("key") - } - return nil - }) - c.Assert(err, IsNil) - c.Assert(len(cmds), Equals, 20000) - for _, cmd := range cmds { - if cmd.Err() != nil { - c.Errorf("got %v, expected nil", cmd.Err()) - } - } - - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "20000") -} - -func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) { - multi := t.client.Multi() - defer func() { - c.Assert(multi.Close(), IsNil) - }() - - watch := multi.Watch("key") - c.Assert(watch.Err(), IsNil) - c.Assert(watch.Val(), Equals, "OK") - - get := multi.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Not(Equals), redis.Nil) - - v, err := strconv.ParseInt(get.Val(), 10, 64) - c.Assert(err, IsNil) - - return multi.Exec(func() error { - multi.Set("key", strconv.FormatInt(v+1, 10)) - return nil - }) -} - -func (t *RedisTest) TestWatchUnwatch(c *C) { - var n = 10000 - if testing.Short() { - n = 1000 - } - - set := t.client.Set("key", "0") - c.Assert(set.Err(), IsNil) - - wg := &sync.WaitGroup{} - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for { - cmds, err := t.transactionalIncr(c) - if err == redis.TxFailedErr { - continue - } - c.Assert(err, IsNil) - c.Assert(cmds, HasLen, 1) - c.Assert(cmds[0].Err(), IsNil) - break - } - }() - } - wg.Wait() - - val, err := t.client.Get("key").Int64() - c.Assert(err, IsNil) - c.Assert(val, Equals, int64(n)) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestRaceEcho(c *C) { - var n = 10000 - if testing.Short() { - n = 1000 - } - - wg := &sync.WaitGroup{} - wg.Add(n) - for i := 0; i < n; i++ { - go func(i int) { - msg := "echo" + strconv.Itoa(i) - echo := t.client.Echo(msg) - c.Assert(echo.Err(), IsNil) - c.Assert(echo.Val(), Equals, msg) - wg.Done() - }(i) - } - wg.Wait() -} - -func (t *RedisTest) TestRaceIncr(c *C) { - var n = 10000 - if testing.Short() { - n = 1000 - } - - wg := &sync.WaitGroup{} - wg.Add(n) - for i := 0; i < n; i++ { - go func() { - incr := t.client.Incr("TestRaceIncr") - if err := incr.Err(); err != nil { - panic(err) - } - wg.Done() - }() - } - wg.Wait() - - val, err := t.client.Get("TestRaceIncr").Result() - c.Assert(err, IsNil) - c.Assert(val, Equals, strconv.Itoa(n)) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestCmdBgRewriteAOF(c *C) { - r := t.client.BgRewriteAOF() - c.Assert(r.Err(), IsNil) - c.Assert(r.Val(), Equals, "Background append only file rewriting started") -} - -func (t *RedisTest) TestCmdBgSave(c *C) { - // workaround for "ERR Can't BGSAVE while AOF log rewriting is in progress" - time.Sleep(time.Second) - - r := t.client.BgSave() - c.Assert(r.Err(), IsNil) - c.Assert(r.Val(), Equals, "Background saving started") -} - -func (t *RedisTest) TestCmdClientKill(c *C) { - r := t.client.ClientKill("1.1.1.1:1111") - c.Assert(r.Err(), ErrorMatches, "ERR No such client") - c.Assert(r.Val(), Equals, "") -} - -func (t *RedisTest) TestCmdConfigGet(c *C) { - r := t.client.ConfigGet("*") - c.Assert(r.Err(), IsNil) - c.Assert(len(r.Val()) > 0, Equals, true) -} - -func (t *RedisTest) TestCmdConfigResetStat(c *C) { - r := t.client.ConfigResetStat() - c.Assert(r.Err(), IsNil) - c.Assert(r.Val(), Equals, "OK") -} - -func (t *RedisTest) TestCmdConfigSet(c *C) { - configGet := t.client.ConfigGet("maxmemory") - c.Assert(configGet.Err(), IsNil) - c.Assert(configGet.Val(), HasLen, 2) - c.Assert(configGet.Val()[0], Equals, "maxmemory") - - configSet := t.client.ConfigSet("maxmemory", configGet.Val()[1].(string)) - c.Assert(configSet.Err(), IsNil) - c.Assert(configSet.Val(), Equals, "OK") -} - -func (t *RedisTest) TestCmdDbSize(c *C) { - dbSize := t.client.DbSize() - c.Assert(dbSize.Err(), IsNil) - c.Assert(dbSize.Val(), Equals, int64(0)) -} - -func (t *RedisTest) TestCmdFlushAll(c *C) { - // TODO -} - -func (t *RedisTest) TestCmdFlushDb(c *C) { - // TODO -} - -func (t *RedisTest) TestCmdInfo(c *C) { - info := t.client.Info() - c.Assert(info.Err(), IsNil) - c.Assert(info.Val(), Not(Equals), "") -} - -func (t *RedisTest) TestCmdLastSave(c *C) { - lastSave := t.client.LastSave() - c.Assert(lastSave.Err(), IsNil) - c.Assert(lastSave.Val(), Not(Equals), 0) -} - -func (t *RedisTest) TestCmdSave(c *C) { - save := t.client.Save() - c.Assert(save.Err(), IsNil) - c.Assert(save.Val(), Equals, "OK") -} - -func (t *RedisTest) TestSlaveOf(c *C) { - slaveOf := t.client.SlaveOf("localhost", "8888") - c.Assert(slaveOf.Err(), IsNil) - c.Assert(slaveOf.Val(), Equals, "OK") - - slaveOf = t.client.SlaveOf("NO", "ONE") - c.Assert(slaveOf.Err(), IsNil) - c.Assert(slaveOf.Val(), Equals, "OK") -} - -func (t *RedisTest) TestTime(c *C) { - time := t.client.Time() - c.Assert(time.Err(), IsNil) - c.Assert(time.Val(), HasLen, 2) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestScriptingEval(c *C) { - eval := t.client.Eval( - "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", - []string{"key1", "key2"}, - []string{"first", "second"}, - ) - c.Assert(eval.Err(), IsNil) - c.Assert(eval.Val(), DeepEquals, []interface{}{"key1", "key2", "first", "second"}) - - eval = t.client.Eval( - "return redis.call('set',KEYS[1],'bar')", - []string{"foo"}, - []string{}, - ) - c.Assert(eval.Err(), IsNil) - c.Assert(eval.Val(), Equals, "OK") - - eval = t.client.Eval("return 10", []string{}, []string{}) - c.Assert(eval.Err(), IsNil) - c.Assert(eval.Val(), Equals, int64(10)) - - eval = t.client.Eval("return {1,2,{3,'Hello World!'}}", []string{}, []string{}) - c.Assert(eval.Err(), IsNil) - // DeepEquals can't compare nested slices. - c.Assert( - fmt.Sprintf("%#v", eval.Val()), - Equals, - `[]interface {}{1, 2, []interface {}{3, "Hello World!"}}`, - ) -} - -func (t *RedisTest) TestScriptingEvalSha(c *C) { - set := t.client.Set("foo", "bar") - c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") - - eval := t.client.Eval("return redis.call('get','foo')", nil, nil) - c.Assert(eval.Err(), IsNil) - c.Assert(eval.Val(), Equals, "bar") - - evalSha := t.client.EvalSha("6b1bf486c81ceb7edf3c093f4c48582e38c0e791", nil, nil) - c.Assert(evalSha.Err(), IsNil) - c.Assert(evalSha.Val(), Equals, "bar") - - evalSha = t.client.EvalSha("ffffffffffffffffffffffffffffffffffffffff", nil, nil) - c.Assert(evalSha.Err(), ErrorMatches, "NOSCRIPT No matching script. Please use EVAL.") - c.Assert(evalSha.Val(), Equals, nil) -} - -func (t *RedisTest) TestScriptingScriptExists(c *C) { - scriptLoad := t.client.ScriptLoad("return 1") - c.Assert(scriptLoad.Err(), IsNil) - c.Assert(scriptLoad.Val(), Equals, "e0e1f9fabfc9d4800c877a703b823ac0578ff8db") - - scriptExists := t.client.ScriptExists( - "e0e1f9fabfc9d4800c877a703b823ac0578ff8db", - "ffffffffffffffffffffffffffffffffffffffff", - ) - c.Assert(scriptExists.Err(), IsNil) - c.Assert(scriptExists.Val(), DeepEquals, []bool{true, false}) -} - -func (t *RedisTest) TestScriptingScriptFlush(c *C) { - scriptFlush := t.client.ScriptFlush() - c.Assert(scriptFlush.Err(), IsNil) - c.Assert(scriptFlush.Val(), Equals, "OK") -} - -func (t *RedisTest) TestScriptingScriptKill(c *C) { - scriptKill := t.client.ScriptKill() - c.Assert(scriptKill.Err(), ErrorMatches, ".*No scripts in execution right now.") - c.Assert(scriptKill.Val(), Equals, "") -} - -func (t *RedisTest) TestScriptingScriptLoad(c *C) { - scriptLoad := t.client.ScriptLoad("return redis.call('get','foo')") - c.Assert(scriptLoad.Err(), IsNil) - c.Assert(scriptLoad.Val(), Equals, "6b1bf486c81ceb7edf3c093f4c48582e38c0e791") -} - -func (t *RedisTest) TestScriptingNewScript(c *C) { - s := redis.NewScript("return 1") - run := s.Run(t.client, nil, nil) - c.Assert(run.Err(), IsNil) - c.Assert(run.Val(), Equals, int64(1)) -} - -func (t *RedisTest) TestScriptingEvalAndPipeline(c *C) { - pipeline := t.client.Pipeline() - s := redis.NewScript("return 1") - run := s.Eval(pipeline, nil, nil) - _, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(run.Err(), IsNil) - c.Assert(run.Val(), Equals, int64(1)) -} - -func (t *RedisTest) TestScriptingEvalShaAndPipeline(c *C) { - s := redis.NewScript("return 1") - c.Assert(s.Load(t.client).Err(), IsNil) - - pipeline := t.client.Pipeline() - run := s.Eval(pipeline, nil, nil) - _, err := pipeline.Exec() - c.Assert(err, IsNil) - c.Assert(run.Err(), IsNil) - c.Assert(run.Val(), Equals, int64(1)) -} - -//------------------------------------------------------------------------------ - -func (t *RedisTest) TestCmdDebugObject(c *C) { - { - debug := t.client.DebugObject("foo") - c.Assert(debug.Err(), Not(IsNil)) - c.Assert(debug.Err().Error(), Equals, "ERR no such key") - } - - { - t.client.Set("foo", "bar") - debug := t.client.DebugObject("foo") - c.Assert(debug.Err(), IsNil) - c.Assert(debug.Val(), FitsTypeOf, "") - c.Assert(debug.Val(), Not(Equals), "") - } -} - -//------------------------------------------------------------------------------ - -func BenchmarkRedisPing(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - b.StartTimer() - - for i := 0; i < b.N; i++ { - if err := client.Ping().Err(); err != nil { - panic(err) - } - } -} - -func BenchmarkRedisSet(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - b.StartTimer() - - for i := 0; i < b.N; i++ { - if err := client.Set("key", "hello").Err(); err != nil { - panic(err) - } - } -} - -func BenchmarkRedisGetNil(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - if err := client.FlushDb().Err(); err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - if err := client.Get("key").Err(); err != redis.Nil { - b.Fatal(err) - } - } -} - -func BenchmarkRedisGet(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - if err := client.Set("key", "hello").Err(); err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - if err := client.Get("key").Err(); err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkRedisMGet(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - if err := client.MSet("key1", "hello1", "key2", "hello2").Err(); err != nil { - b.Fatal(err) - } - b.StartTimer() - - for i := 0; i < b.N; i++ { - if err := client.MGet("key1", "key2").Err(); err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkSetExpire(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - b.StartTimer() - - for i := 0; i < b.N; i++ { - if err := client.Set("key", "hello").Err(); err != nil { - b.Fatal(err) - } - if err := client.Expire("key", time.Second).Err(); err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkPipeline(b *testing.B) { - b.StopTimer() - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - b.StartTimer() - - for i := 0; i < b.N; i++ { - _, err := client.Pipelined(func(pipe *redis.Pipeline) error { - pipe.Set("key", "hello") - pipe.Expire("key", time.Second) - return nil - }) - if err != nil { - b.Fatal(err) - } - } -} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/sentinel_test.go b/Godeps/_workspace/src/gopkg.in/redis.v2/sentinel_test.go deleted file mode 100644 index ede59bd..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/sentinel_test.go +++ /dev/null @@ -1,185 +0,0 @@ -package redis_test - -import ( - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "testing" - "text/template" - "time" - - "gopkg.in/redis.v2" -) - -func startRedis(port string) (*exec.Cmd, error) { - cmd := exec.Command("redis-server", "--port", port) - if false { - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - } - if err := cmd.Start(); err != nil { - return nil, err - } - return cmd, nil -} - -func startRedisSlave(port, slave string) (*exec.Cmd, error) { - cmd := exec.Command("redis-server", "--port", port, "--slaveof", "127.0.0.1", slave) - if false { - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - } - if err := cmd.Start(); err != nil { - return nil, err - } - return cmd, nil -} - -func startRedisSentinel(port, masterName, masterPort string) (*exec.Cmd, error) { - dir, err := ioutil.TempDir("", "sentinel") - if err != nil { - return nil, err - } - - sentinelConfFilepath := filepath.Join(dir, "sentinel.conf") - tpl, err := template.New("sentinel.conf").Parse(sentinelConf) - if err != nil { - return nil, err - } - - data := struct { - Port string - MasterName string - MasterPort string - }{ - Port: port, - MasterName: masterName, - MasterPort: masterPort, - } - if err := writeTemplateToFile(sentinelConfFilepath, tpl, data); err != nil { - return nil, err - } - - cmd := exec.Command("redis-server", sentinelConfFilepath, "--sentinel") - if true { - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - } - if err := cmd.Start(); err != nil { - return nil, err - } - - return cmd, nil -} - -func writeTemplateToFile(path string, t *template.Template, data interface{}) error { - f, err := os.Create(path) - if err != nil { - return err - } - defer f.Close() - return t.Execute(f, data) -} - -func TestSentinel(t *testing.T) { - masterName := "mymaster" - masterPort := "8123" - slavePort := "8124" - sentinelPort := "8125" - - masterCmd, err := startRedis(masterPort) - if err != nil { - t.Fatal(err) - } - defer masterCmd.Process.Kill() - - // Wait for master to start. - time.Sleep(200 * time.Millisecond) - - master := redis.NewTCPClient(&redis.Options{ - Addr: ":" + masterPort, - }) - if err := master.Ping().Err(); err != nil { - t.Fatal(err) - } - - slaveCmd, err := startRedisSlave(slavePort, masterPort) - if err != nil { - t.Fatal(err) - } - defer slaveCmd.Process.Kill() - - // Wait for slave to start. - time.Sleep(200 * time.Millisecond) - - slave := redis.NewTCPClient(&redis.Options{ - Addr: ":" + slavePort, - }) - if err := slave.Ping().Err(); err != nil { - t.Fatal(err) - } - - sentinelCmd, err := startRedisSentinel(sentinelPort, masterName, masterPort) - if err != nil { - t.Fatal(err) - } - defer sentinelCmd.Process.Kill() - - // Wait for sentinel to start. - time.Sleep(200 * time.Millisecond) - - sentinel := redis.NewTCPClient(&redis.Options{ - Addr: ":" + sentinelPort, - }) - if err := sentinel.Ping().Err(); err != nil { - t.Fatal(err) - } - defer sentinel.Shutdown() - - client := redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: masterName, - SentinelAddrs: []string{":" + sentinelPort}, - }) - - if err := client.Set("foo", "master").Err(); err != nil { - t.Fatal(err) - } - - val, err := master.Get("foo").Result() - if err != nil { - t.Fatal(err) - } - if val != "master" { - t.Fatalf(`got %q, expected "master"`, val) - } - - // Kill Redis master. - if err := masterCmd.Process.Kill(); err != nil { - t.Fatal(err) - } - if err := master.Ping().Err(); err == nil { - t.Fatalf("master was not killed") - } - - // Wait for Redis sentinel to elect new master. - time.Sleep(5 * time.Second) - - // Check that client picked up new master. - val, err = client.Get("foo").Result() - if err != nil { - t.Fatal(err) - } - if val != "master" { - t.Fatalf(`got %q, expected "master"`, val) - } -} - -var sentinelConf = ` -port {{ .Port }} - -sentinel monitor {{ .MasterName }} 127.0.0.1 {{ .MasterPort }} 1 -sentinel down-after-milliseconds {{ .MasterName }} 1000 -sentinel failover-timeout {{ .MasterName }} 2000 -sentinel parallel-syncs {{ .MasterName }} 1 -` diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/testdata/sentinel.conf b/Godeps/_workspace/src/gopkg.in/redis.v2/testdata/sentinel.conf deleted file mode 100644 index 3da90b3..0000000 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/testdata/sentinel.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26379 - -sentinel monitor master 127.0.0.1 6379 1 -sentinel down-after-milliseconds master 2000 -sentinel failover-timeout master 5000 -sentinel parallel-syncs master 4 diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/.gitignore b/Godeps/_workspace/src/gopkg.in/redis.v3/.gitignore new file mode 100644 index 0000000..5959942 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/.gitignore @@ -0,0 +1,2 @@ +*.rdb +.test/ diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/.travis.yml b/Godeps/_workspace/src/gopkg.in/redis.v3/.travis.yml new file mode 100644 index 0000000..169ccd0 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/.travis.yml @@ -0,0 +1,17 @@ +language: go + +services: +- redis-server + +go: + - 1.3 + - 1.4 + +install: + - go get gopkg.in/bufio.v1 + - go get gopkg.in/bsm/ratelimit.v1 + - go get github.com/onsi/ginkgo + - go get github.com/onsi/gomega + - mkdir -p $HOME/gopath/src/gopkg.in + - mv $HOME/gopath/src/github.com/go-redis/redis $HOME/gopath/src/gopkg.in/redis.v3 + - cd $HOME/gopath/src/gopkg.in/redis.v3 diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/LICENSE b/Godeps/_workspace/src/gopkg.in/redis.v3/LICENSE similarity index 100% rename from Godeps/_workspace/src/gopkg.in/redis.v2/LICENSE rename to Godeps/_workspace/src/gopkg.in/redis.v3/LICENSE diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/Makefile b/Godeps/_workspace/src/gopkg.in/redis.v3/Makefile new file mode 100644 index 0000000..d3763d6 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/Makefile @@ -0,0 +1,17 @@ +all: testdeps + go test ./... -v=1 -cpu=1,2,4 + go test ./... -short -race + +test: testdeps + go test ./... -v=1 + +testdeps: .test/redis/src/redis-server + +.PHONY: all test testdeps + +.test/redis: + mkdir -p $@ + wget -qO- https://github.com/antirez/redis/archive/3.0.tar.gz | tar xvz --strip-components=1 -C $@ + +.test/redis/src/redis-server: .test/redis + cd $< && make all diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/README.md b/Godeps/_workspace/src/gopkg.in/redis.v3/README.md new file mode 100644 index 0000000..d554582 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/README.md @@ -0,0 +1,95 @@ +Redis client for Golang [![Build Status](https://travis-ci.org/go-redis/redis.png?branch=master)](https://travis-ci.org/go-redis/redis) +======================= + +Supports: + +- Redis 3 commands except QUIT, MONITOR, SLOWLOG and SYNC. +- [Pub/Sub](http://godoc.org/gopkg.in/redis.v3#PubSub). +- [Transactions](http://godoc.org/gopkg.in/redis.v3#Multi). +- [Pipelining](http://godoc.org/gopkg.in/redis.v3#Client.Pipeline). +- [Scripting](http://godoc.org/gopkg.in/redis.v3#Script). +- [Timeouts](http://godoc.org/gopkg.in/redis.v3#Options). +- [Redis Sentinel](http://godoc.org/gopkg.in/redis.v3#NewFailoverClient). +- [Redis Cluster](http://godoc.org/gopkg.in/redis.v3#NewClusterClient). +- [Ring](http://godoc.org/gopkg.in/redis.v3#NewRing). + +API docs: http://godoc.org/gopkg.in/redis.v3. +Examples: http://godoc.org/gopkg.in/redis.v3#pkg-examples. + +Installation +------------ + +Install: + + go get gopkg.in/redis.v3 + +Quickstart +---------- + +```go +func ExampleNewClient() { + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + }) + + pong, err := client.Ping().Result() + fmt.Println(pong, err) + // Output: PONG +} + +func ExampleClient() { + err := client.Set("key", "value", 0).Err() + if err != nil { + panic(err) + } + + val, err := client.Get("key").Result() + if err != nil { + panic(err) + } + fmt.Println("key", val) + + val2, err := client.Get("key2").Result() + if err == redis.Nil { + fmt.Println("key2 does not exists") + } else if err != nil { + panic(err) + } else { + fmt.Println("key2", val2) + } + // Output: key value + // key2 does not exists +} +``` + +Howto +----- + +Please go through [examples](http://godoc.org/gopkg.in/redis.v3#pkg-examples) to get an idea how to use this package. + +Look and feel +------------- + +Some corner cases: + + SET key value EX 10 NX + set, err := client.SetNX("key", "value", 10*time.Second).Result() + + SORT list LIMIT 0 2 ASC + vals, err := client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() + + ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 + vals, err := client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ + Min: "-inf", + Max: "+inf", + Offset: 0, + Count: 2, + }).Result() + + ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM + vals, err := client.ZInterStore("out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2").Result() + + EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" + vals, err := client.Eval("return {KEYS[1],ARGV[1]}", []string{"key"}, []string{"hello"}).Result() diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/cluster.go b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster.go new file mode 100644 index 0000000..cbf00b2 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster.go @@ -0,0 +1,343 @@ +package redis + +import ( + "log" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" +) + +type ClusterClient struct { + commandable + + addrs []string + slots [][]string + slotsMx sync.RWMutex // Protects slots and addrs. + + clients map[string]*Client + closed bool + clientsMx sync.RWMutex // Protects clients and closed. + + opt *ClusterOptions + + // Reports where slots reloading is in progress. + reloading uint32 +} + +// NewClusterClient returns a new Redis Cluster client as described in +// http://redis.io/topics/cluster-spec. +func NewClusterClient(opt *ClusterOptions) *ClusterClient { + client := &ClusterClient{ + addrs: opt.Addrs, + slots: make([][]string, hashSlots), + clients: make(map[string]*Client), + opt: opt, + } + client.commandable.process = client.process + client.reloadSlots() + go client.reaper() + return client +} + +// Close closes the cluster client, releasing any open resources. +// +// It is rare to Close a Client, as the Client is meant to be +// long-lived and shared between many goroutines. +func (c *ClusterClient) Close() error { + defer c.clientsMx.Unlock() + c.clientsMx.Lock() + + if c.closed { + return nil + } + c.closed = true + c.resetClients() + c.setSlots(nil) + return nil +} + +// getClient returns a Client for a given address. +func (c *ClusterClient) getClient(addr string) (*Client, error) { + if addr == "" { + return c.randomClient() + } + + c.clientsMx.RLock() + client, ok := c.clients[addr] + if ok { + c.clientsMx.RUnlock() + return client, nil + } + c.clientsMx.RUnlock() + + c.clientsMx.Lock() + if c.closed { + c.clientsMx.Unlock() + return nil, errClosed + } + + client, ok = c.clients[addr] + if !ok { + opt := c.opt.clientOptions() + opt.Addr = addr + client = NewClient(opt) + c.clients[addr] = client + } + c.clientsMx.Unlock() + + return client, nil +} + +func (c *ClusterClient) slotAddrs(slot int) []string { + c.slotsMx.RLock() + addrs := c.slots[slot] + c.slotsMx.RUnlock() + return addrs +} + +func (c *ClusterClient) slotMasterAddr(slot int) string { + addrs := c.slotAddrs(slot) + if len(addrs) > 0 { + return addrs[0] + } + return "" +} + +// randomClient returns a Client for the first live node. +func (c *ClusterClient) randomClient() (client *Client, err error) { + for i := 0; i < 10; i++ { + n := rand.Intn(len(c.addrs)) + client, err = c.getClient(c.addrs[n]) + if err != nil { + continue + } + err = client.ClusterInfo().Err() + if err == nil { + return client, nil + } + } + return nil, err +} + +func (c *ClusterClient) process(cmd Cmder) { + var ask bool + + slot := hashSlot(cmd.clusterKey()) + + addr := c.slotMasterAddr(slot) + client, err := c.getClient(addr) + if err != nil { + cmd.setErr(err) + return + } + + for attempt := 0; attempt <= c.opt.getMaxRedirects(); attempt++ { + if attempt > 0 { + cmd.reset() + } + + if ask { + pipe := client.Pipeline() + pipe.Process(NewCmd("ASKING")) + pipe.Process(cmd) + _, _ = pipe.Exec() + ask = false + } else { + client.Process(cmd) + } + + // If there is no (real) error, we are done! + err := cmd.Err() + if err == nil || err == Nil || err == TxFailedErr { + return + } + + // On network errors try random node. + if isNetworkError(err) { + client, err = c.randomClient() + if err != nil { + return + } + continue + } + + var moved bool + var addr string + moved, ask, addr = isMovedError(err) + if moved || ask { + if moved && c.slotMasterAddr(slot) != addr { + c.lazyReloadSlots() + } + client, err = c.getClient(addr) + if err != nil { + return + } + continue + } + + break + } +} + +// Closes all clients and returns last error if there are any. +func (c *ClusterClient) resetClients() (err error) { + for addr, client := range c.clients { + if e := client.Close(); e != nil { + err = e + } + delete(c.clients, addr) + } + return err +} + +func (c *ClusterClient) setSlots(slots []ClusterSlotInfo) { + c.slotsMx.Lock() + + seen := make(map[string]struct{}) + for _, addr := range c.addrs { + seen[addr] = struct{}{} + } + + for i := 0; i < hashSlots; i++ { + c.slots[i] = c.slots[i][:0] + } + for _, info := range slots { + for slot := info.Start; slot <= info.End; slot++ { + c.slots[slot] = info.Addrs + } + + for _, addr := range info.Addrs { + if _, ok := seen[addr]; !ok { + c.addrs = append(c.addrs, addr) + seen[addr] = struct{}{} + } + } + } + + c.slotsMx.Unlock() +} + +func (c *ClusterClient) reloadSlots() { + defer atomic.StoreUint32(&c.reloading, 0) + + client, err := c.randomClient() + if err != nil { + log.Printf("redis: randomClient failed: %s", err) + return + } + + slots, err := client.ClusterSlots().Result() + if err != nil { + log.Printf("redis: ClusterSlots failed: %s", err) + return + } + c.setSlots(slots) +} + +func (c *ClusterClient) lazyReloadSlots() { + if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { + return + } + go c.reloadSlots() +} + +// reaper closes idle connections to the cluster. +func (c *ClusterClient) reaper() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for _ = range ticker.C { + c.clientsMx.RLock() + + if c.closed { + c.clientsMx.RUnlock() + break + } + + for _, client := range c.clients { + pool := client.connPool + // pool.First removes idle connections from the pool and + // returns first non-idle connection. So just put returned + // connection back. + if cn := pool.First(); cn != nil { + pool.Put(cn) + } + } + + c.clientsMx.RUnlock() + } +} + +//------------------------------------------------------------------------------ + +// ClusterOptions are used to configure a cluster client and should be +// passed to NewClusterClient. +type ClusterOptions struct { + // A seed list of host:port addresses of cluster nodes. + Addrs []string + + // The maximum number of MOVED/ASK redirects to follow before + // giving up. + // Default is 16 + MaxRedirects int + + // Following options are copied from Options struct. + + Password string + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + + PoolSize int + PoolTimeout time.Duration + IdleTimeout time.Duration +} + +func (opt *ClusterOptions) getMaxRedirects() int { + if opt.MaxRedirects == -1 { + return 0 + } + if opt.MaxRedirects == 0 { + return 16 + } + return opt.MaxRedirects +} + +func (opt *ClusterOptions) clientOptions() *Options { + return &Options{ + Password: opt.Password, + + DialTimeout: opt.DialTimeout, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, + } +} + +//------------------------------------------------------------------------------ + +const hashSlots = 16384 + +func hashKey(key string) string { + if s := strings.IndexByte(key, '{'); s > -1 { + if e := strings.IndexByte(key[s+1:], '}'); e > 0 { + return key[s+1 : s+e+1] + } + } + return key +} + +// hashSlot returns a consistent slot number between 0 and 16383 +// for any given string key. +func hashSlot(key string) int { + key = hashKey(key) + if key == "" { + return rand.Intn(hashSlots) + } + return int(crc16sum(key)) % hashSlots +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_client_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_client_test.go new file mode 100644 index 0000000..c7f695d --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_client_test.go @@ -0,0 +1,81 @@ +package redis + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func (c *ClusterClient) SlotAddrs(slot int) []string { + return c.slotAddrs(slot) +} + +// SwapSlot swaps a slot's master/slave address +// for testing MOVED redirects +func (c *ClusterClient) SwapSlot(pos int) []string { + c.slotsMx.Lock() + defer c.slotsMx.Unlock() + c.slots[pos][0], c.slots[pos][1] = c.slots[pos][1], c.slots[pos][0] + return c.slots[pos] +} + +var _ = Describe("ClusterClient", func() { + var subject *ClusterClient + + var populate = func() { + subject.setSlots([]ClusterSlotInfo{ + {0, 4095, []string{"127.0.0.1:7000", "127.0.0.1:7004"}}, + {12288, 16383, []string{"127.0.0.1:7003", "127.0.0.1:7007"}}, + {4096, 8191, []string{"127.0.0.1:7001", "127.0.0.1:7005"}}, + {8192, 12287, []string{"127.0.0.1:7002", "127.0.0.1:7006"}}, + }) + } + + BeforeEach(func() { + subject = NewClusterClient(&ClusterOptions{ + Addrs: []string{"127.0.0.1:6379", "127.0.0.1:7003", "127.0.0.1:7006"}, + }) + }) + + AfterEach(func() { + subject.Close() + }) + + It("should initialize", func() { + Expect(subject.addrs).To(HaveLen(3)) + Expect(subject.slots).To(HaveLen(16384)) + }) + + It("should update slots cache", func() { + populate() + Expect(subject.slots[0]).To(Equal([]string{"127.0.0.1:7000", "127.0.0.1:7004"})) + Expect(subject.slots[4095]).To(Equal([]string{"127.0.0.1:7000", "127.0.0.1:7004"})) + Expect(subject.slots[4096]).To(Equal([]string{"127.0.0.1:7001", "127.0.0.1:7005"})) + Expect(subject.slots[8191]).To(Equal([]string{"127.0.0.1:7001", "127.0.0.1:7005"})) + Expect(subject.slots[8192]).To(Equal([]string{"127.0.0.1:7002", "127.0.0.1:7006"})) + Expect(subject.slots[12287]).To(Equal([]string{"127.0.0.1:7002", "127.0.0.1:7006"})) + Expect(subject.slots[12288]).To(Equal([]string{"127.0.0.1:7003", "127.0.0.1:7007"})) + Expect(subject.slots[16383]).To(Equal([]string{"127.0.0.1:7003", "127.0.0.1:7007"})) + Expect(subject.addrs).To(Equal([]string{ + "127.0.0.1:6379", + "127.0.0.1:7003", + "127.0.0.1:7006", + "127.0.0.1:7000", + "127.0.0.1:7004", + "127.0.0.1:7007", + "127.0.0.1:7001", + "127.0.0.1:7005", + "127.0.0.1:7002", + })) + }) + + It("should close", func() { + populate() + Expect(subject.Close()).NotTo(HaveOccurred()) + Expect(subject.clients).To(BeEmpty()) + Expect(subject.slots[0]).To(BeEmpty()) + Expect(subject.slots[8191]).To(BeEmpty()) + Expect(subject.slots[8192]).To(BeEmpty()) + Expect(subject.slots[16383]).To(BeEmpty()) + Expect(subject.Ping().Err().Error()).To(Equal("redis: client is closed")) + }) +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_pipeline.go b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_pipeline.go new file mode 100644 index 0000000..2e11940 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_pipeline.go @@ -0,0 +1,123 @@ +package redis + +// ClusterPipeline is not thread-safe. +type ClusterPipeline struct { + commandable + + cmds []Cmder + cluster *ClusterClient + closed bool +} + +// Pipeline creates a new pipeline which is able to execute commands +// against multiple shards. +func (c *ClusterClient) Pipeline() *ClusterPipeline { + pipe := &ClusterPipeline{ + cluster: c, + cmds: make([]Cmder, 0, 10), + } + pipe.commandable.process = pipe.process + return pipe +} + +func (pipe *ClusterPipeline) process(cmd Cmder) { + pipe.cmds = append(pipe.cmds, cmd) +} + +// Discard resets the pipeline and discards queued commands. +func (pipe *ClusterPipeline) Discard() error { + if pipe.closed { + return errClosed + } + pipe.cmds = pipe.cmds[:0] + return nil +} + +func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { + if pipe.closed { + return nil, errClosed + } + if len(pipe.cmds) == 0 { + return []Cmder{}, nil + } + + cmds = pipe.cmds + pipe.cmds = make([]Cmder, 0, 10) + + cmdsMap := make(map[string][]Cmder) + for _, cmd := range cmds { + slot := hashSlot(cmd.clusterKey()) + addr := pipe.cluster.slotMasterAddr(slot) + cmdsMap[addr] = append(cmdsMap[addr], cmd) + } + + for attempt := 0; attempt <= pipe.cluster.opt.getMaxRedirects(); attempt++ { + failedCmds := make(map[string][]Cmder) + + for addr, cmds := range cmdsMap { + client, err := pipe.cluster.getClient(addr) + if err != nil { + setCmdsErr(cmds, err) + retErr = err + continue + } + + cn, err := client.conn() + if err != nil { + setCmdsErr(cmds, err) + retErr = err + continue + } + + failedCmds, err = pipe.execClusterCmds(cn, cmds, failedCmds) + if err != nil { + retErr = err + } + client.putConn(cn, err) + } + + cmdsMap = failedCmds + } + + return cmds, retErr +} + +// Close marks the pipeline as closed +func (pipe *ClusterPipeline) Close() error { + pipe.Discard() + pipe.closed = true + return nil +} + +func (pipe *ClusterPipeline) execClusterCmds( + cn *conn, cmds []Cmder, failedCmds map[string][]Cmder, +) (map[string][]Cmder, error) { + if err := cn.writeCmds(cmds...); err != nil { + setCmdsErr(cmds, err) + return failedCmds, err + } + + var firstCmdErr error + for i, cmd := range cmds { + err := cmd.parseReply(cn.rd) + if err == nil { + continue + } + if isNetworkError(err) { + cmd.reset() + failedCmds[""] = append(failedCmds[""], cmds[i:]...) + break + } else if moved, ask, addr := isMovedError(err); moved { + pipe.cluster.lazyReloadSlots() + cmd.reset() + failedCmds[addr] = append(failedCmds[addr], cmd) + } else if ask { + cmd.reset() + failedCmds[addr] = append(failedCmds[addr], NewCmd("ASKING"), cmd) + } else if firstCmdErr == nil { + firstCmdErr = err + } + } + + return failedCmds, firstCmdErr +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_test.go new file mode 100644 index 0000000..136340c --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/cluster_test.go @@ -0,0 +1,334 @@ +package redis_test + +import ( + "math/rand" + "net" + + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +type clusterScenario struct { + ports []string + nodeIds []string + processes map[string]*redisProcess + clients map[string]*redis.Client +} + +func (s *clusterScenario) primary() *redis.Client { + return s.clients[s.ports[0]] +} + +func (s *clusterScenario) masters() []*redis.Client { + result := make([]*redis.Client, 3) + for pos, port := range s.ports[:3] { + result[pos] = s.clients[port] + } + return result +} + +func (s *clusterScenario) slaves() []*redis.Client { + result := make([]*redis.Client, 3) + for pos, port := range s.ports[3:] { + result[pos] = s.clients[port] + } + return result +} + +func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { + addrs := make([]string, len(s.ports)) + for i, port := range s.ports { + addrs[i] = net.JoinHostPort("127.0.0.1", port) + } + if opt == nil { + opt = &redis.ClusterOptions{} + } + opt.Addrs = addrs + return redis.NewClusterClient(opt) +} + +func startCluster(scenario *clusterScenario) error { + // Start processes, connect individual clients + for pos, port := range scenario.ports { + process, err := startRedis(port, "--cluster-enabled", "yes") + if err != nil { + return err + } + + client := redis.NewClient(&redis.Options{Addr: "127.0.0.1:" + port}) + info, err := client.ClusterNodes().Result() + if err != nil { + return err + } + + scenario.processes[port] = process + scenario.clients[port] = client + scenario.nodeIds[pos] = info[:40] + } + + // Meet cluster nodes + for _, client := range scenario.clients { + err := client.ClusterMeet("127.0.0.1", scenario.ports[0]).Err() + if err != nil { + return err + } + } + + // Bootstrap masters + slots := []int{0, 5000, 10000, 16384} + for pos, client := range scenario.masters() { + err := client.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err() + if err != nil { + return err + } + } + + // Bootstrap slaves + for pos, client := range scenario.slaves() { + masterId := scenario.nodeIds[pos] + + // Wait for masters + err := waitForSubstring(func() string { + return client.ClusterNodes().Val() + }, masterId, 10*time.Second) + if err != nil { + return err + } + + err = client.ClusterReplicate(masterId).Err() + if err != nil { + return err + } + + // Wait for slaves + err = waitForSubstring(func() string { + return scenario.primary().ClusterNodes().Val() + }, "slave "+masterId, 10*time.Second) + if err != nil { + return err + } + } + + // Wait for cluster state to turn OK + for _, client := range scenario.clients { + err := waitForSubstring(func() string { + return client.ClusterInfo().Val() + }, "cluster_state:ok", 10*time.Second) + if err != nil { + return err + } + } + + return nil +} + +func stopCluster(scenario *clusterScenario) error { + for _, client := range scenario.clients { + if err := client.Close(); err != nil { + return err + } + } + for _, process := range scenario.processes { + if err := process.Close(); err != nil { + return err + } + } + return nil +} + +//------------------------------------------------------------------------------ + +var _ = Describe("Cluster", func() { + Describe("HashSlot", func() { + + It("should calculate hash slots", func() { + tests := []struct { + key string + slot int + }{ + {"123456789", 12739}, + {"{}foo", 9500}, + {"foo{}", 5542}, + {"foo{}{bar}", 8363}, + {"", 10503}, + {"", 5176}, + {string([]byte{83, 153, 134, 118, 229, 214, 244, 75, 140, 37, 215, 215}), 5463}, + } + rand.Seed(100) + + for _, test := range tests { + Expect(redis.HashSlot(test.key)).To(Equal(test.slot), "for %s", test.key) + } + }) + + It("should extract keys from tags", func() { + tests := []struct { + one, two string + }{ + {"foo{bar}", "bar"}, + {"{foo}bar", "foo"}, + {"{user1000}.following", "{user1000}.followers"}, + {"foo{{bar}}zap", "{bar"}, + {"foo{bar}{zap}", "bar"}, + } + + for _, test := range tests { + Expect(redis.HashSlot(test.one)).To(Equal(redis.HashSlot(test.two)), "for %s <-> %s", test.one, test.two) + } + }) + + }) + + Describe("Commands", func() { + + It("should CLUSTER SLOTS", func() { + res, err := cluster.primary().ClusterSlots().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res).To(HaveLen(3)) + Expect(res).To(ConsistOf([]redis.ClusterSlotInfo{ + {0, 4999, []string{"127.0.0.1:8220", "127.0.0.1:8223"}}, + {5000, 9999, []string{"127.0.0.1:8221", "127.0.0.1:8224"}}, + {10000, 16383, []string{"127.0.0.1:8222", "127.0.0.1:8225"}}, + })) + }) + + It("should CLUSTER NODES", func() { + res, err := cluster.primary().ClusterNodes().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(res)).To(BeNumerically(">", 400)) + }) + + It("should CLUSTER INFO", func() { + res, err := cluster.primary().ClusterInfo().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res).To(ContainSubstring("cluster_known_nodes:6")) + }) + + }) + + Describe("Client", func() { + var client *redis.ClusterClient + + BeforeEach(func() { + client = cluster.clusterClient(nil) + }) + + AfterEach(func() { + for _, client := range cluster.masters() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + } + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should GET/SET/DEL", func() { + val, err := client.Get("A").Result() + Expect(err).To(Equal(redis.Nil)) + Expect(val).To(Equal("")) + + val, err = client.Set("A", "VALUE", 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("OK")) + + val, err = client.Get("A").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("VALUE")) + + cnt, err := client.Del("A").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(cnt).To(Equal(int64(1))) + }) + + It("should follow redirects", func() { + Expect(client.Set("A", "VALUE", 0).Err()).NotTo(HaveOccurred()) + + slot := redis.HashSlot("A") + Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + + val, err := client.Get("A").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("VALUE")) + + Eventually(func() []string { + return client.SlotAddrs(slot) + }, "5s").Should(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"})) + }) + + It("should perform multi-pipelines", func() { + slot := redis.HashSlot("A") + Expect(client.SlotAddrs(slot)).To(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"})) + Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + + pipe := client.Pipeline() + defer pipe.Close() + + keys := []string{"A", "B", "C", "D", "E", "F", "G"} + for i, key := range keys { + pipe.Set(key, key+"_value", 0) + pipe.Expire(key, time.Duration(i+1)*time.Hour) + } + for _, key := range keys { + pipe.Get(key) + pipe.TTL(key) + } + + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(28)) + Expect(cmds[14].(*redis.StringCmd).Val()).To(Equal("A_value")) + Expect(cmds[15].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second)) + Expect(cmds[20].(*redis.StringCmd).Val()).To(Equal("D_value")) + Expect(cmds[21].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second)) + Expect(cmds[26].(*redis.StringCmd).Val()).To(Equal("G_value")) + Expect(cmds[27].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second)) + }) + + It("should return error when there are no attempts left", func() { + client = cluster.clusterClient(&redis.ClusterOptions{ + MaxRedirects: -1, + }) + + slot := redis.HashSlot("A") + Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + + err := client.Get("A").Err() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("MOVED")) + }) + }) +}) + +//------------------------------------------------------------------------------ + +func BenchmarkRedisClusterPing(b *testing.B) { + if testing.Short() { + b.Skip("skipping in short mode") + } + + cluster := &clusterScenario{ + ports: []string{"8220", "8221", "8222", "8223", "8224", "8225"}, + nodeIds: make([]string, 6), + processes: make(map[string]*redisProcess, 6), + clients: make(map[string]*redis.Client, 6), + } + if err := startCluster(cluster); err != nil { + b.Fatal(err) + } + defer stopCluster(cluster) + client := cluster.clusterClient(nil) + defer client.Close() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Ping().Err(); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/command.go b/Godeps/_workspace/src/gopkg.in/redis.v3/command.go similarity index 59% rename from Godeps/_workspace/src/gopkg.in/redis.v2/command.go rename to Godeps/_workspace/src/gopkg.in/redis.v3/command.go index d7c76cf..dab9fc3 100644 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/command.go +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/command.go @@ -1,6 +1,7 @@ package redis import ( + "bytes" "fmt" "strconv" "strings" @@ -21,20 +22,24 @@ var ( _ Cmder = (*StringSliceCmd)(nil) _ Cmder = (*BoolSliceCmd)(nil) _ Cmder = (*StringStringMapCmd)(nil) + _ Cmder = (*StringIntMapCmd)(nil) _ Cmder = (*ZSliceCmd)(nil) _ Cmder = (*ScanCmd)(nil) + _ Cmder = (*ClusterSlotCmd)(nil) ) type Cmder interface { - args() []string + args() []interface{} parseReply(*bufio.Reader) error setErr(error) + reset() writeTimeout() *time.Duration readTimeout() *time.Duration + clusterKey() string Err() error - String() string + fmt.Stringer } func setCmdsErr(cmds []Cmder, e error) { @@ -43,13 +48,28 @@ func setCmdsErr(cmds []Cmder, e error) { } } +func resetCmds(cmds []Cmder) { + for _, cmd := range cmds { + cmd.reset() + } +} + func cmdString(cmd Cmder, val interface{}) string { - s := strings.Join(cmd.args(), " ") + var ss []string + for _, arg := range cmd.args() { + ss = append(ss, fmt.Sprint(arg)) + } + s := strings.Join(ss, " ") if err := cmd.Err(); err != nil { return s + ": " + err.Error() } if val != nil { - return s + ": " + fmt.Sprint(val) + switch vv := val.(type) { + case []byte: + return s + ": " + string(vv) + default: + return s + ": " + fmt.Sprint(val) + } } return s @@ -58,17 +78,13 @@ func cmdString(cmd Cmder, val interface{}) string { //------------------------------------------------------------------------------ type baseCmd struct { - _args []string + _args []interface{} err error - _writeTimeout, _readTimeout *time.Duration -} + _clusterKeyPos int -func newBaseCmd(args ...string) *baseCmd { - return &baseCmd{ - _args: args, - } + _writeTimeout, _readTimeout *time.Duration } func (cmd *baseCmd) Err() error { @@ -78,7 +94,7 @@ func (cmd *baseCmd) Err() error { return nil } -func (cmd *baseCmd) args() []string { +func (cmd *baseCmd) args() []interface{} { return cmd._args } @@ -94,6 +110,13 @@ func (cmd *baseCmd) writeTimeout() *time.Duration { return cmd._writeTimeout } +func (cmd *baseCmd) clusterKey() string { + if cmd._clusterKeyPos > 0 && cmd._clusterKeyPos < len(cmd._args) { + return fmt.Sprint(cmd._args[cmd._clusterKeyPos]) + } + return "" +} + func (cmd *baseCmd) setWriteTimeout(d time.Duration) { cmd._writeTimeout = &d } @@ -105,15 +128,18 @@ func (cmd *baseCmd) setErr(e error) { //------------------------------------------------------------------------------ type Cmd struct { - *baseCmd + baseCmd val interface{} } -func NewCmd(args ...string) *Cmd { - return &Cmd{ - baseCmd: newBaseCmd(args...), - } +func NewCmd(args ...interface{}) *Cmd { + return &Cmd{baseCmd: baseCmd{_args: args}} +} + +func (cmd *Cmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *Cmd) Val() interface{} { @@ -130,21 +156,29 @@ func (cmd *Cmd) String() string { func (cmd *Cmd) parseReply(rd *bufio.Reader) error { cmd.val, cmd.err = parseReply(rd, parseSlice) + // Convert to string to preserve old behaviour. + // TODO: remove in v4 + if v, ok := cmd.val.([]byte); ok { + cmd.val = string(v) + } return cmd.err } //------------------------------------------------------------------------------ type SliceCmd struct { - *baseCmd + baseCmd val []interface{} } -func NewSliceCmd(args ...string) *SliceCmd { - return &SliceCmd{ - baseCmd: newBaseCmd(args...), - } +func NewSliceCmd(args ...interface{}) *SliceCmd { + return &SliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *SliceCmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *SliceCmd) Val() []interface{} { @@ -172,15 +206,22 @@ func (cmd *SliceCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ type StatusCmd struct { - *baseCmd + baseCmd val string } -func NewStatusCmd(args ...string) *StatusCmd { - return &StatusCmd{ - baseCmd: newBaseCmd(args...), - } +func NewStatusCmd(args ...interface{}) *StatusCmd { + return &StatusCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func newKeylessStatusCmd(args ...interface{}) *StatusCmd { + return &StatusCmd{baseCmd: baseCmd{_args: args}} +} + +func (cmd *StatusCmd) reset() { + cmd.val = "" + cmd.err = nil } func (cmd *StatusCmd) Val() string { @@ -201,22 +242,25 @@ func (cmd *StatusCmd) parseReply(rd *bufio.Reader) error { cmd.err = err return err } - cmd.val = v.(string) + cmd.val = string(v.([]byte)) return nil } //------------------------------------------------------------------------------ type IntCmd struct { - *baseCmd + baseCmd val int64 } -func NewIntCmd(args ...string) *IntCmd { - return &IntCmd{ - baseCmd: newBaseCmd(args...), - } +func NewIntCmd(args ...interface{}) *IntCmd { + return &IntCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *IntCmd) reset() { + cmd.val = 0 + cmd.err = nil } func (cmd *IntCmd) Val() int64 { @@ -244,19 +288,24 @@ func (cmd *IntCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ type DurationCmd struct { - *baseCmd + baseCmd val time.Duration precision time.Duration } -func NewDurationCmd(precision time.Duration, args ...string) *DurationCmd { +func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd { return &DurationCmd{ - baseCmd: newBaseCmd(args...), precision: precision, + baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}, } } +func (cmd *DurationCmd) reset() { + cmd.val = 0 + cmd.err = nil +} + func (cmd *DurationCmd) Val() time.Duration { return cmd.val } @@ -282,15 +331,18 @@ func (cmd *DurationCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ type BoolCmd struct { - *baseCmd + baseCmd val bool } -func NewBoolCmd(args ...string) *BoolCmd { - return &BoolCmd{ - baseCmd: newBaseCmd(args...), - } +func NewBoolCmd(args ...interface{}) *BoolCmd { + return &BoolCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *BoolCmd) reset() { + cmd.val = false + cmd.err = nil } func (cmd *BoolCmd) Val() bool { @@ -305,35 +357,57 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } +var ok = []byte("OK") + func (cmd *BoolCmd) parseReply(rd *bufio.Reader) error { v, err := parseReply(rd, nil) + // `SET key value NX` returns nil when key already exists. + if err == Nil { + cmd.val = false + return nil + } if err != nil { cmd.err = err return err } - cmd.val = v.(int64) == 1 - return nil + switch vv := v.(type) { + case int64: + cmd.val = vv == 1 + return nil + case []byte: + cmd.val = bytes.Equal(vv, ok) + return nil + default: + return fmt.Errorf("got %T, wanted int64 or string") + } } //------------------------------------------------------------------------------ type StringCmd struct { - *baseCmd + baseCmd - val string + val []byte } -func NewStringCmd(args ...string) *StringCmd { - return &StringCmd{ - baseCmd: newBaseCmd(args...), - } +func NewStringCmd(args ...interface{}) *StringCmd { + return &StringCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *StringCmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *StringCmd) Val() string { - return cmd.val + return bytesToString(cmd.val) } func (cmd *StringCmd) Result() (string, error) { + return cmd.Val(), cmd.err +} + +func (cmd *StringCmd) Bytes() ([]byte, error) { return cmd.val, cmd.err } @@ -341,21 +415,28 @@ func (cmd *StringCmd) Int64() (int64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseInt(cmd.val, 10, 64) + return strconv.ParseInt(cmd.Val(), 10, 64) } func (cmd *StringCmd) Uint64() (uint64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseUint(cmd.val, 10, 64) + return strconv.ParseUint(cmd.Val(), 10, 64) } func (cmd *StringCmd) Float64() (float64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseFloat(cmd.val, 64) + return strconv.ParseFloat(cmd.Val(), 64) +} + +func (cmd *StringCmd) Scan(val interface{}) error { + if cmd.err != nil { + return cmd.err + } + return scan(cmd.val, val) } func (cmd *StringCmd) String() string { @@ -368,22 +449,27 @@ func (cmd *StringCmd) parseReply(rd *bufio.Reader) error { cmd.err = err return err } - cmd.val = v.(string) + b := v.([]byte) + cmd.val = make([]byte, len(b)) + copy(cmd.val, b) return nil } //------------------------------------------------------------------------------ type FloatCmd struct { - *baseCmd + baseCmd val float64 } -func NewFloatCmd(args ...string) *FloatCmd { - return &FloatCmd{ - baseCmd: newBaseCmd(args...), - } +func NewFloatCmd(args ...interface{}) *FloatCmd { + return &FloatCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *FloatCmd) reset() { + cmd.val = 0 + cmd.err = nil } func (cmd *FloatCmd) Val() float64 { @@ -400,22 +486,26 @@ func (cmd *FloatCmd) parseReply(rd *bufio.Reader) error { cmd.err = err return err } - cmd.val, cmd.err = strconv.ParseFloat(v.(string), 64) + b := v.([]byte) + cmd.val, cmd.err = strconv.ParseFloat(bytesToString(b), 64) return cmd.err } //------------------------------------------------------------------------------ type StringSliceCmd struct { - *baseCmd + baseCmd val []string } -func NewStringSliceCmd(args ...string) *StringSliceCmd { - return &StringSliceCmd{ - baseCmd: newBaseCmd(args...), - } +func NewStringSliceCmd(args ...interface{}) *StringSliceCmd { + return &StringSliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *StringSliceCmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *StringSliceCmd) Val() []string { @@ -443,15 +533,18 @@ func (cmd *StringSliceCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ type BoolSliceCmd struct { - *baseCmd + baseCmd val []bool } -func NewBoolSliceCmd(args ...string) *BoolSliceCmd { - return &BoolSliceCmd{ - baseCmd: newBaseCmd(args...), - } +func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd { + return &BoolSliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *BoolSliceCmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *BoolSliceCmd) Val() []bool { @@ -479,15 +572,18 @@ func (cmd *BoolSliceCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ type StringStringMapCmd struct { - *baseCmd + baseCmd val map[string]string } -func NewStringStringMapCmd(args ...string) *StringStringMapCmd { - return &StringStringMapCmd{ - baseCmd: newBaseCmd(args...), - } +func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd { + return &StringStringMapCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *StringStringMapCmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *StringStringMapCmd) Val() map[string]string { @@ -514,16 +610,58 @@ func (cmd *StringStringMapCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ +type StringIntMapCmd struct { + baseCmd + + val map[string]int64 +} + +func NewStringIntMapCmd(args ...interface{}) *StringIntMapCmd { + return &StringIntMapCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *StringIntMapCmd) Val() map[string]int64 { + return cmd.val +} + +func (cmd *StringIntMapCmd) Result() (map[string]int64, error) { + return cmd.val, cmd.err +} + +func (cmd *StringIntMapCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringIntMapCmd) reset() { + cmd.val = nil + cmd.err = nil +} + +func (cmd *StringIntMapCmd) parseReply(rd *bufio.Reader) error { + v, err := parseReply(rd, parseStringIntMap) + if err != nil { + cmd.err = err + return err + } + cmd.val = v.(map[string]int64) + return nil +} + +//------------------------------------------------------------------------------ + type ZSliceCmd struct { - *baseCmd + baseCmd val []Z } -func NewZSliceCmd(args ...string) *ZSliceCmd { - return &ZSliceCmd{ - baseCmd: newBaseCmd(args...), - } +func NewZSliceCmd(args ...interface{}) *ZSliceCmd { + return &ZSliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *ZSliceCmd) reset() { + cmd.val = nil + cmd.err = nil } func (cmd *ZSliceCmd) Val() []Z { @@ -551,16 +689,20 @@ func (cmd *ZSliceCmd) parseReply(rd *bufio.Reader) error { //------------------------------------------------------------------------------ type ScanCmd struct { - *baseCmd + baseCmd cursor int64 keys []string } -func NewScanCmd(args ...string) *ScanCmd { - return &ScanCmd{ - baseCmd: newBaseCmd(args...), - } +func NewScanCmd(args ...interface{}) *ScanCmd { + return &ScanCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *ScanCmd) reset() { + cmd.cursor = 0 + cmd.keys = nil + cmd.err = nil } func (cmd *ScanCmd) Val() (int64, []string) { @@ -595,3 +737,47 @@ func (cmd *ScanCmd) parseReply(rd *bufio.Reader) error { return nil } + +//------------------------------------------------------------------------------ + +type ClusterSlotInfo struct { + Start, End int + Addrs []string +} + +type ClusterSlotCmd struct { + baseCmd + + val []ClusterSlotInfo +} + +func NewClusterSlotCmd(args ...interface{}) *ClusterSlotCmd { + return &ClusterSlotCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} +} + +func (cmd *ClusterSlotCmd) Val() []ClusterSlotInfo { + return cmd.val +} + +func (cmd *ClusterSlotCmd) Result() ([]ClusterSlotInfo, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *ClusterSlotCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *ClusterSlotCmd) reset() { + cmd.val = nil + cmd.err = nil +} + +func (cmd *ClusterSlotCmd) parseReply(rd *bufio.Reader) error { + v, err := parseReply(rd, parseClusterSlotInfoSlice) + if err != nil { + cmd.err = err + return err + } + cmd.val = v.([]ClusterSlotInfo) + return nil +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/command_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/command_test.go new file mode 100644 index 0000000..1218724 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/command_test.go @@ -0,0 +1,178 @@ +package redis_test + +import ( + "bytes" + "strconv" + "sync" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Command", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + }) + + AfterEach(func() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should implement Stringer", func() { + set := client.Set("foo", "bar", 0) + Expect(set.String()).To(Equal("SET foo bar: OK")) + + get := client.Get("foo") + Expect(get.String()).To(Equal("GET foo: bar")) + }) + + It("should have correct val/err states", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + }) + + It("should escape special chars", func() { + set := client.Set("key", "hello1\r\nhello2\r\n", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n")) + }) + + It("should handle big vals", func() { + val := string(bytes.Repeat([]byte{'*'}, 1<<16)) + set := client.Set("key", val, 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal(val)) + }) + + It("should handle many keys #1", func() { + const n = 100000 + for i := 0; i < n; i++ { + client.Set("keys.key"+strconv.Itoa(i), "hello"+strconv.Itoa(i), 0) + } + keys := client.Keys("keys.*") + Expect(keys.Err()).NotTo(HaveOccurred()) + Expect(len(keys.Val())).To(Equal(n)) + }) + + It("should handle many keys #2", func() { + const n = 100000 + + keys := []string{"non-existent-key"} + for i := 0; i < n; i++ { + key := "keys.key" + strconv.Itoa(i) + client.Set(key, "hello"+strconv.Itoa(i), 0) + keys = append(keys, key) + } + keys = append(keys, "non-existent-key") + + mget := client.MGet(keys...) + Expect(mget.Err()).NotTo(HaveOccurred()) + Expect(len(mget.Val())).To(Equal(n + 2)) + vals := mget.Val() + for i := 0; i < n; i++ { + Expect(vals[i+1]).To(Equal("hello" + strconv.Itoa(i))) + } + Expect(vals[0]).To(BeNil()) + Expect(vals[n+1]).To(BeNil()) + }) + + It("should convert strings via helpers", func() { + set := client.Set("key", "10", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + + n, err := client.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(10))) + + un, err := client.Get("key").Uint64() + Expect(err).NotTo(HaveOccurred()) + Expect(un).To(Equal(uint64(10))) + + f, err := client.Get("key").Float64() + Expect(err).NotTo(HaveOccurred()) + Expect(f).To(Equal(float64(10))) + }) + + It("Cmd should return string", func() { + cmd := redis.NewCmd("PING") + client.Process(cmd) + Expect(cmd.Err()).NotTo(HaveOccurred()) + Expect(cmd.Val()).To(Equal("PONG")) + }) + + Describe("races", func() { + var C, N = 10, 1000 + if testing.Short() { + N = 100 + } + + It("should echo", func() { + wg := &sync.WaitGroup{} + for i := 0; i < C; i++ { + wg.Add(1) + + go func(i int) { + defer GinkgoRecover() + defer wg.Done() + + for j := 0; j < N; j++ { + msg := "echo" + strconv.Itoa(i) + echo := client.Echo(msg) + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal(msg)) + } + }(i) + } + wg.Wait() + }) + + It("should incr", func() { + key := "TestIncrFromGoroutines" + wg := &sync.WaitGroup{} + for i := 0; i < C; i++ { + wg.Add(1) + + go func() { + defer GinkgoRecover() + defer wg.Done() + + for j := 0; j < N; j++ { + err := client.Incr(key).Err() + Expect(err).NotTo(HaveOccurred()) + } + }() + } + wg.Wait() + + val, err := client.Get(key).Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int64(C * N))) + }) + + }) + +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/commands.go b/Godeps/_workspace/src/gopkg.in/redis.v3/commands.go new file mode 100644 index 0000000..7f10a2f --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/commands.go @@ -0,0 +1,1564 @@ +package redis + +import ( + "io" + "log" + "strconv" + "time" +) + +func formatInt(i int64) string { + return strconv.FormatInt(i, 10) +} + +func formatUint(i uint64) string { + return strconv.FormatUint(i, 10) +} + +func formatFloat(f float64) string { + return strconv.FormatFloat(f, 'f', -1, 64) +} + +func readTimeout(timeout time.Duration) time.Duration { + if timeout == 0 { + return 0 + } + return timeout + time.Second +} + +func usePrecise(dur time.Duration) bool { + return dur < time.Second || dur%time.Second != 0 +} + +func formatMs(dur time.Duration) string { + if dur > 0 && dur < time.Millisecond { + log.Printf( + "redis: specified duration is %s, but minimal supported value is %s", + dur, time.Millisecond, + ) + } + return formatInt(int64(dur / time.Millisecond)) +} + +func formatSec(dur time.Duration) string { + if dur > 0 && dur < time.Second { + log.Printf( + "redis: specified duration is %s, but minimal supported value is %s", + dur, time.Second, + ) + } + return formatInt(int64(dur / time.Second)) +} + +type commandable struct { + process func(cmd Cmder) +} + +func (c *commandable) Process(cmd Cmder) { + c.process(cmd) +} + +//------------------------------------------------------------------------------ + +func (c *commandable) Auth(password string) *StatusCmd { + cmd := newKeylessStatusCmd("AUTH", password) + c.Process(cmd) + return cmd +} + +func (c *commandable) Echo(message string) *StringCmd { + cmd := NewStringCmd("ECHO", message) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) Ping() *StatusCmd { + cmd := newKeylessStatusCmd("PING") + c.Process(cmd) + return cmd +} + +func (c *commandable) Quit() *StatusCmd { + panic("not implemented") +} + +func (c *commandable) Select(index int64) *StatusCmd { + cmd := newKeylessStatusCmd("SELECT", formatInt(index)) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) Del(keys ...string) *IntCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "DEL" + for i, key := range keys { + args[1+i] = key + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) Dump(key string) *StringCmd { + cmd := NewStringCmd("DUMP", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) Exists(key string) *BoolCmd { + cmd := NewBoolCmd("EXISTS", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) Expire(key string, expiration time.Duration) *BoolCmd { + cmd := NewBoolCmd("EXPIRE", key, formatSec(expiration)) + c.Process(cmd) + return cmd +} + +func (c *commandable) ExpireAt(key string, tm time.Time) *BoolCmd { + cmd := NewBoolCmd("EXPIREAT", key, formatInt(tm.Unix())) + c.Process(cmd) + return cmd +} + +func (c *commandable) Keys(pattern string) *StringSliceCmd { + cmd := NewStringSliceCmd("KEYS", pattern) + c.Process(cmd) + return cmd +} + +func (c *commandable) Migrate(host, port, key string, db int64, timeout time.Duration) *StatusCmd { + cmd := NewStatusCmd( + "MIGRATE", + host, + port, + key, + formatInt(db), + formatMs(timeout), + ) + cmd._clusterKeyPos = 3 + cmd.setReadTimeout(readTimeout(timeout)) + c.Process(cmd) + return cmd +} + +func (c *commandable) Move(key string, db int64) *BoolCmd { + cmd := NewBoolCmd("MOVE", key, formatInt(db)) + c.Process(cmd) + return cmd +} + +func (c *commandable) ObjectRefCount(keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "OBJECT" + args[1] = "REFCOUNT" + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + cmd._clusterKeyPos = 2 + c.Process(cmd) + return cmd +} + +func (c *commandable) ObjectEncoding(keys ...string) *StringCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "OBJECT" + args[1] = "ENCODING" + for i, key := range keys { + args[2+i] = key + } + cmd := NewStringCmd(args...) + cmd._clusterKeyPos = 2 + c.Process(cmd) + return cmd +} + +func (c *commandable) ObjectIdleTime(keys ...string) *DurationCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "OBJECT" + args[1] = "IDLETIME" + for i, key := range keys { + args[2+i] = key + } + cmd := NewDurationCmd(time.Second, args...) + cmd._clusterKeyPos = 2 + c.Process(cmd) + return cmd +} + +func (c *commandable) Persist(key string) *BoolCmd { + cmd := NewBoolCmd("PERSIST", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) PExpire(key string, expiration time.Duration) *BoolCmd { + cmd := NewBoolCmd("PEXPIRE", key, formatMs(expiration)) + c.Process(cmd) + return cmd +} + +func (c *commandable) PExpireAt(key string, tm time.Time) *BoolCmd { + cmd := NewBoolCmd( + "PEXPIREAT", + key, + formatInt(tm.UnixNano()/int64(time.Millisecond)), + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) PTTL(key string) *DurationCmd { + cmd := NewDurationCmd(time.Millisecond, "PTTL", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) RandomKey() *StringCmd { + cmd := NewStringCmd("RANDOMKEY") + c.Process(cmd) + return cmd +} + +func (c *commandable) Rename(key, newkey string) *StatusCmd { + cmd := NewStatusCmd("RENAME", key, newkey) + c.Process(cmd) + return cmd +} + +func (c *commandable) RenameNX(key, newkey string) *BoolCmd { + cmd := NewBoolCmd("RENAMENX", key, newkey) + c.Process(cmd) + return cmd +} + +func (c *commandable) Restore(key string, ttl time.Duration, value string) *StatusCmd { + cmd := NewStatusCmd( + "RESTORE", + key, + formatMs(ttl), + value, + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) RestoreReplace(key string, ttl time.Duration, value string) *StatusCmd { + cmd := NewStatusCmd( + "RESTORE", + key, + formatMs(ttl), + value, + "REPLACE", + ) + c.Process(cmd) + return cmd +} + +type Sort struct { + By string + Offset, Count float64 + Get []string + Order string + IsAlpha bool + Store string +} + +func (c *commandable) Sort(key string, sort Sort) *StringSliceCmd { + args := []interface{}{"SORT", key} + if sort.By != "" { + args = append(args, "BY", sort.By) + } + if sort.Offset != 0 || sort.Count != 0 { + args = append(args, "LIMIT", formatFloat(sort.Offset), formatFloat(sort.Count)) + } + for _, get := range sort.Get { + args = append(args, "GET", get) + } + if sort.Order != "" { + args = append(args, sort.Order) + } + if sort.IsAlpha { + args = append(args, "ALPHA") + } + if sort.Store != "" { + args = append(args, "STORE", sort.Store) + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) TTL(key string) *DurationCmd { + cmd := NewDurationCmd(time.Second, "TTL", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) Type(key string) *StatusCmd { + cmd := NewStatusCmd("TYPE", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) Scan(cursor int64, match string, count int64) *ScanCmd { + args := []interface{}{"SCAN", formatInt(cursor)} + if match != "" { + args = append(args, "MATCH", match) + } + if count > 0 { + args = append(args, "COUNT", formatInt(count)) + } + cmd := NewScanCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SScan(key string, cursor int64, match string, count int64) *ScanCmd { + args := []interface{}{"SSCAN", key, formatInt(cursor)} + if match != "" { + args = append(args, "MATCH", match) + } + if count > 0 { + args = append(args, "COUNT", formatInt(count)) + } + cmd := NewScanCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) HScan(key string, cursor int64, match string, count int64) *ScanCmd { + args := []interface{}{"HSCAN", key, formatInt(cursor)} + if match != "" { + args = append(args, "MATCH", match) + } + if count > 0 { + args = append(args, "COUNT", formatInt(count)) + } + cmd := NewScanCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZScan(key string, cursor int64, match string, count int64) *ScanCmd { + args := []interface{}{"ZSCAN", key, formatInt(cursor)} + if match != "" { + args = append(args, "MATCH", match) + } + if count > 0 { + args = append(args, "COUNT", formatInt(count)) + } + cmd := NewScanCmd(args...) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) Append(key, value string) *IntCmd { + cmd := NewIntCmd("APPEND", key, value) + c.Process(cmd) + return cmd +} + +type BitCount struct { + Start, End int64 +} + +func (c *commandable) BitCount(key string, bitCount *BitCount) *IntCmd { + args := []interface{}{"BITCOUNT", key} + if bitCount != nil { + args = append( + args, + formatInt(bitCount.Start), + formatInt(bitCount.End), + ) + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) bitOp(op, destKey string, keys ...string) *IntCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "BITOP" + args[1] = op + args[2] = destKey + for i, key := range keys { + args[3+i] = key + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) BitOpAnd(destKey string, keys ...string) *IntCmd { + return c.bitOp("AND", destKey, keys...) +} + +func (c *commandable) BitOpOr(destKey string, keys ...string) *IntCmd { + return c.bitOp("OR", destKey, keys...) +} + +func (c *commandable) BitOpXor(destKey string, keys ...string) *IntCmd { + return c.bitOp("XOR", destKey, keys...) +} + +func (c *commandable) BitOpNot(destKey string, key string) *IntCmd { + return c.bitOp("NOT", destKey, key) +} + +func (c *commandable) BitPos(key string, bit int64, pos ...int64) *IntCmd { + args := make([]interface{}, 3+len(pos)) + args[0] = "BITPOS" + args[1] = key + args[2] = formatInt(bit) + switch len(pos) { + case 0: + case 1: + args[3] = formatInt(pos[0]) + case 2: + args[3] = formatInt(pos[0]) + args[4] = formatInt(pos[1]) + default: + panic("too many arguments") + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) Decr(key string) *IntCmd { + cmd := NewIntCmd("DECR", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) DecrBy(key string, decrement int64) *IntCmd { + cmd := NewIntCmd("DECRBY", key, formatInt(decrement)) + c.Process(cmd) + return cmd +} + +func (c *commandable) Get(key string) *StringCmd { + cmd := NewStringCmd("GET", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) GetBit(key string, offset int64) *IntCmd { + cmd := NewIntCmd("GETBIT", key, formatInt(offset)) + c.Process(cmd) + return cmd +} + +func (c *commandable) GetRange(key string, start, end int64) *StringCmd { + cmd := NewStringCmd( + "GETRANGE", + key, + formatInt(start), + formatInt(end), + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) GetSet(key string, value interface{}) *StringCmd { + cmd := NewStringCmd("GETSET", key, value) + c.Process(cmd) + return cmd +} + +func (c *commandable) Incr(key string) *IntCmd { + cmd := NewIntCmd("INCR", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) IncrBy(key string, value int64) *IntCmd { + cmd := NewIntCmd("INCRBY", key, formatInt(value)) + c.Process(cmd) + return cmd +} + +func (c *commandable) IncrByFloat(key string, value float64) *FloatCmd { + cmd := NewFloatCmd("INCRBYFLOAT", key, formatFloat(value)) + c.Process(cmd) + return cmd +} + +func (c *commandable) MGet(keys ...string) *SliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "MGET" + for i, key := range keys { + args[1+i] = key + } + cmd := NewSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) MSet(pairs ...string) *StatusCmd { + args := make([]interface{}, 1+len(pairs)) + args[0] = "MSET" + for i, pair := range pairs { + args[1+i] = pair + } + cmd := NewStatusCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) MSetNX(pairs ...string) *BoolCmd { + args := make([]interface{}, 1+len(pairs)) + args[0] = "MSETNX" + for i, pair := range pairs { + args[1+i] = pair + } + cmd := NewBoolCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) Set(key string, value interface{}, expiration time.Duration) *StatusCmd { + args := make([]interface{}, 3, 5) + args[0] = "SET" + args[1] = key + args[2] = value + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "PX", formatMs(expiration)) + } else { + args = append(args, "EX", formatSec(expiration)) + } + } + cmd := NewStatusCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SetBit(key string, offset int64, value int) *IntCmd { + cmd := NewIntCmd( + "SETBIT", + key, + formatInt(offset), + formatInt(int64(value)), + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) SetNX(key string, value interface{}, expiration time.Duration) *BoolCmd { + var cmd *BoolCmd + if expiration == 0 { + // Use old `SETNX` to support old Redis versions. + cmd = NewBoolCmd("SETNX", key, value) + } else { + if usePrecise(expiration) { + cmd = NewBoolCmd("SET", key, value, "PX", formatMs(expiration), "NX") + } else { + cmd = NewBoolCmd("SET", key, value, "EX", formatSec(expiration), "NX") + } + } + c.Process(cmd) + return cmd +} + +func (c *Client) SetXX(key string, value interface{}, expiration time.Duration) *BoolCmd { + var cmd *BoolCmd + if usePrecise(expiration) { + cmd = NewBoolCmd("SET", key, value, "PX", formatMs(expiration), "XX") + } else { + cmd = NewBoolCmd("SET", key, value, "EX", formatSec(expiration), "XX") + } + c.Process(cmd) + return cmd +} + +func (c *commandable) SetRange(key string, offset int64, value string) *IntCmd { + cmd := NewIntCmd("SETRANGE", key, formatInt(offset), value) + c.Process(cmd) + return cmd +} + +func (c *commandable) StrLen(key string) *IntCmd { + cmd := NewIntCmd("STRLEN", key) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) HDel(key string, fields ...string) *IntCmd { + args := make([]interface{}, 2+len(fields)) + args[0] = "HDEL" + args[1] = key + for i, field := range fields { + args[2+i] = field + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) HExists(key, field string) *BoolCmd { + cmd := NewBoolCmd("HEXISTS", key, field) + c.Process(cmd) + return cmd +} + +func (c *commandable) HGet(key, field string) *StringCmd { + cmd := NewStringCmd("HGET", key, field) + c.Process(cmd) + return cmd +} + +func (c *commandable) HGetAll(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("HGETALL", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) HGetAllMap(key string) *StringStringMapCmd { + cmd := NewStringStringMapCmd("HGETALL", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) HIncrBy(key, field string, incr int64) *IntCmd { + cmd := NewIntCmd("HINCRBY", key, field, formatInt(incr)) + c.Process(cmd) + return cmd +} + +func (c *commandable) HIncrByFloat(key, field string, incr float64) *FloatCmd { + cmd := NewFloatCmd("HINCRBYFLOAT", key, field, formatFloat(incr)) + c.Process(cmd) + return cmd +} + +func (c *commandable) HKeys(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("HKEYS", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) HLen(key string) *IntCmd { + cmd := NewIntCmd("HLEN", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) HMGet(key string, fields ...string) *SliceCmd { + args := make([]interface{}, 2+len(fields)) + args[0] = "HMGET" + args[1] = key + for i, field := range fields { + args[2+i] = field + } + cmd := NewSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) HMSet(key, field, value string, pairs ...string) *StatusCmd { + args := make([]interface{}, 4+len(pairs)) + args[0] = "HMSET" + args[1] = key + args[2] = field + args[3] = value + for i, pair := range pairs { + args[4+i] = pair + } + cmd := NewStatusCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) HSet(key, field, value string) *BoolCmd { + cmd := NewBoolCmd("HSET", key, field, value) + c.Process(cmd) + return cmd +} + +func (c *commandable) HSetNX(key, field, value string) *BoolCmd { + cmd := NewBoolCmd("HSETNX", key, field, value) + c.Process(cmd) + return cmd +} + +func (c *commandable) HVals(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("HVALS", key) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) BLPop(timeout time.Duration, keys ...string) *StringSliceCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "BLPOP" + for i, key := range keys { + args[1+i] = key + } + args[len(args)-1] = formatSec(timeout) + cmd := NewStringSliceCmd(args...) + cmd.setReadTimeout(readTimeout(timeout)) + c.Process(cmd) + return cmd +} + +func (c *commandable) BRPop(timeout time.Duration, keys ...string) *StringSliceCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "BRPOP" + for i, key := range keys { + args[1+i] = key + } + args[len(args)-1] = formatSec(timeout) + cmd := NewStringSliceCmd(args...) + cmd.setReadTimeout(readTimeout(timeout)) + c.Process(cmd) + return cmd +} + +func (c *commandable) BRPopLPush(source, destination string, timeout time.Duration) *StringCmd { + cmd := NewStringCmd( + "BRPOPLPUSH", + source, + destination, + formatSec(timeout), + ) + cmd.setReadTimeout(readTimeout(timeout)) + c.Process(cmd) + return cmd +} + +func (c *commandable) LIndex(key string, index int64) *StringCmd { + cmd := NewStringCmd("LINDEX", key, formatInt(index)) + c.Process(cmd) + return cmd +} + +func (c *commandable) LInsert(key, op, pivot, value string) *IntCmd { + cmd := NewIntCmd("LINSERT", key, op, pivot, value) + c.Process(cmd) + return cmd +} + +func (c *commandable) LLen(key string) *IntCmd { + cmd := NewIntCmd("LLEN", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) LPop(key string) *StringCmd { + cmd := NewStringCmd("LPOP", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) LPush(key string, values ...string) *IntCmd { + args := make([]interface{}, 2+len(values)) + args[0] = "LPUSH" + args[1] = key + for i, value := range values { + args[2+i] = value + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) LPushX(key, value string) *IntCmd { + cmd := NewIntCmd("LPUSHX", key, value) + c.Process(cmd) + return cmd +} + +func (c *commandable) LRange(key string, start, stop int64) *StringSliceCmd { + cmd := NewStringSliceCmd( + "LRANGE", + key, + formatInt(start), + formatInt(stop), + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) LRem(key string, count int64, value string) *IntCmd { + cmd := NewIntCmd("LREM", key, formatInt(count), value) + c.Process(cmd) + return cmd +} + +func (c *commandable) LSet(key string, index int64, value string) *StatusCmd { + cmd := NewStatusCmd("LSET", key, formatInt(index), value) + c.Process(cmd) + return cmd +} + +func (c *commandable) LTrim(key string, start, stop int64) *StatusCmd { + cmd := NewStatusCmd( + "LTRIM", + key, + formatInt(start), + formatInt(stop), + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) RPop(key string) *StringCmd { + cmd := NewStringCmd("RPOP", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) RPopLPush(source, destination string) *StringCmd { + cmd := NewStringCmd("RPOPLPUSH", source, destination) + c.Process(cmd) + return cmd +} + +func (c *commandable) RPush(key string, values ...string) *IntCmd { + args := make([]interface{}, 2+len(values)) + args[0] = "RPUSH" + args[1] = key + for i, value := range values { + args[2+i] = value + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) RPushX(key string, value string) *IntCmd { + cmd := NewIntCmd("RPUSHX", key, value) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) SAdd(key string, members ...string) *IntCmd { + args := make([]interface{}, 2+len(members)) + args[0] = "SADD" + args[1] = key + for i, member := range members { + args[2+i] = member + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SCard(key string) *IntCmd { + cmd := NewIntCmd("SCARD", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) SDiff(keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "SDIFF" + for i, key := range keys { + args[1+i] = key + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SDiffStore(destination string, keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "SDIFFSTORE" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SInter(keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "SINTER" + for i, key := range keys { + args[1+i] = key + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SInterStore(destination string, keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "SINTERSTORE" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SIsMember(key, member string) *BoolCmd { + cmd := NewBoolCmd("SISMEMBER", key, member) + c.Process(cmd) + return cmd +} + +func (c *commandable) SMembers(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("SMEMBERS", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) SMove(source, destination, member string) *BoolCmd { + cmd := NewBoolCmd("SMOVE", source, destination, member) + c.Process(cmd) + return cmd +} + +func (c *commandable) SPop(key string) *StringCmd { + cmd := NewStringCmd("SPOP", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) SRandMember(key string) *StringCmd { + cmd := NewStringCmd("SRANDMEMBER", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) SRem(key string, members ...string) *IntCmd { + args := make([]interface{}, 2+len(members)) + args[0] = "SREM" + args[1] = key + for i, member := range members { + args[2+i] = member + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SUnion(keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "SUNION" + for i, key := range keys { + args[1+i] = key + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) SUnionStore(destination string, keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "SUNIONSTORE" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +// Sorted set member. +type Z struct { + Score float64 + Member interface{} +} + +// Sorted set store operation. +type ZStore struct { + Weights []int64 + // Can be SUM, MIN or MAX. + Aggregate string +} + +func (c *commandable) ZAdd(key string, members ...Z) *IntCmd { + args := make([]interface{}, 2+2*len(members)) + args[0] = "ZADD" + args[1] = key + for i, m := range members { + args[2+2*i] = formatFloat(m.Score) + args[2+2*i+1] = m.Member + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZCard(key string) *IntCmd { + cmd := NewIntCmd("ZCARD", key) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZCount(key, min, max string) *IntCmd { + cmd := NewIntCmd("ZCOUNT", key, min, max) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZIncrBy(key string, increment float64, member string) *FloatCmd { + cmd := NewFloatCmd("ZINCRBY", key, formatFloat(increment), member) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZInterStore( + destination string, + store ZStore, + keys ...string, +) *IntCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "ZINTERSTORE" + args[1] = destination + args[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + args[3+i] = key + } + if len(store.Weights) > 0 { + args = append(args, "WEIGHTS") + for _, weight := range store.Weights { + args = append(args, formatInt(weight)) + } + } + if store.Aggregate != "" { + args = append(args, "AGGREGATE", store.Aggregate) + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) zRange(key string, start, stop int64, withScores bool) *StringSliceCmd { + args := []interface{}{ + "ZRANGE", + key, + formatInt(start), + formatInt(stop), + } + if withScores { + args = append(args, "WITHSCORES") + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRange(key string, start, stop int64) *StringSliceCmd { + return c.zRange(key, start, stop, false) +} + +func (c *commandable) ZRangeWithScores(key string, start, stop int64) *ZSliceCmd { + args := []interface{}{ + "ZRANGE", + key, + formatInt(start), + formatInt(stop), + "WITHSCORES", + } + cmd := NewZSliceCmd(args...) + c.Process(cmd) + return cmd +} + +type ZRangeByScore struct { + Min, Max string + Offset, Count int64 +} + +func (c *commandable) zRangeByScore(key string, opt ZRangeByScore, withScores bool) *StringSliceCmd { + args := []interface{}{"ZRANGEBYSCORE", key, opt.Min, opt.Max} + if withScores { + args = append(args, "WITHSCORES") + } + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "LIMIT", + formatInt(opt.Offset), + formatInt(opt.Count), + ) + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRangeByScore(key string, opt ZRangeByScore) *StringSliceCmd { + return c.zRangeByScore(key, opt, false) +} + +func (c *commandable) ZRangeByScoreWithScores(key string, opt ZRangeByScore) *ZSliceCmd { + args := []interface{}{"ZRANGEBYSCORE", key, opt.Min, opt.Max, "WITHSCORES"} + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "LIMIT", + formatInt(opt.Offset), + formatInt(opt.Count), + ) + } + cmd := NewZSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRank(key, member string) *IntCmd { + cmd := NewIntCmd("ZRANK", key, member) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRem(key string, members ...string) *IntCmd { + args := make([]interface{}, 2+len(members)) + args[0] = "ZREM" + args[1] = key + for i, member := range members { + args[2+i] = member + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRemRangeByRank(key string, start, stop int64) *IntCmd { + cmd := NewIntCmd( + "ZREMRANGEBYRANK", + key, + formatInt(start), + formatInt(stop), + ) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRemRangeByScore(key, min, max string) *IntCmd { + cmd := NewIntCmd("ZREMRANGEBYSCORE", key, min, max) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRevRange(key string, start, stop int64) *StringSliceCmd { + cmd := NewStringSliceCmd("ZREVRANGE", key, formatInt(start), formatInt(stop)) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRevRangeWithScores(key string, start, stop int64) *ZSliceCmd { + cmd := NewZSliceCmd("ZREVRANGE", key, formatInt(start), formatInt(stop), "WITHSCORES") + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRevRangeByScore(key string, opt ZRangeByScore) *StringSliceCmd { + args := []interface{}{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min} + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "LIMIT", + formatInt(opt.Offset), + formatInt(opt.Count), + ) + } + cmd := NewStringSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRevRangeByScoreWithScores(key string, opt ZRangeByScore) *ZSliceCmd { + args := []interface{}{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min, "WITHSCORES"} + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "LIMIT", + formatInt(opt.Offset), + formatInt(opt.Count), + ) + } + cmd := NewZSliceCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZRevRank(key, member string) *IntCmd { + cmd := NewIntCmd("ZREVRANK", key, member) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZScore(key, member string) *FloatCmd { + cmd := NewFloatCmd("ZSCORE", key, member) + c.Process(cmd) + return cmd +} + +func (c *commandable) ZUnionStore(dest string, store ZStore, keys ...string) *IntCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "ZUNIONSTORE" + args[1] = dest + args[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + args[3+i] = key + } + if len(store.Weights) > 0 { + args = append(args, "WEIGHTS") + for _, weight := range store.Weights { + args = append(args, formatInt(weight)) + } + } + if store.Aggregate != "" { + args = append(args, "AGGREGATE", store.Aggregate) + } + cmd := NewIntCmd(args...) + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) BgRewriteAOF() *StatusCmd { + cmd := NewStatusCmd("BGREWRITEAOF") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) BgSave() *StatusCmd { + cmd := NewStatusCmd("BGSAVE") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ClientKill(ipPort string) *StatusCmd { + cmd := NewStatusCmd("CLIENT", "KILL", ipPort) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ClientList() *StringCmd { + cmd := NewStringCmd("CLIENT", "LIST") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ClientPause(dur time.Duration) *BoolCmd { + cmd := NewBoolCmd("CLIENT", "PAUSE", formatMs(dur)) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ConfigGet(parameter string) *SliceCmd { + cmd := NewSliceCmd("CONFIG", "GET", parameter) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ConfigResetStat() *StatusCmd { + cmd := NewStatusCmd("CONFIG", "RESETSTAT") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ConfigSet(parameter, value string) *StatusCmd { + cmd := NewStatusCmd("CONFIG", "SET", parameter, value) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) DbSize() *IntCmd { + cmd := NewIntCmd("DBSIZE") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) FlushAll() *StatusCmd { + cmd := newKeylessStatusCmd("FLUSHALL") + c.Process(cmd) + return cmd +} + +func (c *commandable) FlushDb() *StatusCmd { + cmd := newKeylessStatusCmd("FLUSHDB") + c.Process(cmd) + return cmd +} + +func (c *commandable) Info() *StringCmd { + cmd := NewStringCmd("INFO") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) LastSave() *IntCmd { + cmd := NewIntCmd("LASTSAVE") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) Save() *StatusCmd { + cmd := newKeylessStatusCmd("SAVE") + c.Process(cmd) + return cmd +} + +func (c *commandable) shutdown(modifier string) *StatusCmd { + var args []interface{} + if modifier == "" { + args = []interface{}{"SHUTDOWN"} + } else { + args = []interface{}{"SHUTDOWN", modifier} + } + cmd := newKeylessStatusCmd(args...) + c.Process(cmd) + if err := cmd.Err(); err != nil { + if err == io.EOF { + // Server quit as expected. + cmd.err = nil + } + } else { + // Server did not quit. String reply contains the reason. + cmd.err = errorf(cmd.val) + cmd.val = "" + } + return cmd +} + +func (c *commandable) Shutdown() *StatusCmd { + return c.shutdown("") +} + +func (c *commandable) ShutdownSave() *StatusCmd { + return c.shutdown("SAVE") +} + +func (c *commandable) ShutdownNoSave() *StatusCmd { + return c.shutdown("NOSAVE") +} + +func (c *commandable) SlaveOf(host, port string) *StatusCmd { + cmd := newKeylessStatusCmd("SLAVEOF", host, port) + c.Process(cmd) + return cmd +} + +func (c *commandable) SlowLog() { + panic("not implemented") +} + +func (c *commandable) Sync() { + panic("not implemented") +} + +func (c *commandable) Time() *StringSliceCmd { + cmd := NewStringSliceCmd("TIME") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) Eval(script string, keys []string, args []string) *Cmd { + cmdArgs := make([]interface{}, 3+len(keys)+len(args)) + cmdArgs[0] = "EVAL" + cmdArgs[1] = script + cmdArgs[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + cmdArgs[3+i] = key + } + pos := 3 + len(keys) + for i, arg := range args { + cmdArgs[pos+i] = arg + } + cmd := NewCmd(cmdArgs...) + if len(keys) > 0 { + cmd._clusterKeyPos = 3 + } + c.Process(cmd) + return cmd +} + +func (c *commandable) EvalSha(sha1 string, keys []string, args []string) *Cmd { + cmdArgs := make([]interface{}, 3+len(keys)+len(args)) + cmdArgs[0] = "EVALSHA" + cmdArgs[1] = sha1 + cmdArgs[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + cmdArgs[3+i] = key + } + pos := 3 + len(keys) + for i, arg := range args { + cmdArgs[pos+i] = arg + } + cmd := NewCmd(cmdArgs...) + if len(keys) > 0 { + cmd._clusterKeyPos = 3 + } + c.Process(cmd) + return cmd +} + +func (c *commandable) ScriptExists(scripts ...string) *BoolSliceCmd { + args := make([]interface{}, 2+len(scripts)) + args[0] = "SCRIPT" + args[1] = "EXISTS" + for i, script := range scripts { + args[2+i] = script + } + cmd := NewBoolSliceCmd(args...) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ScriptFlush() *StatusCmd { + cmd := newKeylessStatusCmd("SCRIPT", "FLUSH") + c.Process(cmd) + return cmd +} + +func (c *commandable) ScriptKill() *StatusCmd { + cmd := newKeylessStatusCmd("SCRIPT", "KILL") + c.Process(cmd) + return cmd +} + +func (c *commandable) ScriptLoad(script string) *StringCmd { + cmd := NewStringCmd("SCRIPT", "LOAD", script) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) DebugObject(key string) *StringCmd { + cmd := NewStringCmd("DEBUG", "OBJECT", key) + cmd._clusterKeyPos = 2 + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) PubSubChannels(pattern string) *StringSliceCmd { + args := []interface{}{"PUBSUB", "CHANNELS"} + if pattern != "*" { + args = append(args, pattern) + } + cmd := NewStringSliceCmd(args...) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) PubSubNumSub(channels ...string) *StringIntMapCmd { + args := make([]interface{}, 2+len(channels)) + args[0] = "PUBSUB" + args[1] = "NUMSUB" + for i, channel := range channels { + args[2+i] = channel + } + cmd := NewStringIntMapCmd(args...) + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) PubSubNumPat() *IntCmd { + cmd := NewIntCmd("PUBSUB", "NUMPAT") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *commandable) ClusterSlots() *ClusterSlotCmd { + cmd := NewClusterSlotCmd("CLUSTER", "slots") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterNodes() *StringCmd { + cmd := NewStringCmd("CLUSTER", "nodes") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterMeet(host, port string) *StatusCmd { + cmd := newKeylessStatusCmd("CLUSTER", "meet", host, port) + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterReplicate(nodeID string) *StatusCmd { + cmd := newKeylessStatusCmd("CLUSTER", "replicate", nodeID) + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterInfo() *StringCmd { + cmd := NewStringCmd("CLUSTER", "info") + cmd._clusterKeyPos = 0 + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterFailover() *StatusCmd { + cmd := newKeylessStatusCmd("CLUSTER", "failover") + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterAddSlots(slots ...int) *StatusCmd { + args := make([]interface{}, 2+len(slots)) + args[0] = "CLUSTER" + args[1] = "ADDSLOTS" + for i, num := range slots { + args[2+i] = strconv.Itoa(num) + } + cmd := newKeylessStatusCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *commandable) ClusterAddSlotsRange(min, max int) *StatusCmd { + size := max - min + 1 + slots := make([]int, size) + for i := 0; i < size; i++ { + slots[i] = min + i + } + return c.ClusterAddSlots(slots...) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/commands_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/commands_test.go new file mode 100644 index 0000000..6cd9fc0 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/commands_test.go @@ -0,0 +1,2423 @@ +package redis_test + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "sync" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Commands", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + ReadTimeout: 500 * time.Millisecond, + WriteTimeout: 500 * time.Millisecond, + PoolTimeout: 30 * time.Second, + }) + }) + + AfterEach(func() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + //------------------------------------------------------------------------------ + + Describe("server", func() { + + It("should Auth", func() { + auth := client.Auth("password") + Expect(auth.Err()).To(MatchError("ERR Client sent AUTH, but no password is set")) + Expect(auth.Val()).To(Equal("")) + }) + + It("should Echo", func() { + echo := client.Echo("hello") + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal("hello")) + }) + + It("should Ping", func() { + ping := client.Ping() + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + It("should Select", func() { + sel := client.Select(1) + Expect(sel.Err()).NotTo(HaveOccurred()) + Expect(sel.Val()).To(Equal("OK")) + }) + + It("should BgRewriteAOF", func() { + r := client.BgRewriteAOF() + Expect(r.Err()).NotTo(HaveOccurred()) + Expect(r.Val()).To(ContainSubstring("Background append only file rewriting")) + }) + + It("should BgSave", func() { + // workaround for "ERR Can't BGSAVE while AOF log rewriting is in progress" + Eventually(func() string { + return client.BgSave().Val() + }, "10s").Should(Equal("Background saving started")) + }) + + It("should ClientKill", func() { + r := client.ClientKill("1.1.1.1:1111") + Expect(r.Err()).To(MatchError("ERR No such client")) + Expect(r.Val()).To(Equal("")) + }) + + It("should ClientPause", func() { + err := client.ClientPause(time.Second).Err() + Expect(err).NotTo(HaveOccurred()) + + Consistently(func() error { + return client.Ping().Err() + }, "400ms").Should(HaveOccurred()) // pause time - read timeout + + Eventually(func() error { + return client.Ping().Err() + }, "1s").ShouldNot(HaveOccurred()) + }) + + It("should ConfigGet", func() { + r := client.ConfigGet("*") + Expect(r.Err()).NotTo(HaveOccurred()) + Expect(r.Val()).NotTo(BeEmpty()) + }) + + It("should ConfigResetStat", func() { + r := client.ConfigResetStat() + Expect(r.Err()).NotTo(HaveOccurred()) + Expect(r.Val()).To(Equal("OK")) + }) + + It("should ConfigSet", func() { + configGet := client.ConfigGet("maxmemory") + Expect(configGet.Err()).NotTo(HaveOccurred()) + Expect(configGet.Val()).To(HaveLen(2)) + Expect(configGet.Val()[0]).To(Equal("maxmemory")) + + configSet := client.ConfigSet("maxmemory", configGet.Val()[1].(string)) + Expect(configSet.Err()).NotTo(HaveOccurred()) + Expect(configSet.Val()).To(Equal("OK")) + }) + + It("should DbSize", func() { + dbSize := client.DbSize() + Expect(dbSize.Err()).NotTo(HaveOccurred()) + Expect(dbSize.Val()).To(Equal(int64(0))) + }) + + It("should Info", func() { + info := client.Info() + Expect(info.Err()).NotTo(HaveOccurred()) + Expect(info.Val()).NotTo(Equal("")) + }) + + It("should LastSave", func() { + lastSave := client.LastSave() + Expect(lastSave.Err()).NotTo(HaveOccurred()) + Expect(lastSave.Val()).NotTo(Equal(0)) + }) + + It("should Save", func() { + // workaround for "ERR Background save already in progress" + Eventually(func() string { + return client.Save().Val() + }, "10s").Should(Equal("OK")) + }) + + It("should SlaveOf", func() { + slaveOf := client.SlaveOf("localhost", "8888") + Expect(slaveOf.Err()).NotTo(HaveOccurred()) + Expect(slaveOf.Val()).To(Equal("OK")) + + slaveOf = client.SlaveOf("NO", "ONE") + Expect(slaveOf.Err()).NotTo(HaveOccurred()) + Expect(slaveOf.Val()).To(Equal("OK")) + }) + + It("should Time", func() { + time := client.Time() + Expect(time.Err()).NotTo(HaveOccurred()) + Expect(time.Val()).To(HaveLen(2)) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("debugging", func() { + + It("should DebugObject", func() { + debug := client.DebugObject("foo") + Expect(debug.Err()).To(HaveOccurred()) + Expect(debug.Err().Error()).To(Equal("ERR no such key")) + + client.Set("foo", "bar", 0) + debug = client.DebugObject("foo") + Expect(debug.Err()).NotTo(HaveOccurred()) + Expect(debug.Val()).To(ContainSubstring(`serializedlength:4`)) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("keys", func() { + + It("should Del", func() { + set := client.Set("key1", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + set = client.Set("key2", "World", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + del := client.Del("key1", "key2", "key3") + Expect(del.Err()).NotTo(HaveOccurred()) + Expect(del.Val()).To(Equal(int64(2))) + }) + + It("should Dump", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + dump := client.Dump("key") + Expect(dump.Err()).NotTo(HaveOccurred()) + Expect(dump.Val()).To(Equal("\x00\x05hello\x06\x00\xf5\x9f\xb7\xf6\x90a\x1c\x99")) + }) + + It("should Exists", func() { + set := client.Set("key1", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + exists := client.Exists("key1") + Expect(exists.Err()).NotTo(HaveOccurred()) + Expect(exists.Val()).To(Equal(true)) + + exists = client.Exists("key2") + Expect(exists.Err()).NotTo(HaveOccurred()) + Expect(exists.Val()).To(Equal(false)) + }) + + It("should Expire", func() { + set := client.Set("key", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + expire := client.Expire("key", 10*time.Second) + Expect(expire.Err()).NotTo(HaveOccurred()) + Expect(expire.Val()).To(Equal(true)) + + ttl := client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val()).To(Equal(10 * time.Second)) + + set = client.Set("key", "Hello World", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + ttl = client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val() < 0).To(Equal(true)) + }) + + It("should ExpireAt", func() { + set := client.Set("key", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + exists := client.Exists("key") + Expect(exists.Err()).NotTo(HaveOccurred()) + Expect(exists.Val()).To(Equal(true)) + + expireAt := client.ExpireAt("key", time.Now().Add(-time.Hour)) + Expect(expireAt.Err()).NotTo(HaveOccurred()) + Expect(expireAt.Val()).To(Equal(true)) + + exists = client.Exists("key") + Expect(exists.Err()).NotTo(HaveOccurred()) + Expect(exists.Val()).To(Equal(false)) + }) + + It("should Keys", func() { + mset := client.MSet("one", "1", "two", "2", "three", "3", "four", "4") + Expect(mset.Err()).NotTo(HaveOccurred()) + Expect(mset.Val()).To(Equal("OK")) + + keys := client.Keys("*o*") + Expect(keys.Err()).NotTo(HaveOccurred()) + Expect(keys.Val()).To(ConsistOf([]string{"four", "one", "two"})) + + keys = client.Keys("t??") + Expect(keys.Err()).NotTo(HaveOccurred()) + Expect(keys.Val()).To(Equal([]string{"two"})) + + keys = client.Keys("*") + Expect(keys.Err()).NotTo(HaveOccurred()) + Expect(keys.Val()).To(ConsistOf([]string{"four", "one", "three", "two"})) + }) + + It("should Migrate", func() { + migrate := client.Migrate("localhost", redisSecondaryPort, "key", 0, 0) + Expect(migrate.Err()).NotTo(HaveOccurred()) + Expect(migrate.Val()).To(Equal("NOKEY")) + + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + migrate = client.Migrate("localhost", redisSecondaryPort, "key", 0, 0) + Expect(migrate.Err()).To(MatchError("IOERR error or timeout writing to target instance")) + Expect(migrate.Val()).To(Equal("")) + }) + + It("should Move", func() { + move := client.Move("key", 1) + Expect(move.Err()).NotTo(HaveOccurred()) + Expect(move.Val()).To(Equal(false)) + + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + move = client.Move("key", 1) + Expect(move.Err()).NotTo(HaveOccurred()) + Expect(move.Val()).To(Equal(true)) + + get := client.Get("key") + Expect(get.Err()).To(Equal(redis.Nil)) + Expect(get.Val()).To(Equal("")) + + sel := client.Select(1) + Expect(sel.Err()).NotTo(HaveOccurred()) + Expect(sel.Val()).To(Equal("OK")) + + get = client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Select(0).Err()).NotTo(HaveOccurred()) + }) + + It("should Object", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + refCount := client.ObjectRefCount("key") + Expect(refCount.Err()).NotTo(HaveOccurred()) + Expect(refCount.Val()).To(Equal(int64(1))) + + err := client.ObjectEncoding("key").Err() + Expect(err).NotTo(HaveOccurred()) + + idleTime := client.ObjectIdleTime("key") + Expect(idleTime.Err()).NotTo(HaveOccurred()) + Expect(idleTime.Val()).To(Equal(time.Duration(0))) + }) + + It("should Persist", func() { + set := client.Set("key", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + expire := client.Expire("key", 10*time.Second) + Expect(expire.Err()).NotTo(HaveOccurred()) + Expect(expire.Val()).To(Equal(true)) + + ttl := client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val()).To(Equal(10 * time.Second)) + + persist := client.Persist("key") + Expect(persist.Err()).NotTo(HaveOccurred()) + Expect(persist.Val()).To(Equal(true)) + + ttl = client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val() < 0).To(Equal(true)) + }) + + It("should PExpire", func() { + set := client.Set("key", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + expiration := 900 * time.Millisecond + pexpire := client.PExpire("key", expiration) + Expect(pexpire.Err()).NotTo(HaveOccurred()) + Expect(pexpire.Val()).To(Equal(true)) + + ttl := client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val()).To(Equal(time.Second)) + + pttl := client.PTTL("key") + Expect(pttl.Err()).NotTo(HaveOccurred()) + Expect(pttl.Val()).To(BeNumerically("~", expiration, 10*time.Millisecond)) + }) + + It("should PExpireAt", func() { + set := client.Set("key", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + expiration := 900 * time.Millisecond + pexpireat := client.PExpireAt("key", time.Now().Add(expiration)) + Expect(pexpireat.Err()).NotTo(HaveOccurred()) + Expect(pexpireat.Val()).To(Equal(true)) + + ttl := client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val()).To(Equal(time.Second)) + + pttl := client.PTTL("key") + Expect(pttl.Err()).NotTo(HaveOccurred()) + Expect(pttl.Val()).To(BeNumerically("~", expiration, 10*time.Millisecond)) + }) + + It("should PTTL", func() { + set := client.Set("key", "Hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + expiration := time.Second + expire := client.Expire("key", expiration) + Expect(expire.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + pttl := client.PTTL("key") + Expect(pttl.Err()).NotTo(HaveOccurred()) + Expect(pttl.Val()).To(BeNumerically("~", expiration, 10*time.Millisecond)) + }) + + It("should RandomKey", func() { + randomKey := client.RandomKey() + Expect(randomKey.Err()).To(Equal(redis.Nil)) + Expect(randomKey.Val()).To(Equal("")) + + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + randomKey = client.RandomKey() + Expect(randomKey.Err()).NotTo(HaveOccurred()) + Expect(randomKey.Val()).To(Equal("key")) + }) + + It("should Rename", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + status := client.Rename("key", "key1") + Expect(status.Err()).NotTo(HaveOccurred()) + Expect(status.Val()).To(Equal("OK")) + + get := client.Get("key1") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + }) + + It("should RenameNX", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + renameNX := client.RenameNX("key", "key1") + Expect(renameNX.Err()).NotTo(HaveOccurred()) + Expect(renameNX.Val()).To(Equal(true)) + + get := client.Get("key1") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + }) + + It("should Restore", func() { + err := client.Set("key", "hello", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + dump := client.Dump("key") + Expect(dump.Err()).NotTo(HaveOccurred()) + + err = client.Del("key").Err() + Expect(err).NotTo(HaveOccurred()) + + restore, err := client.Restore("key", 0, dump.Val()).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(restore).To(Equal("OK")) + + type_, err := client.Type("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(type_).To(Equal("string")) + + val, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("hello")) + }) + + It("should RestoreReplace", func() { + err := client.Set("key", "hello", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + dump := client.Dump("key") + Expect(dump.Err()).NotTo(HaveOccurred()) + + restore, err := client.RestoreReplace("key", 0, dump.Val()).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(restore).To(Equal("OK")) + + type_, err := client.Type("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(type_).To(Equal("string")) + + val, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("hello")) + }) + + It("should Sort", func() { + lPush := client.LPush("list", "1") + Expect(lPush.Err()).NotTo(HaveOccurred()) + Expect(lPush.Val()).To(Equal(int64(1))) + lPush = client.LPush("list", "3") + Expect(lPush.Err()).NotTo(HaveOccurred()) + Expect(lPush.Val()).To(Equal(int64(2))) + lPush = client.LPush("list", "2") + Expect(lPush.Err()).NotTo(HaveOccurred()) + Expect(lPush.Val()).To(Equal(int64(3))) + + sort := client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}) + Expect(sort.Err()).NotTo(HaveOccurred()) + Expect(sort.Val()).To(Equal([]string{"1", "2"})) + }) + + It("should TTL", func() { + ttl := client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val() < 0).To(Equal(true)) + + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + expire := client.Expire("key", 60*time.Second) + Expect(expire.Err()).NotTo(HaveOccurred()) + Expect(expire.Val()).To(Equal(true)) + + ttl = client.TTL("key") + Expect(ttl.Err()).NotTo(HaveOccurred()) + Expect(ttl.Val()).To(Equal(60 * time.Second)) + }) + + It("should Type", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + type_ := client.Type("key") + Expect(type_.Err()).NotTo(HaveOccurred()) + Expect(type_.Val()).To(Equal("string")) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("scanning", func() { + + It("should Scan", func() { + for i := 0; i < 1000; i++ { + set := client.Set(fmt.Sprintf("key%d", i), "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + } + + cursor, keys, err := client.Scan(0, "", 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(cursor > 0).To(Equal(true)) + Expect(len(keys) > 0).To(Equal(true)) + }) + + It("should SScan", func() { + for i := 0; i < 1000; i++ { + sadd := client.SAdd("myset", fmt.Sprintf("member%d", i)) + Expect(sadd.Err()).NotTo(HaveOccurred()) + } + + cursor, keys, err := client.SScan("myset", 0, "", 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(cursor > 0).To(Equal(true)) + Expect(len(keys) > 0).To(Equal(true)) + }) + + It("should HScan", func() { + for i := 0; i < 1000; i++ { + sadd := client.HSet("myhash", fmt.Sprintf("key%d", i), "hello") + Expect(sadd.Err()).NotTo(HaveOccurred()) + } + + cursor, keys, err := client.HScan("myhash", 0, "", 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(cursor > 0).To(Equal(true)) + Expect(len(keys) > 0).To(Equal(true)) + }) + + It("should ZScan", func() { + for i := 0; i < 1000; i++ { + sadd := client.ZAdd("myset", redis.Z{float64(i), fmt.Sprintf("member%d", i)}) + Expect(sadd.Err()).NotTo(HaveOccurred()) + } + + cursor, keys, err := client.ZScan("myset", 0, "", 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(cursor > 0).To(Equal(true)) + Expect(len(keys) > 0).To(Equal(true)) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("strings", func() { + + It("should Append", func() { + exists := client.Exists("key") + Expect(exists.Err()).NotTo(HaveOccurred()) + Expect(exists.Val()).To(Equal(false)) + + append := client.Append("key", "Hello") + Expect(append.Err()).NotTo(HaveOccurred()) + Expect(append.Val()).To(Equal(int64(5))) + + append = client.Append("key", " World") + Expect(append.Err()).NotTo(HaveOccurred()) + Expect(append.Val()).To(Equal(int64(11))) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("Hello World")) + }) + + It("should BitCount", func() { + set := client.Set("key", "foobar", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + bitCount := client.BitCount("key", nil) + Expect(bitCount.Err()).NotTo(HaveOccurred()) + Expect(bitCount.Val()).To(Equal(int64(26))) + + bitCount = client.BitCount("key", &redis.BitCount{0, 0}) + Expect(bitCount.Err()).NotTo(HaveOccurred()) + Expect(bitCount.Val()).To(Equal(int64(4))) + + bitCount = client.BitCount("key", &redis.BitCount{1, 1}) + Expect(bitCount.Err()).NotTo(HaveOccurred()) + Expect(bitCount.Val()).To(Equal(int64(6))) + }) + + It("should BitOpAnd", func() { + set := client.Set("key1", "1", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + set = client.Set("key2", "0", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + bitOpAnd := client.BitOpAnd("dest", "key1", "key2") + Expect(bitOpAnd.Err()).NotTo(HaveOccurred()) + Expect(bitOpAnd.Val()).To(Equal(int64(1))) + + get := client.Get("dest") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("0")) + }) + + It("should BitOpOr", func() { + set := client.Set("key1", "1", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + set = client.Set("key2", "0", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + bitOpOr := client.BitOpOr("dest", "key1", "key2") + Expect(bitOpOr.Err()).NotTo(HaveOccurred()) + Expect(bitOpOr.Val()).To(Equal(int64(1))) + + get := client.Get("dest") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("1")) + }) + + It("should BitOpXor", func() { + set := client.Set("key1", "\xff", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + set = client.Set("key2", "\x0f", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + bitOpXor := client.BitOpXor("dest", "key1", "key2") + Expect(bitOpXor.Err()).NotTo(HaveOccurred()) + Expect(bitOpXor.Val()).To(Equal(int64(1))) + + get := client.Get("dest") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("\xf0")) + }) + + It("should BitOpNot", func() { + set := client.Set("key1", "\x00", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + bitOpNot := client.BitOpNot("dest", "key1") + Expect(bitOpNot.Err()).NotTo(HaveOccurred()) + Expect(bitOpNot.Val()).To(Equal(int64(1))) + + get := client.Get("dest") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("\xff")) + }) + + It("should BitPos", func() { + err := client.Set("mykey", "\xff\xf0\x00", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + pos, err := client.BitPos("mykey", 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(12))) + + pos, err = client.BitPos("mykey", 1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(0))) + + pos, err = client.BitPos("mykey", 0, 2).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(16))) + + pos, err = client.BitPos("mykey", 1, 2).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(-1))) + + pos, err = client.BitPos("mykey", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(16))) + + pos, err = client.BitPos("mykey", 1, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(-1))) + + pos, err = client.BitPos("mykey", 0, 2, 1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(-1))) + + pos, err = client.BitPos("mykey", 0, 0, -3).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(-1))) + + pos, err = client.BitPos("mykey", 0, 0, 0).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(pos).To(Equal(int64(-1))) + }) + + It("should Decr", func() { + set := client.Set("key", "10", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + decr := client.Decr("key") + Expect(decr.Err()).NotTo(HaveOccurred()) + Expect(decr.Val()).To(Equal(int64(9))) + + set = client.Set("key", "234293482390480948029348230948", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + decr = client.Decr("key") + Expect(decr.Err()).To(MatchError("ERR value is not an integer or out of range")) + Expect(decr.Val()).To(Equal(int64(0))) + }) + + It("should DecrBy", func() { + set := client.Set("key", "10", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + decrBy := client.DecrBy("key", 5) + Expect(decrBy.Err()).NotTo(HaveOccurred()) + Expect(decrBy.Val()).To(Equal(int64(5))) + }) + + It("should Get", func() { + get := client.Get("_") + Expect(get.Err()).To(Equal(redis.Nil)) + Expect(get.Val()).To(Equal("")) + + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + get = client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + }) + + It("should GetBit", func() { + setBit := client.SetBit("key", 7, 1) + Expect(setBit.Err()).NotTo(HaveOccurred()) + Expect(setBit.Val()).To(Equal(int64(0))) + + getBit := client.GetBit("key", 0) + Expect(getBit.Err()).NotTo(HaveOccurred()) + Expect(getBit.Val()).To(Equal(int64(0))) + + getBit = client.GetBit("key", 7) + Expect(getBit.Err()).NotTo(HaveOccurred()) + Expect(getBit.Val()).To(Equal(int64(1))) + + getBit = client.GetBit("key", 100) + Expect(getBit.Err()).NotTo(HaveOccurred()) + Expect(getBit.Val()).To(Equal(int64(0))) + }) + + It("should GetRange", func() { + set := client.Set("key", "This is a string", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + getRange := client.GetRange("key", 0, 3) + Expect(getRange.Err()).NotTo(HaveOccurred()) + Expect(getRange.Val()).To(Equal("This")) + + getRange = client.GetRange("key", -3, -1) + Expect(getRange.Err()).NotTo(HaveOccurred()) + Expect(getRange.Val()).To(Equal("ing")) + + getRange = client.GetRange("key", 0, -1) + Expect(getRange.Err()).NotTo(HaveOccurred()) + Expect(getRange.Val()).To(Equal("This is a string")) + + getRange = client.GetRange("key", 10, 100) + Expect(getRange.Err()).NotTo(HaveOccurred()) + Expect(getRange.Val()).To(Equal("string")) + }) + + It("should GetSet", func() { + incr := client.Incr("key") + Expect(incr.Err()).NotTo(HaveOccurred()) + Expect(incr.Val()).To(Equal(int64(1))) + + getSet := client.GetSet("key", "0") + Expect(getSet.Err()).NotTo(HaveOccurred()) + Expect(getSet.Val()).To(Equal("1")) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("0")) + }) + + It("should Incr", func() { + set := client.Set("key", "10", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + incr := client.Incr("key") + Expect(incr.Err()).NotTo(HaveOccurred()) + Expect(incr.Val()).To(Equal(int64(11))) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("11")) + }) + + It("should IncrBy", func() { + set := client.Set("key", "10", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + incrBy := client.IncrBy("key", 5) + Expect(incrBy.Err()).NotTo(HaveOccurred()) + Expect(incrBy.Val()).To(Equal(int64(15))) + }) + + It("should IncrByFloat", func() { + set := client.Set("key", "10.50", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + incrByFloat := client.IncrByFloat("key", 0.1) + Expect(incrByFloat.Err()).NotTo(HaveOccurred()) + Expect(incrByFloat.Val()).To(Equal(10.6)) + + set = client.Set("key", "5.0e3", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + incrByFloat = client.IncrByFloat("key", 2.0e2) + Expect(incrByFloat.Err()).NotTo(HaveOccurred()) + Expect(incrByFloat.Val()).To(Equal(float64(5200))) + }) + + It("should IncrByFloatOverflow", func() { + incrByFloat := client.IncrByFloat("key", 996945661) + Expect(incrByFloat.Err()).NotTo(HaveOccurred()) + Expect(incrByFloat.Val()).To(Equal(float64(996945661))) + }) + + It("should MSetMGet", func() { + mSet := client.MSet("key1", "hello1", "key2", "hello2") + Expect(mSet.Err()).NotTo(HaveOccurred()) + Expect(mSet.Val()).To(Equal("OK")) + + mGet := client.MGet("key1", "key2", "_") + Expect(mGet.Err()).NotTo(HaveOccurred()) + Expect(mGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil})) + }) + + It("should MSetNX", func() { + mSetNX := client.MSetNX("key1", "hello1", "key2", "hello2") + Expect(mSetNX.Err()).NotTo(HaveOccurred()) + Expect(mSetNX.Val()).To(Equal(true)) + + mSetNX = client.MSetNX("key2", "hello1", "key3", "hello2") + Expect(mSetNX.Err()).NotTo(HaveOccurred()) + Expect(mSetNX.Val()).To(Equal(false)) + }) + + It("should Set with expiration", func() { + err := client.Set("key", "hello", 100*time.Millisecond).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("hello")) + + Eventually(func() error { + return client.Get("foo").Err() + }, "1s", "100ms").Should(Equal(redis.Nil)) + }) + + It("should SetGet", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + }) + + It("should SetNX", func() { + setNX := client.SetNX("key", "hello", 0) + Expect(setNX.Err()).NotTo(HaveOccurred()) + Expect(setNX.Val()).To(Equal(true)) + + setNX = client.SetNX("key", "hello2", 0) + Expect(setNX.Err()).NotTo(HaveOccurred()) + Expect(setNX.Val()).To(Equal(false)) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + }) + + It("should SetNX with expiration", func() { + isSet, err := client.SetNX("key", "hello", time.Second).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(isSet).To(Equal(true)) + + isSet, err = client.SetNX("key", "hello2", time.Second).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(isSet).To(Equal(false)) + + val, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("hello")) + }) + + It("should SetXX", func() { + isSet, err := client.SetXX("key", "hello2", time.Second).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(isSet).To(Equal(false)) + + err = client.Set("key", "hello", time.Second).Err() + Expect(err).NotTo(HaveOccurred()) + + isSet, err = client.SetXX("key", "hello2", time.Second).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(isSet).To(Equal(true)) + + val, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("hello2")) + }) + + It("should SetRange", func() { + set := client.Set("key", "Hello World", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + range_ := client.SetRange("key", 6, "Redis") + Expect(range_.Err()).NotTo(HaveOccurred()) + Expect(range_.Val()).To(Equal(int64(11))) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("Hello Redis")) + }) + + It("should StrLen", func() { + set := client.Set("key", "hello", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + strLen := client.StrLen("key") + Expect(strLen.Err()).NotTo(HaveOccurred()) + Expect(strLen.Val()).To(Equal(int64(5))) + + strLen = client.StrLen("_") + Expect(strLen.Err()).NotTo(HaveOccurred()) + Expect(strLen.Val()).To(Equal(int64(0))) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("hashes", func() { + + It("should HDel", func() { + hSet := client.HSet("hash", "key", "hello") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hDel := client.HDel("hash", "key") + Expect(hDel.Err()).NotTo(HaveOccurred()) + Expect(hDel.Val()).To(Equal(int64(1))) + + hDel = client.HDel("hash", "key") + Expect(hDel.Err()).NotTo(HaveOccurred()) + Expect(hDel.Val()).To(Equal(int64(0))) + }) + + It("should HExists", func() { + hSet := client.HSet("hash", "key", "hello") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hExists := client.HExists("hash", "key") + Expect(hExists.Err()).NotTo(HaveOccurred()) + Expect(hExists.Val()).To(Equal(true)) + + hExists = client.HExists("hash", "key1") + Expect(hExists.Err()).NotTo(HaveOccurred()) + Expect(hExists.Val()).To(Equal(false)) + }) + + It("should HGet", func() { + hSet := client.HSet("hash", "key", "hello") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hGet := client.HGet("hash", "key") + Expect(hGet.Err()).NotTo(HaveOccurred()) + Expect(hGet.Val()).To(Equal("hello")) + + hGet = client.HGet("hash", "key1") + Expect(hGet.Err()).To(Equal(redis.Nil)) + Expect(hGet.Val()).To(Equal("")) + }) + + It("should HGetAll", func() { + hSet := client.HSet("hash", "key1", "hello1") + Expect(hSet.Err()).NotTo(HaveOccurred()) + hSet = client.HSet("hash", "key2", "hello2") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hGetAll := client.HGetAll("hash") + Expect(hGetAll.Err()).NotTo(HaveOccurred()) + Expect(hGetAll.Val()).To(Equal([]string{"key1", "hello1", "key2", "hello2"})) + }) + + It("should HGetAllMap", func() { + hSet := client.HSet("hash", "key1", "hello1") + Expect(hSet.Err()).NotTo(HaveOccurred()) + hSet = client.HSet("hash", "key2", "hello2") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hGetAll := client.HGetAllMap("hash") + Expect(hGetAll.Err()).NotTo(HaveOccurred()) + Expect(hGetAll.Val()).To(Equal(map[string]string{"key1": "hello1", "key2": "hello2"})) + }) + + It("should HIncrBy", func() { + hSet := client.HSet("hash", "key", "5") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hIncrBy := client.HIncrBy("hash", "key", 1) + Expect(hIncrBy.Err()).NotTo(HaveOccurred()) + Expect(hIncrBy.Val()).To(Equal(int64(6))) + + hIncrBy = client.HIncrBy("hash", "key", -1) + Expect(hIncrBy.Err()).NotTo(HaveOccurred()) + Expect(hIncrBy.Val()).To(Equal(int64(5))) + + hIncrBy = client.HIncrBy("hash", "key", -10) + Expect(hIncrBy.Err()).NotTo(HaveOccurred()) + Expect(hIncrBy.Val()).To(Equal(int64(-5))) + }) + + It("should HIncrByFloat", func() { + hSet := client.HSet("hash", "field", "10.50") + Expect(hSet.Err()).NotTo(HaveOccurred()) + Expect(hSet.Val()).To(Equal(true)) + + hIncrByFloat := client.HIncrByFloat("hash", "field", 0.1) + Expect(hIncrByFloat.Err()).NotTo(HaveOccurred()) + Expect(hIncrByFloat.Val()).To(Equal(10.6)) + + hSet = client.HSet("hash", "field", "5.0e3") + Expect(hSet.Err()).NotTo(HaveOccurred()) + Expect(hSet.Val()).To(Equal(false)) + + hIncrByFloat = client.HIncrByFloat("hash", "field", 2.0e2) + Expect(hIncrByFloat.Err()).NotTo(HaveOccurred()) + Expect(hIncrByFloat.Val()).To(Equal(float64(5200))) + }) + + It("should HKeys", func() { + hkeys := client.HKeys("hash") + Expect(hkeys.Err()).NotTo(HaveOccurred()) + Expect(hkeys.Val()).To(Equal([]string{})) + + hset := client.HSet("hash", "key1", "hello1") + Expect(hset.Err()).NotTo(HaveOccurred()) + hset = client.HSet("hash", "key2", "hello2") + Expect(hset.Err()).NotTo(HaveOccurred()) + + hkeys = client.HKeys("hash") + Expect(hkeys.Err()).NotTo(HaveOccurred()) + Expect(hkeys.Val()).To(Equal([]string{"key1", "key2"})) + }) + + It("should HLen", func() { + hSet := client.HSet("hash", "key1", "hello1") + Expect(hSet.Err()).NotTo(HaveOccurred()) + hSet = client.HSet("hash", "key2", "hello2") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hLen := client.HLen("hash") + Expect(hLen.Err()).NotTo(HaveOccurred()) + Expect(hLen.Val()).To(Equal(int64(2))) + }) + + It("should HMGet", func() { + hSet := client.HSet("hash", "key1", "hello1") + Expect(hSet.Err()).NotTo(HaveOccurred()) + hSet = client.HSet("hash", "key2", "hello2") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hMGet := client.HMGet("hash", "key1", "key2", "_") + Expect(hMGet.Err()).NotTo(HaveOccurred()) + Expect(hMGet.Val()).To(Equal([]interface{}{"hello1", "hello2", nil})) + }) + + It("should HMSet", func() { + hMSet := client.HMSet("hash", "key1", "hello1", "key2", "hello2") + Expect(hMSet.Err()).NotTo(HaveOccurred()) + Expect(hMSet.Val()).To(Equal("OK")) + + hGet := client.HGet("hash", "key1") + Expect(hGet.Err()).NotTo(HaveOccurred()) + Expect(hGet.Val()).To(Equal("hello1")) + + hGet = client.HGet("hash", "key2") + Expect(hGet.Err()).NotTo(HaveOccurred()) + Expect(hGet.Val()).To(Equal("hello2")) + }) + + It("should HSet", func() { + hSet := client.HSet("hash", "key", "hello") + Expect(hSet.Err()).NotTo(HaveOccurred()) + Expect(hSet.Val()).To(Equal(true)) + + hGet := client.HGet("hash", "key") + Expect(hGet.Err()).NotTo(HaveOccurred()) + Expect(hGet.Val()).To(Equal("hello")) + }) + + It("should HSetNX", func() { + hSetNX := client.HSetNX("hash", "key", "hello") + Expect(hSetNX.Err()).NotTo(HaveOccurred()) + Expect(hSetNX.Val()).To(Equal(true)) + + hSetNX = client.HSetNX("hash", "key", "hello") + Expect(hSetNX.Err()).NotTo(HaveOccurred()) + Expect(hSetNX.Val()).To(Equal(false)) + + hGet := client.HGet("hash", "key") + Expect(hGet.Err()).NotTo(HaveOccurred()) + Expect(hGet.Val()).To(Equal("hello")) + }) + + It("should HVals", func() { + hSet := client.HSet("hash", "key1", "hello1") + Expect(hSet.Err()).NotTo(HaveOccurred()) + hSet = client.HSet("hash", "key2", "hello2") + Expect(hSet.Err()).NotTo(HaveOccurred()) + + hVals := client.HVals("hash") + Expect(hVals.Err()).NotTo(HaveOccurred()) + Expect(hVals.Val()).To(Equal([]string{"hello1", "hello2"})) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("lists", func() { + + It("should BLPop", func() { + rPush := client.RPush("list1", "a", "b", "c") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + bLPop := client.BLPop(0, "list1", "list2") + Expect(bLPop.Err()).NotTo(HaveOccurred()) + Expect(bLPop.Val()).To(Equal([]string{"list1", "a"})) + }) + + It("should BLPopBlocks", func() { + started := make(chan bool) + done := make(chan bool) + go func() { + defer GinkgoRecover() + + started <- true + bLPop := client.BLPop(0, "list") + Expect(bLPop.Err()).NotTo(HaveOccurred()) + Expect(bLPop.Val()).To(Equal([]string{"list", "a"})) + done <- true + }() + <-started + + select { + case <-done: + Fail("BLPop is not blocked") + case <-time.After(time.Second): + // ok + } + + rPush := client.RPush("list", "a") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + select { + case <-done: + // ok + case <-time.After(time.Second): + Fail("BLPop is still blocked") + } + }) + + It("should BLPop timeout", func() { + bLPop := client.BLPop(time.Second, "list1") + Expect(bLPop.Val()).To(BeNil()) + Expect(bLPop.Err()).To(Equal(redis.Nil)) + }) + + It("should BRPop", func() { + rPush := client.RPush("list1", "a", "b", "c") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + bRPop := client.BRPop(0, "list1", "list2") + Expect(bRPop.Err()).NotTo(HaveOccurred()) + Expect(bRPop.Val()).To(Equal([]string{"list1", "c"})) + }) + + It("should BRPop blocks", func() { + started := make(chan bool) + done := make(chan bool) + go func() { + defer GinkgoRecover() + + started <- true + brpop := client.BRPop(0, "list") + Expect(brpop.Err()).NotTo(HaveOccurred()) + Expect(brpop.Val()).To(Equal([]string{"list", "a"})) + done <- true + }() + <-started + + select { + case <-done: + Fail("BRPop is not blocked") + case <-time.After(time.Second): + // ok + } + + rPush := client.RPush("list", "a") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + select { + case <-done: + // ok + case <-time.After(time.Second): + Fail("BRPop is still blocked") + // ok + } + }) + + It("should BRPopLPush", func() { + rPush := client.RPush("list1", "a", "b", "c") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + bRPopLPush := client.BRPopLPush("list1", "list2", 0) + Expect(bRPopLPush.Err()).NotTo(HaveOccurred()) + Expect(bRPopLPush.Val()).To(Equal("c")) + }) + + It("should LIndex", func() { + lPush := client.LPush("list", "World") + Expect(lPush.Err()).NotTo(HaveOccurred()) + lPush = client.LPush("list", "Hello") + Expect(lPush.Err()).NotTo(HaveOccurred()) + + lIndex := client.LIndex("list", 0) + Expect(lIndex.Err()).NotTo(HaveOccurred()) + Expect(lIndex.Val()).To(Equal("Hello")) + + lIndex = client.LIndex("list", -1) + Expect(lIndex.Err()).NotTo(HaveOccurred()) + Expect(lIndex.Val()).To(Equal("World")) + + lIndex = client.LIndex("list", 3) + Expect(lIndex.Err()).To(Equal(redis.Nil)) + Expect(lIndex.Val()).To(Equal("")) + }) + + It("should LInsert", func() { + rPush := client.RPush("list", "Hello") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "World") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + lInsert := client.LInsert("list", "BEFORE", "World", "There") + Expect(lInsert.Err()).NotTo(HaveOccurred()) + Expect(lInsert.Val()).To(Equal(int64(3))) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"Hello", "There", "World"})) + }) + + It("should LLen", func() { + lPush := client.LPush("list", "World") + Expect(lPush.Err()).NotTo(HaveOccurred()) + lPush = client.LPush("list", "Hello") + Expect(lPush.Err()).NotTo(HaveOccurred()) + + lLen := client.LLen("list") + Expect(lLen.Err()).NotTo(HaveOccurred()) + Expect(lLen.Val()).To(Equal(int64(2))) + }) + + It("should LPop", func() { + rPush := client.RPush("list", "one") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "two") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "three") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + lPop := client.LPop("list") + Expect(lPop.Err()).NotTo(HaveOccurred()) + Expect(lPop.Val()).To(Equal("one")) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"two", "three"})) + }) + + It("should LPush", func() { + lPush := client.LPush("list", "World") + Expect(lPush.Err()).NotTo(HaveOccurred()) + lPush = client.LPush("list", "Hello") + Expect(lPush.Err()).NotTo(HaveOccurred()) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"Hello", "World"})) + }) + + It("should LPushX", func() { + lPush := client.LPush("list", "World") + Expect(lPush.Err()).NotTo(HaveOccurred()) + + lPushX := client.LPushX("list", "Hello") + Expect(lPushX.Err()).NotTo(HaveOccurred()) + Expect(lPushX.Val()).To(Equal(int64(2))) + + lPushX = client.LPushX("list2", "Hello") + Expect(lPushX.Err()).NotTo(HaveOccurred()) + Expect(lPushX.Val()).To(Equal(int64(0))) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"Hello", "World"})) + + lRange = client.LRange("list2", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{})) + }) + + It("should LRange", func() { + rPush := client.RPush("list", "one") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "two") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "three") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + lRange := client.LRange("list", 0, 0) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"one"})) + + lRange = client.LRange("list", -3, 2) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"one", "two", "three"})) + + lRange = client.LRange("list", -100, 100) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"one", "two", "three"})) + + lRange = client.LRange("list", 5, 10) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{})) + }) + + It("should LRem", func() { + rPush := client.RPush("list", "hello") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "hello") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "key") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "hello") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + lRem := client.LRem("list", -2, "hello") + Expect(lRem.Err()).NotTo(HaveOccurred()) + Expect(lRem.Val()).To(Equal(int64(2))) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"hello", "key"})) + }) + + It("should LSet", func() { + rPush := client.RPush("list", "one") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "two") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "three") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + lSet := client.LSet("list", 0, "four") + Expect(lSet.Err()).NotTo(HaveOccurred()) + Expect(lSet.Val()).To(Equal("OK")) + + lSet = client.LSet("list", -2, "five") + Expect(lSet.Err()).NotTo(HaveOccurred()) + Expect(lSet.Val()).To(Equal("OK")) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"four", "five", "three"})) + }) + + It("should LTrim", func() { + rPush := client.RPush("list", "one") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "two") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "three") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + lTrim := client.LTrim("list", 1, -1) + Expect(lTrim.Err()).NotTo(HaveOccurred()) + Expect(lTrim.Val()).To(Equal("OK")) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"two", "three"})) + }) + + It("should RPop", func() { + rPush := client.RPush("list", "one") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "two") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "three") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + rPop := client.RPop("list") + Expect(rPop.Err()).NotTo(HaveOccurred()) + Expect(rPop.Val()).To(Equal("three")) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"one", "two"})) + }) + + It("should RPopLPush", func() { + rPush := client.RPush("list", "one") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "two") + Expect(rPush.Err()).NotTo(HaveOccurred()) + rPush = client.RPush("list", "three") + Expect(rPush.Err()).NotTo(HaveOccurred()) + + rPopLPush := client.RPopLPush("list", "list2") + Expect(rPopLPush.Err()).NotTo(HaveOccurred()) + Expect(rPopLPush.Val()).To(Equal("three")) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"one", "two"})) + + lRange = client.LRange("list2", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"three"})) + }) + + It("should RPush", func() { + rPush := client.RPush("list", "Hello") + Expect(rPush.Err()).NotTo(HaveOccurred()) + Expect(rPush.Val()).To(Equal(int64(1))) + + rPush = client.RPush("list", "World") + Expect(rPush.Err()).NotTo(HaveOccurred()) + Expect(rPush.Val()).To(Equal(int64(2))) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"Hello", "World"})) + }) + + It("should RPushX", func() { + rPush := client.RPush("list", "Hello") + Expect(rPush.Err()).NotTo(HaveOccurred()) + Expect(rPush.Val()).To(Equal(int64(1))) + + rPushX := client.RPushX("list", "World") + Expect(rPushX.Err()).NotTo(HaveOccurred()) + Expect(rPushX.Val()).To(Equal(int64(2))) + + rPushX = client.RPushX("list2", "World") + Expect(rPushX.Err()).NotTo(HaveOccurred()) + Expect(rPushX.Val()).To(Equal(int64(0))) + + lRange := client.LRange("list", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{"Hello", "World"})) + + lRange = client.LRange("list2", 0, -1) + Expect(lRange.Err()).NotTo(HaveOccurred()) + Expect(lRange.Val()).To(Equal([]string{})) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("sets", func() { + + It("should SAdd", func() { + sAdd := client.SAdd("set", "Hello") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + Expect(sAdd.Val()).To(Equal(int64(1))) + + sAdd = client.SAdd("set", "World") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + Expect(sAdd.Val()).To(Equal(int64(1))) + + sAdd = client.SAdd("set", "World") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + Expect(sAdd.Val()).To(Equal(int64(0))) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(ConsistOf([]string{"Hello", "World"})) + }) + + It("should SCard", func() { + sAdd := client.SAdd("set", "Hello") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + Expect(sAdd.Val()).To(Equal(int64(1))) + + sAdd = client.SAdd("set", "World") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + Expect(sAdd.Val()).To(Equal(int64(1))) + + sCard := client.SCard("set") + Expect(sCard.Err()).NotTo(HaveOccurred()) + Expect(sCard.Val()).To(Equal(int64(2))) + }) + + It("should SDiff", func() { + sAdd := client.SAdd("set1", "a") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "b") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "d") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "e") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sDiff := client.SDiff("set1", "set2") + Expect(sDiff.Err()).NotTo(HaveOccurred()) + Expect(sDiff.Val()).To(ConsistOf([]string{"a", "b"})) + }) + + It("should SDiffStore", func() { + sAdd := client.SAdd("set1", "a") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "b") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "d") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "e") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sDiffStore := client.SDiffStore("set", "set1", "set2") + Expect(sDiffStore.Err()).NotTo(HaveOccurred()) + Expect(sDiffStore.Val()).To(Equal(int64(2))) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(ConsistOf([]string{"a", "b"})) + }) + + It("should SInter", func() { + sAdd := client.SAdd("set1", "a") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "b") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "d") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "e") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sInter := client.SInter("set1", "set2") + Expect(sInter.Err()).NotTo(HaveOccurred()) + Expect(sInter.Val()).To(Equal([]string{"c"})) + }) + + It("should SInterStore", func() { + sAdd := client.SAdd("set1", "a") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "b") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "d") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "e") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sInterStore := client.SInterStore("set", "set1", "set2") + Expect(sInterStore.Err()).NotTo(HaveOccurred()) + Expect(sInterStore.Val()).To(Equal(int64(1))) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(Equal([]string{"c"})) + }) + + It("should IsMember", func() { + sAdd := client.SAdd("set", "one") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sIsMember := client.SIsMember("set", "one") + Expect(sIsMember.Err()).NotTo(HaveOccurred()) + Expect(sIsMember.Val()).To(Equal(true)) + + sIsMember = client.SIsMember("set", "two") + Expect(sIsMember.Err()).NotTo(HaveOccurred()) + Expect(sIsMember.Val()).To(Equal(false)) + }) + + It("should SMembers", func() { + sAdd := client.SAdd("set", "Hello") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "World") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(ConsistOf([]string{"Hello", "World"})) + }) + + It("should SMove", func() { + sAdd := client.SAdd("set1", "one") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "two") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "three") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sMove := client.SMove("set1", "set2", "two") + Expect(sMove.Err()).NotTo(HaveOccurred()) + Expect(sMove.Val()).To(Equal(true)) + + sMembers := client.SMembers("set1") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(Equal([]string{"one"})) + + sMembers = client.SMembers("set2") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(ConsistOf([]string{"three", "two"})) + }) + + It("should SPop", func() { + sAdd := client.SAdd("set", "one") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "two") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "three") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sPop := client.SPop("set") + Expect(sPop.Err()).NotTo(HaveOccurred()) + Expect(sPop.Val()).NotTo(Equal("")) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(HaveLen(2)) + }) + + It("should SRandMember", func() { + sAdd := client.SAdd("set", "one") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "two") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "three") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sRandMember := client.SRandMember("set") + Expect(sRandMember.Err()).NotTo(HaveOccurred()) + Expect(sRandMember.Val()).NotTo(Equal("")) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(HaveLen(3)) + }) + + It("should SRem", func() { + sAdd := client.SAdd("set", "one") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "two") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set", "three") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sRem := client.SRem("set", "one") + Expect(sRem.Err()).NotTo(HaveOccurred()) + Expect(sRem.Val()).To(Equal(int64(1))) + + sRem = client.SRem("set", "four") + Expect(sRem.Err()).NotTo(HaveOccurred()) + Expect(sRem.Val()).To(Equal(int64(0))) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(ConsistOf([]string{"three", "two"})) + }) + + It("should SUnion", func() { + sAdd := client.SAdd("set1", "a") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "b") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "d") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "e") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sUnion := client.SUnion("set1", "set2") + Expect(sUnion.Err()).NotTo(HaveOccurred()) + Expect(sUnion.Val()).To(HaveLen(5)) + }) + + It("should SUnionStore", func() { + sAdd := client.SAdd("set1", "a") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "b") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set1", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sAdd = client.SAdd("set2", "c") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "d") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + sAdd = client.SAdd("set2", "e") + Expect(sAdd.Err()).NotTo(HaveOccurred()) + + sUnionStore := client.SUnionStore("set", "set1", "set2") + Expect(sUnionStore.Err()).NotTo(HaveOccurred()) + Expect(sUnionStore.Val()).To(Equal(int64(5))) + + sMembers := client.SMembers("set") + Expect(sMembers.Err()).NotTo(HaveOccurred()) + Expect(sMembers.Val()).To(HaveLen(5)) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("sorted sets", func() { + + It("should ZAdd", func() { + added, err := client.ZAdd("zset", redis.Z{1, "one"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(1))) + + added, err = client.ZAdd("zset", redis.Z{1, "uno"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(1))) + + added, err = client.ZAdd("zset", redis.Z{2, "two"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(1))) + + added, err = client.ZAdd("zset", redis.Z{3, "two"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(0))) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}, {1, "uno"}, {3, "two"}})) + }) + + It("should ZAdd bytes", func() { + added, err := client.ZAdd("zset", redis.Z{1, []byte("one")}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(1))) + + added, err = client.ZAdd("zset", redis.Z{1, []byte("uno")}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(1))) + + added, err = client.ZAdd("zset", redis.Z{2, []byte("two")}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(1))) + + added, err = client.ZAdd("zset", redis.Z{3, []byte("two")}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(int64(0))) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}, {1, "uno"}, {3, "two"}})) + }) + + It("should ZCard", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zCard := client.ZCard("zset") + Expect(zCard.Err()).NotTo(HaveOccurred()) + Expect(zCard.Val()).To(Equal(int64(2))) + }) + + It("should ZCount", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zCount := client.ZCount("zset", "-inf", "+inf") + Expect(zCount.Err()).NotTo(HaveOccurred()) + Expect(zCount.Val()).To(Equal(int64(3))) + + zCount = client.ZCount("zset", "(1", "3") + Expect(zCount.Err()).NotTo(HaveOccurred()) + Expect(zCount.Val()).To(Equal(int64(2))) + }) + + It("should ZIncrBy", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zIncrBy := client.ZIncrBy("zset", 2, "one") + Expect(zIncrBy.Err()).NotTo(HaveOccurred()) + Expect(zIncrBy.Val()).To(Equal(float64(3))) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{2, "two"}, {3, "one"}})) + }) + + It("should ZInterStore", func() { + zAdd := client.ZAdd("zset1", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset1", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zAdd = client.ZAdd("zset2", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset2", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset3", redis.Z{3, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zInterStore := client.ZInterStore( + "out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2") + Expect(zInterStore.Err()).NotTo(HaveOccurred()) + Expect(zInterStore.Val()).To(Equal(int64(2))) + + val, err := client.ZRangeWithScores("out", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{5, "one"}, {10, "two"}})) + }) + + It("should ZRange", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRange := client.ZRange("zset", 0, -1) + Expect(zRange.Err()).NotTo(HaveOccurred()) + Expect(zRange.Val()).To(Equal([]string{"one", "two", "three"})) + + zRange = client.ZRange("zset", 2, 3) + Expect(zRange.Err()).NotTo(HaveOccurred()) + Expect(zRange.Val()).To(Equal([]string{"three"})) + + zRange = client.ZRange("zset", -2, -1) + Expect(zRange.Err()).NotTo(HaveOccurred()) + Expect(zRange.Val()).To(Equal([]string{"two", "three"})) + }) + + It("should ZRangeWithScores", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}, {2, "two"}, {3, "three"}})) + + val, err = client.ZRangeWithScores("zset", 2, 3).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{3, "three"}})) + + val, err = client.ZRangeWithScores("zset", -2, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{2, "two"}, {3, "three"}})) + }) + + It("should ZRangeByScore", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRangeByScore := client.ZRangeByScore("zset", redis.ZRangeByScore{ + Min: "-inf", + Max: "+inf", + }) + Expect(zRangeByScore.Err()).NotTo(HaveOccurred()) + Expect(zRangeByScore.Val()).To(Equal([]string{"one", "two", "three"})) + + zRangeByScore = client.ZRangeByScore("zset", redis.ZRangeByScore{ + Min: "1", + Max: "2", + }) + Expect(zRangeByScore.Err()).NotTo(HaveOccurred()) + Expect(zRangeByScore.Val()).To(Equal([]string{"one", "two"})) + + zRangeByScore = client.ZRangeByScore("zset", redis.ZRangeByScore{ + Min: "(1", + Max: "2", + }) + Expect(zRangeByScore.Err()).NotTo(HaveOccurred()) + Expect(zRangeByScore.Val()).To(Equal([]string{"two"})) + + zRangeByScore = client.ZRangeByScore("zset", redis.ZRangeByScore{ + Min: "(1", + Max: "(2", + }) + Expect(zRangeByScore.Err()).NotTo(HaveOccurred()) + Expect(zRangeByScore.Val()).To(Equal([]string{})) + }) + + It("should ZRangeByScoreWithScoresMap", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + val, err := client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ + Min: "-inf", + Max: "+inf", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}, {2, "two"}, {3, "three"}})) + + val, err = client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ + Min: "1", + Max: "2", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}, {2, "two"}})) + + val, err = client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ + Min: "(1", + Max: "2", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{2, "two"}})) + + val, err = client.ZRangeByScoreWithScores("zset", redis.ZRangeByScore{ + Min: "(1", + Max: "(2", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{})) + }) + + It("should ZRank", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRank := client.ZRank("zset", "three") + Expect(zRank.Err()).NotTo(HaveOccurred()) + Expect(zRank.Val()).To(Equal(int64(2))) + + zRank = client.ZRank("zset", "four") + Expect(zRank.Err()).To(Equal(redis.Nil)) + Expect(zRank.Val()).To(Equal(int64(0))) + }) + + It("should ZRem", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRem := client.ZRem("zset", "two") + Expect(zRem.Err()).NotTo(HaveOccurred()) + Expect(zRem.Val()).To(Equal(int64(1))) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}, {3, "three"}})) + }) + + It("should ZRemRangeByRank", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRemRangeByRank := client.ZRemRangeByRank("zset", 0, 1) + Expect(zRemRangeByRank.Err()).NotTo(HaveOccurred()) + Expect(zRemRangeByRank.Val()).To(Equal(int64(2))) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{3, "three"}})) + }) + + It("should ZRemRangeByScore", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRemRangeByScore := client.ZRemRangeByScore("zset", "-inf", "(2") + Expect(zRemRangeByScore.Err()).NotTo(HaveOccurred()) + Expect(zRemRangeByScore.Val()).To(Equal(int64(1))) + + val, err := client.ZRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{2, "two"}, {3, "three"}})) + }) + + It("should ZRevRange", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRevRange := client.ZRevRange("zset", 0, -1) + Expect(zRevRange.Err()).NotTo(HaveOccurred()) + Expect(zRevRange.Val()).To(Equal([]string{"three", "two", "one"})) + + zRevRange = client.ZRevRange("zset", 2, 3) + Expect(zRevRange.Err()).NotTo(HaveOccurred()) + Expect(zRevRange.Val()).To(Equal([]string{"one"})) + + zRevRange = client.ZRevRange("zset", -2, -1) + Expect(zRevRange.Err()).NotTo(HaveOccurred()) + Expect(zRevRange.Val()).To(Equal([]string{"two", "one"})) + }) + + It("should ZRevRangeWithScoresMap", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + val, err := client.ZRevRangeWithScores("zset", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{3, "three"}, {2, "two"}, {1, "one"}})) + + val, err = client.ZRevRangeWithScores("zset", 2, 3).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{1, "one"}})) + + val, err = client.ZRevRangeWithScores("zset", -2, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{2, "two"}, {1, "one"}})) + }) + + It("should ZRevRangeByScore", func() { + zadd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zadd.Err()).NotTo(HaveOccurred()) + zadd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zadd.Err()).NotTo(HaveOccurred()) + zadd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zadd.Err()).NotTo(HaveOccurred()) + + vals, err := client.ZRevRangeByScore( + "zset", redis.ZRangeByScore{Max: "+inf", Min: "-inf"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(vals).To(Equal([]string{"three", "two", "one"})) + + vals, err = client.ZRevRangeByScore( + "zset", redis.ZRangeByScore{Max: "2", Min: "(1"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(vals).To(Equal([]string{"two"})) + + vals, err = client.ZRevRangeByScore( + "zset", redis.ZRangeByScore{Max: "(2", Min: "(1"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(vals).To(Equal([]string{})) + }) + + It("should ZRevRangeByScoreWithScores", func() { + zadd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zadd.Err()).NotTo(HaveOccurred()) + zadd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zadd.Err()).NotTo(HaveOccurred()) + zadd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zadd.Err()).NotTo(HaveOccurred()) + + vals, err := client.ZRevRangeByScoreWithScores( + "zset", redis.ZRangeByScore{Max: "+inf", Min: "-inf"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(vals).To(Equal([]redis.Z{{3, "three"}, {2, "two"}, {1, "one"}})) + }) + + It("should ZRevRangeByScoreWithScoresMap", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + val, err := client.ZRevRangeByScoreWithScores( + "zset", redis.ZRangeByScore{Max: "+inf", Min: "-inf"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{3, "three"}, {2, "two"}, {1, "one"}})) + + val, err = client.ZRevRangeByScoreWithScores( + "zset", redis.ZRangeByScore{Max: "2", Min: "(1"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{2, "two"}})) + + val, err = client.ZRevRangeByScoreWithScores( + "zset", redis.ZRangeByScore{Max: "(2", Min: "(1"}).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{})) + }) + + It("should ZRevRank", func() { + zAdd := client.ZAdd("zset", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zRevRank := client.ZRevRank("zset", "one") + Expect(zRevRank.Err()).NotTo(HaveOccurred()) + Expect(zRevRank.Val()).To(Equal(int64(2))) + + zRevRank = client.ZRevRank("zset", "four") + Expect(zRevRank.Err()).To(Equal(redis.Nil)) + Expect(zRevRank.Val()).To(Equal(int64(0))) + }) + + It("should ZScore", func() { + zAdd := client.ZAdd("zset", redis.Z{1.001, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zScore := client.ZScore("zset", "one") + Expect(zScore.Err()).NotTo(HaveOccurred()) + Expect(zScore.Val()).To(Equal(float64(1.001))) + }) + + It("should ZUnionStore", func() { + zAdd := client.ZAdd("zset1", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset1", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zAdd = client.ZAdd("zset2", redis.Z{1, "one"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset2", redis.Z{2, "two"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + zAdd = client.ZAdd("zset2", redis.Z{3, "three"}) + Expect(zAdd.Err()).NotTo(HaveOccurred()) + + zUnionStore := client.ZUnionStore( + "out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2") + Expect(zUnionStore.Err()).NotTo(HaveOccurred()) + Expect(zUnionStore.Val()).To(Equal(int64(3))) + + val, err := client.ZRangeWithScores("out", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal([]redis.Z{{5, "one"}, {9, "three"}, {10, "two"}})) + }) + + }) + + //------------------------------------------------------------------------------ + + Describe("watch/unwatch", func() { + + It("should WatchUnwatch", func() { + var C, N = 10, 1000 + if testing.Short() { + N = 100 + } + + err := client.Set("key", "0", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + wg := &sync.WaitGroup{} + for i := 0; i < C; i++ { + wg.Add(1) + + go func() { + defer GinkgoRecover() + defer wg.Done() + + multi := client.Multi() + defer multi.Close() + + for j := 0; j < N; j++ { + val, err := multi.Watch("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("OK")) + + val, err = multi.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).NotTo(Equal(redis.Nil)) + + num, err := strconv.ParseInt(val, 10, 64) + Expect(err).NotTo(HaveOccurred()) + + cmds, err := multi.Exec(func() error { + multi.Set("key", strconv.FormatInt(num+1, 10), 0) + return nil + }) + if err == redis.TxFailedErr { + j-- + continue + } + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].Err()).NotTo(HaveOccurred()) + } + }() + } + wg.Wait() + + val, err := client.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int64(C * N))) + }) + + }) + + Describe("marshaling/unmarshaling", func() { + + type convTest struct { + value interface{} + wanted string + dest interface{} + } + + convTests := []convTest{ + {nil, "", nil}, + {"hello", "hello", new(string)}, + {[]byte("hello"), "hello", new([]byte)}, + {int(1), "1", new(int)}, + {int8(1), "1", new(int8)}, + {int16(1), "1", new(int16)}, + {int32(1), "1", new(int32)}, + {int64(1), "1", new(int64)}, + {uint(1), "1", new(uint)}, + {uint8(1), "1", new(uint8)}, + {uint16(1), "1", new(uint16)}, + {uint32(1), "1", new(uint32)}, + {uint64(1), "1", new(uint64)}, + {float32(1.0), "1", new(float32)}, + {float64(1.0), "1", new(float64)}, + {true, "1", new(bool)}, + {false, "0", new(bool)}, + } + + It("should convert to string", func() { + for _, test := range convTests { + err := client.Set("key", test.value, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + s, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(Equal(test.wanted)) + + if test.dest == nil { + continue + } + + err = client.Get("key").Scan(test.dest) + Expect(err).NotTo(HaveOccurred()) + Expect(deref(test.dest)).To(Equal(test.value)) + } + }) + + }) + + Describe("json marshaling/unmarshaling", func() { + BeforeEach(func() { + value := &numberStruct{Number: 42} + err := client.Set("key", value, 0).Err() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should marshal custom values using json", func() { + s, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(Equal(`{"Number":42}`)) + }) + + It("should scan custom values using json", func() { + value := &numberStruct{} + err := client.Get("key").Scan(value) + Expect(err).To(BeNil()) + Expect(value.Number).To(Equal(42)) + }) + + }) + +}) + +type numberStruct struct { + Number int +} + +func (s *numberStruct) MarshalBinary() ([]byte, error) { + return json.Marshal(s) +} + +func (s *numberStruct) UnmarshalBinary(b []byte) error { + return json.Unmarshal(b, s) +} + +func deref(viface interface{}) interface{} { + v := reflect.ValueOf(viface) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Interface() +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/conn.go b/Godeps/_workspace/src/gopkg.in/redis.v3/conn.go new file mode 100644 index 0000000..9dc2ede --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/conn.go @@ -0,0 +1,104 @@ +package redis + +import ( + "net" + "time" + + "gopkg.in/bufio.v1" +) + +var ( + zeroTime = time.Time{} +) + +type conn struct { + netcn net.Conn + rd *bufio.Reader + buf []byte + + usedAt time.Time + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func newConnDialer(opt *Options) func() (*conn, error) { + dialer := opt.getDialer() + return func() (*conn, error) { + netcn, err := dialer() + if err != nil { + return nil, err + } + cn := &conn{ + netcn: netcn, + buf: make([]byte, 0, 64), + } + cn.rd = bufio.NewReader(cn) + return cn, cn.init(opt) + } +} + +func (cn *conn) init(opt *Options) error { + if opt.Password == "" && opt.DB == 0 { + return nil + } + + // Use connection to connect to Redis. + pool := newSingleConnPoolConn(cn) + + // Client is not closed because we want to reuse underlying connection. + client := newClient(opt, pool) + + if opt.Password != "" { + if err := client.Auth(opt.Password).Err(); err != nil { + return err + } + } + + if opt.DB > 0 { + if err := client.Select(opt.DB).Err(); err != nil { + return err + } + } + + return nil +} + +func (cn *conn) writeCmds(cmds ...Cmder) error { + buf := cn.buf[:0] + for _, cmd := range cmds { + var err error + buf, err = appendArgs(buf, cmd.args()) + if err != nil { + return err + } + } + + _, err := cn.Write(buf) + return err +} + +func (cn *conn) Read(b []byte) (int, error) { + if cn.ReadTimeout != 0 { + cn.netcn.SetReadDeadline(time.Now().Add(cn.ReadTimeout)) + } else { + cn.netcn.SetReadDeadline(zeroTime) + } + return cn.netcn.Read(b) +} + +func (cn *conn) Write(b []byte) (int, error) { + if cn.WriteTimeout != 0 { + cn.netcn.SetWriteDeadline(time.Now().Add(cn.WriteTimeout)) + } else { + cn.netcn.SetWriteDeadline(zeroTime) + } + return cn.netcn.Write(b) +} + +func (cn *conn) RemoteAddr() net.Addr { + return cn.netcn.RemoteAddr() +} + +func (cn *conn) Close() error { + return cn.netcn.Close() +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/crc16.go b/Godeps/_workspace/src/gopkg.in/redis.v3/crc16.go new file mode 100644 index 0000000..a7f3b56 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/crc16.go @@ -0,0 +1,47 @@ +package redis + +// CRC16 implementation according to CCITT standards. +// Copyright 2001-2010 Georges Menie (www.menie.org) +// Copyright 2013 The Go Authors. All rights reserved. +// http://redis.io/topics/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c +var crc16tab = [256]uint16{ + 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, + 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, + 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, + 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, + 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, + 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, + 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, + 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, + 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, + 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, + 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, + 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, + 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, + 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, + 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, + 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, + 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, + 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, + 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, + 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, + 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, + 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, + 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, + 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, + 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, + 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, + 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, + 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, + 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, + 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, + 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, + 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0, +} + +func crc16sum(key string) (crc uint16) { + for i := 0; i < len(key); i++ { + crc = (crc << 8) ^ crc16tab[(byte(crc>>8)^key[i])&0x00ff] + } + return +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/crc16_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/crc16_test.go new file mode 100644 index 0000000..a6b3416 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/crc16_test.go @@ -0,0 +1,25 @@ +package redis + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CRC16", func() { + + // http://redis.io/topics/cluster-spec#keys-distribution-model + It("should calculate CRC16", func() { + tests := []struct { + s string + n uint16 + }{ + {"123456789", 0x31C3}, + {string([]byte{83, 153, 134, 118, 229, 214, 244, 75, 140, 37, 215, 215}), 21847}, + } + + for _, test := range tests { + Expect(crc16sum(test.s)).To(Equal(test.n), "for %s", test.s) + } + }) + +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/doc.go b/Godeps/_workspace/src/gopkg.in/redis.v3/doc.go similarity index 100% rename from Godeps/_workspace/src/gopkg.in/redis.v2/doc.go rename to Godeps/_workspace/src/gopkg.in/redis.v3/doc.go diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/error.go b/Godeps/_workspace/src/gopkg.in/redis.v3/error.go new file mode 100644 index 0000000..9e5d973 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/error.go @@ -0,0 +1,63 @@ +package redis + +import ( + "fmt" + "io" + "net" + "strings" +) + +// Redis nil reply, .e.g. when key does not exist. +var Nil = errorf("redis: nil") + +// Redis transaction failed. +var TxFailedErr = errorf("redis: transaction failed") + +type redisError struct { + s string +} + +func errorf(s string, args ...interface{}) redisError { + return redisError{s: fmt.Sprintf(s, args...)} +} + +func (err redisError) Error() string { + return err.s +} + +func isNetworkError(err error) bool { + if _, ok := err.(net.Error); ok || err == io.EOF { + return true + } + return false +} + +func isMovedError(err error) (moved bool, ask bool, addr string) { + if _, ok := err.(redisError); !ok { + return + } + + parts := strings.SplitN(err.Error(), " ", 3) + if len(parts) != 3 { + return + } + + switch parts[0] { + case "MOVED": + moved = true + addr = parts[2] + case "ASK": + ask = true + addr = parts[2] + } + + return +} + +// shouldRetry reports whether failed command should be retried. +func shouldRetry(err error) bool { + if err == nil { + return false + } + return isNetworkError(err) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/example_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/example_test.go new file mode 100644 index 0000000..25869be --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/example_test.go @@ -0,0 +1,239 @@ +package redis_test + +import ( + "fmt" + "strconv" + "sync" + "time" + + "gopkg.in/redis.v3" +) + +var client *redis.Client + +func init() { + client = redis.NewClient(&redis.Options{ + Addr: ":6379", + }) + client.FlushDb() +} + +func ExampleNewClient() { + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + }) + + pong, err := client.Ping().Result() + fmt.Println(pong, err) + // Output: PONG +} + +func ExampleNewFailoverClient() { + // See http://redis.io/topics/sentinel for instructions how to + // setup Redis Sentinel. + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: "master", + SentinelAddrs: []string{":26379"}, + }) + client.Ping() +} + +func ExampleNewClusterClient() { + // See http://redis.io/topics/cluster-tutorial for instructions + // how to setup Redis Cluster. + client := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{":7000", ":7001", ":7002", ":7003", ":7004", ":7005"}, + }) + client.Ping() +} + +func ExampleNewRing() { + client := redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{ + "shard1": ":7000", + "shard2": ":7001", + "shard3": ":7002", + }, + }) + client.Ping() +} + +func ExampleClient() { + err := client.Set("key", "value", 0).Err() + if err != nil { + panic(err) + } + + val, err := client.Get("key").Result() + if err != nil { + panic(err) + } + fmt.Println("key", val) + + val2, err := client.Get("key2").Result() + if err == redis.Nil { + fmt.Println("key2 does not exists") + } else if err != nil { + panic(err) + } else { + fmt.Println("key2", val2) + } + // Output: key value + // key2 does not exists +} + +func ExampleClient_Incr() { + if err := client.Incr("counter").Err(); err != nil { + panic(err) + } + + n, err := client.Get("counter").Int64() + fmt.Println(n, err) + // Output: 1 +} + +func ExampleClient_Pipelined() { + var incr *redis.IntCmd + _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + incr = pipe.Incr("counter1") + pipe.Expire("counter1", time.Hour) + return nil + }) + fmt.Println(incr.Val(), err) + // Output: 1 +} + +func ExamplePipeline() { + pipe := client.Pipeline() + defer pipe.Close() + + incr := pipe.Incr("counter2") + pipe.Expire("counter2", time.Hour) + _, err := pipe.Exec() + fmt.Println(incr.Val(), err) + // Output: 1 +} + +func ExampleMulti() { + // Transactionally increments key using GET and SET commands. + incr := func(tx *redis.Multi, key string) error { + err := tx.Watch(key).Err() + if err != nil { + return err + } + + n, err := tx.Get(key).Int64() + if err != nil && err != redis.Nil { + return err + } + + _, err = tx.Exec(func() error { + tx.Set(key, strconv.FormatInt(n+1, 10), 0) + return nil + }) + return err + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + tx := client.Multi() + defer tx.Close() + + for { + err := incr(tx, "counter3") + if err == redis.TxFailedErr { + // Retry. + continue + } else if err != nil { + panic(err) + } + break + } + }() + } + wg.Wait() + + n, err := client.Get("counter3").Int64() + fmt.Println(n, err) + // Output: 10 +} + +func ExamplePubSub() { + pubsub, err := client.Subscribe("mychannel") + if err != nil { + panic(err) + } + defer pubsub.Close() + + err = client.Publish("mychannel", "hello").Err() + if err != nil { + panic(err) + } + + for i := 0; i < 4; i++ { + msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond) + if err != nil { + err := pubsub.Ping("") + if err != nil { + panic(err) + } + continue + } + + switch msg := msgi.(type) { + case *redis.Subscription: + fmt.Println(msg.Kind, msg.Channel) + case *redis.Message: + fmt.Println(msg.Channel, msg.Payload) + case *redis.Pong: + fmt.Println(msg) + default: + panic(fmt.Sprintf("unknown message: %#v", msgi)) + } + } + + // Output: subscribe mychannel + // mychannel hello + // Pong +} + +func ExampleScript() { + IncrByXX := redis.NewScript(` + if redis.call("GET", KEYS[1]) ~= false then + return redis.call("INCRBY", KEYS[1], ARGV[1]) + end + return false + `) + + n, err := IncrByXX.Run(client, []string{"xx_counter"}, []string{"2"}).Result() + fmt.Println(n, err) + + err = client.Set("xx_counter", "40", 0).Err() + if err != nil { + panic(err) + } + + n, err = IncrByXX.Run(client, []string{"xx_counter"}, []string{"2"}).Result() + fmt.Println(n, err) + + // Output: redis: nil + // 42 +} + +func Example_customCommand() { + Get := func(client *redis.Client, key string) *redis.StringCmd { + cmd := redis.NewStringCmd("GET", key) + client.Process(cmd) + return cmd + } + + v, err := Get(client, "key_does_not_exist").Result() + fmt.Printf("%q %s", v, err) + // Output: "" redis: nil +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/export_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/export_test.go new file mode 100644 index 0000000..f468729 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/export_test.go @@ -0,0 +1,15 @@ +package redis + +import "net" + +func (c *baseClient) Pool() pool { + return c.connPool +} + +func (cn *conn) SetNetConn(netcn net.Conn) { + cn.netcn = netcn +} + +func HashSlot(key string) int { + return hashSlot(key) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/internal/consistenthash/consistenthash.go b/Godeps/_workspace/src/gopkg.in/redis.v3/internal/consistenthash/consistenthash.go new file mode 100644 index 0000000..a9c56f0 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/internal/consistenthash/consistenthash.go @@ -0,0 +1,81 @@ +/* +Copyright 2013 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package consistenthash provides an implementation of a ring hash. +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +type Hash func(data []byte) uint32 + +type Map struct { + hash Hash + replicas int + keys []int // Sorted + hashMap map[int]string +} + +func New(replicas int, fn Hash) *Map { + m := &Map{ + replicas: replicas, + hash: fn, + hashMap: make(map[int]string), + } + if m.hash == nil { + m.hash = crc32.ChecksumIEEE + } + return m +} + +// Returns true if there are no items available. +func (m *Map) IsEmpty() bool { + return len(m.keys) == 0 +} + +// Adds some keys to the hash. +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + hash := int(m.hash([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) +} + +// Gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if m.IsEmpty() { + return "" + } + + hash := int(m.hash([]byte(key))) + + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash }) + + // Means we have cycled back to the first replica. + if idx == len(m.keys) { + idx = 0 + } + + return m.hashMap[m.keys[idx]] +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/internal/consistenthash/consistenthash_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/internal/consistenthash/consistenthash_test.go new file mode 100644 index 0000000..1a37fd7 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/internal/consistenthash/consistenthash_test.go @@ -0,0 +1,110 @@ +/* +Copyright 2013 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package consistenthash + +import ( + "fmt" + "strconv" + "testing" +) + +func TestHashing(t *testing.T) { + + // Override the hash function to return easier to reason about values. Assumes + // the keys can be converted to an integer. + hash := New(3, func(key []byte) uint32 { + i, err := strconv.Atoi(string(key)) + if err != nil { + panic(err) + } + return uint32(i) + }) + + // Given the above hash function, this will give replicas with "hashes": + // 2, 4, 6, 12, 14, 16, 22, 24, 26 + hash.Add("6", "4", "2") + + testCases := map[string]string{ + "2": "2", + "11": "2", + "23": "4", + "27": "2", + } + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + + // Adds 8, 18, 28 + hash.Add("8") + + // 27 should now map to 8. + testCases["27"] = "8" + + for k, v := range testCases { + if hash.Get(k) != v { + t.Errorf("Asking for %s, should have yielded %s", k, v) + } + } + +} + +func TestConsistency(t *testing.T) { + hash1 := New(1, nil) + hash2 := New(1, nil) + + hash1.Add("Bill", "Bob", "Bonny") + hash2.Add("Bob", "Bonny", "Bill") + + if hash1.Get("Ben") != hash2.Get("Ben") { + t.Errorf("Fetching 'Ben' from both hashes should be the same") + } + + hash2.Add("Becky", "Ben", "Bobby") + + if hash1.Get("Ben") != hash2.Get("Ben") || + hash1.Get("Bob") != hash2.Get("Bob") || + hash1.Get("Bonny") != hash2.Get("Bonny") { + t.Errorf("Direct matches should always return the same entry") + } + +} + +func BenchmarkGet8(b *testing.B) { benchmarkGet(b, 8) } +func BenchmarkGet32(b *testing.B) { benchmarkGet(b, 32) } +func BenchmarkGet128(b *testing.B) { benchmarkGet(b, 128) } +func BenchmarkGet512(b *testing.B) { benchmarkGet(b, 512) } + +func benchmarkGet(b *testing.B, shards int) { + + hash := New(50, nil) + + var buckets []string + for i := 0; i < shards; i++ { + buckets = append(buckets, fmt.Sprintf("shard-%d", i)) + } + + hash.Add(buckets...) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + hash.Get(buckets[i&(shards-1)]) + } +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/main_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/main_test.go new file mode 100644 index 0000000..c4b5a59 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/main_test.go @@ -0,0 +1,250 @@ +package redis_test + +import ( + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +const ( + redisPort = "6380" + redisAddr = ":" + redisPort + redisSecondaryPort = "6381" +) + +const ( + ringShard1Port = "6390" + ringShard2Port = "6391" +) + +const ( + sentinelName = "mymaster" + sentinelMasterPort = "8123" + sentinelSlave1Port = "8124" + sentinelSlave2Port = "8125" + sentinelPort = "8126" +) + +var ( + redisMain *redisProcess + ringShard1, ringShard2 *redisProcess + sentinelMaster, sentinelSlave1, sentinelSlave2, sentinel *redisProcess +) + +var cluster = &clusterScenario{ + ports: []string{"8220", "8221", "8222", "8223", "8224", "8225"}, + nodeIds: make([]string, 6), + processes: make(map[string]*redisProcess, 6), + clients: make(map[string]*redis.Client, 6), +} + +var _ = BeforeSuite(func() { + var err error + + redisMain, err = startRedis(redisPort) + Expect(err).NotTo(HaveOccurred()) + + ringShard1, err = startRedis(ringShard1Port) + Expect(err).NotTo(HaveOccurred()) + + ringShard2, err = startRedis(ringShard2Port) + Expect(err).NotTo(HaveOccurred()) + + sentinelMaster, err = startRedis(sentinelMasterPort) + Expect(err).NotTo(HaveOccurred()) + + sentinel, err = startSentinel(sentinelPort, sentinelName, sentinelMasterPort) + Expect(err).NotTo(HaveOccurred()) + + sentinelSlave1, err = startRedis( + sentinelSlave1Port, "--slaveof", "127.0.0.1", sentinelMasterPort) + Expect(err).NotTo(HaveOccurred()) + + sentinelSlave2, err = startRedis( + sentinelSlave2Port, "--slaveof", "127.0.0.1", sentinelMasterPort) + Expect(err).NotTo(HaveOccurred()) + + Expect(startCluster(cluster)).NotTo(HaveOccurred()) +}) + +var _ = AfterSuite(func() { + Expect(redisMain.Close()).NotTo(HaveOccurred()) + + Expect(ringShard1.Close()).NotTo(HaveOccurred()) + Expect(ringShard2.Close()).NotTo(HaveOccurred()) + + Expect(sentinel.Close()).NotTo(HaveOccurred()) + Expect(sentinelSlave1.Close()).NotTo(HaveOccurred()) + Expect(sentinelSlave2.Close()).NotTo(HaveOccurred()) + Expect(sentinelMaster.Close()).NotTo(HaveOccurred()) + + Expect(stopCluster(cluster)).NotTo(HaveOccurred()) +}) + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "gopkg.in/redis.v3") +} + +//------------------------------------------------------------------------------ + +// Replaces ginkgo's Eventually. +func waitForSubstring(fn func() string, substr string, timeout time.Duration) error { + var s string + + found := make(chan struct{}) + var exit int32 + go func() { + for atomic.LoadInt32(&exit) == 0 { + s = fn() + if strings.Contains(s, substr) { + found <- struct{}{} + return + } + time.Sleep(timeout / 100) + } + }() + + select { + case <-found: + return nil + case <-time.After(timeout): + atomic.StoreInt32(&exit, 1) + } + return fmt.Errorf("%q does not contain %q", s, substr) +} + +func execCmd(name string, args ...string) (*os.Process, error) { + cmd := exec.Command(name, args...) + if testing.Verbose() { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + } + return cmd.Process, cmd.Start() +} + +func connectTo(port string) (client *redis.Client, err error) { + client = redis.NewClient(&redis.Options{ + Addr: ":" + port, + }) + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + if err = client.Ping().Err(); err == nil { + return client, nil + } + time.Sleep(250 * time.Millisecond) + } + + return nil, err +} + +type redisProcess struct { + *os.Process + *redis.Client +} + +func (p *redisProcess) Close() error { + p.Client.Close() + return p.Kill() +} + +var ( + redisServerBin, _ = filepath.Abs(filepath.Join(".test", "redis", "src", "redis-server")) + redisServerConf, _ = filepath.Abs(filepath.Join(".test", "redis.conf")) +) + +func redisDir(port string) (string, error) { + dir, err := filepath.Abs(filepath.Join(".test", "instances", port)) + if err != nil { + return "", err + } else if err = os.RemoveAll(dir); err != nil { + return "", err + } else if err = os.MkdirAll(dir, 0775); err != nil { + return "", err + } + return dir, nil +} + +func startRedis(port string, args ...string) (*redisProcess, error) { + dir, err := redisDir(port) + if err != nil { + return nil, err + } + if err = exec.Command("cp", "-f", redisServerConf, dir).Run(); err != nil { + return nil, err + } + + baseArgs := []string{filepath.Join(dir, "redis.conf"), "--port", port, "--dir", dir} + process, err := execCmd(redisServerBin, append(baseArgs, args...)...) + if err != nil { + return nil, err + } + + client, err := connectTo(port) + if err != nil { + process.Kill() + return nil, err + } + return &redisProcess{process, client}, err +} + +func startSentinel(port, masterName, masterPort string) (*redisProcess, error) { + dir, err := redisDir(port) + if err != nil { + return nil, err + } + process, err := execCmd(redisServerBin, os.DevNull, "--sentinel", "--port", port, "--dir", dir) + if err != nil { + return nil, err + } + client, err := connectTo(port) + if err != nil { + process.Kill() + return nil, err + } + for _, cmd := range []*redis.StatusCmd{ + redis.NewStatusCmd("SENTINEL", "MONITOR", masterName, "127.0.0.1", masterPort, "1"), + redis.NewStatusCmd("SENTINEL", "SET", masterName, "down-after-milliseconds", "500"), + redis.NewStatusCmd("SENTINEL", "SET", masterName, "failover-timeout", "1000"), + redis.NewStatusCmd("SENTINEL", "SET", masterName, "parallel-syncs", "1"), + } { + client.Process(cmd) + if err := cmd.Err(); err != nil { + process.Kill() + return nil, err + } + } + return &redisProcess{process, client}, nil +} + +//------------------------------------------------------------------------------ + +type badNetConn struct { + net.TCPConn +} + +var _ net.Conn = &badNetConn{} + +func newBadNetConn() net.Conn { + return &badNetConn{} +} + +func (badNetConn) Read([]byte) (int, error) { + return 0, net.UnknownNetworkError("badNetConn") +} + +func (badNetConn) Write([]byte) (int, error) { + return 0, net.UnknownNetworkError("badNetConn") +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/multi.go b/Godeps/_workspace/src/gopkg.in/redis.v3/multi.go similarity index 72% rename from Godeps/_workspace/src/gopkg.in/redis.v2/multi.go rename to Godeps/_workspace/src/gopkg.in/redis.v3/multi.go index bff38df..63ecdd5 100644 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/multi.go +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/multi.go @@ -3,42 +3,63 @@ package redis import ( "errors" "fmt" + "log" ) var errDiscard = errors.New("redis: Discard can be used only inside Exec") -// Not thread-safe. +// Multi implements Redis transactions as described in +// http://redis.io/topics/transactions. type Multi struct { - *Client + commandable + + base *baseClient + cmds []Cmder } func (c *Client) Multi() *Multi { - return &Multi{ - Client: &Client{ - baseClient: &baseClient{ - opt: c.opt, - connPool: newSingleConnPool(c.connPool, true), - }, + multi := &Multi{ + base: &baseClient{ + opt: c.opt, + connPool: newSingleConnPool(c.connPool, true), }, } + multi.commandable.process = multi.process + return multi +} + +func (c *Multi) process(cmd Cmder) { + if c.cmds == nil { + c.base.process(cmd) + } else { + c.cmds = append(c.cmds, cmd) + } } func (c *Multi) Close() error { if err := c.Unwatch().Err(); err != nil { - return err + log.Printf("redis: Unwatch failed: %s", err) } - return c.Client.Close() + return c.base.Close() } func (c *Multi) Watch(keys ...string) *StatusCmd { - args := append([]string{"WATCH"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "WATCH" + for i, key := range keys { + args[1+i] = key + } cmd := NewStatusCmd(args...) c.Process(cmd) return cmd } func (c *Multi) Unwatch(keys ...string) *StatusCmd { - args := append([]string{"UNWATCH"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "UNWATCH" + for i, key := range keys { + args[1+i] = key + } cmd := NewStatusCmd(args...) c.Process(cmd) return cmd @@ -69,24 +90,19 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { return []Cmder{}, nil } - cn, err := c.conn() + cn, err := c.base.conn() if err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) return cmds[1 : len(cmds)-1], err } err = c.execCmds(cn, cmds) - if err != nil { - c.freeConn(cn, err) - return cmds[1 : len(cmds)-1], err - } - - c.putConn(cn) - return cmds[1 : len(cmds)-1], nil + c.base.putConn(cn, err) + return cmds[1 : len(cmds)-1], err } func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { - err := c.writeCmd(cn, cmds...) + err := cn.writeCmds(cmds...) if err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) return err diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/multi_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/multi_test.go new file mode 100644 index 0000000..b481a52 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/multi_test.go @@ -0,0 +1,122 @@ +package redis_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Multi", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + }) + + AfterEach(func() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should exec", func() { + multi := client.Multi() + defer func() { + Expect(multi.Close()).NotTo(HaveOccurred()) + }() + + var ( + set *redis.StatusCmd + get *redis.StringCmd + ) + cmds, err := multi.Exec(func() error { + set = multi.Set("key", "hello", 0) + get = multi.Get("key") + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(2)) + + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello")) + }) + + It("should discard", func() { + multi := client.Multi() + defer func() { + Expect(multi.Close()).NotTo(HaveOccurred()) + }() + + cmds, err := multi.Exec(func() error { + multi.Set("key1", "hello1", 0) + multi.Discard() + multi.Set("key2", "hello2", 0) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + + get := client.Get("key1") + Expect(get.Err()).To(Equal(redis.Nil)) + Expect(get.Val()).To(Equal("")) + + get = client.Get("key2") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello2")) + }) + + It("should exec empty", func() { + multi := client.Multi() + defer func() { + Expect(multi.Close()).NotTo(HaveOccurred()) + }() + + cmds, err := multi.Exec(func() error { return nil }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(0)) + + ping := multi.Ping() + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + It("should exec empty queue", func() { + multi := client.Multi() + defer func() { + Expect(multi.Close()).NotTo(HaveOccurred()) + }() + + cmds, err := multi.Exec(func() error { return nil }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(0)) + }) + + It("should exec bulks", func() { + multi := client.Multi() + defer func() { + Expect(multi.Close()).NotTo(HaveOccurred()) + }() + + cmds, err := multi.Exec(func() error { + for i := int64(0); i < 20000; i++ { + multi.Incr("key") + } + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(len(cmds)).To(Equal(20000)) + for _, cmd := range cmds { + Expect(cmd.Err()).NotTo(HaveOccurred()) + } + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("20000")) + }) + +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/parser.go b/Godeps/_workspace/src/gopkg.in/redis.v3/parser.go new file mode 100644 index 0000000..32646ff --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/parser.go @@ -0,0 +1,529 @@ +package redis + +import ( + "errors" + "fmt" + "net" + "strconv" + + "gopkg.in/bufio.v1" +) + +type multiBulkParser func(rd *bufio.Reader, n int64) (interface{}, error) + +var ( + errReaderTooSmall = errors.New("redis: reader is too small") +) + +//------------------------------------------------------------------------------ + +// Copy of encoding.BinaryMarshaler. +type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) +} + +// Copy of encoding.BinaryUnmarshaler. +type binaryUnmarshaler interface { + UnmarshalBinary(data []byte) error +} + +func appendString(b []byte, s string) []byte { + b = append(b, '$') + b = strconv.AppendUint(b, uint64(len(s)), 10) + b = append(b, '\r', '\n') + b = append(b, s...) + b = append(b, '\r', '\n') + return b +} + +func appendBytes(b, bb []byte) []byte { + b = append(b, '$') + b = strconv.AppendUint(b, uint64(len(bb)), 10) + b = append(b, '\r', '\n') + b = append(b, bb...) + b = append(b, '\r', '\n') + return b +} + +func appendArg(b []byte, val interface{}) ([]byte, error) { + switch v := val.(type) { + case nil: + b = appendString(b, "") + case string: + b = appendString(b, v) + case []byte: + b = appendBytes(b, v) + case int: + b = appendString(b, formatInt(int64(v))) + case int8: + b = appendString(b, formatInt(int64(v))) + case int16: + b = appendString(b, formatInt(int64(v))) + case int32: + b = appendString(b, formatInt(int64(v))) + case int64: + b = appendString(b, formatInt(v)) + case uint: + b = appendString(b, formatUint(uint64(v))) + case uint8: + b = appendString(b, formatUint(uint64(v))) + case uint16: + b = appendString(b, formatUint(uint64(v))) + case uint32: + b = appendString(b, formatUint(uint64(v))) + case uint64: + b = appendString(b, formatUint(v)) + case float32: + b = appendString(b, formatFloat(float64(v))) + case float64: + b = appendString(b, formatFloat(v)) + case bool: + if v { + b = appendString(b, "1") + } else { + b = appendString(b, "0") + } + default: + if bm, ok := val.(binaryMarshaler); ok { + bb, err := bm.MarshalBinary() + if err != nil { + return nil, err + } + b = appendBytes(b, bb) + } else { + err := fmt.Errorf( + "redis: can't marshal %T (consider implementing BinaryMarshaler)", val) + return nil, err + } + } + return b, nil +} + +func appendArgs(b []byte, args []interface{}) ([]byte, error) { + b = append(b, '*') + b = strconv.AppendUint(b, uint64(len(args)), 10) + b = append(b, '\r', '\n') + for _, arg := range args { + var err error + b, err = appendArg(b, arg) + if err != nil { + return nil, err + } + } + return b, nil +} + +func scan(b []byte, val interface{}) error { + switch v := val.(type) { + case nil: + return errorf("redis: Scan(nil)") + case *string: + *v = bytesToString(b) + return nil + case *[]byte: + *v = b + return nil + case *int: + var err error + *v, err = strconv.Atoi(bytesToString(b)) + return err + case *int8: + n, err := strconv.ParseInt(bytesToString(b), 10, 8) + if err != nil { + return err + } + *v = int8(n) + return nil + case *int16: + n, err := strconv.ParseInt(bytesToString(b), 10, 16) + if err != nil { + return err + } + *v = int16(n) + return nil + case *int32: + n, err := strconv.ParseInt(bytesToString(b), 10, 16) + if err != nil { + return err + } + *v = int32(n) + return nil + case *int64: + n, err := strconv.ParseInt(bytesToString(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *uint: + n, err := strconv.ParseUint(bytesToString(b), 10, 64) + if err != nil { + return err + } + *v = uint(n) + return nil + case *uint8: + n, err := strconv.ParseUint(bytesToString(b), 10, 8) + if err != nil { + return err + } + *v = uint8(n) + return nil + case *uint16: + n, err := strconv.ParseUint(bytesToString(b), 10, 16) + if err != nil { + return err + } + *v = uint16(n) + return nil + case *uint32: + n, err := strconv.ParseUint(bytesToString(b), 10, 32) + if err != nil { + return err + } + *v = uint32(n) + return nil + case *uint64: + n, err := strconv.ParseUint(bytesToString(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *float32: + n, err := strconv.ParseFloat(bytesToString(b), 32) + if err != nil { + return err + } + *v = float32(n) + return err + case *float64: + var err error + *v, err = strconv.ParseFloat(bytesToString(b), 64) + return err + case *bool: + *v = len(b) == 1 && b[0] == '1' + return nil + default: + if bu, ok := val.(binaryUnmarshaler); ok { + return bu.UnmarshalBinary(b) + } + err := fmt.Errorf( + "redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", val) + return err + } +} + +//------------------------------------------------------------------------------ + +func readLine(rd *bufio.Reader) ([]byte, error) { + line, isPrefix, err := rd.ReadLine() + if err != nil { + return line, err + } + if isPrefix { + return line, errReaderTooSmall + } + return line, nil +} + +func readN(rd *bufio.Reader, n int) ([]byte, error) { + b, err := rd.ReadN(n) + if err == bufio.ErrBufferFull { + tmp := make([]byte, n) + r := copy(tmp, b) + b = tmp + + for { + nn, err := rd.Read(b[r:]) + r += nn + if r >= n { + // Ignore error if we read enough. + break + } + if err != nil { + return nil, err + } + } + } else if err != nil { + return nil, err + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func parseReq(rd *bufio.Reader) ([]string, error) { + line, err := readLine(rd) + if err != nil { + return nil, err + } + + if line[0] != '*' { + return []string{string(line)}, nil + } + numReplies, err := strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return nil, err + } + + args := make([]string, 0, numReplies) + for i := int64(0); i < numReplies; i++ { + line, err = readLine(rd) + if err != nil { + return nil, err + } + if line[0] != '$' { + return nil, fmt.Errorf("redis: expected '$', but got %q", line) + } + + argLen, err := strconv.ParseInt(string(line[1:]), 10, 32) + if err != nil { + return nil, err + } + + arg, err := readN(rd, int(argLen)+2) + if err != nil { + return nil, err + } + args = append(args, string(arg[:argLen])) + } + return args, nil +} + +//------------------------------------------------------------------------------ + +func parseReply(rd *bufio.Reader, p multiBulkParser) (interface{}, error) { + line, err := readLine(rd) + if err != nil { + return nil, err + } + + switch line[0] { + case '-': + return nil, errorf(string(line[1:])) + case '+': + return line[1:], nil + case ':': + v, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) + if err != nil { + return nil, err + } + return v, nil + case '$': + if len(line) == 3 && line[1] == '-' && line[2] == '1' { + return nil, Nil + } + + replyLen, err := strconv.Atoi(string(line[1:])) + if err != nil { + return nil, err + } + + b, err := readN(rd, replyLen+2) + if err != nil { + return nil, err + } + return b[:replyLen], nil + case '*': + if len(line) == 3 && line[1] == '-' && line[2] == '1' { + return nil, Nil + } + + repliesNum, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) + if err != nil { + return nil, err + } + + return p(rd, repliesNum) + } + return nil, fmt.Errorf("redis: can't parse %q", line) +} + +func parseSlice(rd *bufio.Reader, n int64) (interface{}, error) { + vals := make([]interface{}, 0, n) + for i := int64(0); i < n; i++ { + v, err := parseReply(rd, parseSlice) + if err == Nil { + vals = append(vals, nil) + } else if err != nil { + return nil, err + } else { + switch vv := v.(type) { + case []byte: + vals = append(vals, string(vv)) + default: + vals = append(vals, v) + } + } + } + return vals, nil +} + +func parseStringSlice(rd *bufio.Reader, n int64) (interface{}, error) { + vals := make([]string, 0, n) + for i := int64(0); i < n; i++ { + viface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + v, ok := viface.([]byte) + if !ok { + return nil, fmt.Errorf("got %T, expected string", viface) + } + vals = append(vals, string(v)) + } + return vals, nil +} + +func parseBoolSlice(rd *bufio.Reader, n int64) (interface{}, error) { + vals := make([]bool, 0, n) + for i := int64(0); i < n; i++ { + viface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + v, ok := viface.(int64) + if !ok { + return nil, fmt.Errorf("got %T, expected int64", viface) + } + vals = append(vals, v == 1) + } + return vals, nil +} + +func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) { + m := make(map[string]string, n/2) + for i := int64(0); i < n; i += 2 { + keyiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + key, ok := keyiface.([]byte) + if !ok { + return nil, fmt.Errorf("got %T, expected string", keyiface) + } + + valueiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + value, ok := valueiface.([]byte) + if !ok { + return nil, fmt.Errorf("got %T, expected string", valueiface) + } + + m[string(key)] = string(value) + } + return m, nil +} + +func parseStringIntMap(rd *bufio.Reader, n int64) (interface{}, error) { + m := make(map[string]int64, n/2) + for i := int64(0); i < n; i += 2 { + keyiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + key, ok := keyiface.([]byte) + if !ok { + return nil, fmt.Errorf("got %T, expected string", keyiface) + } + + valueiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + switch value := valueiface.(type) { + case int64: + m[string(key)] = value + case string: + m[string(key)], err = strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, fmt.Errorf("got %v, expected number", value) + } + default: + return nil, fmt.Errorf("got %T, expected number or string", valueiface) + } + } + return m, nil +} + +func parseZSlice(rd *bufio.Reader, n int64) (interface{}, error) { + zz := make([]Z, n/2) + for i := int64(0); i < n; i += 2 { + z := &zz[i/2] + + memberiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + member, ok := memberiface.([]byte) + if !ok { + return nil, fmt.Errorf("got %T, expected string", memberiface) + } + z.Member = string(member) + + scoreiface, err := parseReply(rd, nil) + if err != nil { + return nil, err + } + scoreb, ok := scoreiface.([]byte) + if !ok { + return nil, fmt.Errorf("got %T, expected string", scoreiface) + } + score, err := strconv.ParseFloat(bytesToString(scoreb), 64) + if err != nil { + return nil, err + } + z.Score = score + } + return zz, nil +} + +func parseClusterSlotInfoSlice(rd *bufio.Reader, n int64) (interface{}, error) { + infos := make([]ClusterSlotInfo, 0, n) + for i := int64(0); i < n; i++ { + viface, err := parseReply(rd, parseSlice) + if err != nil { + return nil, err + } + + item, ok := viface.([]interface{}) + if !ok { + return nil, fmt.Errorf("got %T, expected []interface{}", viface) + } else if len(item) < 3 { + return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + } + + start, ok := item[0].(int64) + if !ok || start < 0 || start > hashSlots { + return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + } + end, ok := item[1].(int64) + if !ok || end < 0 || end > hashSlots { + return nil, fmt.Errorf("got %v, expected {int64, int64, string...}", item) + } + + info := ClusterSlotInfo{int(start), int(end), make([]string, len(item)-2)} + for n, ipair := range item[2:] { + pair, ok := ipair.([]interface{}) + if !ok || len(pair) != 2 { + return nil, fmt.Errorf("got %v, expected []interface{host, port}", viface) + } + + ip, ok := pair[0].(string) + if !ok || len(ip) < 1 { + return nil, fmt.Errorf("got %v, expected IP PORT pair", pair) + } + port, ok := pair[1].(int64) + if !ok || port < 1 { + return nil, fmt.Errorf("got %v, expected IP PORT pair", pair) + } + + info.Addrs[n] = net.JoinHostPort(ip, strconv.FormatInt(port, 10)) + } + infos = append(infos, info) + } + return infos, nil +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/parser_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/parser_test.go similarity index 95% rename from Godeps/_workspace/src/gopkg.in/redis.v2/parser_test.go rename to Godeps/_workspace/src/gopkg.in/redis.v3/parser_test.go index 1b9e158..b71305a 100644 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/parser_test.go +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/parser_test.go @@ -47,7 +47,7 @@ func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr func BenchmarkAppendArgs(b *testing.B) { buf := make([]byte, 0, 64) - args := []string{"hello", "world", "foo", "bar"} + args := []interface{}{"hello", "world", "foo", "bar"} for i := 0; i < b.N; i++ { appendArgs(buf, args) } diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/pipeline.go b/Godeps/_workspace/src/gopkg.in/redis.v3/pipeline.go new file mode 100644 index 0000000..8981cb5 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/pipeline.go @@ -0,0 +1,113 @@ +package redis + +// Pipeline implements pipelining as described in +// http://redis.io/topics/pipelining. +// +// Pipeline is not thread-safe. +type Pipeline struct { + commandable + + client *baseClient + + cmds []Cmder + closed bool +} + +func (c *Client) Pipeline() *Pipeline { + pipe := &Pipeline{ + client: c.baseClient, + cmds: make([]Cmder, 0, 10), + } + pipe.commandable.process = pipe.process + return pipe +} + +func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { + pipe := c.Pipeline() + if err := fn(pipe); err != nil { + return nil, err + } + cmds, err := pipe.Exec() + pipe.Close() + return cmds, err +} + +func (pipe *Pipeline) process(cmd Cmder) { + pipe.cmds = append(pipe.cmds, cmd) +} + +func (pipe *Pipeline) Close() error { + pipe.Discard() + pipe.closed = true + return nil +} + +// Discard resets the pipeline and discards queued commands. +func (pipe *Pipeline) Discard() error { + if pipe.closed { + return errClosed + } + pipe.cmds = pipe.cmds[:0] + return nil +} + +// Exec always returns list of commands and error of the first failed +// command if any. +func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { + if pipe.closed { + return nil, errClosed + } + if len(pipe.cmds) == 0 { + return pipe.cmds, nil + } + + cmds = pipe.cmds + pipe.cmds = make([]Cmder, 0, 10) + + failedCmds := cmds + for i := 0; i <= pipe.client.opt.MaxRetries; i++ { + cn, err := pipe.client.conn() + if err != nil { + setCmdsErr(failedCmds, err) + return cmds, err + } + + if i > 0 { + resetCmds(failedCmds) + } + failedCmds, err = execCmds(cn, failedCmds) + pipe.client.putConn(cn, err) + if err != nil && retErr == nil { + retErr = err + } + if len(failedCmds) == 0 { + break + } + } + + return cmds, retErr +} + +func execCmds(cn *conn, cmds []Cmder) ([]Cmder, error) { + if err := cn.writeCmds(cmds...); err != nil { + setCmdsErr(cmds, err) + return cmds, err + } + + var firstCmdErr error + var failedCmds []Cmder + for _, cmd := range cmds { + err := cmd.parseReply(cn.rd) + if err == nil { + continue + } + if firstCmdErr == nil { + firstCmdErr = err + } + if shouldRetry(err) { + failedCmds = append(failedCmds, cmd) + } + } + + return failedCmds, firstCmdErr +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/pipeline_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/pipeline_test.go new file mode 100644 index 0000000..ddf7480 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/pipeline_test.go @@ -0,0 +1,153 @@ +package redis_test + +import ( + "strconv" + "sync" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Pipelining", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + }) + + AfterEach(func() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should pipeline", func() { + set := client.Set("key2", "hello2", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + pipeline := client.Pipeline() + set = pipeline.Set("key1", "hello1", 0) + get := pipeline.Get("key2") + incr := pipeline.Incr("key3") + getNil := pipeline.Get("key4") + + cmds, err := pipeline.Exec() + Expect(err).To(Equal(redis.Nil)) + Expect(cmds).To(HaveLen(4)) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello2")) + + Expect(incr.Err()).NotTo(HaveOccurred()) + Expect(incr.Val()).To(Equal(int64(1))) + + Expect(getNil.Err()).To(Equal(redis.Nil)) + Expect(getNil.Val()).To(Equal("")) + }) + + It("should discard", func() { + pipeline := client.Pipeline() + + pipeline.Get("key") + pipeline.Discard() + cmds, err := pipeline.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(0)) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + }) + + It("should support block style", func() { + var get *redis.StringCmd + cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { + get = pipe.Get("foo") + return nil + }) + Expect(err).To(Equal(redis.Nil)) + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0]).To(Equal(get)) + Expect(get.Err()).To(Equal(redis.Nil)) + Expect(get.Val()).To(Equal("")) + }) + + It("should handle vals/err", func() { + pipeline := client.Pipeline() + + get := pipeline.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("")) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + }) + + It("should pipeline with empty queue", func() { + pipeline := client.Pipeline() + cmds, err := pipeline.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(0)) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + }) + + It("should increment correctly", func() { + const N = 20000 + key := "TestPipelineIncr" + pipeline := client.Pipeline() + for i := 0; i < N; i++ { + pipeline.Incr(key) + } + + cmds, err := pipeline.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + + Expect(len(cmds)).To(Equal(20000)) + for _, cmd := range cmds { + Expect(cmd.Err()).NotTo(HaveOccurred()) + } + + get := client.Get(key) + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal(strconv.Itoa(N))) + }) + + It("should PipelineEcho", func() { + const N = 1000 + + wg := &sync.WaitGroup{} + wg.Add(N) + for i := 0; i < N; i++ { + go func(i int) { + defer GinkgoRecover() + defer wg.Done() + + pipeline := client.Pipeline() + + msg1 := "echo" + strconv.Itoa(i) + msg2 := "echo" + strconv.Itoa(i+1) + + echo1 := pipeline.Echo(msg1) + echo2 := pipeline.Echo(msg2) + + cmds, err := pipeline.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(2)) + + Expect(echo1.Err()).NotTo(HaveOccurred()) + Expect(echo1.Val()).To(Equal(msg1)) + + Expect(echo2.Err()).NotTo(HaveOccurred()) + Expect(echo2.Val()).To(Equal(msg2)) + + Expect(pipeline.Close()).NotTo(HaveOccurred()) + }(i) + } + wg.Wait() + }) + +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/pool.go b/Godeps/_workspace/src/gopkg.in/redis.v3/pool.go new file mode 100644 index 0000000..71ac456 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/pool.go @@ -0,0 +1,442 @@ +package redis + +import ( + "errors" + "fmt" + "log" + "sync" + "sync/atomic" + "time" + + "gopkg.in/bsm/ratelimit.v1" +) + +var ( + errClosed = errors.New("redis: client is closed") + errPoolTimeout = errors.New("redis: connection pool timeout") +) + +type pool interface { + First() *conn + Get() (*conn, error) + Put(*conn) error + Remove(*conn) error + Len() int + FreeLen() int + Close() error +} + +type connList struct { + cns []*conn + mx sync.Mutex + len int32 // atomic + size int32 +} + +func newConnList(size int) *connList { + return &connList{ + cns: make([]*conn, 0, size), + size: int32(size), + } +} + +func (l *connList) Len() int { + return int(atomic.LoadInt32(&l.len)) +} + +// Reserve reserves place in the list and returns true on success. The +// caller must add or remove connection if place was reserved. +func (l *connList) Reserve() bool { + len := atomic.AddInt32(&l.len, 1) + reserved := len <= l.size + if !reserved { + atomic.AddInt32(&l.len, -1) + } + return reserved +} + +// Add adds connection to the list. The caller must reserve place first. +func (l *connList) Add(cn *conn) { + l.mx.Lock() + l.cns = append(l.cns, cn) + l.mx.Unlock() +} + +// Remove closes connection and removes it from the list. +func (l *connList) Remove(cn *conn) error { + defer l.mx.Unlock() + l.mx.Lock() + + if cn == nil { + atomic.AddInt32(&l.len, -1) + return nil + } + + for i, c := range l.cns { + if c == cn { + l.cns = append(l.cns[:i], l.cns[i+1:]...) + atomic.AddInt32(&l.len, -1) + return cn.Close() + } + } + + if l.closed() { + return nil + } + panic("conn not found in the list") +} + +func (l *connList) Replace(cn, newcn *conn) error { + defer l.mx.Unlock() + l.mx.Lock() + + for i, c := range l.cns { + if c == cn { + l.cns[i] = newcn + return cn.Close() + } + } + + if l.closed() { + return newcn.Close() + } + panic("conn not found in the list") +} + +func (l *connList) Close() (retErr error) { + l.mx.Lock() + for _, c := range l.cns { + if err := c.Close(); err != nil { + retErr = err + } + } + l.cns = nil + atomic.StoreInt32(&l.len, 0) + l.mx.Unlock() + return retErr +} + +func (l *connList) closed() bool { + return l.cns == nil +} + +type connPool struct { + dialer func() (*conn, error) + + rl *ratelimit.RateLimiter + opt *Options + conns *connList + freeConns chan *conn + + _closed int32 + + lastDialErr error +} + +func newConnPool(opt *Options) *connPool { + p := &connPool{ + dialer: newConnDialer(opt), + + rl: ratelimit.New(2*opt.getPoolSize(), time.Second), + opt: opt, + conns: newConnList(opt.getPoolSize()), + freeConns: make(chan *conn, opt.getPoolSize()), + } + if p.opt.getIdleTimeout() > 0 { + go p.reaper() + } + return p +} + +func (p *connPool) closed() bool { + return atomic.LoadInt32(&p._closed) == 1 +} + +func (p *connPool) isIdle(cn *conn) bool { + return p.opt.getIdleTimeout() > 0 && time.Since(cn.usedAt) > p.opt.getIdleTimeout() +} + +// First returns first non-idle connection from the pool or nil if +// there are no connections. +func (p *connPool) First() *conn { + for { + select { + case cn := <-p.freeConns: + if p.isIdle(cn) { + p.conns.Remove(cn) + continue + } + return cn + default: + return nil + } + } + panic("not reached") +} + +// wait waits for free non-idle connection. It returns nil on timeout. +func (p *connPool) wait() *conn { + deadline := time.After(p.opt.getPoolTimeout()) + for { + select { + case cn := <-p.freeConns: + if p.isIdle(cn) { + p.Remove(cn) + continue + } + return cn + case <-deadline: + return nil + } + } + panic("not reached") +} + +// Establish a new connection +func (p *connPool) new() (*conn, error) { + if p.rl.Limit() { + err := fmt.Errorf( + "redis: you open connections too fast (last error: %v)", + p.lastDialErr, + ) + return nil, err + } + + cn, err := p.dialer() + if err != nil { + p.lastDialErr = err + return nil, err + } + + return cn, nil +} + +// Get returns existed connection from the pool or creates a new one. +func (p *connPool) Get() (*conn, error) { + if p.closed() { + return nil, errClosed + } + + // Fetch first non-idle connection, if available. + if cn := p.First(); cn != nil { + return cn, nil + } + + // Try to create a new one. + if p.conns.Reserve() { + cn, err := p.new() + if err != nil { + p.conns.Remove(nil) + return nil, err + } + p.conns.Add(cn) + return cn, nil + } + + // Otherwise, wait for the available connection. + if cn := p.wait(); cn != nil { + return cn, nil + } + + return nil, errPoolTimeout +} + +func (p *connPool) Put(cn *conn) error { + if cn.rd.Buffered() != 0 { + b, _ := cn.rd.ReadN(cn.rd.Buffered()) + log.Printf("redis: connection has unread data: %q", b) + return p.Remove(cn) + } + if p.opt.getIdleTimeout() > 0 { + cn.usedAt = time.Now() + } + p.freeConns <- cn + return nil +} + +func (p *connPool) Remove(cn *conn) error { + // Replace existing connection with new one and unblock waiter. + newcn, err := p.new() + if err != nil { + log.Printf("redis: new failed: %s", err) + return p.conns.Remove(cn) + } + err = p.conns.Replace(cn, newcn) + p.freeConns <- newcn + return err +} + +// Len returns total number of connections. +func (p *connPool) Len() int { + return p.conns.Len() +} + +// FreeLen returns number of free connections. +func (p *connPool) FreeLen() int { + return len(p.freeConns) +} + +func (p *connPool) Close() (retErr error) { + if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { + return errClosed + } + // Wait for app to free connections, but don't close them immediately. + for i := 0; i < p.Len(); i++ { + if cn := p.wait(); cn == nil { + break + } + } + // Close all connections. + if err := p.conns.Close(); err != nil { + retErr = err + } + return retErr +} + +func (p *connPool) reaper() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for _ = range ticker.C { + if p.closed() { + break + } + + // pool.First removes idle connections from the pool and + // returns first non-idle connection. So just put returned + // connection back. + if cn := p.First(); cn != nil { + p.Put(cn) + } + } +} + +//------------------------------------------------------------------------------ + +type singleConnPool struct { + pool pool + reusable bool + + cn *conn + closed bool + mx sync.Mutex +} + +func newSingleConnPool(pool pool, reusable bool) *singleConnPool { + return &singleConnPool{ + pool: pool, + reusable: reusable, + } +} + +func newSingleConnPoolConn(cn *conn) *singleConnPool { + return &singleConnPool{ + cn: cn, + } +} + +func (p *singleConnPool) First() *conn { + p.mx.Lock() + cn := p.cn + p.mx.Unlock() + return cn +} + +func (p *singleConnPool) Get() (*conn, error) { + defer p.mx.Unlock() + p.mx.Lock() + + if p.closed { + return nil, errClosed + } + if p.cn != nil { + return p.cn, nil + } + + cn, err := p.pool.Get() + if err != nil { + return nil, err + } + p.cn = cn + + return p.cn, nil +} + +func (p *singleConnPool) put() (err error) { + if p.pool != nil { + err = p.pool.Put(p.cn) + } + p.cn = nil + return err +} + +func (p *singleConnPool) Put(cn *conn) error { + defer p.mx.Unlock() + p.mx.Lock() + if p.cn != cn { + panic("p.cn != cn") + } + if p.closed { + return errClosed + } + return nil +} + +func (p *singleConnPool) remove() (err error) { + if p.pool != nil { + err = p.pool.Remove(p.cn) + } + p.cn = nil + return err +} + +func (p *singleConnPool) Remove(cn *conn) error { + defer p.mx.Unlock() + p.mx.Lock() + if p.cn == nil { + panic("p.cn == nil") + } + if p.cn != cn { + panic("p.cn != cn") + } + if p.closed { + return errClosed + } + return p.remove() +} + +func (p *singleConnPool) Len() int { + defer p.mx.Unlock() + p.mx.Lock() + if p.cn == nil { + return 0 + } + return 1 +} + +func (p *singleConnPool) FreeLen() int { + defer p.mx.Unlock() + p.mx.Lock() + if p.cn == nil { + return 1 + } + return 0 +} + +func (p *singleConnPool) Close() error { + defer p.mx.Unlock() + p.mx.Lock() + if p.closed { + return errClosed + } + p.closed = true + var err error + if p.cn != nil { + if p.reusable { + err = p.put() + } else { + err = p.remove() + } + } + return err +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/pool_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/pool_test.go new file mode 100644 index 0000000..bff892c --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/pool_test.go @@ -0,0 +1,203 @@ +package redis_test + +import ( + "sync" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Pool", func() { + var client *redis.Client + + var perform = func(n int, cb func()) { + wg := &sync.WaitGroup{} + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + + cb() + }() + } + wg.Wait() + } + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + PoolSize: 10, + }) + }) + + AfterEach(func() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should respect max size", func() { + perform(1000, func() { + val, err := client.Ping().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("PONG")) + }) + + pool := client.Pool() + Expect(pool.Len()).To(BeNumerically("<=", 10)) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) + }) + + It("should respect max on multi", func() { + perform(1000, func() { + var ping *redis.StatusCmd + + multi := client.Multi() + cmds, err := multi.Exec(func() error { + ping = multi.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + Expect(multi.Close()).NotTo(HaveOccurred()) + }) + + pool := client.Pool() + Expect(pool.Len()).To(BeNumerically("<=", 10)) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) + }) + + It("should respect max on pipelines", func() { + perform(1000, func() { + pipe := client.Pipeline() + ping := pipe.Ping() + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + Expect(pipe.Close()).NotTo(HaveOccurred()) + }) + + pool := client.Pool() + Expect(pool.Len()).To(BeNumerically("<=", 10)) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) + }) + + It("should respect max on pubsub", func() { + perform(10, func() { + pubsub := client.PubSub() + Expect(pubsub.Subscribe()).NotTo(HaveOccurred()) + Expect(pubsub.Close()).NotTo(HaveOccurred()) + }) + + pool := client.Pool() + Expect(pool.Len()).To(BeNumerically("<=", 10)) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) + }) + + It("should remove broken connections", func() { + cn, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + Expect(cn.Close()).NotTo(HaveOccurred()) + Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) + + err = client.Ping().Err() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("use of closed network connection")) + + val, err := client.Ping().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("PONG")) + + pool := client.Pool() + Expect(pool.Len()).To(Equal(1)) + Expect(pool.FreeLen()).To(Equal(1)) + }) + + It("should reuse connections", func() { + for i := 0; i < 100; i++ { + val, err := client.Ping().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("PONG")) + } + + pool := client.Pool() + Expect(pool.Len()).To(Equal(1)) + Expect(pool.FreeLen()).To(Equal(1)) + }) + + It("should unblock client when connection is removed", func() { + pool := client.Pool() + + // Reserve one connection. + cn, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + + // Reserve the rest of connections. + for i := 0; i < 9; i++ { + _, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + } + + var ping *redis.StatusCmd + started := make(chan bool, 1) + done := make(chan bool, 1) + go func() { + started <- true + ping = client.Ping() + done <- true + }() + <-started + + // Check that Ping is blocked. + select { + case <-done: + panic("Ping is not blocked") + default: + // ok + } + + Expect(pool.Remove(cn)).NotTo(HaveOccurred()) + + // Check that Ping is unblocked. + select { + case <-done: + // ok + case <-time.After(time.Second): + panic("Ping is not unblocked") + } + Expect(ping.Err()).NotTo(HaveOccurred()) + }) +}) + +func BenchmarkPool(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + pool := client.Pool() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) + } + if err = pool.Put(conn); err != nil { + b.Fatalf("no error expected on pool.Put but received: %s", err.Error()) + } + } + }) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/pubsub.go b/Godeps/_workspace/src/gopkg.in/redis.v3/pubsub.go new file mode 100644 index 0000000..1f4f5b6 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/pubsub.go @@ -0,0 +1,190 @@ +package redis + +import ( + "fmt" + "time" +) + +// Posts a message to the given channel. +func (c *Client) Publish(channel, message string) *IntCmd { + req := NewIntCmd("PUBLISH", channel, message) + c.Process(req) + return req +} + +// PubSub implements Pub/Sub commands as described in +// http://redis.io/topics/pubsub. +type PubSub struct { + *baseClient +} + +// Deprecated. Use Subscribe/PSubscribe instead. +func (c *Client) PubSub() *PubSub { + return &PubSub{ + baseClient: &baseClient{ + opt: c.opt, + connPool: newSingleConnPool(c.connPool, false), + }, + } +} + +// Subscribes the client to the specified channels. +func (c *Client) Subscribe(channels ...string) (*PubSub, error) { + pubsub := c.PubSub() + return pubsub, pubsub.Subscribe(channels...) +} + +// Subscribes the client to the given patterns. +func (c *Client) PSubscribe(channels ...string) (*PubSub, error) { + pubsub := c.PubSub() + return pubsub, pubsub.PSubscribe(channels...) +} + +func (c *PubSub) Ping(payload string) error { + cn, err := c.conn() + if err != nil { + return err + } + + args := []interface{}{"PING"} + if payload != "" { + args = append(args, payload) + } + cmd := NewCmd(args...) + return cn.writeCmds(cmd) +} + +// Message received after a successful subscription to channel. +type Subscription struct { + // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". + Kind string + // Channel name we have subscribed to. + Channel string + // Number of channels we are currently subscribed to. + Count int +} + +func (m *Subscription) String() string { + return fmt.Sprintf("%s: %s", m.Kind, m.Channel) +} + +// Message received as result of a PUBLISH command issued by another client. +type Message struct { + Channel string + Payload string +} + +func (m *Message) String() string { + return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) +} + +// Message matching a pattern-matching subscription received as result +// of a PUBLISH command issued by another client. +type PMessage struct { + Channel string + Pattern string + Payload string +} + +func (m *PMessage) String() string { + return fmt.Sprintf("PMessage<%s: %s>", m.Channel, m.Payload) +} + +// Pong received as result of a PING command issued by another client. +type Pong struct { + Payload string +} + +func (p *Pong) String() string { + if p.Payload != "" { + return fmt.Sprintf("Pong<%s>", p.Payload) + } + return "Pong" +} + +// Returns a message as a Subscription, Message, PMessage, Pong or +// error. See PubSub example for details. +func (c *PubSub) Receive() (interface{}, error) { + return c.ReceiveTimeout(0) +} + +func newMessage(reply []interface{}) (interface{}, error) { + switch kind := reply[0].(string); kind { + case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": + return &Subscription{ + Kind: kind, + Channel: reply[1].(string), + Count: int(reply[2].(int64)), + }, nil + case "message": + return &Message{ + Channel: reply[1].(string), + Payload: reply[2].(string), + }, nil + case "pmessage": + return &PMessage{ + Pattern: reply[1].(string), + Channel: reply[2].(string), + Payload: reply[3].(string), + }, nil + case "pong": + return &Pong{ + Payload: reply[1].(string), + }, nil + default: + return nil, fmt.Errorf("redis: unsupported pubsub notification: %q", kind) + } +} + +// ReceiveTimeout acts like Receive but returns an error if message +// is not received in time. +func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { + cn, err := c.conn() + if err != nil { + return nil, err + } + cn.ReadTimeout = timeout + + cmd := NewSliceCmd() + if err := cmd.parseReply(cn.rd); err != nil { + return nil, err + } + return newMessage(cmd.Val()) +} + +func (c *PubSub) subscribe(cmd string, channels ...string) error { + cn, err := c.conn() + if err != nil { + return err + } + + args := make([]interface{}, 1+len(channels)) + args[0] = cmd + for i, channel := range channels { + args[1+i] = channel + } + req := NewSliceCmd(args...) + return cn.writeCmds(req) +} + +// Subscribes the client to the specified channels. +func (c *PubSub) Subscribe(channels ...string) error { + return c.subscribe("SUBSCRIBE", channels...) +} + +// Subscribes the client to the given patterns. +func (c *PubSub) PSubscribe(patterns ...string) error { + return c.subscribe("PSUBSCRIBE", patterns...) +} + +// Unsubscribes the client from the given channels, or from all of +// them if none is given. +func (c *PubSub) Unsubscribe(channels ...string) error { + return c.subscribe("UNSUBSCRIBE", channels...) +} + +// Unsubscribes the client from the given patterns, or from all of +// them if none is given. +func (c *PubSub) PUnsubscribe(patterns ...string) error { + return c.subscribe("PUNSUBSCRIBE", patterns...) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/pubsub_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/pubsub_test.go new file mode 100644 index 0000000..ac1d629 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/pubsub_test.go @@ -0,0 +1,230 @@ +package redis_test + +import ( + "net" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("PubSub", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should support pattern matching", func() { + pubsub, err := client.PSubscribe("mychannel*") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + n, err := client.Publish("mychannel1", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred()) + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("psubscribe")) + Expect(subscr.Channel).To(Equal("mychannel*")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.PMessage) + Expect(subscr.Channel).To(Equal("mychannel1")) + Expect(subscr.Pattern).To(Equal("mychannel*")) + Expect(subscr.Payload).To(Equal("hello")) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("punsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel*")) + Expect(subscr.Count).To(Equal(0)) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err.(net.Error).Timeout()).To(Equal(true)) + Expect(msgi).NotTo(HaveOccurred()) + } + }) + + It("should pub/sub channels", func() { + channels, err := client.PubSubChannels("mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + pubsub, err := client.Subscribe("mychannel", "mychannel2") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + channels, err = client.PubSubChannels("mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"})) + + channels, err = client.PubSubChannels("").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + channels, err = client.PubSubChannels("*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(channels)).To(BeNumerically(">=", 2)) + }) + + It("should return the numbers of subscribers", func() { + pubsub, err := client.Subscribe("mychannel", "mychannel2") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + channels, err := client.PubSubNumSub("mychannel", "mychannel2", "mychannel3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(Equal(map[string]int64{ + "mychannel": 1, + "mychannel2": 1, + "mychannel3": 0, + })) + }) + + It("should return the numbers of subscribers by pattern", func() { + num, err := client.PubSubNumPat().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(num).To(Equal(int64(0))) + + pubsub, err := client.PSubscribe("*") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + num, err = client.PubSubNumPat().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(num).To(Equal(int64(1))) + }) + + It("should pub/sub", func() { + pubsub, err := client.Subscribe("mychannel", "mychannel2") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + n, err := client.Publish("mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + n, err = client.Publish("mychannel2", "hello2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("subscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("subscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(2)) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Message) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Payload).To(Equal("hello")) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel2")) + Expect(msg.Payload).To(Equal("hello2")) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("unsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("unsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(0)) + } + + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err.(net.Error).Timeout()).To(Equal(true)) + Expect(msgi).NotTo(HaveOccurred()) + } + }) + + It("should ping/pong", func() { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + _, err = pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + + err = pubsub.Ping("") + Expect(err).NotTo(HaveOccurred()) + + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + pong := msgi.(*redis.Pong) + Expect(pong.Payload).To(Equal("")) + }) + + It("should ping/pong with payload", func() { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + _, err = pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + + err = pubsub.Ping("hello") + Expect(err).NotTo(HaveOccurred()) + + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + pong := msgi.(*redis.Pong) + Expect(pong.Payload).To(Equal("hello")) + }) + +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/redis.go b/Godeps/_workspace/src/gopkg.in/redis.v3/redis.go new file mode 100644 index 0000000..a6e12f5 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/redis.go @@ -0,0 +1,192 @@ +package redis + +import ( + "fmt" + "log" + "net" + "time" +) + +type baseClient struct { + connPool pool + opt *Options +} + +func (c *baseClient) String() string { + return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) +} + +func (c *baseClient) conn() (*conn, error) { + return c.connPool.Get() +} + +func (c *baseClient) putConn(cn *conn, ei error) { + var err error + if cn.rd.Buffered() > 0 { + err = c.connPool.Remove(cn) + } else if ei == nil { + err = c.connPool.Put(cn) + } else if _, ok := ei.(redisError); ok { + err = c.connPool.Put(cn) + } else { + err = c.connPool.Remove(cn) + } + if err != nil { + log.Printf("redis: putConn failed: %s", err) + } +} + +func (c *baseClient) process(cmd Cmder) { + for i := 0; i <= c.opt.MaxRetries; i++ { + if i > 0 { + cmd.reset() + } + + cn, err := c.conn() + if err != nil { + cmd.setErr(err) + return + } + + if timeout := cmd.writeTimeout(); timeout != nil { + cn.WriteTimeout = *timeout + } else { + cn.WriteTimeout = c.opt.WriteTimeout + } + + if timeout := cmd.readTimeout(); timeout != nil { + cn.ReadTimeout = *timeout + } else { + cn.ReadTimeout = c.opt.ReadTimeout + } + + if err := cn.writeCmds(cmd); err != nil { + c.putConn(cn, err) + cmd.setErr(err) + if shouldRetry(err) { + continue + } + return + } + + err = cmd.parseReply(cn.rd) + c.putConn(cn, err) + if shouldRetry(err) { + continue + } + + return + } +} + +// Close closes the client, releasing any open resources. +func (c *baseClient) Close() error { + return c.connPool.Close() +} + +//------------------------------------------------------------------------------ + +type Options struct { + // The network type, either tcp or unix. + // Default is tcp. + Network string + // host:port address. + Addr string + + // Dialer creates new network connection and has priority over + // Network and Addr options. + Dialer func() (net.Conn, error) + + // An optional password. Must match the password specified in the + // requirepass server configuration option. + Password string + // A database to be selected after connecting to server. + DB int64 + + // The maximum number of retries before giving up. + // Default is to not retry failed commands. + MaxRetries int + + // Sets the deadline for establishing new connections. If reached, + // dial will fail with a timeout. + DialTimeout time.Duration + // Sets the deadline for socket reads. If reached, commands will + // fail with a timeout instead of blocking. + ReadTimeout time.Duration + // Sets the deadline for socket writes. If reached, commands will + // fail with a timeout instead of blocking. + WriteTimeout time.Duration + + // The maximum number of socket connections. + // Default is 10 connections. + PoolSize int + // Specifies amount of time client waits for connection if all + // connections are busy before returning an error. + // Default is 5 seconds. + PoolTimeout time.Duration + // Specifies amount of time after which client closes idle + // connections. Should be less than server's timeout. + // Default is to not close idle connections. + IdleTimeout time.Duration +} + +func (opt *Options) getNetwork() string { + if opt.Network == "" { + return "tcp" + } + return opt.Network +} + +func (opt *Options) getDialer() func() (net.Conn, error) { + if opt.Dialer == nil { + opt.Dialer = func() (net.Conn, error) { + return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) + } + } + return opt.Dialer +} + +func (opt *Options) getPoolSize() int { + if opt.PoolSize == 0 { + return 10 + } + return opt.PoolSize +} + +func (opt *Options) getDialTimeout() time.Duration { + if opt.DialTimeout == 0 { + return 5 * time.Second + } + return opt.DialTimeout +} + +func (opt *Options) getPoolTimeout() time.Duration { + if opt.PoolTimeout == 0 { + return 1 * time.Second + } + return opt.PoolTimeout +} + +func (opt *Options) getIdleTimeout() time.Duration { + return opt.IdleTimeout +} + +//------------------------------------------------------------------------------ + +type Client struct { + *baseClient + commandable +} + +func newClient(opt *Options, pool pool) *Client { + base := &baseClient{opt: opt, connPool: pool} + return &Client{ + baseClient: base, + commandable: commandable{process: base.process}, + } +} + +func NewClient(opt *Options) *Client { + pool := newConnPool(opt) + return newClient(opt, pool) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/redis_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/redis_test.go new file mode 100644 index 0000000..b1a2547 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/redis_test.go @@ -0,0 +1,365 @@ +package redis_test + +import ( + "bytes" + "net" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Client", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + }) + }) + + AfterEach(func() { + client.Close() + }) + + It("should Stringer", func() { + Expect(client.String()).To(Equal("Redis<:6380 db:0>")) + }) + + It("should ping", func() { + val, err := client.Ping().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("PONG")) + }) + + It("should support custom dialers", func() { + custom := redis.NewClient(&redis.Options{ + Dialer: func() (net.Conn, error) { + return net.Dial("tcp", redisAddr) + }, + }) + + val, err := custom.Ping().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("PONG")) + Expect(custom.Close()).NotTo(HaveOccurred()) + }) + + It("should close", func() { + Expect(client.Close()).NotTo(HaveOccurred()) + err := client.Ping().Err() + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("redis: client is closed")) + }) + + It("should close pubsub without closing the connection", func() { + pubsub := client.PubSub() + Expect(pubsub.Close()).NotTo(HaveOccurred()) + + _, err := pubsub.Receive() + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("redis: client is closed")) + Expect(client.Ping().Err()).NotTo(HaveOccurred()) + }) + + It("should close multi without closing the connection", func() { + multi := client.Multi() + Expect(multi.Close()).NotTo(HaveOccurred()) + + _, err := multi.Exec(func() error { + multi.Ping() + return nil + }) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("redis: client is closed")) + Expect(client.Ping().Err()).NotTo(HaveOccurred()) + }) + + It("should close pipeline without closing the connection", func() { + pipeline := client.Pipeline() + Expect(pipeline.Close()).NotTo(HaveOccurred()) + + pipeline.Ping() + _, err := pipeline.Exec() + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("redis: client is closed")) + Expect(client.Ping().Err()).NotTo(HaveOccurred()) + }) + + It("should close pubsub when client is closed", func() { + pubsub := client.PubSub() + Expect(client.Close()).NotTo(HaveOccurred()) + Expect(pubsub.Close()).NotTo(HaveOccurred()) + }) + + It("should close multi when client is closed", func() { + multi := client.Multi() + Expect(client.Close()).NotTo(HaveOccurred()) + Expect(multi.Close()).NotTo(HaveOccurred()) + }) + + It("should close pipeline when client is closed", func() { + pipeline := client.Pipeline() + Expect(client.Close()).NotTo(HaveOccurred()) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + }) + + It("should support idle-timeouts", func() { + idle := redis.NewClient(&redis.Options{ + Addr: redisAddr, + IdleTimeout: 100 * time.Microsecond, + }) + defer idle.Close() + + Expect(idle.Ping().Err()).NotTo(HaveOccurred()) + time.Sleep(time.Millisecond) + Expect(idle.Ping().Err()).NotTo(HaveOccurred()) + }) + + It("should support DB selection", func() { + db1 := redis.NewClient(&redis.Options{ + Addr: redisAddr, + DB: 1, + }) + defer db1.Close() + + Expect(db1.Get("key").Err()).To(Equal(redis.Nil)) + Expect(db1.Set("key", "value", 0).Err()).NotTo(HaveOccurred()) + + Expect(client.Get("key").Err()).To(Equal(redis.Nil)) + Expect(db1.Get("key").Val()).To(Equal("value")) + Expect(db1.FlushDb().Err()).NotTo(HaveOccurred()) + }) + + It("should support DB selection with read timeout (issue #135)", func() { + for i := 0; i < 100; i++ { + db1 := redis.NewClient(&redis.Options{ + Addr: redisAddr, + DB: 1, + ReadTimeout: time.Nanosecond, + }) + + err := db1.Ping().Err() + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + } + }) + + It("should retry command on network error", func() { + Expect(client.Close()).NotTo(HaveOccurred()) + + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + MaxRetries: 1, + }) + + // Put bad connection in the pool. + cn, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + cn.SetNetConn(newBadNetConn()) + Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) + + err = client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + }) +}) + +//------------------------------------------------------------------------------ + +func benchRedisClient() *redis.Client { + client := redis.NewClient(&redis.Options{ + Addr: ":6379", + }) + if err := client.FlushDb().Err(); err != nil { + panic(err) + } + return client +} + +func BenchmarkRedisPing(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Ping().Err(); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkRedisSet(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + value := string(bytes.Repeat([]byte{'1'}, 10000)) + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Set("key", value, 0).Err(); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkRedisGetNil(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Get("key").Err(); err != redis.Nil { + b.Fatal(err) + } + } + }) +} + +func benchmarkRedisSetGet(b *testing.B, size int) { + client := benchRedisClient() + defer client.Close() + + value := string(bytes.Repeat([]byte{'1'}, size)) + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Set("key", value, 0).Err(); err != nil { + b.Fatal(err) + } + + got, err := client.Get("key").Result() + if err != nil { + b.Fatal(err) + } + if got != value { + b.Fatalf("got != value") + } + } + }) +} + +func BenchmarkRedisSetGet64Bytes(b *testing.B) { + benchmarkRedisSetGet(b, 64) +} + +func BenchmarkRedisSetGet1KB(b *testing.B) { + benchmarkRedisSetGet(b, 1024) +} + +func BenchmarkRedisSetGet10KB(b *testing.B) { + benchmarkRedisSetGet(b, 10*1024) +} + +func BenchmarkRedisSetGet1MB(b *testing.B) { + benchmarkRedisSetGet(b, 1024*1024) +} + +func BenchmarkRedisSetGetBytes(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + value := bytes.Repeat([]byte{'1'}, 10000) + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Set("key", value, 0).Err(); err != nil { + b.Fatal(err) + } + + got, err := client.Get("key").Bytes() + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(got, value) { + b.Fatalf("got != value") + } + } + }) +} + +func BenchmarkRedisMGet(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + if err := client.MSet("key1", "hello1", "key2", "hello2").Err(); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.MGet("key1", "key2").Err(); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkSetExpire(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.Set("key", "hello", 0).Err(); err != nil { + b.Fatal(err) + } + if err := client.Expire("key", time.Second).Err(); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkPipeline(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := client.Pipelined(func(pipe *redis.Pipeline) error { + pipe.Set("key", "hello", 0) + pipe.Expire("key", time.Second) + return nil + }) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkZAdd(b *testing.B) { + client := benchRedisClient() + defer client.Close() + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := client.ZAdd("key", redis.Z{float64(1), "hello"}).Err(); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/ring.go b/Godeps/_workspace/src/gopkg.in/redis.v3/ring.go new file mode 100644 index 0000000..4b20e7a --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/ring.go @@ -0,0 +1,349 @@ +package redis + +import ( + "errors" + "fmt" + "log" + "sync" + "time" + + "gopkg.in/redis.v3/internal/consistenthash" +) + +var ( + errRingShardsDown = errors.New("redis: all ring shards are down") +) + +// RingOptions are used to configure a ring client and should be +// passed to NewRing. +type RingOptions struct { + // A map of name => host:port addresses of ring shards. + Addrs map[string]string + + // Following options are copied from Options struct. + + DB int64 + Password string + + MaxRetries int + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + + PoolSize int + PoolTimeout time.Duration + IdleTimeout time.Duration +} + +func (opt *RingOptions) clientOptions() *Options { + return &Options{ + DB: opt.DB, + Password: opt.Password, + + DialTimeout: opt.DialTimeout, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, + } +} + +type ringShard struct { + Client *Client + down int +} + +func (shard *ringShard) String() string { + var state string + if shard.IsUp() { + state = "up" + } else { + state = "down" + } + return fmt.Sprintf("%s is %s", shard.Client, state) +} + +func (shard *ringShard) IsDown() bool { + const threshold = 5 + return shard.down >= threshold +} + +func (shard *ringShard) IsUp() bool { + return !shard.IsDown() +} + +// Vote votes to set shard state and returns true if state was changed. +func (shard *ringShard) Vote(up bool) bool { + if up { + changed := shard.IsDown() + shard.down = 0 + return changed + } + + if shard.IsDown() { + return false + } + + shard.down++ + return shard.IsDown() +} + +// Ring is a Redis client that uses constistent hashing to distribute +// keys across multiple Redis servers (shards). +// +// It monitors the state of each shard and removes dead shards from +// the ring. When shard comes online it is added back to the ring. This +// gives you maximum availability and partition tolerance, but no +// consistency between different shards or even clients. Each client +// uses shards that are available to the client and does not do any +// coordination when shard state is changed. +// +// Ring should be used when you use multiple Redis servers for caching +// and can tolerate losing data when one of the servers dies. +// Otherwise you should use Redis Cluster. +type Ring struct { + commandable + + opt *RingOptions + nreplicas int + + mx sync.RWMutex + hash *consistenthash.Map + shards map[string]*ringShard + + closed bool +} + +func NewRing(opt *RingOptions) *Ring { + const nreplicas = 100 + ring := &Ring{ + opt: opt, + nreplicas: nreplicas, + + hash: consistenthash.New(nreplicas, nil), + shards: make(map[string]*ringShard), + } + ring.commandable.process = ring.process + for name, addr := range opt.Addrs { + clopt := opt.clientOptions() + clopt.Addr = addr + ring.addClient(name, NewClient(clopt)) + } + go ring.heartbeat() + return ring +} + +func (ring *Ring) addClient(name string, cl *Client) { + ring.mx.Lock() + ring.hash.Add(name) + ring.shards[name] = &ringShard{Client: cl} + ring.mx.Unlock() +} + +func (ring *Ring) getClient(key string) (*Client, error) { + ring.mx.RLock() + + if ring.closed { + return nil, errClosed + } + + name := ring.hash.Get(hashKey(key)) + if name == "" { + ring.mx.RUnlock() + return nil, errRingShardsDown + } + + cl := ring.shards[name].Client + ring.mx.RUnlock() + return cl, nil +} + +func (ring *Ring) process(cmd Cmder) { + cl, err := ring.getClient(cmd.clusterKey()) + if err != nil { + cmd.setErr(err) + return + } + cl.baseClient.process(cmd) +} + +// rebalance removes dead shards from the ring. +func (ring *Ring) rebalance() { + defer ring.mx.Unlock() + ring.mx.Lock() + + ring.hash = consistenthash.New(ring.nreplicas, nil) + for name, shard := range ring.shards { + if shard.IsUp() { + ring.hash.Add(name) + } + } +} + +// heartbeat monitors state of each shard in the ring. +func (ring *Ring) heartbeat() { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for _ = range ticker.C { + var rebalance bool + + ring.mx.RLock() + + if ring.closed { + ring.mx.RUnlock() + break + } + + for _, shard := range ring.shards { + err := shard.Client.Ping().Err() + if shard.Vote(err == nil || err == errPoolTimeout) { + log.Printf("redis: ring shard state changed: %s", shard) + rebalance = true + } + } + + ring.mx.RUnlock() + + if rebalance { + ring.rebalance() + } + } +} + +// Close closes the ring client, releasing any open resources. +// +// It is rare to Close a Client, as the Client is meant to be +// long-lived and shared between many goroutines. +func (ring *Ring) Close() (retErr error) { + defer ring.mx.Unlock() + ring.mx.Lock() + + if ring.closed { + return nil + } + ring.closed = true + + for _, shard := range ring.shards { + if err := shard.Client.Close(); err != nil { + retErr = err + } + } + ring.hash = nil + ring.shards = nil + + return retErr +} + +// RingPipeline creates a new pipeline which is able to execute commands +// against multiple shards. +type RingPipeline struct { + commandable + + ring *Ring + + cmds []Cmder + closed bool +} + +func (ring *Ring) Pipeline() *RingPipeline { + pipe := &RingPipeline{ + ring: ring, + cmds: make([]Cmder, 0, 10), + } + pipe.commandable.process = pipe.process + return pipe +} + +func (ring *Ring) Pipelined(fn func(*RingPipeline) error) ([]Cmder, error) { + pipe := ring.Pipeline() + if err := fn(pipe); err != nil { + return nil, err + } + cmds, err := pipe.Exec() + pipe.Close() + return cmds, err +} + +func (pipe *RingPipeline) process(cmd Cmder) { + pipe.cmds = append(pipe.cmds, cmd) +} + +// Discard resets the pipeline and discards queued commands. +func (pipe *RingPipeline) Discard() error { + if pipe.closed { + return errClosed + } + pipe.cmds = pipe.cmds[:0] + return nil +} + +// Exec always returns list of commands and error of the first failed +// command if any. +func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { + if pipe.closed { + return nil, errClosed + } + if len(pipe.cmds) == 0 { + return pipe.cmds, nil + } + + cmds = pipe.cmds + pipe.cmds = make([]Cmder, 0, 10) + + cmdsMap := make(map[string][]Cmder) + for _, cmd := range cmds { + name := pipe.ring.hash.Get(hashKey(cmd.clusterKey())) + if name == "" { + cmd.setErr(errRingShardsDown) + if retErr == nil { + retErr = errRingShardsDown + } + continue + } + cmdsMap[name] = append(cmdsMap[name], cmd) + } + + for i := 0; i <= pipe.ring.opt.MaxRetries; i++ { + failedCmdsMap := make(map[string][]Cmder) + + for name, cmds := range cmdsMap { + client := pipe.ring.shards[name].Client + cn, err := client.conn() + if err != nil { + setCmdsErr(cmds, err) + if retErr == nil { + retErr = err + } + continue + } + + if i > 0 { + resetCmds(cmds) + } + failedCmds, err := execCmds(cn, cmds) + client.putConn(cn, err) + if err != nil && retErr == nil { + retErr = err + } + if len(failedCmds) > 0 { + failedCmdsMap[name] = failedCmds + } + } + + if len(failedCmdsMap) == 0 { + break + } + cmdsMap = failedCmdsMap + } + + return cmds, retErr +} + +func (pipe *RingPipeline) Close() error { + pipe.Discard() + pipe.closed = true + return nil +} diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/ring_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/ring_test.go new file mode 100644 index 0000000..5b52b32 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/ring_test.go @@ -0,0 +1,164 @@ +package redis_test + +import ( + "crypto/rand" + "fmt" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Redis ring", func() { + var ring *redis.Ring + + setRingKeys := func() { + for i := 0; i < 100; i++ { + err := ring.Set(fmt.Sprintf("key%d", i), "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + } + + BeforeEach(func() { + ring = redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{ + "ringShardOne": ":" + ringShard1Port, + "ringShardTwo": ":" + ringShard2Port, + }, + }) + + // Shards should not have any keys. + Expect(ringShard1.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(ringShard1.Info().Val()).NotTo(ContainSubstring("keys=")) + + Expect(ringShard2.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(ringShard2.Info().Val()).NotTo(ContainSubstring("keys=")) + }) + + AfterEach(func() { + Expect(ring.Close()).NotTo(HaveOccurred()) + }) + + It("uses both shards", func() { + setRingKeys() + + // Both shards should have some keys now. + Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) + Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) + }) + + It("uses one shard when other shard is down", func() { + // Stop ringShard2. + Expect(ringShard2.Close()).NotTo(HaveOccurred()) + + // Ring needs 5 * heartbeat time to detect that node is down. + // Give it more to be sure. + heartbeat := 100 * time.Millisecond + time.Sleep(5*heartbeat + heartbeat) + + setRingKeys() + + // RingShard1 should have all keys. + Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=100")) + + // Start ringShard2. + var err error + ringShard2, err = startRedis(ringShard2Port) + Expect(err).NotTo(HaveOccurred()) + + // Wait for ringShard2 to come up. + Eventually(func() error { + return ringShard2.Ping().Err() + }, "1s").ShouldNot(HaveOccurred()) + + // Ring needs heartbeat time to detect that node is up. + // Give it more to be sure. + time.Sleep(heartbeat + heartbeat) + + setRingKeys() + + // RingShard2 should have its keys. + Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) + }) + + It("supports hash tags", func() { + for i := 0; i < 100; i++ { + err := ring.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + + Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys=")) + Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100")) + }) + + Describe("pipelining", func() { + It("returns an error when all shards are down", func() { + ring := redis.NewRing(&redis.RingOptions{}) + _, err := ring.Pipelined(func(pipe *redis.RingPipeline) error { + pipe.Ping() + return nil + }) + Expect(err).To(MatchError("redis: all ring shards are down")) + }) + + It("uses both shards", func() { + pipe := ring.Pipeline() + for i := 0; i < 100; i++ { + err := pipe.Set(fmt.Sprintf("key%d", i), "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(100)) + Expect(pipe.Close()).NotTo(HaveOccurred()) + + for _, cmd := range cmds { + Expect(cmd.Err()).NotTo(HaveOccurred()) + Expect(cmd.(*redis.StatusCmd).Val()).To(Equal("OK")) + } + + // Both shards should have some keys now. + Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) + Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) + }) + + It("is consistent with ring", func() { + var keys []string + for i := 0; i < 100; i++ { + key := make([]byte, 64) + _, err := rand.Read(key) + Expect(err).NotTo(HaveOccurred()) + keys = append(keys, string(key)) + } + + _, err := ring.Pipelined(func(pipe *redis.RingPipeline) error { + for _, key := range keys { + pipe.Set(key, "value", 0).Err() + } + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + for _, key := range keys { + val, err := ring.Get(key).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("value")) + } + }) + + It("supports hash tags", func() { + _, err := ring.Pipelined(func(pipe *redis.RingPipeline) error { + for i := 0; i < 100; i++ { + pipe.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err() + } + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys=")) + Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100")) + }) + }) +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/script.go b/Godeps/_workspace/src/gopkg.in/redis.v3/script.go similarity index 93% rename from Godeps/_workspace/src/gopkg.in/redis.v2/script.go rename to Godeps/_workspace/src/gopkg.in/redis.v3/script.go index 96c35f5..3f22f46 100644 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/script.go +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/script.go @@ -43,7 +43,7 @@ func (s *Script) EvalSha(c scripter, keys []string, args []string) *Cmd { return c.EvalSha(s.hash, keys, args) } -func (s *Script) Run(c *Client, keys []string, args []string) *Cmd { +func (s *Script) Run(c scripter, keys []string, args []string) *Cmd { r := s.EvalSha(c, keys, args) if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { return s.Eval(c, keys, args) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v2/sentinel.go b/Godeps/_workspace/src/gopkg.in/redis.v3/sentinel.go similarity index 76% rename from Godeps/_workspace/src/gopkg.in/redis.v2/sentinel.go rename to Godeps/_workspace/src/gopkg.in/redis.v3/sentinel.go index d3ffeca..82d9bc9 100644 --- a/Godeps/_workspace/src/gopkg.in/redis.v2/sentinel.go +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/sentinel.go @@ -11,49 +11,47 @@ import ( //------------------------------------------------------------------------------ +// FailoverOptions are used to configure a failover client and should +// be passed to NewFailoverClient. type FailoverOptions struct { - MasterName string + // The master name. + MasterName string + // A seed list of host:port addresses of sentinel nodes. SentinelAddrs []string + // Following options are copied from Options struct. + Password string DB int64 - PoolSize int - DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration - IdleTimeout time.Duration -} -func (opt *FailoverOptions) getPoolSize() int { - if opt.PoolSize == 0 { - return 10 - } - return opt.PoolSize + PoolSize int + PoolTimeout time.Duration + IdleTimeout time.Duration } -func (opt *FailoverOptions) getDialTimeout() time.Duration { - if opt.DialTimeout == 0 { - return 5 * time.Second - } - return opt.DialTimeout -} +func (opt *FailoverOptions) options() *Options { + return &Options{ + Addr: "FailoverClient", -func (opt *FailoverOptions) options() *options { - return &options{ DB: opt.DB, Password: opt.Password, - DialTimeout: opt.getDialTimeout(), + DialTimeout: opt.DialTimeout, ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, - PoolSize: opt.getPoolSize(), + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, } } +// NewFailoverClient returns a Redis client with automatic failover +// capabilities using Redis Sentinel. func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt := failoverOpt.options() failover := &sentinelFailover{ @@ -62,32 +60,24 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt: opt, } - return &Client{ - baseClient: &baseClient{ - opt: opt, - connPool: failover.Pool(), - }, - } + return newClient(opt, failover.Pool()) } //------------------------------------------------------------------------------ type sentinelClient struct { + commandable *baseClient } -func newSentinel(clOpt *Options) *sentinelClient { - opt := clOpt.options() - opt.Password = "" - opt.DB = 0 - dialer := func() (net.Conn, error) { - return net.DialTimeout("tcp", clOpt.Addr, opt.DialTimeout) +func newSentinel(opt *Options) *sentinelClient { + base := &baseClient{ + opt: opt, + connPool: newConnPool(opt), } return &sentinelClient{ - baseClient: &baseClient{ - opt: opt, - connPool: newConnPool(newConnFunc(dialer), opt), - }, + baseClient: base, + commandable: commandable{process: base.process}, } } @@ -116,7 +106,7 @@ type sentinelFailover struct { masterName string sentinelAddrs []string - opt *options + opt *Options pool pool poolOnce sync.Once @@ -135,7 +125,8 @@ func (d *sentinelFailover) dial() (net.Conn, error) { func (d *sentinelFailover) Pool() pool { d.poolOnce.Do(func() { - d.pool = newConnPool(newConnFunc(d.dial), d.opt) + d.opt.Dialer = d.dial + d.pool = newConnPool(d.opt) }) return d.pool } @@ -161,14 +152,12 @@ func (d *sentinelFailover) MasterAddr() (string, error) { sentinel := newSentinel(&Options{ Addr: sentinelAddr, - DB: d.opt.DB, - Password: d.opt.Password, - DialTimeout: d.opt.DialTimeout, ReadTimeout: d.opt.ReadTimeout, WriteTimeout: d.opt.WriteTimeout, PoolSize: d.opt.PoolSize, + PoolTimeout: d.opt.PoolTimeout, IdleTimeout: d.opt.IdleTimeout, }) masterAddr, err := sentinel.GetMasterAddrByName(d.masterName).Result() @@ -220,6 +209,34 @@ func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { } } +// closeOldConns closes connections to the old master after failover switch. +func (d *sentinelFailover) closeOldConns(newMaster string) { + // Good connections that should be put back to the pool. They + // can't be put immediately, because pool.First will return them + // again on next iteration. + cnsToPut := make([]*conn, 0) + + for { + cn := d.pool.First() + if cn == nil { + break + } + if cn.RemoteAddr().String() != newMaster { + log.Printf( + "redis-sentinel: closing connection to the old master %s", + cn.RemoteAddr(), + ) + d.pool.Remove(cn) + } else { + cnsToPut = append(cnsToPut, cn) + } + } + + for _, cn := range cnsToPut { + d.pool.Put(cn) + } +} + func (d *sentinelFailover) listen() { var pubsub *PubSub for { @@ -255,16 +272,8 @@ func (d *sentinelFailover) listen() { "redis-sentinel: new %q addr is %s", d.masterName, addr, ) - d.pool.Filter(func(cn *conn) bool { - if cn.RemoteAddr().String() != addr { - log.Printf( - "redis-sentinel: closing connection to old master %s", - cn.RemoteAddr(), - ) - return false - } - return true - }) + + d.closeOldConns(addr) default: log.Printf("redis-sentinel: unsupported message: %s", msg) } diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/sentinel_test.go b/Godeps/_workspace/src/gopkg.in/redis.v3/sentinel_test.go new file mode 100644 index 0000000..14dcf83 --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/sentinel_test.go @@ -0,0 +1,74 @@ +package redis_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" +) + +var _ = Describe("Sentinel", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: []string{":" + sentinelPort}, + }) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should facilitate failover", func() { + // Set value on master, verify + err := client.Set("foo", "master", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err := sentinelMaster.Get("foo").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("master")) + + // Wait until replicated + Eventually(func() string { + return sentinelSlave1.Get("foo").Val() + }, "1s", "100ms").Should(Equal("master")) + Eventually(func() string { + return sentinelSlave2.Get("foo").Val() + }, "1s", "100ms").Should(Equal("master")) + + // Wait until slaves are picked up by sentinel. + Eventually(func() string { + return sentinel.Info().Val() + }, "10s", "100ms").Should(ContainSubstring("slaves=2")) + + // Kill master. + sentinelMaster.Shutdown() + Eventually(func() error { + return sentinelMaster.Ping().Err() + }, "5s", "100ms").Should(HaveOccurred()) + + // Wait for Redis sentinel to elect new master. + Eventually(func() string { + return sentinelSlave1.Info().Val() + sentinelSlave2.Info().Val() + }, "30s", "1s").Should(ContainSubstring("role:master")) + + // Check that client picked up new master. + Eventually(func() error { + return client.Get("foo").Err() + }, "5s", "100ms").ShouldNot(HaveOccurred()) + }) + + It("supports DB selection", func() { + Expect(client.Close()).NotTo(HaveOccurred()) + + client = redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: []string{":" + sentinelPort}, + DB: 1, + }) + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + }) +}) diff --git a/Godeps/_workspace/src/gopkg.in/redis.v3/unsafe.go b/Godeps/_workspace/src/gopkg.in/redis.v3/unsafe.go new file mode 100644 index 0000000..1c4d55f --- /dev/null +++ b/Godeps/_workspace/src/gopkg.in/redis.v3/unsafe.go @@ -0,0 +1,12 @@ +package redis + +import ( + "reflect" + "unsafe" +) + +func bytesToString(b []byte) string { + bytesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + strHeader := reflect.StringHeader{bytesHeader.Data, bytesHeader.Len} + return *(*string)(unsafe.Pointer(&strHeader)) +} diff --git a/README.md b/README.md index 7be79f7..1777e17 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,51 @@ -# dockyard +# Dockyard A image hub for rkt & docker and other container engine. -## `runtime.conf` +# How to compile Dockyard application + +Make sure Go has been installed and env has been set. + +Clone code which Dockyard is depended into directory: + +```bash +git clone https://github.com/containerops/dockyard.git $GOPATH/src/github.com/containerops/dockyard +git clone https://github.com/containerops/crew.git $GOPATH/src/github.com/containerops/crew +git clone https://github.com/containerops/wrench.git $GOPATH/src/github.com/containerops/wrench +git clone https://github.com/containerops/ameba.git $GOPATH/src/github.com/containerops/ameba ``` + +Then exec commands in each project directory as below,it will download the third dependent packages automatically: + +```bash +cd $GOPATH/src/github.com/containerops/dockyard +go get + +cd $GOPATH/src/github.com/containerops/crew +go get + +cd $GOPATH/src/github.com/containerops/wrench +go get + +cd $GOPATH/src/github.com/containerops/ameba +go get +``` + +Finally,enter Dockyard directory and build: +```bash +cd $GOPATH/src/github.com/containerops/dockyard +go build +``` + + +# Dockyard runtime configuration + +Please add a runtime config file named `runtime.conf` under `dockyard/conf` before starting `dockyard` service. + +## `runtime.conf` Example + +```ini runmode = dev listenmode = https @@ -13,4 +54,130 @@ httpskeyfile = cert/containerops/containerops.key [log] filepath = log/containerops-log -``` \ No newline at end of file + +[db] +uri = localhost:6379 +passwd = containerops +db = 8 + +[dockyard] +path = data +domains = containerops.me +registry = 0.9 +distribution = registry/2.0 +standalone = true +``` + +* runmode: application run mode must be `dev` or `prod`. +* listenmode: support `http` and `https` protocol. +* httpscertfile: specify user own https certificate file by this parameter. +* httpskeyfile: specify user own https key file by this parameter. +* [log] filepath: specify where Dockyard logs are stored. +* [db] uri: Dockyard database provider is `redis`,`IP` and `Port` would be specified before `redis` boots. +* [db] passwd: specify the password to login and access db. +* [db] db: specify db area number to use. +* [dockyard] path: specify where `Docker` and `Rocket` image files are stored. +* [dockyard] domains: registry server name or IP. +* [dockyard] registry: specify the version of Docker V1 protocol. +* [dockyard] distribution: specify the version of Docker V2 protocol. +* [dockyard] standalone: must be `true` or `false`,specify run mode whether do authorization checks or not. + + +# Nginx configuration + +It's a Nginx config example. You can change **client_max_body_size** what limited upload file size. + +You should copy `containerops.me` keys from `cert/containerops.me` to `/etc/nginx`, then run **Dockyard** with `http` mode and listen on `127.0.0.1:9911`. + +```nginx +upstream dockyard_upstream { + server 127.0.0.1:9911; +} + +server { + listen 80; + server_name containerops.me; + rewrite ^/(.*)$ https://containerops.me/$1 permanent; +} + +server { + listen 443; + + server_name containerops.me; + + access_log /var/log/nginx/containerops-me.log; + error_log /var/log/nginx/containerops-me-errror.log; + + ssl on; + ssl_certificate /etc/nginx/containerops.me.crt; + ssl_certificate_key /etc/nginx/containerops.me.key; + + client_max_body_size 1024m; + chunked_transfer_encoding on; + + proxy_redirect off; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header Host $http_host; + proxy_set_header X-NginX-Proxy true; + proxy_set_header Connection ""; + proxy_http_version 1.1; + + location / { + proxy_pass http://dockyard_upstream; + } +} +``` + + +# How to run + +Run directly: + +```bash +./dockyard web --address 0.0.0.0 --port 80 +``` + +Run behind Nginx: + +```bash +./dockyard web --address 127.0.0.1 --port 9911 +``` + + +# How to use + +1. Add **containerops.me** in your `hosts` file like `192.168.1.66 containerops.me` with IP which run `dockyard`. +2. Then `push` with `docker push containerops.me/somebody/ubuntu`. +3. You could `pull` with `docker pull -a containerops.me/somebody/ubuntu`. +4. Work fun! + + +# Reporting issues + +Please submit issue at https://github.com/containerops/dockyard/issues + + +# Maintainers + +* Meaglith Ma https://twitter.com/genedna +* Leo Meng https://github.com/fivestarsky + + +# Licensing + +Dockyard is licensed under the MIT License. + + +# Todo in the feature + +1. Support Docker V1/V2 protocol conversion. +2. Support Rocket **CAS**. +3. More relative pages. + + +# We are working on other projects of Dockyard related + +* [Vessel](https://github.com/dockercn/vessel): Continuous Integration Service Core Of ContainerOps. +* [Rudder](https://github.com/dockercn/rudder): Rtk & Docker api client. diff --git a/backend/aliyun.go b/backend/aliyun.go new file mode 100644 index 0000000..0e09998 --- /dev/null +++ b/backend/aliyun.go @@ -0,0 +1,282 @@ +package backend + +import ( + "bytes" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/config" +) + +var ( + g_aliEndpoint string + g_aliBucket string + g_aliAccessKeyID string + g_aliAccessKeySecret string +) + +func init() { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Errorf("read env GOPATH fail") + os.Exit(1) + } + err := aligetconfig(gopath + "/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + fmt.Errorf("read config file conf/runtime.conf fail:" + err.Error()) + os.Exit(1) + } + g_injector.Bind("alicloudsave", alicloudsave) +} + +func aligetconfig(conffile string) (err error) { + conf, err := config.NewConfig("ini", conffile) + if err != nil { + return err + } + + g_aliEndpoint = conf.String("alicloud::endpoint") + if g_aliEndpoint == "" { + return errors.New("read config file's endpoint failed!") + } + + g_aliBucket = conf.String("alicloud::bucket") + if g_aliBucket == "" { + return errors.New("read config file's bucket failed!") + } + + g_aliAccessKeyID = conf.String("alicloud::accessKeyID") + if g_aliAccessKeyID == "" { + return errors.New("read config file's accessKeyID failed!") + } + + g_aliAccessKeySecret = conf.String("alicloud::accessKeysecret") + if g_aliAccessKeySecret == "" { + return errors.New("read config file's accessKeysecret failed!") + } + return nil +} + +func alicloudsave(file string) (url string, err error) { + + client := NewClient(g_aliAccessKeyID, g_aliAccessKeySecret) + + bucket := NewBucket(g_aliBucket, g_aliEndpoint, client) + + var key string + //get the filename from the file , eg,get "1.txt" from /home/liugenping/1.txt + for _, key = range strings.Split(file, "/") { + + } + opath := "/" + g_aliBucket + "/" + key + url = "http://" + g_aliEndpoint + opath + + headers := map[string]string{} + + err = bucket.PutFile(key, file, headers) + + if nil != err { + return "", err + } else { + return url, nil + } +} + +var resourceQSWhitelist []string = []string{ + "acl", + "group", + "uploadId", + "partNumber", + "uploads", + "logging", + "response-content-type", + "response-content-language", + "response-expires", + "reponse-cache-control", + "response-content-disposition", + "response-content-encoding", +} + +// Holds OSS client informations +type Client struct { + accessKeyId string + accessKeySecret string + *http.Client +} + +// Holds OSS bucket informations +type Bucket struct { + name string + region string + *Client +} + +// Initialize a new client and sets access_key_id and access_key_secret. +func NewClient(accessKeyId, accessKeySecret string) *Client { + return &Client{accessKeyId, accessKeySecret, new(http.Client)} +} + +// Initialize a new OSS bucket with the given `name`, `region` and `*Client`. +func NewBucket(name string, region string, client *Client) *Bucket { + return &Bucket{name, region, client} +} + +// PUT the given `content` as `object`. +func (b *Bucket) Put(object string, content io.Reader, headers map[string]string) error { + buffer := new(bytes.Buffer) + io.Copy(buffer, content) + + header := make(http.Header) + header.Set("Content-Type", http.DetectContentType(buffer.Bytes())) + header.Set("Content-Length", strconv.Itoa(buffer.Len())) + + for key, val := range headers { + header.Set(key, val) + } + + resp, err := b.do("PUT", b.name, string(b.region), object, header, buffer) + if err != nil { + return err + } + + if resp.StatusCode != 200 { + err = errors.New(resp.Status) + return err + } + + return nil +} + +// PUT the file at `filepath` to `object`. +func (b *Bucket) PutFile(object, filepath string, headers map[string]string) error { + file, err := os.Open(filepath) + if err != nil { + return err + } + defer file.Close() + + return b.Put(object, file, headers) +} + +func (c *Client) do(method, bucket, region, object string, header http.Header, body io.Reader) (*http.Response, error) { + object = strings.Trim(object, "/") + req, err := http.NewRequest(method, fmt.Sprintf("http://%s.%s/%s", bucket, region, object), body) + if err != nil { + return nil, err + } + + if header == nil { + header = make(http.Header) + } + header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + + resource := fmt.Sprintf("/%s/%s", bucket, object) + header.Set("Authorization", c.authorization(method, header, resource)) + + req.Header = header + + return c.Do(req) +} + +// Return an "Authorization" header value in the form of "OSS " + Access Key Id + ":" + Signature +// +// Signature: +// +// base64(hmac-sha1(Access Key Secret + "\n" +// + VERB + "\n" +// + CONTENT-MD5 + "\n" +// + CONTENT-TYPE + "\n" +// + DATE + "\n" +// + CanonicalizedossHeaders +// + CanonicalizedResource)) +func (c *Client) authorization(verb string, header http.Header, resource string) string { + params := []string{ + verb, + header.Get("Content-MD5"), + header.Get("Content-Type"), + header.Get("Date"), + } + + signatureStr := strings.Join(params, "\n") + "\n" + + canonicalizedHeaders := c.canonicalizeHeaders(header) + canonicalizedResource := c.canonicalizeResource(resource) + + if canonicalizedHeaders != "" { + signatureStr += canonicalizedHeaders + } + signatureStr += canonicalizedResource + + h := hmac.New(sha1.New, []byte(c.accessKeySecret)) + h.Write([]byte(signatureStr)) + + signedStr := strings.TrimSpace(base64.StdEncoding.EncodeToString(h.Sum(nil))) + + return "OSS " + c.accessKeyId + ":" + signedStr +} + +// Generate `CanonicalizedossHeaders` +// +// Spec: +// - ignore none x-oss- headers +// - lowercase fields +// - sort lexicographically +// - trim whitespace between field and value +// - join with newline +func (c *Client) canonicalizeHeaders(header http.Header) string { + ossHeaders := []string{} + canonicalizedHeaders := "" + + for k, _ := range header { + field := strings.ToLower(k) + + if strings.HasPrefix(field, "x-oss-") { + ossHeaders = append(ossHeaders, field) + } + } + + sort.Strings(ossHeaders) + + for _, k := range ossHeaders { + canonicalizedHeaders += k + ":" + header.Get(k) + "\n" + } + + return canonicalizedHeaders +} + +// Generate `CanonicalizedResource` +// +// Spec: +// - ignore non sub-resource +// - ignore non override headers +// - sort lexicographically +func (c *Client) canonicalizeResource(resource string) string { + u, _ := url.Parse(resource) + + queryies := u.Query() + query := url.Values{} + + sort.Strings(resourceQSWhitelist) + for _, q := range resourceQSWhitelist { + val := queryies.Get(q) + if val != "" { + query.Add(q, val) + } + } + + u.RawQuery = query.Encode() + + return u.String() +} diff --git a/backend/aliyun_test.go b/backend/aliyun_test.go new file mode 100644 index 0000000..5c03af3 --- /dev/null +++ b/backend/aliyun_test.go @@ -0,0 +1,27 @@ +package backend + +import ( + "net/http" + "os" + "testing" +) + +func Test_alicloudsave(t *testing.T) { + + var gopath string + gopath = os.Getenv("GOPATH") + if gopath == "" { + t.Error("read env GOPATH fail") + return + } + file := gopath + "/src/github.com/containerops/dockyard/backend/aliyun.go" + url, err := alicloudsave(file) + if err != nil { + t.Error(err) + return + } + _, err = http.Get(url) + if err != nil { + t.Error(err) + } +} diff --git a/backend/amazons3cloud.go b/backend/amazons3cloud.go new file mode 100644 index 0000000..fdae5b2 --- /dev/null +++ b/backend/amazons3cloud.go @@ -0,0 +1,151 @@ +package backend + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "net/http" + "os" + "sort" + "strings" + "time" + + "github.com/astaxie/beego/config" +) + +var ( + g_amazons3Endpoint string + g_amazons3Bucket string + g_amazons3AccessKeyID string + g_amazons3AccessKeySecret string +) + +func init() { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Errorf("read env GOPATH fail") + os.Exit(1) + } + err := amazons3getconfig(gopath + "/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + fmt.Errorf("read config file conf/runtime.conf fail:" + err.Error()) + os.Exit(1) + } + g_injector.Bind("amazons3cloudsave", amazons3cloudsave) +} + +func amazons3getconfig(conffile string) (err error) { + conf, err := config.NewConfig("ini", conffile) + if err != nil { + return err + } + + g_amazons3Endpoint = conf.String("amazons3cloud::endpoint") + if g_amazons3Endpoint == "" { + return errors.New("read config file's endpoint failed!") + } + + g_amazons3Bucket = conf.String("amazons3cloud::bucket") + if g_amazons3Bucket == "" { + return errors.New("read config file's bucket failed!") + } + + g_amazons3AccessKeyID = conf.String("amazons3cloud::accessKeyID") + if g_amazons3AccessKeyID == "" { + return errors.New("read config file's accessKeyID failed!") + } + + g_amazons3AccessKeySecret = conf.String("amazons3cloud::accessKeysecret") + if g_amazons3AccessKeySecret == "" { + return errors.New("read config file's accessKeysecret failed!") + } + return nil +} + +func amazons3cloudsave(file string) (url string, err error) { + + var key string + //get the filename from the file , eg,get "1.txt" from /home/liugenping/1.txt + for _, key = range strings.Split(file, "/") { + + } + + fin, err := os.Open(file) + if err != nil { + return "", err + } + defer fin.Close() + var fi os.FileInfo + fi, err = fin.Stat() + if err != nil { + return "", err + } + filesize := fi.Size() + + requstUrl := "http://" + g_amazons3Bucket + "." + g_amazons3Endpoint + "/" + key + r, _ := http.NewRequest("PUT", requstUrl, fin) + r.ContentLength = int64(filesize) + r.Header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) + r.Header.Set("X-Amz-Acl", "public-read") + + amazons3Sign(r, key, g_amazons3AccessKeyID, g_amazons3AccessKeySecret) + _, err = http.DefaultClient.Do(r) + if err != nil { + return "", err + } + + url = "http://" + g_amazons3Endpoint + "/" + g_amazons3Bucket + "/" + key + return url, nil + +} + +func amazons3Sign(r *http.Request, key string, accessKeyId string, accessKeySecret string) { + + plainText := amazons3cloudMakePlainText(r, key) + h := hmac.New(sha1.New, []byte(accessKeySecret)) + h.Write([]byte(plainText)) + sign := base64.StdEncoding.EncodeToString(h.Sum(nil)) + r.Header.Set("Authorization", "AWS "+accessKeyId+":"+sign) +} + +func amazons3cloudMakePlainText(r *http.Request, key string) (plainText string) { + + plainText = r.Method + "\n" + plainText += r.Header.Get("content-md5") + "\n" + plainText += r.Header.Get("content-type") + "\n" + if _, ok := r.Header["X-Amz-Date"]; !ok { + plainText += r.Header.Get("date") + "\n" + } + + amzHeader := getAmzHeaders(r) + if amzHeader != "" { + plainText += amzHeader + "\n" + } + + plainText += "/" + g_amazons3Bucket + "/" + key + return +} + +func getAmzHeaders(r *http.Request) (amzHeader string) { + var keys []string + for k, _ := range r.Header { + if strings.HasPrefix(strings.ToLower(k), "x-amz-") { + keys = append(keys, k) + } + } + + sort.Strings(keys) + var a []string + for _, k := range keys { + v := r.Header[k] + a = append(a, strings.ToLower(k)+":"+strings.Join(v, ",")) + } + for _, h := range a { + + return h + } + return "" +} diff --git a/backend/amazons3cloud_test.go b/backend/amazons3cloud_test.go new file mode 100644 index 0000000..f62c865 --- /dev/null +++ b/backend/amazons3cloud_test.go @@ -0,0 +1,23 @@ +package backend + +import ( + "os" + "testing" +) + +func Test_amazons3cloudsave(t *testing.T) { + + var gopath string + gopath = os.Getenv("GOPATH") + if gopath == "" { + t.Error("read env GOPATH fail") + return + } + file := gopath + "/src/github.com/containerops/dockyard/backend/amazons3cloud_test.go" + url, err := amazons3cloudsave(file) + if err != nil { + t.Error(err) + return + } + t.Log(url) +} diff --git a/backend/backend.go b/backend/backend.go index 2480943..06ed4a6 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -1 +1,160 @@ package backend + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "reflect" + "sync" + + "github.com/astaxie/beego/config" +) + +type In struct { + Key string `json:"key"` + Uploadfile string `json:"uploadfile"` +} + +type OutSuccess struct { + Key string `json:"key"` + Uploadfile string `json:"uploadfile"` + Downloadurl string `json:"downloadurl"` +} + +type ShareChannel struct { + In chan string + OutSuccess chan string + OutFailure chan string + ExitFlag bool + waitGroup *sync.WaitGroup +} + +var channelSize = 200 + +//reflect struct +var g_injector = NewInjector(50) +var g_driver string + +func init() { + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Println("read env GOPATH fail") + os.Exit(1) + } + conf, err := config.NewConfig("ini", gopath+"/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + fmt.Println(fmt.Errorf("read conf/runtime.conf fail: %v", err).Error()) + os.Exit(1) + } + + g_driver = conf.String("backend::driver") + if g_driver == "" { + fmt.Println("read config file's dirver failed!") + os.Exit(1) + } +} + +func NewShareChannel() *ShareChannel { + return &ShareChannel{make(chan string, channelSize), + make(chan string, channelSize), + make(chan string, channelSize), false, new(sync.WaitGroup)} +} +func (sc *ShareChannel) PutIn(jsonObj string) { + sc.In <- jsonObj +} + +func (sc *ShareChannel) getIn() (jsonObj string) { + return <-sc.In +} + +func (sc *ShareChannel) putOutSuccess(jsonObj string) { + sc.OutSuccess <- jsonObj +} + +func (sc *ShareChannel) GutOutSuccess() (jsonObj string) { + return <-sc.OutSuccess +} + +func (sc *ShareChannel) putOutFailure(jsonObj string) { + sc.OutFailure <- jsonObj +} + +func (sc *ShareChannel) GutOutFailure() (jsonObj string) { + return <-sc.OutFailure +} + +func (sc *ShareChannel) Open() { + sc.waitGroup.Add(1) + go func() { + for !sc.ExitFlag { + obj := sc.getIn() + outJson, err := Save(obj) + if nil != err { + //fmt.Println(err) + sc.putOutFailure(obj) + } else { + sc.putOutSuccess(outJson) + } + } + sc.waitGroup.Done() + }() +} + +func (sc *ShareChannel) Close() { + sc.ExitFlag = true + sc.waitGroup.Wait() + + for f := true; f; { + select { + case obj := <-sc.In: + outJson, err := Save(obj) + if nil != err { + //fmt.Println(err) + sc.putOutFailure(obj) + } else { + sc.putOutSuccess(outJson) + } + default: + f = false + } + } + + close(sc.In) + //close(sc.OutSuccess) + //close(sc.OutFail) +} + +func Save(jsonIn string) (jsonOut string, err error) { + + var url string + var rt []reflect.Value + in := In{} + var jsonTempOut []byte + + err = json.Unmarshal([]byte(jsonIn), &in) + if nil != err { + return "", err + } + + rt, err = g_injector.Call(g_driver+"save", in.Uploadfile) + if nil != err { + return "", err + } + + if !rt[1].IsNil() { + errstr := rt[1].MethodByName("Error").Call(nil)[0].String() + if errstr != "" { + return "", errors.New(errstr) + } + + } + url = rt[0].String() + + outSuccess := &OutSuccess{Key: in.Key, Uploadfile: in.Uploadfile, Downloadurl: url} + jsonTempOut, err = json.Marshal(outSuccess) + if err != nil { + return "", err + } + return string(jsonTempOut), nil +} diff --git a/backend/backend_test.go b/backend/backend_test.go new file mode 100644 index 0000000..4c4e3db --- /dev/null +++ b/backend/backend_test.go @@ -0,0 +1,44 @@ +package backend + +import ( + "encoding/json" + "os" + "testing" +) + +func Test_backend_put(t *testing.T) { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + t.Error("read env GOPATH fail") + return + } + file := gopath + "/src/github.com/containerops/dockyard/backend/backend_test.go" + + in := &In{Key: "asdf8976485r32r613879rwegfuiwet739ruwef", Uploadfile: file} + jsonIn, err := json.Marshal(in) + if err != nil { + t.Error(err) + return + } + + sc := NewShareChannel() + sc.Open() + + for i := 0; i < 2; i++ { + sc.PutIn(string(jsonIn)) + } + sc.Close() + + for f := true; f; { + select { + case obj := <-sc.OutSuccess: + t.Log(obj) + case obj := <-sc.OutFailure: + t.Error(obj) + default: + f = false + } + } + +} diff --git a/backend/googlecloud.go b/backend/googlecloud.go new file mode 100644 index 0000000..8dff672 --- /dev/null +++ b/backend/googlecloud.go @@ -0,0 +1,109 @@ +package backend + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "strings" + + "github.com/astaxie/beego/config" + "github.com/google/google-api-go-client/storage/v1" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/jwt" + //"github.com/google/google-api-go-client/storage/v1" +) + +var ( + projectID string + bucket string + scope string + privateKey []byte + clientEmail string +) + +func init() { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Errorf("read env GOPATH fail") + os.Exit(1) + } + //Reading config file named conf/runtime.conf for backend + conf, err := config.NewConfig("ini", gopath+"/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + log.Fatalf("GCS reading conf/runtime.conf err %v", err) + } + + if projectID = conf.String("Googlecloud::projectid"); projectID == "" { + log.Fatalf("GCS reading conf/runtime.conf, GCS get projectID err, is nil") + } + + //Get config var for jsonKeyFile, bucketName, projectID, which should be used later in oauth and get obj + if bucket = conf.String("Googlecloud::bucket"); bucket == "" { + log.Fatalf("GCS reading conf/runtime.conf, GCS get bucket err, is nil") + } + + if scope = conf.String("Googlecloud::scope"); scope == "" { + log.Fatalf("GCS reading conf/runtime.conf, GCS get privateKey err, is nil") + } + + var privateKeyFile string + if privateKeyFile = conf.String("Googlecloud::privatekey"); privateKeyFile == "" { + log.Fatalf("GCS reading conf/runtime.conf, GCS get privateKey err, is nil") + } + privateKey, err = ioutil.ReadFile(gopath + "/src/github.com/containerops/dockyard/conf/" + privateKeyFile) + if err != nil { + log.Fatal(err) + } + + if clientEmail = conf.String("Googlecloud::clientemail"); clientEmail == "" { + log.Fatalf("GCS reading conf/runtime.conf, GCS get clientEmail err, is nil") + } + + g_injector.Bind("googlecloudsave", googlecloudsave) +} + +func googlecloudsave(file string) (url string, err error) { + + s := []string{scope} + + conf := jwt.Config{ + Email: clientEmail, + PrivateKey: privateKey, + Scopes: s, + TokenURL: google.JWTTokenURL, + } + + //new storage service and token, we dont need context here + client := conf.Client(oauth2.NoContext) + gcsToken, err := conf.TokenSource(oauth2.NoContext).Token() + service, err := storage.New(client) + if err != nil { + log.Fatalf("GCS unable to create storage service: %v", err) + } + + //Split filename as a objectName + var objectName string + for _, objectName = range strings.Split(file, "/") { + } + object := &storage.Object{Name: objectName} + + // Insert an object into a bucket. + fileDes, err := os.Open(file) + if err != nil { + log.Fatalf("Error opening %q: %v", file, err) + } + objs, err := service.Objects.Insert(bucket, object).Media(fileDes).Do() + if err != nil { + log.Fatalf("GCS Objects.Insert failed: %v", err) + } + retUrl := objs.MediaLink + "&access_token=" + gcsToken.AccessToken + + if err != nil { + return "", err + } else { + return retUrl, nil + } +} diff --git a/backend/googlecloud_test.go b/backend/googlecloud_test.go new file mode 100644 index 0000000..2aeffb3 --- /dev/null +++ b/backend/googlecloud_test.go @@ -0,0 +1,78 @@ +package backend + +import ( + "io" + "io/ioutil" + "net/http" + "os" + "strings" + "testing" +) + +var ( + upFileName string = "/tmp/gcs_test.txt" + downFileName string = "/tmp/new_gcs_test.txt" + fileContent string = "Just for test gcs.\n Congratulations! U are sucess." + //retUrl_tmp string +) + +func newTestFile(t *testing.T) (f *os.File, err error) { + file, err := os.Create(upFileName) + if err != nil { + t.Error(err) + } + + ret, err := file.WriteString(fileContent) + if err != nil { + t.Error(err) + t.Fatalf("GCS_TEST Write String ret = %v", ret) + } + if err != nil { + return nil, err + } else { + return file, nil + } +} + +// Unit Test for gcs +func TestGcssave(t *testing.T) { + file, err := newTestFile(t) + if err != nil { + t.Error(err) + } + + retUrl, err := googlecloudsave(upFileName) + if err != nil { + t.Error(err) + } + //retUrl_tmp = retUrl + + resp, err := http.Get(retUrl) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + // Open file for writing + nFile, err := os.Create(downFileName) + if err != nil { + t.Error(err) + } + + // Use io.Copy to copy a file from URL to a locald disk + _, err = io.Copy(nFile, resp.Body) + if err != nil { + t.Error(err) + } + + buf, err := ioutil.ReadFile(downFileName) + if err != nil { + t.Error(err) + } + file.Close() + + isEqual := strings.EqualFold(fileContent, string(buf)) + if !isEqual { + t.Fatalf("Testing fail, content of uploadFile is not the same as the content of downloadFile") + } +} diff --git a/backend/injector.go b/backend/injector.go new file mode 100644 index 0000000..0948872 --- /dev/null +++ b/backend/injector.go @@ -0,0 +1,35 @@ +package backend + +import ( + "errors" + "reflect" +) + +type Injector map[string]reflect.Value + +func NewInjector(size int) Injector { + return make(Injector, size) + +} + +func (inj Injector) Bind(name string, fn interface{}) { + v := reflect.ValueOf(fn) + inj[name] = v +} + +func (inj Injector) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + if _, ok := inj[name]; !ok { + err = errors.New(name + " does not exist.") + return + } + if len(params) != inj[name].Type().NumIn() { + err = errors.New("The number of params is not adapted.") + return + } + in := make([]reflect.Value, len(params)) + for k, param := range params { + in[k] = reflect.ValueOf(param) + } + result = inj[name].Call(in) + return +} diff --git a/backend/qiniu.go b/backend/qiniu.go new file mode 100644 index 0000000..aaa3c22 --- /dev/null +++ b/backend/qiniu.go @@ -0,0 +1,92 @@ +package backend + +import ( + "errors" + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/config" + "github.com/qiniu/api/conf" + "github.com/qiniu/api/io" + "github.com/qiniu/api/rs" +) + +var ( + g_qiniuEndpoint string + g_qiniuBucket string + g_qiniuAccessKeyID string + g_qiniuAccessKeySecret string +) + +func init() { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Errorf("read env GOPATH fail") + os.Exit(1) + } + err := qiniugetconfig(gopath + "/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + fmt.Errorf("read config file conf/runtime.conf fail:" + err.Error()) + os.Exit(0) + } + + conf.ACCESS_KEY = g_qiniuAccessKeyID + conf.SECRET_KEY = g_qiniuAccessKeySecret + + g_injector.Bind("qiniucloudsave", qiniucloudsave) +} + +func qiniugetconfig(conffile string) (err error) { + var conf config.ConfigContainer + conf, err = config.NewConfig("ini", conffile) + if err != nil { + return err + } + + g_qiniuEndpoint = conf.String("qiniucloud::endpoint") + if g_qiniuEndpoint == "" { + return errors.New("read config file's endpoint failed!") + } + + g_qiniuBucket = conf.String("qiniucloud::bucket") + if g_qiniuBucket == "" { + return errors.New("read config file's bucket failed!") + } + + g_qiniuAccessKeyID = conf.String("qiniucloud::accessKeyID") + if g_qiniuAccessKeyID == "" { + return errors.New("read config file's accessKeyID failed!") + } + + g_qiniuAccessKeySecret = conf.String("qiniucloud::accessKeysecret") + if g_qiniuAccessKeySecret == "" { + return errors.New("read config file's accessKeysecret failed!") + } + return nil +} + +func qiniucloudsave(file string) (url string, err error) { + + var key string + //get the filename from the file , eg,get "1.txt" from /home/liugenping/1.txt + for _, key = range strings.Split(file, "/") { + + } + + url = "http://" + g_qiniuEndpoint + "/" + key + + putPolicy := rs.PutPolicy{Scope: g_qiniuBucket} + uptoken := putPolicy.Token(nil) + + var ret io.PutRet + var extra = &io.PutExtra{} + err = io.PutFile(nil, &ret, uptoken, key, file, extra) + if err != nil { + return "", err + } else { + return url, nil + } + +} diff --git a/backend/qiniu_test.go b/backend/qiniu_test.go new file mode 100644 index 0000000..6d9b60a --- /dev/null +++ b/backend/qiniu_test.go @@ -0,0 +1,40 @@ +package backend + +import ( + "net/http" + "os" + "testing" + + "github.com/qiniu/api/conf" +) + +/* +func expect(t *testing.T, a interface{}, b interface{}) { + if a != b { + t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) + } +} +*/ +func Test_qiniucloudsave(t *testing.T) { + + var gopath string + gopath = os.Getenv("GOPATH") + if gopath == "" { + t.Error("read env GOPATH fail") + return + } + + conf.ACCESS_KEY = g_qiniuAccessKeyID + conf.SECRET_KEY = g_qiniuAccessKeySecret + + file := gopath + "/src/github.com/containerops/dockyard/backend/qiniu.go" + url, err := qiniucloudsave(file) + if err != nil { + t.Error(err) + return + } + _, err = http.Get(url) + if err != nil { + t.Error(err) + } +} diff --git a/backend/tencentcloud.go b/backend/tencentcloud.go new file mode 100644 index 0000000..3ffcfa4 --- /dev/null +++ b/backend/tencentcloud.go @@ -0,0 +1,192 @@ +package backend + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "sort" + "strings" + "time" + + "github.com/astaxie/beego/config" +) + +var ( + g_tencentEndpoint string + g_tencentAccessID string + g_tencentBucket string + g_tencentAccessKeyID string + g_tencentAccessKeySecret string +) + +func init() { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Errorf("read env GOPATH fail") + os.Exit(1) + } + err := tencentgetconfig(gopath + "/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + fmt.Errorf("read config file conf/runtime.conf fail:" + err.Error()) + os.Exit(0) + } + + // 用户定义变量 + /* + g_tencentEndpoint = "cosapi.myqcloud.com" + + g_tencentAccessID = "11000464" + g_tencentBucket = "test" + g_tencentAccessKeySecret = "4ceCa4wNP10c40QPPDgXdfx5MhvuCBWG" + g_tencentAccessKeyID = "AKIDBxM1SkbDzdEtLED1KeQhW8HjW5qRu2R5" + */ + + g_injector.Bind("tencentcloudsave", tencentcloudsave) + +} + +func tencentgetconfig(conffile string) (err error) { + var conf config.ConfigContainer + conf, err = config.NewConfig("ini", conffile) + if err != nil { + return err + } + + g_tencentEndpoint = conf.String("tencentcloud::endpoint") + if g_tencentEndpoint == "" { + return errors.New("read config file's endpoint failed!") + } + + g_tencentAccessID = conf.String("tencentcloud::accessID") + if g_tencentAccessID == "" { + return errors.New("read config file's accessID failed!") + } + + g_tencentBucket = conf.String("tencentcloud::bucket") + if g_tencentBucket == "" { + return errors.New("read config file's bucket failed!") + } + + g_tencentAccessKeyID = conf.String("tencentcloud::accessKeyID") + if g_tencentAccessKeyID == "" { + return errors.New("read config file's accessKeyID failed!") + } + + g_tencentAccessKeySecret = conf.String("tencentcloud::accessKeysecret") + if g_tencentAccessKeySecret == "" { + return errors.New("read config file's accessKeysecret failed!") + } + return nil +} + +func makePlainText(api string, params map[string]interface{}) (plainText string) { + // sort + keys := make([]string, 0, len(params)) + for k, _ := range params { + keys = append(keys, k) + } + sort.Strings(keys) + + var plainParms string + for i := range keys { + k := keys[i] + plainParms += "&" + fmt.Sprintf("%v", k) + "=" + fmt.Sprintf("%v", params[k]) + } + if api != "" { + plainText = "/" + api + "&" + plainParms[1:] + } else { + plainText = plainParms[1:] + } + + plainText = url.QueryEscape(plainText) + + return plainText +} + +func sign(plainText string, secretKey string) (sign string) { + hmacObj := hmac.New(sha1.New, []byte(secretKey)) + hmacObj.Write([]byte(plainText)) + sign = base64.StdEncoding.EncodeToString(hmacObj.Sum(nil)) + return +} + +func tencentcloudsave(file string) (url string, err error) { + + var key string + //get the filename from the file , eg,get "1.txt" from /home/liugenping/1.txt + for _, key = range strings.Split(file, "/") { + + } + + fin, err := os.Open(file) + if err != nil { + return "", err + } + defer fin.Close() + var fi os.FileInfo + fi, err = fin.Stat() + if err != nil { + return "", err + } + filesize := fi.Size() + + params := map[string]interface{}{} + fileName := key + + ////// + api := "api/cos_upload" + params["accessId"] = g_tencentAccessID + params["bucketId"] = g_tencentBucket + params["secretId"] = g_tencentAccessKeyID + params["cosFile"] = fileName + params["path"] = "/" + time := fmt.Sprintf("%v", time.Now().Unix()) + params["time"] = time + ///// + + uploadPlainText := makePlainText(api, params) + + sign := sign(uploadPlainText, g_tencentAccessKeySecret) + + var requstUrl string + requstUrl = "http://" + g_tencentEndpoint + "/" + api + "?bucketId=" + g_tencentBucket + "&cosFile=" + fileName + "&path=%2F" + "&accessId=" + g_tencentAccessID + "&secretId=" + g_tencentAccessKeyID + "&time=" + time + "&sign=" + sign + + req, _ := http.NewRequest("POST", requstUrl, fin) + req.Body = fin + req.ContentLength = filesize + client := &http.Client{} + _, err = client.Do(req) + if err != nil { + return "", err + } + + downloadUrl := tencentGetDownloadUrl(fileName) + + return downloadUrl, nil +} + +func tencentGetDownloadUrl(fileName string) (downloadUrl string) { + + params := map[string]interface{}{} + + ////// + params["accessId"] = g_tencentAccessID + params["bucketId"] = g_tencentBucket + params["secretId"] = g_tencentAccessKeyID + params["path"] = "/" + fileName + time := fmt.Sprintf("%v", time.Now().Unix()) + params["time"] = time + + downloadPlainText := makePlainText("", params) + sign := sign(downloadPlainText, g_tencentAccessKeySecret) + url := "cos.myqcloud.com/" + g_tencentAccessID + "/" + g_tencentBucket + "/" + fileName + "?" + "secretId=" + g_tencentAccessKeyID + "&time=" + time + url += "&sign=" + sign + return url + +} diff --git a/backend/tencentcloud_test.go b/backend/tencentcloud_test.go new file mode 100644 index 0000000..e963776 --- /dev/null +++ b/backend/tencentcloud_test.go @@ -0,0 +1,23 @@ +package backend + +import ( + "os" + "testing" +) + +func Test_tencentcloudsave(t *testing.T) { + + var gopath string + gopath = os.Getenv("GOPATH") + if gopath == "" { + t.Error("read env GOPATH fail") + return + } + file := gopath + "/src/github.com/containerops/dockyard/backend/tencentcloud_test.go" + url, err := tencentcloudsave(file) + if err != nil { + t.Error(err) + return + } + t.Log(url) +} diff --git a/backend/upyun.go b/backend/upyun.go new file mode 100644 index 0000000..e919e0f --- /dev/null +++ b/backend/upyun.go @@ -0,0 +1,102 @@ +package backend + +import ( + "errors" + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/config" + "github.com/upyun/go-sdk/upyun" +) + +var ( + g_upEndpoint string + g_upBucket string + g_upUser string + g_upPasswd string +) + +func init() { + + gopath := os.Getenv("GOPATH") + if gopath == "" { + fmt.Errorf("read env GOPATH fail") + os.Exit(1) + } + err := upgetconfig(gopath + "/src/github.com/containerops/dockyard/conf/runtime.conf") + if err != nil { + fmt.Errorf("read config file conf/runtime.conf fail:" + err.Error()) + os.Exit(1) + } + + g_injector.Bind("upcloudsave", upcloudsave) +} + +func upgetconfig(conffile string) (err error) { + var conf config.ConfigContainer + conf, err = config.NewConfig("ini", conffile) + if err != nil { + return err + } + + g_upEndpoint = conf.String("upCloud::endpoint") + if g_upEndpoint == "" { + return errors.New("read config file's endpoint failed!") + } + + g_upBucket = conf.String("upCloud::bucket") + if g_upBucket == "" { + return errors.New("read config file's bucket failed!") + } + + g_upUser = conf.String("upCloud::user") + if g_upUser == "" { + return errors.New("read config file's user failed!") + } + + g_upPasswd = conf.String("upCloud::passwd") + if g_upPasswd == "" { + return errors.New("read config file's passwd failed!") + + } + return nil +} + +func upcloudsave(file string) (url string, err error) { + + var key string + //get the filename from the file , eg,get "1.txt" from "/home/liugenping/1.txt" + for _, key = range strings.Split(file, "/") { + + } + opath := "/" + g_upBucket + "/" + key + url = "http://" + g_upEndpoint + opath + + var u *upyun.UpYun + u = upyun.NewUpYun(g_upBucket, g_upUser, g_upPasswd) + if nil == u { + return "", errors.New("UpYun.NewUpYun Fail") + } + + /* Endpoint list: + Auto = "v0.api.upyun.com" + Telecom = "v1.api.upyun.com" + Cnc = "v2.api.upyun.com" + Ctt = "v3.api.upyun.com" + */ + u.SetEndpoint(g_upEndpoint) + + fin, err := os.Open(file) + if err != nil { + return "", err + } + defer fin.Close() + + _, err = u.Put(key, fin, false, "") + if err != nil { + return "", err + } + return url, nil + +} diff --git a/backend/upyun_test.go b/backend/upyun_test.go new file mode 100644 index 0000000..08886fc --- /dev/null +++ b/backend/upyun_test.go @@ -0,0 +1,28 @@ +package backend + +import ( + "net/http" + "os" + "testing" +) + +func Test_upcloudsave(t *testing.T) { + + var gopath string + gopath = os.Getenv("GOPATH") + if gopath == "" { + t.Error("read env GOPATH fail") + return + } + + file := gopath + "/src/github.com/containerops/dockyard/backend/upyun.go" + url, err := upcloudsave(file) + if err != nil { + t.Error(err) + return + } + _, err = http.Get(url) + if err != nil { + t.Error(err) + } +} diff --git a/cmd/backend.go b/cmd/backend.go deleted file mode 100644 index 6847e55..0000000 --- a/cmd/backend.go +++ /dev/null @@ -1,17 +0,0 @@ -package cmd - -import ( - "github.com/codegangsta/cli" -) - -var CmdBackend = cli.Command{ - Name: "backend", - Usage: "处理 dockyard 的后端存储服务", - Description: "dockyard 支持使用一个或多个存储服务, 国内服务支持七牛、又拍、阿里云和腾讯云,国外服务支持亚马逊和谷歌云服务。", - Action: runBackend, - Flags: []cli.Flag{}, -} - -func runBackend(c *cli.Context) { - -} diff --git a/cmd/database.go b/cmd/database.go deleted file mode 100644 index 30e5649..0000000 --- a/cmd/database.go +++ /dev/null @@ -1,15 +0,0 @@ -package cmd - -import ( - "github.com/codegangsta/cli" -) - -var CmdDatabase = cli.Command{ - Name: "db", - Usage: "处理 dockyard 程序的数据库创建、备份和恢复等数据库维护", - Description: "dockyard 使用 RebornDB 和 Redis 处理持久化数据", - Action: runDatabase, - Flags: []cli.Flag{}, -} - -func runDatabase(c *cli.Context) {} diff --git a/cmd/web.go b/cmd/web.go index 2f612d2..2552db4 100644 --- a/cmd/web.go +++ b/cmd/web.go @@ -7,30 +7,39 @@ import ( "net/http" "os" - "github.com/codegangsta/cli" - "github.com/Unknwon/macaron" + "github.com/codegangsta/cli" - "github.com/containerops/dockyard/setting" - "github.com/containerops/dockyard/utils" "github.com/containerops/dockyard/web" + "github.com/containerops/wrench/setting" + "github.com/containerops/wrench/utils" ) var CmdWeb = cli.Command{ Name: "web", - Usage: "启动 dockyard 的 Web 服务", - Description: "dockyard 提供 Docker 镜像仓库存储服务。", + Usage: "start dockyard web service", + Description: "dockyard is the module of handler docker and rkt image.", Action: runWeb, Flags: []cli.Flag{ cli.StringFlag{ Name: "address", Value: "0.0.0.0", - Usage: "Web 服务监听的 IP,默认 0.0.0.0;如果使用 Unix Socket 模式是 sock 文件的路径", + Usage: "web service listen ip, default is 0.0.0.0; if listen with Unix Socket, the value is sock file path.", }, cli.IntFlag{ Name: "port", Value: 80, - Usage: "Web 服务监听的端口,默认 80", + Usage: "web service listen at port 80; if run with https will be 443.", + }, + cli.StringFlag{ + Name: "backend", + Value: "", + Usage: "Start object storage service in the backend of service.", + }, + cli.StringFlag{ + Name: "driver", + Value: "", + Usage: "Backend object storage driver, like s3, qiniu...", }, }, } @@ -45,30 +54,28 @@ func runWeb(c *cli.Context) { case "http": listenaddr := fmt.Sprintf("%s:%d", c.String("address"), c.Int("port")) if err := http.ListenAndServe(listenaddr, m); err != nil { - fmt.Printf("启动 dockyard 的 HTTP 服务错误: %v", err) + fmt.Printf("start dockyard http service error: %v", err.Error()) } break case "https": - //HTTPS 强制使用 443 端口 listenaddr := fmt.Sprintf("%s:443", c.String("address")) server := &http.Server{Addr: listenaddr, TLSConfig: &tls.Config{MinVersion: tls.VersionTLS10}, Handler: m} if err := server.ListenAndServeTLS(setting.HttpsCertFile, setting.HttpsKeyFile); err != nil { - fmt.Printf("启动 dockyard 的 HTTPS 服务错误: %v", err) + fmt.Printf("start dockyard https service error: %v", err.Error()) } break case "unix": listenaddr := fmt.Sprintf("%s", c.String("address")) - //如果存在 Unix Socket 文件就删除 - if utils.Exist(listenaddr) { + if utils.IsFileExist(listenaddr) { os.Remove(listenaddr) } if listener, err := net.Listen("unix", listenaddr); err != nil { - fmt.Printf("启动 dockyard 的 Unix Socket 监听错误: %v", err) + fmt.Printf("start dockyard unix socket error: %v", err.Error()) } else { server := &http.Server{Handler: m} if err := server.Serve(listener); err != nil { - fmt.Printf("启动 dockyard 的 Unix Socket 监听错误: %v", err) + fmt.Printf("start dockyard unix socket error: %v", err.Error()) } } break diff --git a/conf/dockyard.conf b/conf/containerops.conf similarity index 71% rename from conf/dockyard.conf rename to conf/containerops.conf index cb91b24..725bafc 100644 --- a/conf/dockyard.conf +++ b/conf/containerops.conf @@ -1,7 +1,7 @@ appname = dockyard usage = Rocket & Docker Repository Pubic / Private Service version = 0.0.1 -author = Meaglith Ma -email = genedna@gmail.com +author = Marvell Ma +email = bin.ma@huawei.com include runtime.conf diff --git a/static/.gitkeep b/external/.gitkeep similarity index 100% rename from static/.gitkeep rename to external/.gitkeep diff --git a/handler/blob.go b/handler/blob.go new file mode 100644 index 0000000..ccb6350 --- /dev/null +++ b/handler/blob.go @@ -0,0 +1,182 @@ +package handler + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "strings" + "time" + + "github.com/Unknwon/macaron" + "github.com/astaxie/beego/logs" + "github.com/satori/go.uuid" + + "github.com/containerops/dockyard/models" + "github.com/containerops/dockyard/module" + "github.com/containerops/wrench/setting" + "github.com/containerops/wrench/utils" +) + +func HeadBlobsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + digest := ctx.Params(":digest") + tarsum := strings.Split(digest, ":")[1] + + i := new(models.Image) + if has, _ := i.HasTarsum(tarsum); has == false { + log.Info("[REGISTRY API V2] Tarsum not found: %v", tarsum) + + result, _ := json.Marshal(map[string]string{"message": "Tarsum not found"}) + return http.StatusNotFound, result + } + + ctx.Resp.Header().Set("Content-Type", "application/x-gzip") + ctx.Resp.Header().Set("Docker-Content-Digest", digest) + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(i.Size)) + + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} + +func PostBlobsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + uuid := utils.MD5(uuid.NewV4().String()) + state := utils.MD5(fmt.Sprintf("%s/%s/%s", namespace, repository, time.Now().UnixNano()/int64(time.Millisecond))) + random := fmt.Sprintf("https://%s/v2/%s/%s/blobs/uploads/%s?_state=%s", + setting.Domains, + namespace, + repository, + uuid, + state) + + ctx.Resp.Header().Set("Docker-Upload-Uuid", uuid) + ctx.Resp.Header().Set("Location", random) + ctx.Resp.Header().Set("Range", "0-0") + + result, _ := json.Marshal(map[string]string{}) + return http.StatusAccepted, result +} + +func PatchBlobsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + desc := ctx.Params(":uuid") + uuid := strings.Split(desc, "?")[0] + + imagePathTmp := fmt.Sprintf("%v/%v", setting.ImagePath, uuid) + layerfileTmp := fmt.Sprintf("%v/%v/layer", setting.ImagePath, uuid) + + //saving specific tarsum every times is in order to split the same tarsum in HEAD handler + if !utils.IsDirExist(imagePathTmp) { + os.MkdirAll(imagePathTmp, os.ModePerm) + } + + if _, err := os.Stat(layerfileTmp); err == nil { + os.Remove(layerfileTmp) + } + + data, _ := ioutil.ReadAll(ctx.Req.Request.Body) + if err := ioutil.WriteFile(layerfileTmp, data, 0777); err != nil { + log.Error("[REGISTRY API V2] Save layerfile failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Save layerfile failed"}) + return http.StatusBadRequest, result + } + + state := utils.MD5(fmt.Sprintf("%s/%s/%s", namespace, repository, time.Now().UnixNano()/int64(time.Millisecond))) + random := fmt.Sprintf("https://%s/v2/%s/%s/blobs/uploads/%s?_state=%s", + setting.Domains, + namespace, + repository, + uuid, + state) + + ctx.Resp.Header().Set("Docker-Upload-Uuid", uuid) + ctx.Resp.Header().Set("Location", random) + ctx.Resp.Header().Set("Range", fmt.Sprintf("0-%v", len(data)-1)) + + result, _ := json.Marshal(map[string]string{}) + return http.StatusAccepted, result +} +func PutBlobsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + desc := ctx.Params(":uuid") + uuid := strings.Split(desc, "?")[0] + + digest := ctx.Query("digest") + tarsum := strings.Split(digest, ":")[1] + + imagePathTmp := fmt.Sprintf("%v/%v", setting.ImagePath, uuid) + layerfileTmp := fmt.Sprintf("%v/%v/layer", setting.ImagePath, uuid) + imagePath := fmt.Sprintf("%v/tarsum/%v", setting.ImagePath, tarsum) + layerfile := fmt.Sprintf("%v/tarsum/%v/layer", setting.ImagePath, tarsum) + layerlen, err := modules.CopyImgLayer(imagePathTmp, layerfileTmp, imagePath, layerfile, ctx.Req.Request.Body) + if err != nil { + log.Error("[REGISTRY API V2] Save layerfile failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Save layerfile failed"}) + return http.StatusBadRequest, result + } + + //saving specific tarsum every times is in order to split the same tarsum in HEAD handler + i := new(models.Image) + i.Path, i.Size = layerfile, int64(layerlen) + if err := i.PutTarsum(tarsum); err != nil { + log.Error("[REGISTRY API V2] Save tarsum failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Save tarsum failed"}) + return http.StatusBadRequest, result + } + + random := fmt.Sprintf("https://%s/v2/%s/%s/blobs/%s", + setting.Domains, + ctx.Params(":namespace"), + ctx.Params(":repository"), + digest) + + ctx.Resp.Header().Set("Docker-Content-Digest", digest) + ctx.Resp.Header().Set("Location", random) + + result, _ := json.Marshal(map[string]string{}) + return http.StatusCreated, result +} + +func GetBlobsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + digest := ctx.Params(":digest") + + tarsum := strings.Split(digest, ":")[1] + + i := new(models.Image) + has, _ := i.HasTarsum(tarsum) + if has == false { + log.Error("[REGISTRY API V2] Digest not found: %v", tarsum) + + result, _ := json.Marshal(map[string]string{"message": "Digest not found"}) + return http.StatusNotFound, result + } + + layerfile := i.Path + if _, err := os.Stat(layerfile); err != nil { + log.Error("[REGISTRY API V2] File path is invalid: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "File path is invalid"}) + return http.StatusBadRequest, result + } + + file, err := ioutil.ReadFile(layerfile) + if err != nil { + log.Error("[REGISTRY API V2] Read file failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Read file failed"}) + return http.StatusBadRequest, result + } + + ctx.Resp.Header().Set("Content-Type", "application/x-gzip") + ctx.Resp.Header().Set("Docker-Content-Digest", digest) + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(len(file))) + + return http.StatusOK, file +} diff --git a/handler/image.go b/handler/image.go new file mode 100644 index 0000000..864fab1 --- /dev/null +++ b/handler/image.go @@ -0,0 +1,191 @@ +package handler + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "strings" + + "github.com/Unknwon/macaron" + "github.com/astaxie/beego/logs" + + "github.com/containerops/dockyard/models" + "github.com/containerops/wrench/setting" + "github.com/containerops/wrench/utils" +) + +func GetImageAncestryV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + imageId := ctx.Params(":imageId") + + i := new(models.Image) + if has, _, err := i.Has(imageId); err != nil { + log.Error("[REGISTRY API V1] Read Image Ancestry Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Read Image Ancestry Error"}) + return http.StatusBadRequest, result + } else if has == false { + log.Error("[REGISTRY API V1] Read Image None: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Read Image None"}) + return http.StatusNotFound, result + } + + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(len(i.Ancestry))) + + return http.StatusOK, []byte(i.Ancestry) +} + +func GetImageJSONV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + var jsonInfo string + var payload string + var err error + + imageId := ctx.Params(":imageId") + + i := new(models.Image) + if jsonInfo, err = i.GetJSON(imageId); err != nil { + log.Error("[REGISTRY API V1] Search Image JSON Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Search Image JSON Error"}) + return http.StatusNotFound, result + } + + if payload, err = i.GetChecksumPayload(imageId); err != nil { + log.Error("[REGISTRY API V1] Search Image Checksum Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Search Image Checksum Error"}) + return http.StatusNotFound, result + } + + ctx.Resp.Header().Set("X-Docker-Checksum-Payload", fmt.Sprintf("sha256:%v", payload)) + ctx.Resp.Header().Set("X-Docker-Size", fmt.Sprint(i.Size)) + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(len(jsonInfo))) + + return http.StatusOK, []byte(jsonInfo) +} + +func GetImageLayerV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + imageId := ctx.Params(":imageId") + + i := new(models.Image) + if has, _, err := i.Has(imageId); err != nil { + log.Error("[REGISTRY API V1] Read Image Layer File Status Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Read Image Layer file Error"}) + return http.StatusBadRequest, result + } else if has == false { + log.Error("[REGISTRY API V1] Read Image None Error") + + result, _ := json.Marshal(map[string]string{"message": "Read Image None"}) + return http.StatusNotFound, result + } + + layerfile := i.Path + if _, err := os.Stat(layerfile); err != nil { + log.Error("[REGISTRY API V1] Read Image file state error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Read Image file state error"}) + return http.StatusBadRequest, result + } + + file, err := ioutil.ReadFile(layerfile) + if err != nil { + log.Error("[REGISTRY API V1] Read Image file error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Read Image file error"}) + return http.StatusBadRequest, result + } + + ctx.Resp.Header().Set("Content-Type", "application/octet-stream") + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(len(file))) + + return http.StatusOK, file +} + +func PutImageJSONV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + imageId := ctx.Params(":imageId") + + info, err := ctx.Req.Body().String() + if err != nil { + log.Error("[REGISTRY API V1] Get request body error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put V1 image JSON failed,request body is empty"}) + return http.StatusBadRequest, result + } + + i := new(models.Image) + if err := i.PutJSON(imageId, info, setting.APIVERSION_V1); err != nil { + log.Error("[REGISTRY API V1] Put Image JSON Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put Image JSON Error"}) + return http.StatusBadRequest, result + } + + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} + +func PutImageLayerv1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + imageId := ctx.Params(":imageId") + + basePath := setting.ImagePath + imagePath := fmt.Sprintf("%v/images/%v", basePath, imageId) + layerfile := fmt.Sprintf("%v/images/%v/layer", basePath, imageId) + + if !utils.IsDirExist(imagePath) { + os.MkdirAll(imagePath, os.ModePerm) + } + + if _, err := os.Stat(layerfile); err == nil { + os.Remove(layerfile) + } + + data, _ := ioutil.ReadAll(ctx.Req.Request.Body) + if err := ioutil.WriteFile(layerfile, data, 0777); err != nil { + log.Error("[REGISTRY API V1] Put Image Layer File Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put Image Layer File Error"}) + return http.StatusBadRequest, result + } + + i := new(models.Image) + if err := i.PutLayer(imageId, layerfile, true, int64(len(data))); err != nil { + log.Error("[REGISTRY API V1] Put Image Layer File Data Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put Image Layer File Data Error"}) + return http.StatusBadRequest, result + } + + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} + +func PutImageChecksumV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + imageId := ctx.Params(":imageId") + + checksum := strings.Split(ctx.Req.Header.Get("X-Docker-Checksum"), ":")[1] + payload := strings.Split(ctx.Req.Header.Get("X-Docker-Checksum-Payload"), ":")[1] + + log.Debug("[REGISTRY API V1] Image Checksum : %v", checksum) + log.Debug("[REGISTRY API V1] Image Payload: %v", payload) + + i := new(models.Image) + if err := i.PutChecksum(imageId, checksum, true, payload); err != nil { + log.Error("[REGISTRY API V1] Put Image Checksum & Payload Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put Image Checksum & Payload Error"}) + return http.StatusBadRequest, result + } + + if err := i.PutAncestry(imageId); err != nil { + log.Error("[REGISTRY API V1] Put Image Ancestry Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put Image Ancestry Error"}) + return http.StatusBadRequest, result + } + + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} diff --git a/handler/manifests.go b/handler/manifests.go new file mode 100644 index 0000000..ea89966 --- /dev/null +++ b/handler/manifests.go @@ -0,0 +1,119 @@ +package handler + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/Unknwon/macaron" + "github.com/astaxie/beego/logs" + + "github.com/containerops/dockyard/models" + "github.com/containerops/dockyard/module" + "github.com/containerops/wrench/setting" + "github.com/containerops/wrench/utils" +) + +func PutManifestsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + agent := ctx.Req.Header.Get("User-Agent") + + repo := new(models.Repository) + if err := repo.Put(namespace, repository, "", agent, setting.APIVERSION_V2); err != nil { + log.Error("[REGISTRY API V2] Save repository failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": err.Error()}) + return http.StatusBadRequest, result + } + + manifest, _ := ioutil.ReadAll(ctx.Req.Request.Body) + if err := modules.ParseManifest(manifest); err != nil { + log.Error("[REGISTRY API V2] Decode Manifest Error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Manifest converted failed"}) + return http.StatusBadRequest, result + } + + digest, err := utils.DigestManifest(manifest) + if err != nil { + log.Error("[REGISTRY API V2] Get manifest digest failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Get manifest digest failed"}) + return http.StatusBadRequest, result + } + + random := fmt.Sprintf("https://%v/v2/%v/%v/manifests/%v", + setting.Domains, + namespace, + repository, + digest) + + ctx.Resp.Header().Set("Docker-Content-Digest", digest) + ctx.Resp.Header().Set("Location", random) + + result, _ := json.Marshal(map[string]string{}) + return http.StatusAccepted, result +} + +func GetTagsListV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + r := new(models.Repository) + if has, _, err := r.Has(namespace, repository); err != nil || has == false { + log.Error("[REGISTRY API V2] Repository not found: %v", repository) + + result, _ := json.Marshal(map[string]string{"message": "Repository not found"}) + return http.StatusNotFound, result + } + + data := map[string]interface{}{} + tags := []string{} + + data["name"] = fmt.Sprintf("%s/%s", namespace, repository) + + for _, value := range r.Tags { + t := new(models.Tag) + if err := t.GetByKey(value); err != nil { + log.Error("[REGISTRY API V2] Tag not found: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Tag not found"}) + return http.StatusNotFound, result + } + + tags = append(tags, t.Name) + } + + data["tags"] = tags + + result, _ := json.Marshal(data) + return http.StatusOK, result +} + +func GetManifestsV2Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + t := new(models.Tag) + + if err := t.Get(ctx.Params(":namespace"), ctx.Params(":repository"), ctx.Params(":tag")); err != nil { + log.Error("[REGISTRY API V2] Manifest not found: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Manifest not found"}) + return http.StatusNotFound, result + } + + digest, err := utils.DigestManifest([]byte(t.Manifest)) + if err != nil { + log.Error("[REGISTRY API V2] Get manifest digest failed: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Get manifest digest failed"}) + return http.StatusBadRequest, result + } + + ctx.Resp.Header().Set("Content-Type", "application/json; charset=utf-8") + ctx.Resp.Header().Set("Docker-Content-Digest", digest) + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(len(t.Manifest))) + + return http.StatusOK, []byte(t.Manifest) +} diff --git a/handler/ping.go b/handler/ping.go new file mode 100644 index 0000000..2134c41 --- /dev/null +++ b/handler/ping.go @@ -0,0 +1,24 @@ +package handler + +import ( + "encoding/json" + "net/http" + + "github.com/Unknwon/macaron" + "github.com/astaxie/beego/logs" +) + +func GetPingV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + result, _ := json.Marshal(map[string]string{}) + + return http.StatusOK, result +} + +func GetPingV2Handler(ctx *macaron.Context) (int, []byte) { + + ctx.Resp.Header().Set("Content-Type", "application/json; charset=utf-8") + + result, _ := json.Marshal(map[string]string{}) + + return http.StatusOK, result +} diff --git a/handler/repository.go b/handler/repository.go new file mode 100644 index 0000000..6dd5a08 --- /dev/null +++ b/handler/repository.go @@ -0,0 +1,177 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + "regexp" + + "github.com/Unknwon/macaron" + "github.com/astaxie/beego/logs" + + "github.com/containerops/dockyard/models" + "github.com/containerops/wrench/db" + "github.com/containerops/wrench/setting" + "github.com/containerops/wrench/utils" +) + +func PutTagV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + tag := ctx.Params(":tag") + + bodystr, _ := ctx.Req.Body().String() + log.Debug("[REGISTRY API V1] Repository Tag : %v", bodystr) + + r, _ := regexp.Compile(`"([[:alnum:]]+)"`) + imageIds := r.FindStringSubmatch(bodystr) + + repo := new(models.Repository) + if err := repo.PutTag(imageIds[1], namespace, repository, tag); err != nil { + log.Error("[REGISTRY API V1] Put repository tag error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": err.Error()}) + return http.StatusBadRequest, result + } + + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} + +func PutRepositoryImagesV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + r := new(models.Repository) + if err := r.PutImages(namespace, repository); err != nil { + log.Error("[REGISTRY API V1] Put images error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Put V1 images error"}) + return http.StatusBadRequest, result + } + + if ctx.Req.Header.Get("X-Docker-Token") == "true" { + username, _, _ := utils.DecodeBasicAuth(ctx.Req.Header.Get("Authorization")) + token := fmt.Sprintf("Token signature=%v,repository=\"%v/%v\",access=%v", + utils.MD5(username), + namespace, + repository, + "write") + + ctx.Resp.Header().Set("X-Docker-Token", token) + ctx.Resp.Header().Set("WWW-Authenticate", token) + } + + result, _ := json.Marshal(map[string]string{}) + return http.StatusNoContent, result +} + +func GetRepositoryImagesV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + repo := new(models.Repository) + if has, _, err := repo.Has(namespace, repository); err != nil { + log.Error("[REGISTRY API V1] Read repository json error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Get V1 repository images failed,wrong name or repository"}) + return http.StatusBadRequest, result + } else if has == false { + log.Error("[REGISTRY API V1] Read repository no found, %v/%v", namespace, repository) + + result, _ := json.Marshal(map[string]string{"message": "Get V1 repository images failed,repository no found"}) + return http.StatusNotFound, result + } + + repo.Download += 1 + + if err := repo.Save(); err != nil { + log.Error("[REGISTRY API V1] Update download count error: %v", err.Error()) + result, _ := json.Marshal(map[string]string{"message": "Save V1 repository failed"}) + return http.StatusBadRequest, result + } + + username, _, _ := utils.DecodeBasicAuth(ctx.Req.Header.Get("Authorization")) + token := fmt.Sprintf("Token signature=%v,repository=\"%v/%v\",access=%v", + utils.MD5(username), + namespace, + repository, + "read") + + ctx.Resp.Header().Set("X-Docker-Token", token) + ctx.Resp.Header().Set("WWW-Authenticate", token) + ctx.Resp.Header().Set("Content-Length", fmt.Sprint(len(repo.JSON))) + + return http.StatusOK, []byte(repo.JSON) +} + +func GetTagV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + repo := new(models.Repository) + if has, _, err := repo.Has(namespace, repository); err != nil { + log.Error("[REGISTRY API V1] Read repository json error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": "Get V1 tag failed,wrong name or repository"}) + return http.StatusBadRequest, result + } else if has == false { + log.Error("[REGISTRY API V1] Read repository no found. %v/%v", namespace, repository) + + result, _ := json.Marshal(map[string]string{"message": "Get V1 tag failed,read repository no found"}) + return http.StatusNotFound, result + } + + tag := map[string]string{} + + for _, value := range repo.Tags { + t := new(models.Tag) + if err := db.Get(t, value); err != nil { + log.Error(fmt.Sprintf("[REGISTRY API V1] %s/%s Tags is not exist", namespace, repository)) + + result, _ := json.Marshal(map[string]string{"message": fmt.Sprintf("%s/%s Tags is not exist", namespace, repository)}) + return http.StatusNotFound, result + } + + tag[t.Name] = t.ImageId + } + + result, _ := json.Marshal(tag) + return http.StatusOK, result +} + +func PutRepositoryV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + username, _, _ := utils.DecodeBasicAuth(ctx.Req.Header.Get("Authorization")) + + namespace := ctx.Params(":namespace") + repository := ctx.Params(":repository") + + body, err := ctx.Req.Body().String() + if err != nil { + log.Error("[REGISTRY API V1] Get request body error: %v", err.Error()) + result, _ := json.Marshal(map[string]string{"message": "Put V1 repository failed,request body is empty"}) + return http.StatusBadRequest, result + } + + r := new(models.Repository) + if err := r.Put(namespace, repository, body, ctx.Req.Header.Get("User-Agent"), setting.APIVERSION_V1); err != nil { + log.Error("[REGISTRY API V1] Put repository error: %v", err.Error()) + + result, _ := json.Marshal(map[string]string{"message": err.Error()}) + return http.StatusBadRequest, result + } + + if ctx.Req.Header.Get("X-Docker-Token") == "true" { + token := fmt.Sprintf("Token signature=%v,repository=\"%v/%v\",access=%v", + utils.MD5(username), + namespace, + repository, + "write") + + ctx.Resp.Header().Set("X-Docker-Token", token) + ctx.Resp.Header().Set("WWW-Authenticate", token) + } + + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} diff --git a/handler/users.go b/handler/users.go new file mode 100644 index 0000000..94f8229 --- /dev/null +++ b/handler/users.go @@ -0,0 +1,19 @@ +package handler + +import ( + "encoding/json" + "net/http" + + "github.com/Unknwon/macaron" + "github.com/astaxie/beego/logs" +) + +func GetUsersV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + result, _ := json.Marshal(map[string]string{}) + return http.StatusOK, result +} + +func PostUsersV1Handler(ctx *macaron.Context, log *logs.BeeLogger) (int, []byte) { + result, _ := json.Marshal(map[string]string{}) + return http.StatusUnauthorized, result +} diff --git a/main.go b/main.go index f3c5b0c..1b911e0 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,7 @@ import ( "github.com/codegangsta/cli" "github.com/containerops/dockyard/cmd" - "github.com/containerops/dockyard/setting" + "github.com/containerops/wrench/setting" ) func init() { @@ -15,6 +15,8 @@ func init() { } func main() { + setting.SetConfig("conf/containerops.conf") + app := cli.NewApp() app.Name = setting.AppName @@ -25,8 +27,6 @@ func main() { app.Commands = []cli.Command{ cmd.CmdWeb, - cmd.CmdBackend, - cmd.CmdDatabase, } app.Flags = append(app.Flags, []cli.Flag{}...) diff --git a/middleware/header.go b/middleware/header.go new file mode 100644 index 0000000..573774c --- /dev/null +++ b/middleware/header.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "fmt" + "strings" + + "github.com/Unknwon/macaron" + + "github.com/containerops/wrench/setting" +) + +func setRespHeaders() macaron.Handler { + return func(ctx *macaron.Context) { + if flag := strings.Contains(ctx.Req.RequestURI, "v1"); flag == true { + ctx.Resp.Header().Set("Content-Type", "application/json") + ctx.Resp.Header().Set("X-Docker-Registry-Standalone", setting.Standalone) + ctx.Resp.Header().Set("X-Docker-Registry-Version", setting.RegistryVersion) + ctx.Resp.Header().Set("X-Docker-Registry-Config", setting.RunMode) + ctx.Resp.Header().Set("X-Docker-Endpoints", setting.Domains) + } else if flag == false { + ctx.Resp.Header().Set("Content-Type", "text/plain; charset=utf-8") + ctx.Resp.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%v\"", setting.Domains)) + ctx.Resp.Header().Set("Docker-Distribution-Api-Version", setting.DistributionVersion) + } + } +} diff --git a/middleware/logger.go b/middleware/logger.go index 4f80335..ad96e5b 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -4,33 +4,29 @@ import ( "fmt" "github.com/Unknwon/macaron" - "github.com/astaxie/beego/logs" - - "github.com/containerops/dockyard/setting" ) -var Log *logs.BeeLogger //全局日志对象 +var Log *logs.BeeLogger -func init() { +func InitLog(runmode, path string) { Log = logs.NewLogger(10000) - if setting.RunMode == "dev" { + if runmode == "dev" { Log.SetLogger("console", "") } - Log.SetLogger("file", fmt.Sprintf("{\"filename\":\"%s\"}", setting.LogPath)) + Log.SetLogger("file", fmt.Sprintf("{\"filename\":\"%s\"}", path)) } -func logger() macaron.Handler { +func logger(runmode string) macaron.Handler { return func(ctx *macaron.Context) { - //在调试阶段为了便于阅读控制台的信息,输出空行和分隔符区分多个访问的日志 - if setting.RunMode == "dev" { + if runmode == "dev" { Log.Trace("") Log.Trace("----------------------------------------------------------------------------------") } - //默认输出 Request 的 Method、 URI 和 Header 的信息 + Log.Trace("[%s] [%s]", ctx.Req.Method, ctx.Req.RequestURI) Log.Trace("[Header] %v", ctx.Req.Header) } diff --git a/middleware/middleware.go b/middleware/middleware.go index 557b8fc..6dc6396 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -3,20 +3,25 @@ package middleware import ( "github.com/Unknwon/macaron" - _ "github.com/macaron-contrib/session/redis" + "github.com/containerops/wrench/setting" ) func SetMiddlewares(m *macaron.Macaron) { - //设置静态文件目录,静态文件的访问不进行日志输出 - m.Use(macaron.Static("static", macaron.StaticOptions{ + //Set static file directory,static file access without log output + m.Use(macaron.Static("external", macaron.StaticOptions{ Expires: func() string { return "max-age=0" }, })) - //设置全局 Logger + InitLog(setting.RunMode, setting.LogPath) + + //Set global Logger m.Map(Log) - //设置 logger 的 Handler 函数,处理所有 Request 的日志输出 - m.Use(logger()) + //Set logger handler function, deal with all the Request log output + m.Use(logger(setting.RunMode)) + + //Set the response header info + m.Use(setRespHeaders()) - //设置 panic 的 Recovery + //Set recovery handler to returns a middleware that recovers from any panics m.Use(macaron.Recovery()) } diff --git a/models/image.go b/models/image.go new file mode 100644 index 0000000..230f616 --- /dev/null +++ b/models/image.go @@ -0,0 +1,206 @@ +package models + +import ( + "encoding/json" + "fmt" + "time" + + "gopkg.in/redis.v3" + + "github.com/containerops/wrench/db" +) + +type Image struct { + ImageId string `json:"imageid"` // + JSON string `json:"json"` // + Ancestry string `json:"ancestry"` // + Checksum string `json:"checksum"` // tarsum+sha256 + Payload string `json:"payload"` // sha256 + URL string `json:"url"` // + Backend string `json:"backend"` // + Path string `json:"path"` // + Sign string `json:"sign"` // + Size int64 `json:"size"` // + Uploaded bool `json:"uploaded"` // + Checksumed bool `json:"checksumed"` // + Encrypted bool `json:"encrypted"` // + Created int64 `json:"created"` // + Updated int64 `json:"updated"` // + Memo []string `json:"memo"` // + Version int64 `json:"version"` // +} + +func (i *Image) Has(image string) (bool, string, error) { + if key := db.Key("image", image); len(key) <= 0 { + return false, "", fmt.Errorf("Invalid image key") + } else { + if err := db.Get(i, key); err != nil { + if err == redis.Nil { + return false, "", nil + } else { + return false, "", err + } + } + + return true, key, nil + } +} + +func (i *Image) Save() error { + key := db.Key("image", i.ImageId) + + if err := db.Save(i, key); err != nil { + return err + } + + if _, err := db.Client.HSet(db.GLOBAL_IMAGE_INDEX, i.ImageId, key).Result(); err != nil { + return err + } + + return nil +} + +func (i *Image) GetJSON(imageId string) (string, error) { + if has, _, err := i.Has(imageId); err != nil { + return "", err + } else if has == false { + return "", fmt.Errorf("Image not found") + } else if !i.Checksumed || !i.Uploaded { + return "", fmt.Errorf("Image JSON not found") + } else { + return i.JSON, nil + } +} + +func (i *Image) GetChecksumPayload(imageId string) (string, error) { + if has, _, err := i.Has(imageId); err != nil { + return "", err + } else if has == false { + return "", fmt.Errorf("Image not found") + } else if !i.Checksumed || !i.Uploaded { + return "", fmt.Errorf("Image JSON not found") + } else { + return i.Payload, nil + } +} + +func (i *Image) PutJSON(imageId, json string, version int64) error { + if has, _, err := i.Has(imageId); err != nil { + return err + } else if has == false { + i.ImageId = imageId + i.JSON = json + i.Created = time.Now().UnixNano() / int64(time.Millisecond) + i.Version = version + + if err = i.Save(); err != nil { + return err + } + } else { + i.ImageId, i.JSON = imageId, json + i.Uploaded, i.Checksumed, i.Encrypted, i.Size, i.Updated, i.Version = + false, false, false, 0, time.Now().UnixNano()/int64(time.Millisecond), version + + if err := i.Save(); err != nil { + return err + } + } + + return nil +} + +func (i *Image) PutChecksum(imageId string, checksum string, checksumed bool, payload string) error { + if has, _, err := i.Has(imageId); err != nil { + return err + } else if has == false { + return fmt.Errorf("Image not found") + } else { + if err := i.PutAncestry(imageId); err != nil { + + return err + } + + i.Checksum, i.Checksumed, i.Payload, i.Updated = checksum, checksumed, payload, time.Now().UnixNano()/int64(time.Millisecond) + + if err = i.Save(); err != nil { + return err + } + } + + return nil +} + +func (i *Image) PutAncestry(imageId string) error { + if has, _, err := i.Has(imageId); err != nil { + return err + } else if has == false { + return fmt.Errorf("Image not found") + } + + var imageJSONMap map[string]interface{} + var imageAncestry []string + + if err := json.Unmarshal([]byte(i.JSON), &imageJSONMap); err != nil { + return err + } + + if value, has := imageJSONMap["parent"]; has == true { + parentImage := new(Image) + parentHas, _, err := parentImage.Has(value.(string)) + if err != nil { + return err + } + + if !parentHas { + return fmt.Errorf("Parent image not found") + } + + var parentAncestry []string + json.Unmarshal([]byte(parentImage.Ancestry), &parentAncestry) + imageAncestry = append(imageAncestry, imageId) + imageAncestry = append(imageAncestry, parentAncestry...) + } else { + imageAncestry = append(imageAncestry, imageId) + } + + ancestryJSON, _ := json.Marshal(imageAncestry) + i.Ancestry = string(ancestryJSON) + + if err := i.Save(); err != nil { + return err + } + + return nil +} + +func (i *Image) PutLayer(imageId string, path string, uploaded bool, size int64) error { + if has, _, err := i.Has(imageId); err != nil { + return err + } else if has == false { + return fmt.Errorf("Image not found") + } else { + i.Path, i.Uploaded, i.Size, i.Updated = path, uploaded, size, time.Now().UnixNano()/int64(time.Millisecond) + + if err := i.Save(); err != nil { + return err + } + } + + return nil +} + +func (i *Image) HasTarsum(tarsum string) (bool, error) { + if err := db.Get(i, db.Key("tarsum", tarsum)); err != nil { + return false, err + } + + return true, nil +} + +func (i *Image) PutTarsum(tarsum string) error { + if err := db.Save(i, db.Key("tarsum", tarsum)); err != nil { + return err + } + + return nil +} diff --git a/models/repo.go b/models/repo.go new file mode 100644 index 0000000..aa3d484 --- /dev/null +++ b/models/repo.go @@ -0,0 +1,274 @@ +package models + +import ( + "encoding/json" + "fmt" + "time" + + "gopkg.in/redis.v3" + + "github.com/containerops/wrench/db" + "github.com/containerops/wrench/setting" +) + +type Repository struct { + Repository string `json:"repository"` // + Namespace string `json:"namespace"` // + NamespaceType bool `json:"namespacetype"` // + Organization string `json:"organization"` // + Tags []string `json:"tags"` // + Starts []string `json:"starts"` // + Comments []string `json:"comments"` // + Short string `json:"short"` // + Description string `json:"description"` // + JSON string `json:"json"` // + Dockerfile string `json:"dockerfile"` // + Agent string `json:"agent"` // + Links string `json:"links"` // + Size int64 `json:"size"` // + Download int64 `json:"download"` // + Uploaded bool `json:"uploaded"` // + Checksum string `json:"checksum"` // + Checksumed bool `json:"checksumed"` // + Icon string `json:"icon"` // + Sign string `json:"sign"` // + Privated bool `json:"privated"` // + Clear string `json:"clear"` // + Cleared bool `json:"cleared"` // + Encrypted bool `json:"encrypted"` // + Created int64 `json:"created"` // + Updated int64 `json:"updated"` // + Version int64 `json:"version"` // + Memo []string `json:"memo"` // +} + +type Tag struct { + Name string `json:"name"` // + ImageId string `json:"imageid"` // + Namespace string `json:"namespace"` // + Repository string `json:"repository"` // + Sign string `json:"sign"` // + Manifest string `json:"manifest"` // + Memo []string `json:"memo"` // +} + +func (r *Repository) Has(namespace, repository string) (bool, string, error) { + if key := db.Key("repository", namespace, repository); len(key) <= 0 { + return false, "", fmt.Errorf("Invalid repository key") + } else { + if err := db.Get(r, key); err != nil { + if err == redis.Nil { + return false, "", nil + } else { + return false, "", err + } + } + + return true, key, nil + } +} + +func (r *Repository) Save() error { + key := db.Key("repository", r.Namespace, r.Repository) + + if err := db.Save(r, key); err != nil { + return err + } + + if _, err := db.Client.HSet(db.GLOBAL_REPOSITORY_INDEX, (fmt.Sprintf("%s/%s", r.Namespace, r.Repository)), key).Result(); err != nil { + return err + } + + return nil +} + +func (t *Tag) Save() error { + key := db.Key("tag", t.Namespace, t.Repository, t.Name) + + if err := db.Save(t, key); err != nil { + return err + } + + if _, err := db.Client.HSet(db.GLOBAL_TAG_INDEX, (fmt.Sprintf("%s/%s/%s:%s", t.Namespace, t.Repository, t.Name, t.ImageId)), key).Result(); err != nil { + return err + } + + return nil +} + +func (t *Tag) Get(namespace, repository, tag string) error { + key := db.Key("tag", namespace, repository, tag) + + if err := db.Get(t, key); err != nil { + return err + } + + return nil +} + +func (t *Tag) GetByKey(key string) error { + if err := db.Get(t, key); err != nil { + return err + } + + return nil +} + +func (r *Repository) Put(namespace, repository, json, agent string, version int64) error { + if has, _, err := r.Has(namespace, repository); err != nil { + return err + } else if has == false { + r.Created = time.Now().UnixNano() / int64(time.Millisecond) + } + + r.Namespace, r.Repository, r.JSON, r.Agent, r.Version = + namespace, repository, json, agent, version + + r.Updated = time.Now().UnixNano() / int64(time.Millisecond) + r.Checksumed, r.Uploaded, r.Cleared, r.Encrypted = false, false, false, false + r.Size, r.Download = 0, 0 + + if err := r.Save(); err != nil { + return err + } + + return nil +} + +func (r *Repository) PutImages(namespace, repository string) error { + if _, _, err := r.Has(namespace, repository); err != nil { + return err + } + + r.Checksumed, r.Uploaded, r.Updated = true, true, time.Now().Unix() + + if err := r.Save(); err != nil { + return fmt.Errorf("[REGISTRY API V1] Update Uploaded flag error") + } + + return nil +} + +func (r *Repository) PutTag(imageId, namespace, repository, tag string) error { + if has, _, err := r.Has(namespace, repository); err != nil { + return err + } else if has == false { + return fmt.Errorf("Repository not found") + } + + i := new(Image) + if has, _, err := i.Has(imageId); err != nil { + return err + } else if has == false { + return fmt.Errorf("Tag's image not found") + } + + t := new(Tag) + t.Name, t.ImageId, t.Namespace, t.Repository = tag, imageId, namespace, repository + + if err := t.Save(); err != nil { + return err + } + + has := false + for _, value := range r.Tags { + if value == db.Key("tag", t.Namespace, t.Repository, t.Name) { + has = true + } + } + + if !has { + r.Tags = append(r.Tags, db.Key("tag", t.Namespace, t.Repository, t.Name)) + } + + if err := r.Save(); err != nil { + return err + } + + return nil +} + +func (r *Repository) PutJSONFromManifests(image map[string]string, namespace, repository string) error { + if has, _, err := r.Has(namespace, repository); err != nil { + return err + } else if has == false { + r.Created = time.Now().UnixNano() / int64(time.Millisecond) + r.JSON = "" + } + + r.Namespace, r.Repository, r.Version = namespace, repository, setting.APIVERSION_V2 + + r.Updated = time.Now().UnixNano() / int64(time.Millisecond) + r.Checksumed, r.Uploaded, r.Cleared, r.Encrypted = true, true, true, false + r.Size, r.Download = 0, 0 + + if len(r.JSON) == 0 { + if data, err := json.Marshal([]map[string]string{image}); err != nil { + return err + } else { + r.JSON = string(data) + } + + } else { + var ids []map[string]string + + if err := json.Unmarshal([]byte(r.JSON), &ids); err != nil { + return err + } + + has := false + for _, v := range ids { + if v["id"] == image["id"] { + has = true + } + } + + if has == false { + ids = append(ids, image) + } + + if data, err := json.Marshal(ids); err != nil { + return err + } else { + r.JSON = string(data) + } + } + + if err := r.Save(); err != nil { + return err + } + + return nil +} + +func (r *Repository) PutTagFromManifests(image, namespace, repository, tag, manifests string) error { + if has, _, err := r.Has(namespace, repository); err != nil { + return err + } else if has == false { + return fmt.Errorf("Repository not found") + } + + t := new(Tag) + t.Name, t.ImageId, t.Namespace, t.Repository, t.Manifest = tag, image, namespace, repository, manifests + + if err := t.Save(); err != nil { + return err + } + + has := false + for _, v := range r.Tags { + if v == db.Key("tag", t.Namespace, t.Repository, t.Name) { + has = true + } + } + + if has == false { + r.Tags = append(r.Tags, db.Key("tag", t.Namespace, t.Repository, t.Name)) + } + + if err := r.Save(); err != nil { + return err + } + + return nil +} diff --git a/module/module.go b/module/module.go index 11174d7..a88b516 100644 --- a/module/module.go +++ b/module/module.go @@ -1 +1,79 @@ package modules + +import ( + "encoding/json" + "io" + "io/ioutil" + "os" + "strings" + + "github.com/containerops/dockyard/models" + "github.com/containerops/wrench/utils" +) + +func ParseManifest(data []byte) error { + + var manifest map[string]interface{} + if err := json.Unmarshal(data, &manifest); err != nil { + return err + } + + tag := manifest["tag"] + namespace, repository := strings.Split(manifest["name"].(string), "/")[0], strings.Split(manifest["name"].(string), "/")[1] + + for k := len(manifest["history"].([]interface{})) - 1; k >= 0; k-- { + v := manifest["history"].([]interface{})[k] + compatibility := v.(map[string]interface{})["v1Compatibility"].(string) + + var image map[string]interface{} + if err := json.Unmarshal([]byte(compatibility), &image); err != nil { + return err + } + + i := map[string]string{} + r := new(models.Repository) + + if k == 0 { + i["Tag"] = tag.(string) + } + i["id"] = image["id"].(string) + + if err := r.PutJSONFromManifests(i, namespace, repository); err != nil { + return err + } + + if k == 0 { + if err := r.PutTagFromManifests(image["id"].(string), namespace, repository, tag.(string), string(data)); err != nil { + return err + } + } + } + + return nil +} + +func CopyImgLayer(srcPath, srcFile, dstPath, dstFile string, resp io.Reader) (int, error) { + if !utils.IsDirExist(dstPath) { + os.MkdirAll(dstPath, os.ModePerm) + } + + if utils.IsFileExist(dstFile) { + os.Remove(dstFile) + } + + var data []byte + if _, err := os.Stat(srcFile); err == nil { + data, _ = ioutil.ReadFile(srcFile) + if err := ioutil.WriteFile(dstFile, data, 0777); err != nil { + return 0, err + } + os.RemoveAll(srcPath) + } else { + data, _ = ioutil.ReadAll(resp) + if err := ioutil.WriteFile(dstFile, data, 0777); err != nil { + return 0, err + } + } + + return len(data), nil +} diff --git a/router/router.go b/router/router.go index 217d3ea..b58fcf0 100644 --- a/router/router.go +++ b/router/router.go @@ -2,12 +2,46 @@ package router import ( "github.com/Unknwon/macaron" + + "github.com/containerops/dockyard/handler" ) func SetRouters(m *macaron.Macaron) { //Docker Registry & Hub V1 API + m.Group("/v1", func() { + m.Get("/_ping", handler.GetPingV1Handler) - //Docker Registry & Hub V2 API + m.Get("/users", handler.GetUsersV1Handler) + m.Post("/users", handler.PostUsersV1Handler) + + m.Group("/repositories", func() { + m.Put("/:namespace/:repository/tags/:tag", handler.PutTagV1Handler) + m.Put("/:namespace/:repository/images", handler.PutRepositoryImagesV1Handler) + m.Get("/:namespace/:repository/images", handler.GetRepositoryImagesV1Handler) + m.Get("/:namespace/:repository/tags", handler.GetTagV1Handler) + m.Put("/:namespace/:repository", handler.PutRepositoryV1Handler) + }) - //Rocket Image Discovery + m.Group("/images", func() { + m.Get("/:imageId/ancestry", handler.GetImageAncestryV1Handler) + m.Get("/:imageId/json", handler.GetImageJSONV1Handler) + m.Get("/:imageId/layer", handler.GetImageLayerV1Handler) + m.Put("/:imageId/json", handler.PutImageJSONV1Handler) + m.Put("/:imageId/layer", handler.PutImageLayerv1Handler) + m.Put("/:imageId/checksum", handler.PutImageChecksumV1Handler) + }) + }) + + //Docker Registry & Hub V2 API + m.Group("/v2", func() { + m.Get("/", handler.GetPingV2Handler) + m.Head("/:namespace/:repository/blobs/:digest", handler.HeadBlobsV2Handler) + m.Post("/:namespace/:repository/blobs/uploads", handler.PostBlobsV2Handler) + m.Patch("/:namespace/:repository/blobs/uploads/:uuid", handler.PatchBlobsV2Handler) + m.Put("/:namespace/:repository/blobs/uploads/:uuid", handler.PutBlobsV2Handler) + m.Get("/:namespace/:repository/blobs/:digest", handler.GetBlobsV2Handler) + m.Put("/:namespace/:repository/manifests/:tag", handler.PutManifestsV2Handler) + m.Get("/:namespace/:repository/tags/list", handler.GetTagsListV2Handler) + m.Get("/:namespace/:repository/manifests/:tag", handler.GetManifestsV2Handler) + }) } diff --git a/setting/.gitignore b/setting/.gitignore deleted file mode 100644 index e69de29..0000000 diff --git a/setting/setting.go b/setting/setting.go deleted file mode 100644 index eb8fd14..0000000 --- a/setting/setting.go +++ /dev/null @@ -1,70 +0,0 @@ -package setting - -import ( - "fmt" - - "github.com/astaxie/beego/config" -) - -var ( - conf config.ConfigContainer - AppName string - Usage string - Version string - Author string - Email string - RunMode string - ListenMode string - HttpsCertFile string - HttpsKeyFile string - LogPath string -) - -func init() { - var err error - - conf, err = config.NewConfig("ini", "conf/dockyard.conf") - if err != nil { - fmt.Errorf("读取配置文件 conf/dockyard.conf 错误: %v", err) - } - - if appname := conf.String("appname"); appname != "" { - AppName = appname - } - - if usage := conf.String("usage"); usage != "" { - Usage = usage - } - - if version := conf.String("version"); version != "" { - Version = version - } - - if author := conf.String("author"); author != "" { - Author = author - } - - if email := conf.String("email"); email != "" { - Email = email - } - - if runmode := conf.String("runmode"); runmode != "" { - RunMode = runmode - } - - if listenmode := conf.String("listenmode"); listenmode != "" { - ListenMode = listenmode - } - - if httpscertfile := conf.String("httpscertfile"); httpscertfile != "" { - HttpsCertFile = httpscertfile - } - - if httpskeyfile := conf.String("httpskeyfile"); httpskeyfile != "" { - HttpsKeyFile = httpskeyfile - } - - if logpath := conf.String("log::filepath"); logpath != "" { - LogPath = logpath - } -} diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index 066bd8c..0000000 --- a/utils/utils.go +++ /dev/null @@ -1,12 +0,0 @@ -package utils - -import ( - "os" -) - -// 检查文件或目录是否存在 -// 如果由 filename 指定的文件或目录存在则返回 true,否则返回 false -func Exist(filename string) bool { - _, err := os.Stat(filename) - return err == nil || os.IsExist(err) -} diff --git a/web/web.go b/web/web.go index acc0be7..f2bd1f8 100644 --- a/web/web.go +++ b/web/web.go @@ -1,17 +1,25 @@ package web import ( + "fmt" + "github.com/Unknwon/macaron" "github.com/containerops/dockyard/middleware" "github.com/containerops/dockyard/router" + "github.com/containerops/wrench/db" + "github.com/containerops/wrench/setting" ) func SetDockyardMacaron(m *macaron.Macaron) { - //设置 Setting + //Setting Database + if err := db.InitDB(setting.DBURI, setting.DBPasswd, setting.DBDB); err != nil { + fmt.Printf("Connect Database Error %s", err.Error()) + } - //设置 Middleware + //Setting Middleware middleware.SetMiddlewares(m) - //设置 Router + + //Setting Router router.SetRouters(m) }