#    Copyright (c) 2003, Nullcube Pty Ltd 
#    All rights reserved.
#
#    Redistribution and use in source and binary forms, with or without
#    modification, are permitted provided that the following conditions are met:
#
#    *   Redistributions of source code must retain the above copyright notice, this
#        list of conditions and the following disclaimer.
#    *   Redistributions in binary form must reproduce the above copyright notice,
#        this list of conditions and the following disclaimer in the documentation
#        and/or other materials provided with the distribution.
#    *   Neither the name of Nullcube nor the names of its contributors may be used to
#        endorse or promote products derived from this software without specific
#        prior written permission.
#
#    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
#    ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
#    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
#    ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
#    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
#    LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
#    ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

def multiord(x):
    """
        Like ord(), but takes multiple characters. I.e. calculate the
        base10 equivalent of a string considered as a set of base-256 digits.
    """
    num = 0
    scale = 1
    for i in range(len(x)-1, -1, -1):
        num = num + (ord(x[i])*scale)
        scale = scale*256
    return num


def multichar(a, width):
    """
        Like chr(), but takes a large integer that could fill many bytes,
        and returns a string. I.e. calculate the base256 equivalent string,
        from a given base10 integer.

        The return string will be padded to the left to ensure that it is of
        length "width".
    """
    chars = []
    while (a != 0):
        chars.insert(0, chr(a%256))
        a = a/256
    if len(chars) > width:
        raise ValueError, "Number too wide for width."
    ret = ["\0"]*(width-len(chars)) + chars
    return "".join(ret)


def cksum16(data):
    """
        Calculates the 16-bit CRC checksum accross data.
    """
    sum = 0
    try:
        for i in range(0, len(data), 2):
            a = ord(data[i])
            b = ord(data[i+1])
            sum = sum + ((a<<8) + b)
    except IndexError:
        sum = sum + (a<<8)
    while (sum >> 16):
        sum = (sum & 0xFFFF) + (sum >> 16)
    return (~sum & 0xFFFF)


def isStringLike(anobj):
    try:
        anobj + ''
    except:
        return 0
    else:
        return 1


def findLongestSubsequence(seq, value):
    """
        Find the longest subsequence consisting only of "value".
    """
    itr = iter(range(len(seq)))
    maxseq = (0, 0)
    for i in itr:
        if seq[i] == value:
            start = i
            for j in itr:
                if not seq[j] == value:
                    j -= 1
                    break
            if (j-start) > (maxseq[1]-maxseq[0]):
                maxseq = (start, j)
    return maxseq


#
# Manipulation of addresses
#
def getBlocks(addr):
    """
        Get the 16-bit hexadecimal blocks from a ":"-delimited address definition.
        Applicable to Ethernet and IPv6 addresses.
    """
    numstrs = addr.split(":")
    nums = []
    for i in numstrs:
        if not i:
            continue
        num = int(i, 16)
        if num > 0xffff:
            raise ValueError, "Malformed address."
        nums.append(num)
    return nums
            

def ipToBytes(addr):
    """
        Converts a standard IPv4 address to 4 bytes.
    """
    nums = addr.split(".")
    if len(nums) != 4:
        raise ValueError, "Mal-formed IP address."
    ret = []
    for i in nums:
        num = int(i)
        if num > 255 or num < 0:
            raise ValueError, "Mal-formed IP address."
        ret.append(chr(num))
    return "".join(ret)


def ipToStr(bytes):
    """
        Converts a sequence of 4 bytes to an IPv4 address.
    """
    if len(bytes) != 4:
        raise ValueError, "IP Address must have 4 bytes."
    octets = []
    for i in bytes:
        val = ord(i)
        octets.append(str(val))
    return ".".join(octets)

def isIPAddr(addr):
    """
        Return true if this is a valid IPv4 address.
    """
    try:
        ipToBytes(addr)
        return True
    except ValueError:
        return False

def ipFromPrefix(prefix):
    """
        Produce an IPv4 address (netmask) from a prefix length.
    """
    if (prefix > 32) or (prefix < 0):
        raise ValueError, "Prefix must be between 0 and 32."
    addr = "\xff" * (prefix/8)
    if prefix%8:
        addr += chr((255 << (8-(prefix%8)))&255)
    addr += "\0"*(4 - len(addr))
    return ipToStr(addr)


def ip6ToBytes(addr):
    """
        Converts a standard IPv6 address to 16 bytes.
    """
    abbr = addr.count("::")
    if addr.find("::") > -1:
        if (addr.count("::") > 1):
            raise ValueError, "Mal-formed IPv6 address: only one :: abbreviation allowed."
        first, second = addr.split("::")
        first = getBlocks(first)
        second = getBlocks(second)
        padlen = 8 - len(first) - len(second)
        nums = first + [0]*padlen + second
    else:
        nums = getBlocks(addr)
    if len(nums) != 8:
        raise ValueError, "Mal-formed IPv6 address."
    return "".join([multichar(i, 2) for i in nums])


def ip6ToStr(addr):
    """
        Converts a standard 16-byte IPv6 address to a human-readable string.
    """
    if len(addr) != 16:
        raise ValueError, "IPv6 address must have 16 bytes: %s"%repr(addr)
    octets = []
    for i in range(8):
        octets.append(hex(multiord(addr[2*i:2*i+2]))[2:])
    start, finish = findLongestSubsequence(octets, "0")
    if finish:
        return ":".join(octets[0:start]) + "::" + ":".join(octets[finish+1:])
    else:
        return ":".join(octets)


def isIP6Addr(addr):
    """
        Return true if this is a valid IPv6 address.
    """
    try:
        ip6ToBytes(addr)
        return True
    except ValueError:
        return False


def ip6FromPrefix(prefix):
    """
        Produce an IPv6 address (netmask) from a prefix length.
    """
    if (prefix > 128) or (prefix < 0):
        raise ValueError, "Prefix must be between 0 and 128."
    addr = "\xff" * (prefix/8)
    if prefix%8:
        addr += chr((255 << (8-(prefix%8)))&255)
    addr += "\0"*(16 - len(addr))
    return ip6ToStr(addr)


def ethToBytes(addr):
    """
        Converts an Ethernet addres (of the format xx:xx:xx:xx:xx:xx to 6
        bytes.
    """
    nums = getBlocks(addr)
    if len(nums) != 6:
        raise ValueError, "Malformed Ethernet address."
    return "".join([chr(i) for i in nums])


def ethToStr(addr):
    """
        Converts a binary Ethernet addres into a string of the format xx:xx:xx:xx:xx:xx.
    """
    if len(addr) != 6:
        raise ValueError, "Ethernet address must have 6 bytes."
    octets = []
    for i in addr:
        next = "%x"%ord(i)
        if len(next) == 1:
            next = "0"+next
        octets.append(next)
    return ":".join(octets)


class DoubleAssociation(dict):
    """
        A double-association is a broadminded dictionary - it goes both ways.
            
        The rather simple implementation below requires the keys and values to
        be two disjoint sets. That is, if a given value is both a key and a
        value in a DoubleAssociation, you get unexpected behaviour.
    """
    # FIXME:
    #   While DoubleAssociation is adequate for our use, it is not entirely complete:
    #       - Deletion should delete both associations
    #       - Other dict methods that set values (eg. setdefault) will need to be over-ridden.
    def __init__(self, idict=None):
        if idict:
            for k, v in idict.items():
                self[k] = v

    def __setitem__(self, key, value):
        dict.__setitem__(self, key, value)
        dict.__setitem__(self, value, key)
