""" LookupManager.py

Module that wraps the DNS resolver.
"""
__copyright__ = "Copyright (c) 2002-2005 Free Software Foundation, Inc."
__license__ = """ GNU General Public License

This program is free software; you can redistribute it and/or modify it under the
terms of the GNU General Public License as published by the Free Software
Foundation; either version 2 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program; if not, write to the Free Software Foundation, Inc., 59 Temple
Place - Suite 330, Boston, MA 02111-1307, USA. """


import time
import socket
import random
import error
import Config
try:
    import adns, ADNS
    _have_adns = False
except ImportError:
    error.log("No ADNS library found, using synchronous name lookups.")
    _have_adns = False

MAX_TIMES = 3

class NameFormatException(Exception):
    pass

class CBWrapper:
    def __init__(self, name, callback, data, times = 0):
        self.name = name
        self.data = data
        self.callback = callback
        self.times = 0

    def cb(self, ip):
        self.callback(self.name, ip, self.data)

# This is the threaded version: lookups are done in a thread,
# LookupMananger communicates via two Queue objects with it.
import Queue
import threading

class LookupThread(threading.Thread):
    def __init__(self, namequeue, ipqueue, group=None, target=None,
                 name=None, *args, **kwargs):
        threading.Thread.__init__(self, group, target, name, args, kwargs)
        self._namequeue = namequeue
        self._ipqueue = ipqueue
        self.setDaemon(True)

    def run(self):
        while True:
            name, cbw = self._namequeue.get(True)
            try:
                ip = socket.gethostbyname(name)
            except socket.error:
                ip = None
            self._ipqueue.put((cbw, ip))

class ThreadLookupManager:
    NameFormatException = NameFormatException
    namecache = {}

    def __init__(self, mintime=86400, maxtime=259200):
        self.mintime = mintime
        self.maxtime = maxtime
        self._namequeue = Queue.Queue(0)
        self._ipqueue = Queue.Queue(0)
        self._queryengine = LookupThread(
            namequeue=self._namequeue, ipqueue=self._ipqueue)
        self._queryengine.start()

    def _lookup(self, name, cbw):
        now = time.time()
        ip = self.namecache.get(name, None)
        expiretime = 0
        if ip is not None:
            ip, expiretime = ip
            if now < expiretime:
                cbw.cb(ip)
                return
        self._namequeue.put((name, cbw))

    def lookup(self, name, callback, data=None):
        cbw = CBWrapper(name, callback, data)
        self._lookup(name, cbw)

    def poll(self, timeout=0.1):
        t1 = time.time()
        while True:
            try:
                cbw, ip = self._ipqueue.get_nowait()
                cbw.cb(ip)
            except Queue.Empty:
                break
            if time.time() - t1 > timeout:
                break
            
# The ADNS version
if _have_adns:
    class MyQE(ADNS.QueryEngine):
        def lookup_a_record(self, name, callback, extra):
            self.submit(name, adns.rr.A, callback = callback, extra = extra)

    class ADNSLookupManager:
        namecache = {}
        queryengine = MyQE()

        def __init__(self, mintime=86400, maxtime=259200):
            self.mintime = mintime
            self.maxtime = maxtime

        def _lookup(self, name, callback, cbw):
            try:
                name = name.encode('ascii')
            except UnicodeError:
                raise NameFormatException, "Host names must be ASCII"
            now = time.time()
            ip = self.namecache.get(name, None)
            expiretime = 0
            if ip is not None:
                ip, expiretime = ip
                if now < expiretime:
                    cbw.cb(ip)
                    return
            self.queryengine.lookup_a_record(name, callback, cbw)

        def lookup(self, name, callback, data=None):
            cbw = CBWrapper(name, callback, data)
            self._lookup(name, self.adns_callback, cbw)

        def adns_callback(self, answer, qname, rr, flags, cbwrapper):
            now = int(time.time())
            ips = answer[3]
            ip = None
            if answer[1] and not ips:
                # we got a cname even though we asked for an a record
                self._lookup(answer[1], self.adns_callback, cbwrapper)
                return
            elif ips:
                ip = ips[0]
                self.namecache[qname] = (ip, now+random.randint(self.mintime,
                                                                self.maxtime))
            else:
                # try again
                if cbwrapper.times < MAX_TIMES:
                    cbwrapper.times += 1
                    self._lookup(cbwrapper.name, self.adns_callback, cbwrapper)
                    return
            cbwrapper.cb(ip)

        def poll(self, timeout=0.1):
            self.queryengine.run(timeout)

# Blocking version: use normal lookups, no threads
class BlockingLookupManager:
    NameFormatException = NameFormatException
    namecache = {}

    def __init__(self, mintime=86400, maxtime=259200):
        self.mintime = mintime
        self.maxtime = maxtime

    def _lookup(self, name):
        now = time.time()
        ip = self.namecache.get(name, None)
        expiretime = 0
        if ip is not None:
            ip, expiretime = ip
            if now < expiretime:
                return ip
        try:
            ip = socket.gethostbyname(name)
        except socket.error:
            ip = None
        now = int(time.time())
        self.namecache[name] = (ip, now + random.randint(self.mintime,
                                                         self.maxtime))
        return ip

    def lookup(self, name, callback, data=None):
        try:
            name = name.encode('ascii')
        except UnicodeError:
            raise NameFormatException, "Host names must be ASCII"
        ip = self._lookup(name)
        callback(name, ip, data)

    def poll(self, timeout=0.1):
        pass

lookupmanager_instance = None

def get_instance():
    global lookupmanager_instance
    if lookupmanager_instance is None:
        if Config.get_instance().use_threads:
            #error.log("using threaded lookups")
            lookupmanager_instance = ThreadLookupManager()
        elif _have_adns:
            #error.log("using ADNS for lookups")
            lookupmanager_instance = ADNSLookupManager()
        else:
            #error.log("using blocking lookups")
            lookupmanager_instance = BlockingLookupManager()
    return lookupmanager_instance



syntax highlighted by Code2HTML, v. 0.9.1