diff --git a/control.go b/control.go index 4e26e86d4..f132e3f76 100644 --- a/control.go +++ b/control.go @@ -244,9 +244,9 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { return 0, err } -func (c *controlConn) connect(hosts []*HostInfo) error { +func (c *controlConn) connect(hosts []*HostInfo) (newHosts []*HostInfo, partitionerName string, errOut error) { if len(hosts) == 0 { - return errors.New("control: no endpoints specified") + return hosts, "", errors.New("control: no endpoints specified") } // shuffle endpoints so not all drivers will connect to the same initial @@ -264,16 +264,18 @@ func (c *controlConn) connect(hosts []*HostInfo) error { c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) continue } - err = c.setupConn(conn) + + newHosts, partitionerName, err = c.connectSetupConn(conn, hosts) if err == nil { break } - c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) + err = fmt.Errorf("%v:%v: %v", host.ConnectAddress(), host.Port(), err) + c.session.logger.Printf("gocql: unable setup control conn %v\n", err) conn.Close() conn = nil } if conn == nil { - return fmt.Errorf("unable to connect to initial hosts: %v", err) + return hosts, "", fmt.Errorf("unable to connect to initial hosts: %v", err) } // we could fetch the initial ring here and update initial host data. So that @@ -281,7 +283,7 @@ func (c *controlConn) connect(hosts []*HostInfo) error { go c.heartBeat() - return nil + return newHosts, partitionerName, nil } type connHost struct { @@ -289,6 +291,32 @@ type connHost struct { host *HostInfo } +func (c *controlConn) connectSetupConn(conn *Conn, hosts []*HostInfo, +) (newHosts []*HostInfo, partitionerName string, errOut error) { + err := c.setupConn(conn) + if err != nil { + return hosts, "", err + } + + if c.session.cfg.DisableInitialHostLookup { + return hosts, "", nil + } + + allHosts, partitionerName, err := c.session.hostSource.GetHosts() + if err != nil { + return hosts, "", err + } + c.session.policy.SetPartitioner(partitionerName) + filteredHosts := make([]*HostInfo, 0, len(allHosts)) + for _, host := range allHosts { + if !c.session.cfg.filterHost(host) { + filteredHosts = append(filteredHosts, host) + } + } + + return filteredHosts, partitionerName, nil +} + func (c *controlConn) setupConn(conn *Conn) error { // we need up-to-date host info for the filterHost call below iter := conn.querySystemLocal(context.TODO()) @@ -322,6 +350,7 @@ func (c *controlConn) setupConn(conn *Conn) error { // TODO(martin-sucha): Trigger pool refill for all hosts, like in reconnectDownedHosts? go c.session.startPoolFill(host) } + return nil } diff --git a/session.go b/session.go index 3422a2c2c..a4771c573 100644 --- a/session.go +++ b/session.go @@ -228,26 +228,14 @@ func (s *Session) init() error { s.connCfg.ProtoVersion = proto } - if err := s.control.connect(hosts); err != nil { + newHosts, partitionerName, err := s.control.connect(hosts) + if err != nil { return err } - - if !s.cfg.DisableInitialHostLookup { - var partitioner string - newHosts, partitioner, err := s.hostSource.GetHosts() - if err != nil { - return err - } - s.policy.SetPartitioner(partitioner) - filteredHosts := make([]*HostInfo, 0, len(newHosts)) - for _, host := range newHosts { - if !s.cfg.filterHost(host) { - filteredHosts = append(filteredHosts, host) - } - } - - hosts = filteredHosts + if partitionerName != "" { + s.policy.SetPartitioner(partitionerName) } + hosts = newHosts } for _, host := range hosts {