Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed to allow IP ranges to be specified. #92

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.idea/
DS_Store/
build/
dist/
docs/.build/
Expand Down
5 changes: 2 additions & 3 deletions example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def HandleAuthPacket(self, pkt):
self.SendReplyPacket(pkt.fd, reply)

def HandleAcctPacket(self, pkt):

print("Received an accounting request")
print("Attributes: ")
for attr in pkt.keys():
Expand All @@ -35,7 +34,6 @@ def HandleAcctPacket(self, pkt):
self.SendReplyPacket(pkt.fd, reply)

def HandleCoaPacket(self, pkt):

print("Received an coa request")
print("Attributes: ")
for attr in pkt.keys():
Expand All @@ -45,7 +43,6 @@ def HandleCoaPacket(self, pkt):
self.SendReplyPacket(pkt.fd, reply)

def HandleDisconnectPacket(self, pkt):

print("Received an disconnect request")
print("Attributes: ")
for attr in pkt.keys():
Expand All @@ -56,13 +53,15 @@ def HandleDisconnectPacket(self, pkt):
reply.code = 45
self.SendReplyPacket(pkt.fd, reply)


if __name__ == '__main__':

# create server and read dictionary
srv = FakeServer(dict=dictionary.Dictionary("dictionary"), coa_enabled=True)

# add clients (address, secret, name)
srv.hosts["127.0.0.1"] = server.RemoteHost("127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost")
srv.hosts["::1"] = server.RemoteHost("::1", b"Kah3choteereethiejeimaeziecumi", "localhost")
srv.BindToAddress("")

# start server
Expand Down
27 changes: 19 additions & 8 deletions pyrad/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(self, server, authport=1812, acctport=1813,
self._socket = None
self.retries = 3
self.timeout = 5
self._poll = select.poll()
if hasattr(select, 'poll'):
self._poll = select.poll()
else:
self._kqueue = select.kqueue()

def bind(self, addr):
"""Bind socket to an address.
Expand All @@ -73,15 +76,20 @@ def _SocketOpen(self):
except:
family = socket.AF_INET
if not self._socket:
self._socket = socket.socket(family,
socket.SOCK_DGRAM)
self._socket.setsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR, 1)
self._poll.register(self._socket, select.POLLIN)
self._socket = socket.socket(family, socket.SOCK_DGRAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(select, 'poll'):
self._poll.register(self._socket, select.POLLIN)
else:
ev = select.kevent(self._socket,
filter=select.KQ_FILTER_READ,
flags=select.KQ_EV_ADD | select.KQ_EV_ENABLE)
self._kqueue.control([ev], 0, 0)

def _CloseSocket(self):
if self._socket:
self._poll.unregister(self._socket)
if hasattr(select, 'poll'):
self._poll.unregister(self._socket)
self._socket.close()
self._socket = None

Expand Down Expand Up @@ -148,7 +156,10 @@ def _SendPacket(self, pkt, port):
self._socket.sendto(pkt.RequestPacket(), (self.server, port))

while now < waitto:
ready = self._poll.poll((waitto - now) * 1000)
if hasattr(select, 'poll'):
ready = self._poll.poll((waitto - now) * 1000)
else:
ready = self._kqueue.control([], 1, (waitto - now))

if ready:
rawreply = self._socket.recv(4096)
Expand Down
86 changes: 67 additions & 19 deletions pyrad/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyrad import host
from pyrad import packet
import logging
from netaddr import IPNetwork, IPAddress


logger = logging.getLogger('pyrad')
Expand All @@ -18,7 +19,7 @@ class RemoteHost:
def __init__(self, address, secret, name, authport=1812, acctport=1813, coaport=3799):
"""Constructor.

:param address: IP address
:param address: IP address, CIDR or hostname of client(s)
:type address: string
:param secret: RADIUS secret
:type secret: string
Expand All @@ -31,7 +32,11 @@ def __init__(self, address, secret, name, authport=1812, acctport=1813, coaport=
:param coaport: port used for CoA packets
:type coaport: integer
"""
self.address = address
try:
# allow addresses to map to hostnames
self.address = socket.gethostbyname(address)
except socket.gaierror:
self.address = address
self.secret = secret
self.authport = authport
self.acctport = acctport
Expand Down Expand Up @@ -192,19 +197,32 @@ def HandleDisconnectPacket(self, pkt):
:type pkt: Packet class instance
"""

def _get_remote_host(self, ip):
"""Checks if an IP is in an IP range.
:param ip: an IPv4 or IPv6 address
:type ip: String
"""
all_hosts = None
ip_addr = IPAddress(ip)
for ref, host in self.hosts.items():
if ip_addr in IPNetwork(host.address):
return host
elif host.address in ['0.0.0.0', '0.0.0.0/0', '0/0']:
all_hosts = host
return all_hosts

def _AddSecret(self, pkt):
"""Add secret to packets received and raise ServerPacketError
for unknown hosts.

:param pkt: packet to process
:type pkt: Packet class instance
"""
if pkt.source[0] in self.hosts:
pkt.secret = self.hosts[pkt.source[0]].secret
elif '0.0.0.0' in self.hosts:
pkt.secret = self.hosts['0.0.0.0'].secret
remote_host = self._get_remote_host(pkt.source[0])
if remote_host:
pkt.secret = remote_host.secret
else:
raise ServerPacketError('Received packet from unknown host')
raise ServerPacketError('Received packet from unknown host {}'.format(pkt.source[0]))

def _HandleAuthPacket(self, pkt):
"""Process a packet received on the authentication port.
Expand Down Expand Up @@ -247,7 +265,6 @@ def _HandleCoaPacket(self, pkt):
:type pkt: Packet class instance
"""
self._AddSecret(pkt)
pkt.secret = self.hosts[pkt.source[0]].secret
if pkt.code == packet.CoARequest:
self.HandleCoaPacket(pkt)
elif pkt.code == packet.DisconnectRequest:
Expand Down Expand Up @@ -276,7 +293,13 @@ def _PrepareSockets(self):
"""
for fd in self.authfds + self.acctfds + self.coafds:
self._fdmap[fd.fileno()] = fd
self._poll.register(fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR)
if hasattr(self, '_poll'):
self._poll.register(fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR)
else:
ev = select.kevent(fd.fileno(),
filter=select.KQ_FILTER_READ,
flags=select.KQ_EV_ADD | select.KQ_EV_ENABLE)
self._kqueue.control([ev], 0, 0)
if self.auth_enabled:
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
if self.acct_enabled:
Expand Down Expand Up @@ -321,16 +344,7 @@ def _ProcessInput(self, fd):
else:
raise ServerPacketError('Received packet for unknown handler')

def Run(self):
"""Main loop.
This method is the main loop for a RADIUS server. It waits
for packets to arrive via the network and calls other methods
to process them.
"""
self._poll = select.poll()
self._fdmap = {}
self._PrepareSockets()

def _poll_run(self):
while True:
for (fd, event) in self._poll.poll():
if event == select.POLLIN:
Expand All @@ -343,3 +357,37 @@ def Run(self):
logger.info('Received a broken packet: ' + str(err))
else:
logger.error('Unexpected event in server main loop')

def _kqueue_run(self):
while True:
revents = self._kqueue.control([], 1, None)
for event in revents:
if event.filter == select.KQ_FILTER_READ:
try:
fd = event.ident
fdo = self._fdmap[fd]
self._ProcessInput(fdo)
except ServerPacketError as err:
logger.info('Dropping packet: ' + str(err))
except packet.PacketError as err:
logger.info('Received a broken packet: ' + str(err))
else:
logger.error('Unexpected event in server main loop')


def Run(self):
"""Main loop.
This method is the main loop for a RADIUS server. It waits
for packets to arrive via the network and calls other methods
to process them.
"""
self._fdmap = {}

if hasattr(select, 'poll'):
self._poll = select.poll()
self._PrepareSockets()
self._poll_run()
else:
self._kqueue = select.kqueue()
self._PrepareSockets()
self._kqueue_run()
4 changes: 4 additions & 0 deletions pyrad/tests/testClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def testIgnorePacketError(self):
self.assertRaises(Timeout, self.client._SendPacket, packet, 432)

def testValidReply(self):
# TODO: work out how to make test work for BSD
if not hasattr(select, 'POLLIN'): return
self.client.retries = 1
self.client.timeout = 1
self.client._socket = MockSocket(1, 2, six.b("valid reply"))
Expand All @@ -157,6 +159,8 @@ def testValidReply(self):
self.failUnless(reply is packet.reply)

def testInvalidReply(self):
# TODO: work out how to make test work for BSD
if not hasattr(select, 'POLLIN'): return
self.client.retries = 1
self.client.timeout = 1
self.client._socket = MockSocket(1, 2, six.b("invalid reply"))
Expand Down
14 changes: 6 additions & 8 deletions pyrad/tests/testServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ def testPrepareSocketAcctFds(self):
class AuthPacketHandlingTests(unittest.TestCase):
def setUp(self):
self.server = Server()
self.server.hosts['host'] = TrivialObject()
self.server.hosts['host'].secret = 'supersecret'
self.server.hosts['host'] = RemoteHost('127.0.0.1', 'supersecret', 'name')
self.packet = TrivialObject()
self.packet.code = AccessRequest
self.packet.source = ('host', 'port')
self.packet.source = ('127.0.0.1', 'port')

def testHandleAuthPacketUnknownHost(self):
self.packet.source = ('stranger', 'port')
self.packet.source = ('127.0.0.2', 'port')
try:
self.server._HandleAuthPacket(self.packet)
except ServerPacketError as e:
Expand Down Expand Up @@ -188,14 +187,13 @@ def HandleAuthPacket(self, pkt):
class AcctPacketHandlingTests(unittest.TestCase):
def setUp(self):
self.server = Server()
self.server.hosts['host'] = TrivialObject()
self.server.hosts['host'].secret = 'supersecret'
self.server.hosts['host'] = RemoteHost('10.0.0.0/24', 'supersecret', 'name')
self.packet = TrivialObject()
self.packet.code = AccountingRequest
self.packet.source = ('host', 'port')
self.packet.source = ('10.0.0.1', 'port')

def testHandleAcctPacketUnknownHost(self):
self.packet.source = ('stranger', 'port')
self.packet.source = ('10.1.0.1', 'port')
try:
self.server._HandleAcctPacket(self.packet)
except ServerPacketError as e:
Expand Down