#!/usr/bin/python3

"""Export Foomuuri statistics to Prometheus."""

import argparse
import json
import pathlib
import re
import subprocess
import time
from collections import Counter
from prometheus_client import REGISTRY
from prometheus_client import start_http_server
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.core import CounterMetricFamily
from prometheus_client.registry import Collector


class MonitorCollector(Collector):
    """Collect Foomuuri Monitor metrics."""
    # pylint: disable=too-few-public-methods

    def __init__(self, args):
        """Initialize class."""
        super().__init__()
        self._stat_file = pathlib.Path(args.statistics_file)

    def collect(self):
        """Return Foomuuri Monitor metrics."""
        try:
            stats = json.loads(self._stat_file.read_text(encoding='utf-8'))
        except (OSError, ValueError):
            return

        # Target/group state
        g = GaugeMetricFamily('foomuuri_monitor_up',
                              'Target connectivity status.',
                              labels=['type', 'name'])
        for name, value in stats.items():
            g.add_metric([value['type'], name], value['state'])
        yield g

        # Target packet loss
        g = GaugeMetricFamily('foomuuri_monitor_packet_loss_ratio',
                              'Average ping packet loss.',
                              labels=['type', 'name'])
        for name, value in stats.items():
            if not value['time']:
                continue
            loss = sum(1 for item in value['time'] if item is None)
            g.add_metric([value['type'], name], loss / len(value['time']))
        yield g

        # Target packet ping latency
        g = GaugeMetricFamily('foomuuri_monitor_ping_seconds',
                              'Average network round trip time.',
                              labels=['type', 'name'])
        for name, value in stats.items():
            times = [item / 1000 for item in value['time'] if item is not None]
            if not times:
                continue
            g.add_metric([value['type'], name], sum(times) / len(times))
        yield g


class RulesetCollector(Collector):
    """Collect ruleset metrics."""
    # pylint: disable=too-few-public-methods

    def __init__(self, args):
        """Initialize class."""
        super().__init__()
        self._set_include = re.compile(args.set_include)
        self._set_exclude = re.compile(args.set_exclude)
        self._counter_include = re.compile(args.counter_include)
        self._counter_exclude = re.compile(args.counter_exclude)

    def collect(self):
        """Return ruleset metrics."""
        try:
            nftdata = json.loads(subprocess.run(
                ['nft', '--json', 'list', 'table', 'inet', 'foomuuri'],
                stdout=subprocess.PIPE, check=False, encoding='utf-8').stdout)
        except (OSError, ValueError):
            return

        # Set size - merge IPv4 and IPv6 to single value
        sets = Counter()
        for item in nftdata.get('nftables', {}):
            if 'set' in item:
                name = item['set']['name'][:-2]  # Without "_4"
                if (
                        self._set_include.search(name) and
                        not self._set_exclude.search(name)
                ):
                    sets.update({name: len(item['set'].get('elem', []))})

        g = GaugeMetricFamily('foomuuri_set_elements',
                              'Number of elements in set.',
                              labels=['name'])
        for name, value in sets.items():
            g.add_metric([name], value)
        yield g

        # Named counters
        counters = {}
        for item in nftdata.get('nftables', {}):
            if 'counter' in item:
                name = item['counter']['name']
                if (
                        self._counter_include.search(name) and
                        not self._counter_exclude.search(name)
                ):
                    counters[name] = {
                        'bytes': item['counter']['bytes'],
                        'packets': item['counter']['packets'],
                    }

        g = CounterMetricFamily('foomuuri_counter_bytes_total',
                                'Counter bytes value.',
                                labels=['name'])
        for name, value in counters.items():
            g.add_metric([name], value['bytes'])
        yield g

        g = CounterMetricFamily('foomuuri_counter_packets_total',
                                'Counter packets value.',
                                labels=['name'])
        for name, value in counters.items():
            g.add_metric([name], value['packets'])
        yield g


def main():
    """Parse arguments and run."""
    # Command line parser
    parser = argparse.ArgumentParser(
        description='Export Foomuuri statistics to Prometheus.')
    parser.add_argument('--address',  default='::',
                        help='listen address (default: ::)')
    parser.add_argument('--port', type=int, default=11041,
                        help='listen port number (default: 11041)')
    parser.add_argument('--tls-certificate', metavar='FILENAME',
                        help='TLS certificate file name')
    parser.add_argument('--tls-key', metavar='FILENAME',
                        help='TLS key file name')
    parser.add_argument('--no-monitor-statistics', action='store_true',
                        help='disable Foomuuri Monitor statistics')
    parser.add_argument('--no-ruleset-statistics', action='store_true',
                        help='disable ruleset statistics')
    parser.add_argument('--statistics-file', metavar='FILENAME',
                        default='/var/lib/foomuuri/monitor.statistics',
                        help='Foomuuri Monitor statistics file name')
    parser.add_argument(
        '--set-include', metavar='REGEXP', default='.',  # default all
        help='set names to be included to ruleset size statistics')
    parser.add_argument(
        '--set-exclude', metavar='REGEXP', default='$^',  # default nothing
        help='set names to be excluded from ruleset size statistics')
    parser.add_argument(
        '--counter-include', metavar='REGEXP', default='.',
        help='counter names to be included to ruleset traffic statistics')
    parser.add_argument(
        '--counter-exclude', metavar='REGEXP', default='$^',
        help='counter names to be excluded from ruleset traffic statistics')
    args = parser.parse_args()

    # Register collectors
    if not args.no_monitor_statistics:
        REGISTRY.register(MonitorCollector(args))
    if not args.no_ruleset_statistics:
        REGISTRY.register(RulesetCollector(args))

    # Run exporter
    start_http_server(port=args.port, addr=args.address,
                      certfile=args.tls_certificate,
                      keyfile=args.tls_key)
    while True:
        time.sleep(1)


if __name__ == '__main__':
    main()
