#! /usr/bin/env python3

import sys
import delta
from delta import Diff
import itertools

help_message = f"""
usage: {sys.argv[0]} <diff1> <diff2>
       {sys.argv[0]} <source> <variant1> <variant2>

This script computes a diff intersection by listing all the lines from <source>
file which have been modified in both <variant1> and <variant2>. It does so by
filtering the output of an annotated llvm-diff.

If only two files are provided, they are assumed to be the output of [llvm-diff
<source> <variant1>] and [llvm-diff <source> <variant2>], respectively.

This script assumes a modified version of llvm-diff that outputs before the ">"
and "<" marks a integer that uniquely identified an instruction within its
basic block.

Produces a diff conflict log.
""".strip()

#
#  Get command-line arguments
#

args = sys.argv[1:]

if "-h" in args or "--help" in args or "-?" in args:
    print(help_message)
    sys.exit(0)

if len(args) not in [2, 3]:
    print(help_message, file=sys.stderr)
    sys.exit(1)

#
#  Get the diff between input files
#

# If two arguments are provided, they are diff files.
if len(args) == 2:
    d1 = open(args[0]).read()
    d1 = Diff.from_llvm_diff_output(d1)

    d2 = open(args[1]).read()
    d2 = Diff.from_llvm_diff_output(d2)

# Otherwise, compute the diffs now.
else:
    d1 = Diff.from_llvm_diff(args[0], args[1])
    d2 = Diff.from_llvm_diff(args[0], args[2])

#
#  Intersect diffs
#  Since llvm-diff does not provide information about deleted blocks (sigh),
#  just look for conflicts in commonly-edited blocks.
#

def conflict(d1, d2):
    f1 = set(d1.edited.keys())
    f2 = set(d2.edited.keys())

    for f in f1.intersection(f2):
        conflict_fun(d1.edited[f], d2.edited[f])

def conflict_fun(d1, d2):
    b1 = set(d1.edited.keys())
    b2 = set(d2.edited.keys())

    for b in b1.intersection(b2):
        conflict_block(d1.edited[b], d2.edited[b])

def conflict_block(d1, d2):

    Left, Right = delta.Left, delta.Right
    Context, Replace = delta.Context, delta.Replace

    def length(el):
        if el[0] == Context:
            return el[2] - el[1] + 1
        if el[0] == Left or el[0] == Replace:
            return len(el[1])
        if el[0] == Right:
            return 0

    # Yield all elements that intersect
    def intersections(list1, list2):
        s1, e1 = 0, 0
        s2, e2 = 0, 0

        E1, E2 = None, None
        i1, i2 = 0, 0

        while i1 < len(list1) or i2 < len(list2):
            # Choose the source of the next element
            if i1 == len(list1):
                source = 2
            elif i2 == len(list2):
                source = 1
            else:
                source = 1 if e1 <= e2 else 2

            # Get that element and update start/ends

            if source == 1:
                E1 = list1[i1]
                i1 += 1

                s1 = e1
                e1 = s1 + length(E1)

            else:
                E2 = list2[i2]
                i2 += 1

                s2 = e2
                e2 = s2 + length(E2)

            # If there is an intersection, yield a pair

            if E1[0] == Context or E2[0] == Context: continue
            if e1 < s2 or e2 < s1: continue

            yield (s1, e1, s2, e2, E1, E2)

    for (s1, e1, s2, e2, E1, E2) in intersections(d1.el, d2.el):
#        print(s1, e1, s2, e2, E1, E2)

        if E1[0] == Right and E2[0] == Right:
            assert(s1 == e1 and s2 == e2 and s1 == s2)
            if E1[1] != E2[1]:
                print(f"insertion conflict after {s1}")

        if E1[0] == Replace and E2[0] == Right and s2 >= s1 and e2 <= e1:
            print(f"insertion after {s2}, a position being replaced")
        if E2[0] == Replace and E1[0] == Right and s1 >= s2 and e1 <= e2:
            print(f"insertion after {s1}, a position being replaced")

        if E1[0] == Replace and E2[0] in [Left, Replace] \
            and not (e2 <= s1 or e1 <= s2):
            print(f"removing {s2}..{e2}, which is being replaced")

        elif E2[0] == Replace and E1[0] in [Left, Replace] \
            and not (e1 <= s2 or e2 <= s1):
            print(f"removing {s1}..{e1}, which is being replaced")

    return

def merge(v1, v2):
    # Take all left lines that exist in both variants v1 and v2
    v1 = { id: inst for (id, direction, inst) in v1 if direction == '<' }
    v2 = { id: inst for (id, direction, inst) in v2 if direction == '<' }

    common = set(v1.keys()).intersection(set(v2.keys()))

    # Sanity check: all instructions should be equal
    for id in common:
        if v1[id] != v2[id]:
            print(f"error: ambiguous numbering", file=sys.stderr)
            print(f"{id} < {v1[id]}")
            print(f"{id} < {v2[id]}")

    return { id: v1[id] for id in common }

conflict(d1, d2)
