Skip to content

Commit

Permalink
Merge pull request #511 from yahoo/leewyang_server_port
Browse files Browse the repository at this point in the history
support range for TFOS_SERVER_PORT
  • Loading branch information
leewyang authored Mar 16, 2020
2 parents b937411 + 8952fff commit ce19cbd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 10 deletions.
40 changes: 34 additions & 6 deletions tensorflowonspark/reservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,42 @@ def _listen(self, sock):
return addr

def get_server_ip(self):
return os.getenv(TFOS_SERVER_HOST) if os.getenv(TFOS_SERVER_HOST) else util.get_ip_address()
"""Returns the value of TFOS_SERVER_HOST environment variable (if set), otherwise defaults to current host/IP."""
return os.getenv(TFOS_SERVER_HOST, util.get_ip_address())

def get_server_ports(self):
"""Returns a list of target ports as defined in the TFOS_SERVER_PORT environment (if set), otherwise defaults to 0 (any port).
TFOS_SERVER_PORT should be either a single port number or a range, e.g. '8888' or '9997-9999'
"""
port_string = os.getenv(TFOS_SERVER_PORT, "0")
if '-' not in port_string:
return [int(port_string)]
else:
ports = port_string.split('-')
if len(ports) != 2:
raise Exception("Invalid TFOS_SERVER_PORT: {}".format(port_string))
return list(range(int(ports[0]), int(ports[1]) + 1))

def start_listening_socket(self):
port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind(('', port_number))
server_sock.listen(10)
"""Starts the registration server socket listener."""
port_list = self.get_server_ports()
for port in port_list:
try:
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind(('', port))
server_sock.listen(10)
logger.info("Reservation server binding to port {}".format(port))
break
except Exception as e:
logger.warn("Unable to bind to port {}, error {}".format(port, e))
server_sock = None
pass

if not server_sock:
raise Exception("Reservation server unable to bind to any ports, port_list = {}".format(port_list))

return server_sock

def stop(self):
Expand Down
30 changes: 26 additions & 4 deletions test/test_reservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,47 @@ def test_reservation_server(self):
time.sleep(1)
self.assertEqual(s.done, True)

def test_reservation_enviroment_exists_get_server_ip_return_environment_value(self):
def test_reservation_environment_exists_get_server_ip_return_environment_value(self):
tfos_server = Server(5)
with mock.patch.dict(os.environ, {'TFOS_SERVER_HOST': 'my_host_ip'}):
assert tfos_server.get_server_ip() == "my_host_ip"

def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
def test_reservation_environment_not_exists_get_server_ip_return_actual_host_ip(self):
tfos_server = Server(5)
assert tfos_server.get_server_ip() == util.get_ip_address()

def test_reservation_enviroment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
def test_reservation_environment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
tfos_server = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
assert tfos_server.start_listening_socket().getsockname()[1] == 9999

def test_reservation_enviroment_not_exists_start_listening_socket_return_socket(self):
def test_reservation_environment_not_exists_start_listening_socket_return_socket(self):
tfos_server = Server(1)
print(tfos_server.start_listening_socket().getsockname()[1])
assert type(tfos_server.start_listening_socket().getsockname()[1]) == int

def test_reservation_environment_exists_port_spec(self):
tfos_server = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
self.assertEqual(tfos_server.get_server_ports(), [9999])

with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9997-9999'}):
self.assertEqual(tfos_server.get_server_ports(), [9997, 9998, 9999])

def test_reservation_environment_exists_start_listening_socket_return_socket_listening_to_environment_port_range(self):
tfos_server1 = Server(1)
tfos_server2 = Server(1)
tfos_server3 = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9998-9999'}):
s1 = tfos_server1.start_listening_socket()
self.assertEqual(s1.getsockname()[1], 9998)
s2 = tfos_server2.start_listening_socket()
self.assertEqual(s2.getsockname()[1], 9999)
with self.assertRaises(Exception):
tfos_server3.start_listening_socket()
tfos_server1.stop()
tfos_server2.stop()

def test_reservation_server_multi(self):
"""Test reservation server, expecting multiple reservations"""
num_clients = 4
Expand Down

0 comments on commit ce19cbd

Please sign in to comment.