Skip to content

Commit

Permalink
Remove attempt to resolve IP address to server_name
Browse files Browse the repository at this point in the history
Instead we just rely on the server_name passed in to the adjustment.
  • Loading branch information
digitalresistor committed Nov 27, 2020
1 parent 90148c9 commit 2f2972e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 80 deletions.
57 changes: 6 additions & 51 deletions src/waitress/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,47 +241,14 @@ def __init__(
self.bind_server_socket()

self.effective_host, self.effective_port = self.getsockname()
self.server_name = self.get_server_name(self.effective_host)
self.server_name = adj.server_name
self.active_channels = {}
if _start:
self.accept_connections()

def bind_server_socket(self):
raise NotImplementedError # pragma: no cover

def get_server_name(self, ip):
"""Given an IP or hostname, try to determine the server name."""

if not ip:
raise ValueError("Requires an IP to get the server name")

server_name = str(ip)

# If we are bound to all IP's, just return the current hostname, only
# fall-back to "localhost" if we fail to get the hostname
if server_name == "0.0.0.0" or server_name == "::":
try:
return str(self.socketmod.gethostname())
except (OSError, UnicodeDecodeError): # pragma: no cover
# We also deal with UnicodeDecodeError in case of Windows with
# non-ascii hostname
return "localhost"

# Now let's try and convert the IP address to a proper hostname
try:
server_name = self.socketmod.gethostbyaddr(server_name)[0]
except (OSError, UnicodeDecodeError): # pragma: no cover
# We also deal with UnicodeDecodeError in case of Windows with
# non-ascii hostname
pass

# If it contains an IPv6 literal, make sure to surround it with
# brackets
if ":" in server_name and "[" not in server_name:
server_name = "[{}]".format(server_name)

return server_name

def getsockname(self):
raise NotImplementedError # pragma: no cover

Expand Down Expand Up @@ -391,20 +358,11 @@ def bind_server_socket(self):
self.bind(sockaddr)

def getsockname(self):
try:
return self.socketmod.getnameinfo(
self.socket.getsockname(), self.socketmod.NI_NUMERICSERV
)
except: # pragma: no cover
# This only happens on Linux because a DNS issue is considered a
# temporary failure that will raise (even when NI_NAMEREQD is not
# set). Instead we try again, but this time we just ask for the
# numerichost and the numericserv (port) and return those. It is
# better than nothing.
return self.socketmod.getnameinfo(
self.socket.getsockname(),
self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV,
)
# Return the IP address, port as numeric
return self.socketmod.getnameinfo(
self.socket.getsockname(),
self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV,
)

def set_socket_options(self, conn):
for (level, optname, value) in self.adj.socket_options:
Expand Down Expand Up @@ -451,9 +409,6 @@ def getsockname(self):
def fix_addr(self, addr):
return ("localhost", None)

def get_server_name(self, ip):
return "localhost"


# Compatibility alias.
WSGIServer = TcpWSGIServer
29 changes: 0 additions & 29 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,35 +113,6 @@ def test_ctor_start_false(self):
inst = self._makeOneWithMap(_start=False)
self.assertEqual(inst.accepting, False)

def test_get_server_name_empty(self):
inst = self._makeOneWithMap(_start=False)
self.assertRaises(ValueError, inst.get_server_name, "")

def test_get_server_name_with_ip(self):
inst = self._makeOneWithMap(_start=False)
result = inst.get_server_name("127.0.0.1")
self.assertTrue(result)

def test_get_server_name_with_hostname(self):
inst = self._makeOneWithMap(_start=False)
result = inst.get_server_name("fred.flintstone.com")
self.assertEqual(result, "fred.flintstone.com")

def test_get_server_name_0000(self):
inst = self._makeOneWithMap(_start=False)
result = inst.get_server_name("0.0.0.0")
self.assertTrue(len(result) != 0)

def test_get_server_name_double_colon(self):
inst = self._makeOneWithMap(_start=False)
result = inst.get_server_name("::")
self.assertTrue(len(result) != 0)

def test_get_server_name_ipv6(self):
inst = self._makeOneWithMap(_start=False)
result = inst.get_server_name("2001:DB8::ffff")
self.assertEqual("[2001:DB8::ffff]", result)

def test_get_server_multi(self):
inst = self._makeOneWithMulti()
self.assertEqual(inst.__class__.__name__, "MultiSocketServer")
Expand Down

0 comments on commit 2f2972e

Please sign in to comment.