import math
import json
import random
from datetime import datetime
import sys
from termcolor import colored, cprint

# MARK: Alert
def alert(text):
	colored_text = colored(text, 'white', 'on_red', attrs=['blink'])
	print(colored_text)

# MARK: Global variables

(NW,NE,E,SE,SW,W) = ("NW","NE","E","SE","SW","W")
directions = [NW,NE,E,SE,SW,W]
cw_dir = {}
ccw_dir = {}
for i in range(6):
	cw_dir[directions[i]] = directions[(i+1)%6]
	ccw_dir[directions[i]] = directions[(i+5)%6]
flip_H = {E:W, W:E, NE:NW, NW:NE, SE:SW, SW:SE}
opposite_directions = {E:W, W:E, NE:SW, SW:NE, NW:SE, SE:NW }
rfH = {E:E, W:W, NE:SE, SW:NW, NW:SW, SE:NE }

def cw(d,k=1):
	for _ in range(k%6): d = cw_dir[d]
	return d

def ccw(d,k=1):
	for _ in range(k%6): d = ccw_dir[d]
	return d

def rotate_cw(conformation, start = 1, k = 1):
	for i in range(start, len(conformation), 2):
		conformation[i] = cw(conformation[i], k) 

def rotate_ccw(conformation, start = 1, k = 1):
	for i in range(start, len(conformation), 2):
		conformation[i] = ccw(conformation[i], k) 

def rotated_cw_copy(conformation, start = 1, k = 1):
	output = list(conformation)
	rotate_cw(output, start, k)
	return output

def rotated_ccw_copy(conformation, start = 1, k = 1):
	output = list(conformation)
	rotate_ccw(output, start, k)
	return output

# MARK: Oritatami system

class OS:

	def __init__(self):
		self.name = "Oritami system"
		self.rule = set()
		self.delay = 3
		self.categoryColors = {}
		self.color_index = 0 # random.random()/50

		self.transcriptPrefix = []
		self.transcriptPeriod = []
		self.compactTranscriptPrefix = []
		self.compactTranscriptPeriod = []
		self.transcriptPeriodCount = 0 # <= 0 is ∞
		self.seedConformation = []

	def save_as_json(self, timestamped = True):
		name = self.name
		filename = name.replace(" ", "_").replace(":","")

		if timestamped:
			now = datetime.now()
			name += " {:%H:%M:%S}".format(now)
			
		os_json = {
			"name": name,
			"delay": self.delay,
			"seedConformation": self.seedConformation,
			"periodCount": self.transcriptPeriodCount,
			"categoryColors": self.categoryColors,
			"rule": self.rule_as_list(),
		}

		if self.compactTranscriptPrefix != []:
			os_json["compactTranscriptPrefix"] =  self.compactTranscriptPrefix
		else:	
			os_json["transcriptPrefix"] =  self.transcriptPrefix

		if self.compactTranscriptPeriod != []:
			os_json["compactTranscriptPeriod"] =  self.compactTranscriptPeriod
		else:	
			os_json["transcriptPeriod"] =  self.transcriptPeriod
		
		filename_os = "{}.os".format(filename)

		with open(filename_os, 'w') as output:
			json.dump(os_json, output)
		print("OS saved as {}".format(filename_os))

	def rule_as_list(self):
		return [ [item[0], item[1]] for item in self.rule ]
		
	def rule_as_dict(self):
		rule_dict = {}
		for (b1,b2) in self.rule:
			for (u,v) in [(b1,b2), (b2,b1)]:
				if u not in rule_dict:
					rule_dict[u] = [v]
				else:
					rule_dict[u].append(v)
		return rule_dict

	def pretty_print_rule(self):
		rule_dict = self.rule_as_dict()
		bt = list(rule_dict.keys())
		bt.sort()
		for b in bt:
			bk = rule_dict[b]
			bk.sort()
			print("{} ︎❤︎️ {}\n".format(b,', '.join(bk)))

# MARK: compact transcript tools
	_declare = "declare"
	_unfold = "unfold"
	_loop = "loop"
	_repeat = "repeat"
	_mark = "mark"

	def declare(name, seq):
		if type(seq) == str: seq = [seq]
		return [OS._declare, name] + seq

	def mark(name = ""):
		return [OS._mark, name]

	def unfold(i,j,seq):
		if j < i: print("/!\\ UNFOLD ERROR /!\\"); exit()
		if type(seq) == str: seq = [seq]
		return [OS._unfold, i, j] + seq

	def repeat(i,j,seq):
		if j < i: print("/!\\ REPEAT ERROR /!\\"); exit()
		if type(seq) == str: seq = [seq]
		return [OS._repeat, i, j] + seq

	def loop(i, seq):
		if i < 0: print("/!\\ LOOP ERROR /!\\"); exit()
		if type(seq) == str: seq = [seq]
		return [OS._unfold, 0, i] + seq

	def expand(seq):
		variables = dict()
		result = OS._expand(seq, variables)
		return (result, variables)

	def _expand(seq, variables):
		if seq[0] == OS._declare:
			result = OS._expand(seq[2:], variables)
			variables[seq[1]] = result
			return []
		elif seq[0] == OS._unfold:
			i = seq[1]
			j = seq[2]
			result = OS._expand(seq[3:], variables)
			n = len(result)
			value = [""]*(j-i)
			for k in range(i, j):
				value[k - i] = result[(k%n+n)%n]
			return value
		elif type(seq) == str:
			if seq in variables:
				return variables[seq]
			else:
				return [seq]
		else:
			result = []
			for x in seq:
				result += OS._expand(x, variables)
			return result

	def append(seq, item):
		if seq == []: 
			seq.append(item)
			return
		if seq[-1] == item: 
			seq[-1] = OS.repeat(0,2, item)
			return
		if seq[-1][0] == "repeat" and seq[-1][3] == item: 
			seq[-1][2] += 1
			return
		seq.append(item)

# MARK: Bead type sequence tools

	def val(l,i):
		n = len(l)
		return l[(n+(i%n))%n]

	def colorize(self, name):
		if name in self.categoryColors:
			return
		
		φ = (1+math.sqrt(5))/2
		self.categoryColors[name] = {"hsba": [math.fmod(self.color_index*φ,1)]}
		self.color_index += 1

# MARK: Rule tools

	def ordered_pair(b1, b2):
		if b1 <= b2:
			return (b1,b2)
		else:
			return (b2,b1)

	def add_attraction(self, b1, b2, verbose = False):
		a = OS.ordered_pair(b1,b2)
		if not a in self.rule:
			if verbose: print("{} + {}".format(a[0], a[1]))
			self.rule.add(OS.ordered_pair(b1, b2)) 

	def add_attraction_pairs(self, beads1, beads2, pairs, verbose = False):
		if type(pairs) == tuple:
			pairs = [pairs]
			
		for (i,j) in pairs:
			self.add_attraction(OS.val(beads1, i), OS.val(beads2, j), verbose)

	def add_attraction_pairs_classes(self, beads1, beads2, pairs, verbose = False):
		if type(pairs) == tuple:
			pairs = [pairs]
			
		for (i,j) in pairs:
			c1 = OS.val(beads1, i)
			c2 = OS.val(beads2, j)
			if type(c1) == str: c1 = [c1]
			if type(c2) == str: c2 = [c2]
			for b1 in c1:
				for b2 in c2:
					self.add_attraction(b1, b2, verbose)

	def del_attraction(self, b1, b2, verbose = False):
		a = OS.ordered_pair(b1,b2)
		if a in self.rule:
			if verbose: print("{} x {}".format(a[0], a[1]))
			self.rule.remove(a)
	
	def del_all_attractions_between_one_bead_and_a_bead_list(self, b, l, verbose = False):
		to_be_deleted = []
		for (u,v) in self.rule:
			if (b == u and v in l) or (b == v and u in l):
				to_be_deleted.append((u,v))
		for (u,v) in to_be_deleted:
			self.del_attraction(u, v, verbose)

	def indices(list_start_number):
		ind = []
		for (start, nb) in list_start_number:
			ind += list(range(start,start + nb))
		return ind

	def beadTypeCategory(name, indices):
		if type(indices) == int:
			indices = range(indices)
		return ["{}{}".format(name, i) for i in indices]
	
# MARK: SEED check

	def check_seedConformation(self):
		success = True
		seed = self.seedConformation
		l = len(seed)
		seed_even = [seed[i] for i in range(l) if i%2 == 0 ]
		seed_odd = [seed[i] for i in range(l) if i%2 == 1 ]

		for i in range(l):
			if i%2 == 1 and not seed[i] in [NE, NW, E, SE, SW, W]:
				print(seed)
				print("seed error: [{}]{}".format(i, seed[i]))
				success = False
		return success

	def fail_if_seed_is_bad(self):
		if not self.check_seedConformation():
			alert(
"""
  /!\
 /!!!\ SEED FAILED
/_____\	
""")
			exit(-1)

# MARK: rotate

	def apply_dict(element, dict):
		if element in dict:
			return dict[element]
		else:
			return element

	def cw(element):
		return OS.apply_dict(element, cw)

	def ccw(element):
		return OS.apply_dict(element, ccw)

	def fH(element):
		return OS.apply_dict(element, flip_H)

	def opp(element):
		return OS.apply_dict(element, opposite_directions)

	def rotate_cw(conformation):
		return list(map(lambda x: OS.cw(x), conformation))
		 
	def rotate_ccw(conformation):
		return list(map(lambda x: OS.ccw(x), conformation))

	def flip_H(conformation):
		output = list(map(lambda x: OS.apply_dict(x,rfH), conformation))
		output.reverse()
		return output[-1:] + output[:-1]

	def opposite(conformation):
		return list(map(lambda x: OS.opp(x), conformation))		
		
