# Copyright (c) 2001-2006 Twisted Matrix Laboratories.
# See LICENSE for details.
import random
from zope.interface import implements
from twisted.internet import error, interfaces
from twisted.names import client
class _SRVConnector_ClientFactoryWrapper:
def __init__(self, connector, wrappedFactory):
self.__connector = connector
self.__wrappedFactory = wrappedFactory
def startedConnecting(self, connector):
self.__wrappedFactory.startedConnecting(self.__connector)
def clientConnectionFailed(self, connector, reason):
self.__connector.connectionFailed(reason)
def clientConnectionLost(self, connector, reason):
self.__connector.connectionLost(reason)
def __getattr__(self, key):
return getattr(self.__wrappedFactory, key)
class SRVConnector:
"""A connector that looks up DNS SRV records. See RFC2782."""
implements(interfaces.IConnector)
stopAfterDNS=0
def __init__(self, reactor, service, domain, factory,
protocol='tcp', connectFuncName='connectTCP',
connectFuncArgs=(),
connectFuncKwArgs={},
):
self.reactor = reactor
self.service = service
self.domain = domain
self.factory = factory
self.protocol = protocol
self.connectFuncName = connectFuncName
self.connectFuncArgs = connectFuncArgs
self.connectFuncKwArgs = connectFuncKwArgs
self.connector = None
self.servers = None
self.orderedServers = None # list of servers already used in this round
def connect(self):
"""Start connection to remote server."""
self.factory.doStart()
self.factory.startedConnecting(self)
if not self.servers:
if self.domain is None:
self.connectionFailed(error.DNSLookupError("Domain is not defined."))
return
d = client.lookupService('_%s._%s.%s' % (self.service,
self.protocol,
self.domain))
d.addCallback(self._cbGotServers)
d.addCallback(lambda x, self=self: self._reallyConnect())
d.addErrback(self.connectionFailed)
elif self.connector is None:
self._reallyConnect()
else:
self.connector.connect()
def _cbGotServers(self, (answers, auth, add)):
if len(answers)==1 and answers[0].payload.target=='.':
# decidedly not available
raise error.DNSLookupError("Service %s not available for domain %s."
% (repr(self.service), repr(self.domain)))
self.servers = []
self.orderedServers = []
for a in answers:
self.orderedServers.append((a.payload.priority, a.payload.weight,
str(a.payload.target), a.payload.port))
def _serverCmp(self, a, b):
if a[0]!=b[0]:
return cmp(a[0], b[0])
else:
return cmp(a[1], b[1])
def pickServer(self):
assert self.servers is not None
assert self.orderedServers is not None
if not self.servers and not self.orderedServers:
# no SRV record, fall back..
return self.domain, self.service
if not self.servers and self.orderedServers:
# start new round
self.servers = self.orderedServers
self.orderedServers = []
assert self.servers
self.servers.sort(self._serverCmp)
minPriority=self.servers[0][0]
weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers
if x[0]==minPriority])
weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1]
rand = random.randint(0, weightSum)
for index, weight in weightIndex:
weightSum -= weight
if weightSum <= 0:
chosen = self.servers[index]
del self.servers[index]
self.orderedServers.append(chosen)
p, w, host, port = chosen
return host, port
raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.__name__
def _reallyConnect(self):
if self.stopAfterDNS:
self.stopAfterDNS=0
return
self.host, self.port = self.pickServer()
assert self.host is not None, 'Must have a host to connect to.'
assert self.port is not None, 'Must have a port to connect to.'
connectFunc = getattr(self.reactor, self.connectFuncName)
self.connector=connectFunc(
self.host, self.port,
_SRVConnector_ClientFactoryWrapper(self, self.factory),
*self.connectFuncArgs, **self.connectFuncKwArgs)
def stopConnecting(self):
"""Stop attempting to connect."""
if self.connector:
self.connector.stopConnecting()
else:
self.stopAfterDNS=1
def disconnect(self):
"""Disconnect whatever our are state is."""
if self.connector is not None:
self.connector.disconnect()
else:
self.stopConnecting()
def getDestination(self):
assert self.connector
return self.connector.getDestination()
def connectionFailed(self, reason):
self.factory.clientConnectionFailed(self, reason)
self.factory.doStop()
def connectionLost(self, reason):
self.factory.clientConnectionLost(self, reason)
self.factory.doStop()
syntax highlighted by Code2HTML, v. 0.9.1