#!/usr/bin/ruby
#
# DNS Balance --- 動的負荷分散を行なう DNS サーバ
#
# By: YOKOTA Hiroshi <yokota@netlab.is.tsukuba.ac.jp>

# $Id: dns_balance.rb,v 1.25 2003/06/13 22:07:27 elca Exp $

# DNS Balance の存在するパス名
$LOAD_PATH.unshift("%%PREFIX%%/etc/%%PORTNAME%%", "%%PREFIX%%/lib/%%PORTNAME%%")
$LOAD_PATH.freeze

require 'socket'
require 'thread'
require 'optparse'

require 'datatype.rb'
require 'multilog.rb'
require 'log_writer.rb'
require 'util.rb'
require 'as_search.rb'

require 'namespace.rb'
require 'addrdb.rb'

#####################################################################
# ユーザ定義例外
class DnsNotImplementedError < StandardError ; end
class DnsTruncatedError      < StandardError ; end
class DnsNoQueryError        < StandardError ; end
class DnsNoMoreResourceError < StandardError ; end

class DnsNoErrorError < StandardError ; end

Socket::do_not_reverse_lookup = true

###################################################################
# 関数

# DNS パケットから質問内容と質問のタイプと質問のクラスを取り出す
def parse_packet(packet)
  (number, flags, num_q, ans_rr, ort_rr, add_rr, str) =  packet.unpack("a2 a2 a2 a2 a2 a2 a*")

  if num_q != "\0\1"
    ML.log("Q must be 1")
    return [nil, nil, nil]
  end

  qlen = str.split("\0")[0].length + 1

  (q, q_type, q_class) = str.unpack("a#{qlen} a2 a2")

  return [q, q_type, q_class]
end

# クライアントの情報を返す
def get_client_data(cli)
  (family, port, fqdn, ipaddr) = cli
  return {"family" => family, "port" => port, "fqdn" => fqdn, "addr" => ipaddr}
end

# クライアントのIPアドレスによって返答内容を変える事が出来る
# 名前空間を選択。当てはまる物がなければ "default" になる
def select_namespace(addrstr, name)

  netaddrs = [addrstr] + ip_masklist(addrstr)

  # custom namespace
  netaddrs.each {
    |i|
    if $namespace_db[i]                  != nil &&
	$addr_db[$namespace_db[i]]       != nil &&
	$addr_db[$namespace_db[i]][name] != nil
      return $namespace_db[i]
    end
  }

  # address number namespace
  netaddrs.each {
    |i|
    if $addr_db[i]        != nil &&
	$addr_db[i][name] != nil
      return i
    end
  }

  # AS namespace
  if OPT["as"] &&
      # RFC1918 / プライベートアドレスはどこの AS にも属していない
      ip_mask(addrstr,  8) != "10.0.0.0"      &&
      ip_mask(addrstr, 12) != "172.16.0.0"    &&
      ip_mask(addrstr, 16) != "192.168.0.0"   &&
      ip_mask(addrstr, 21) != "204.152.184.0" &&
      addrstr              != "127.0.0.1"

    as = as_search(addrstr)
    if as                  != nil &&
	$addr_db[as]       != nil &&
	$addr_db[as][name] != nil
      return as
    end
  end

  return "default"
end

# 重みつき変数のための表を作る
def make_rand_array(namespace, name)
  rnd_max = 0
  rnd_slesh = []

  $addr_db[namespace][name].each {
    |i|
    rnd_max += (10000 - min(10000, i[1])) # badness の最大値は 10000
    rnd_slesh.push(rnd_max)
  }

  return [rnd_max, rnd_slesh]
end

# 重みつき乱数で選択
def select_rand_array(namespace, name, size)
  (rnd_max, rnd_slesh) = make_rand_array(namespace, name)

  if rnd_max == 0  # 全てのホストの Badness が 10000 だった
    return []
  end

  arr = []
  (0...size).each {
    |i|
    rnd = rand(rnd_max)
    (0...rnd_slesh.size).each {
      |j|
      if rnd <= rnd_slesh[j]
	arr.push(j)
	break
      end
    }
  }

  return arr
end

# パケットの正当性チェック
def check_packet(q, q_type, q_class)
  # ゾーン転送は無し
  if q_type == DnsType::AXFR
    ML.log("AXFR: " + q.dump + ":" + q_type.dump + ":" + q_class.dump)
    raise DnsNotImplementedError
  end

  # IP(UDP) のみ受け付け
  if !(q_class == DnsClass::INET || q_class == DnsClass::ANY)
    ML.log("noIP: " + q.dump + ":" + q_type.dump + ":" + q_class.dump)
    raise DnsNoQueryError
  end

  # 使用不可な文字がある
  if (q =~ /[()<>@,;:\\\"\.\[\]]/) != nil
    ML.log("char: " + q.dump + ":" + q_type.dump + ":" + q_class.dump)
    raise DnsNoQueryError
  end

end

def check_type(q, q_type, q_class, namespace)

  # レコードは存在するがタイプが違う
  if q_type != DnsType::A &&
      q_type != DnsType::ANY &&
      $addr_db[namespace] != nil &&
      $addr_db[namespace][dnsstr_to_str(q).downcase] != nil
    ML.log("noT: " + q.dump + ":" + q_type.dump + ":" + q_class.dump)
    raise DnsNoErrorError
  end

  # A/ANY レコードのみ受け付け
  if q_type != DnsType::A && q_type != DnsType::ANY
    ML.log("noA: " + q.dump + ":" + q_type.dump + ":" + q_class.dump)
    raise DnsNoQueryError
  end
end

######################################################################
# main

srand()

OPT = Hash::new
OptionParser::new {
  |opt|
  opt.on("-i ADDR", String, "Listen IP address (default:0.0.0.0)") {
    |o|
    OPT["i"] = o;
  }
  opt.on("--as", "Enable AS namespace") {
    OPT["as"] = true
  }
  opt.on("-l LOGFILE", String, "Print log to LOGFILE") {
    |o|
    OPT["l"] = o;
  }
  opt.on("-p PIDFILE", String, "Record PID to PIDFILE") {
    |o|
    OPT["p"] = o;
  }
  opt.on_tail("-h", "--help", "Show this help message and exit") {
    STDERR.printf("%s", opt.to_s)
    exit(111)
  }
  opt.parse!
}
OPT.freeze

exit! if fork
Process::setsid
exit! if fork
STDIN.close
STDOUT.close
STDERR.close

$pidfile = nil
if OPT["p"]
  $pidfile = OPT["p"]
  File::open($pidfile, 'w') { |f| f.puts $$ }
end

$logout = nil                                                               
if OPT["l"]
  $logout = File::open(OPT["l"], 'a+')
  $logout.sync = true
end

[0, 2, 3, 5, 10, 13, 15].each do |sig|
  trap(sig) {
    File::unlink($pidfile) if $pidfile
    $logout.close if $logout
    exit
  }
end

 ML = MultiLog.new
if OPT["l"] 
  $logout = File::open(OPT["l"], 'a+')
  $logout.sync = true
  ML.open($logout)
else
  ML.open
end

ML.log("start")


#
# アドレスデータベースの動的更新
#
Thread::start {
  loop {
    if test(?r, "%%ETCDIR%%" + "/addr")
      begin
	load("addr")

	ML.log("reload")
      rescue NameError,SyntaxError,Exception
	ML.log("reload failed")
      end
    end

    #if test(?r, "%%ETCDIR%%" + "/addr-once")
    #  Thread.exit
    #end

    #p $addr_db
    sleep(5*60) # 5 分毎に更新
  }
}

gs = UDPSocket::new()
sockaddr = (if OPT["i"] == nil then Socket::INADDR_ANY else OPT["i"] end)
gs.bind(sockaddr, Service::Domain)

#
# メインループ
#
loop {
  (packet, client) = gs.recvfrom(1024)
  Thread.start {
    $SAFE = 2
    begin
      client_data = get_client_data(client)
      (q, q_type, q_class) = parse_packet(packet)
      check_packet(q, q_type, q_class) # -> NoQuery, NotImpl

      name      = dnsstr_to_str(q).downcase
      namespace = select_namespace(client_data["addr"], name)

      check_type(q, q_type, q_class, namespace)

      if $addr_db[namespace][name].size > 1  # -> NoMethodError -> NoQuery
	size = 8
      else
	size = 1
      end
      a_array = select_rand_array(namespace, name, size) #.uniq

      if a_array.size == 0
	raise DnsNoMoreResourceError
      end

      # 返答生成
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x84  # answer & authenticated
      r[3] = 0      # no error

      ans_addrs = []
      a_array.each {
	|i|
	addr = $addr_db[namespace][name][i][0]
	ans_addrs.push(addr)   # ログ出力用

	# TTL 選択。近い所がある事が分かっているなら TTL を長くする
	if (a_array.size == 1)
	  ttl = "\0\0\x0e\x10" # 1時間
	else
	  ttl = "\0\0\0\x3c"   # 60秒
	end

	# 返答生成。 オフセットは 0x000c
	r += "\xc0\x0c" + DnsType::A + DnsClass::INET + ttl + "\0\4" + addr.pack("CCCC")
      }

      # 返答の数をセット
      r[6,2] = [a_array.size].pack("n")
      r[8,4] = "\0\0\0\0"

      # 長過ぎたら削る
      if r.length > 512
	raise DnsTruncatedError
      end

      status = "ok"

    rescue DnsNotImplementedError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] = 0
      r[3] |= 0x04  # not implemented error
      r[6,6] = "\0\0\0\0\0\0"
      status = "NotImpl"

    rescue DnsTruncatedError
      # 長過ぎる時は削ってフラグを立てる
      r = r[0,512]
      r[2] |= 0x02
      status = "Truncated"

    rescue DnsNoErrorError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x84  # answer & authenticated
      r[3] = 0      # no error

      r[6,6] = "\0\0\0\0\0\0"

      status = "NoError"

    rescue DnsNoQueryError,DnsNoMoreResourceError,NoMethodError,StandardError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x84  # answer & authenticated
      r[3] = 0
      r[3] |= 0x03  # name error
      r[6,6] = "\0\0\0\0\0\0"
      status = "NoQuery"

    rescue
      # ここには来ないはず
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] = 0
      r[3] |= 0x05  # query refused error
      r[6,6] = "\0\0\0\0\0\0"
      status = "other"
    end

    #print packet.dump, "\n"
    #print r.dump, "\n"
    #p q

    gs.send(r, 0, client_data["addr"], client_data["port"])

    logger(ML, client_data["addr"], status, name, namespace, ans_addrs)

  }
}

# end


syntax highlighted by Code2HTML, v. 0.9.1