#! /usr/bin/python3

import re
import sys
import itertools
import subprocess

__all__ = [
    "Diff", "FunctionDiff", "BlockDiff",
    "Context", "Left", "Right", "Replace",
]

# A diff of named elements (either functions or basic blocks)
class NamedDiff:
    def __init__(self):
        self.created = set()
        self.removed = set()
        self.edited = dict()

    def create(self, name):
        self.created.add(name)
    def remove(self, name):
        self.removed.add(name)
    def edit(self, name, diff):
        self.edited[name] = diff

    def normalize(self):
        for name, subdiff in self.edited.items():
            subdiff.normalize()

    def size(self):
        c = sum(c.size() for c in self.created)
        r = sum(r.size() for r in self.removed)
        e = sum(self.edited[e].size() for e in self.edited)
        return c + r + e

# Diff of named functions
class Diff(NamedDiff):

    @staticmethod
    def from_llvm_diff_output(output):
        return diff_from_output(output)

    @staticmethod
    def from_llvm_diff(f1, f2):
        result = subprocess.run(["llvm-diff", f1, f2], stderr=subprocess.PIPE)
        return Diff.from_llvm_diff_output(result.stderr.decode('utf-8'))

    def print(self):
        for f in self.created:
            print(f'created {f}')
        for f in self.removed:
            print(f'removed {f}')
        for f, Delta in self.edited.items():
            print(f'edited {f}')
            Delta.print()

# Diff of named basic blocks
class FunctionDiff(NamedDiff):
    def __init__(self):
        NamedDiff.__init__(self)
        self.size_created = 0
        self.size_removed = 0

    def add_size_created(self, size):
        self.size_created += size

    def add_size_removed(self, size):
        self.size_removed += size

    def print(self):
        for b in self.removed:
            print(f'  removed {b}')
        for b, delta in self.edited.items():
            print(f'  edited {b}')
            delta.print()

    def size(self):
        base = NamedDiff.size(self)
        return base + self.size_created + self.size_removed

Context = 0
Left = 1
Right = 2
Replace = 3

# A basic block diff
class BlockDiff:
    def __init__(self):
        self.el = []

    def add_context(self, start, end):
        self.el.append((Context, start, end))

    def add_left(self, code):
        if self.el == [] or self.el[-1][0] == Context:
            self.el.append((Left, []))
        elif self.el[-1][0] == Right:
            self.el[-1] = (Replace, [], self.el[-1][1])

        self.el[-1][1].append(code)

    def add_right(self, code):
        if self.el == [] or self.el[-1][0] != Right:
            self.el.append((Right, []))

        if self.el[-1][0] == Left:
            self.el[-1] = (Replace, self.el[-1][0], [code])
        else:
            self.el[-1][1].append(code)

    def print(self):
        for el in self.el:
            if el[0] == Context:
                print(f'    context {el[1]}..{el[2]}')
            if el[0] == Left:
                print(f'    removed {len(el[1])} lines')
            if el[0] == Right:
                print(f'    added {len(el[1])} lines')
            if el[0] == Replace:
                print(f'    replaced {len(el[1])} lines with {len(el[2])} new')

    def size(self):
        total = 0
        for el in self.el:
            if el[0] == Left or el[0] == Right:
                total += len(el[1])
            if el[0] == Replace:
                total += len(el[1]) + len(el[2])
        return total

# Build the llvm-diff of two files
def diff_from_output(llvm_diff_output):
    the_diff = Diff()

    # Current function and block
    func = None
    block = None

    re_func   = re.compile(r'^in function (.+):$')
    re_unique = re.compile(r'^  (left|right) block [^:]*: (\d+)$')
    re_block1 = re.compile(r'^  in block ([^/]+):$')
    re_block2 = re.compile(r'^  in block .+/(.+):$')
    re_ctx    = re.compile(r'^    --- (\d+), (\d+)$')
    re_inst   = re.compile(r'^    (<|>)\s*(.+)$')

    for (lineno, line) in enumerate(llvm_diff_output.split("\n")):
        if line.strip() == "":
            continue

        # Catch new functions
        m = re.match(re_func, line)
        if m is not None:
            func = FunctionDiff()
            the_diff.edit(m[1], func)
            block = None
            continue

        # Catch unique blocks
        m = re.match(re_unique, line)
        if m is not None:
            size = int(m[2])
            if m[1] == "left":
                func.add_size_created(size)
            else:
                func.add_size_removed(size)
            continue

        # Catch new blocks
        m = re.match(re_block1, line)
        if m is not None:
            block = BlockDiff()
            func.edit(m[1], block)
            continue
        m = re.match(re_block2, line)
        if m is not None:
            block = BlockDiff()
            func.edit(m[1], block)
            continue

        # Catch context separators
        m = re.match(re_ctx, line)
        if m is not None:
            block.add_context(int(m[1]), int(m[2]))
            continue

        # Catch instructions
        m = re.match(re_inst, line)
        if m is None:
            print(f"error: couldn't interpret '{line}'", file=sys.stderr)
            continue

        if func is None or block is None:
            print(f"error: couldn't find all info at lineno {lineno} " +
                f"(func={func} block={block})", file=sys.stderr)
            continue

        if m[1] == "<":
            block.add_left(m[2])
        else:
            block.add_right(m[2])

    return the_diff
