import matplotlib.pyplot as plt
import networkx as nx
from pysat.solvers import Glucose3
from pysat.card import *


# This is an auxiliary function just to print the graph with pretty colors
def draw_colored_graph(nodes, edges, node_colors, filename):
    # Create a graph
    G = nx.Graph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)

    # Ensure the list of colors matches the number of nodes
    if len(node_colors) != G.number_of_nodes():
        raise ValueError("The length of node_colors must match the number of nodes.")

    # Draw the graph
    pos = nx.circular_layout(G)  # positions for all nodes
    nx.draw(G, pos, node_color=node_colors, with_labels=True, font_weight='bold', node_size=2000)

    # Save the plot to a file
    plt.savefig(filename)

# This is an auxiliary functions to map the model returned by the SAT solver to the graph problem.
def print_model(nodes, colors, model):

	# Ensure the model is not empty
	if model == None:
		raise ValueError("The model is empty. Formula is unsatisfiable.")

	node_colors = []
	node_colors_dict = {}
	for node in nodes:
		node_colors_dict[node]='gray'
		for color in colors:
			if s.get_model()[vs[(node,color)]-1] > 0:
				print("node=",node,"is assigned color=",color)
				node_colors_dict[node]=color
		node_colors.append(node_colors_dict[node])

	return node_colors


# create a SAT solver using PySAT
s = Glucose3()

# create a CNF formula using PySAT
cnf = CNF()

nodes = ["A", "B", "C", "D", "E"]
edges = [("A", "E"), ("A", "C"), ("B", "E"), ("B", "C"), ("C", "D"), ("D", "E"), ("B", "D")]
#colors = ['yellow','red'] # if only two colors were used then the formula would be unsatisfiable
colors = ['yellow','red','cyan']


vs = {}
nb_var = 1
for node in nodes:
	for color in colors:
		vs[(node,color)]=nb_var
		nb_var = nb_var+1


# constraints that prohibit adjacent vertices being assigned the same color
for (n1, n2) in edges:
	for color in colors:
		v1 = vs[(n1,color)]
		v2 = vs[(n2,color)]
		cnf.append([-v1, -v2])

# constraints which state that each vertex must be assigned a color
for node in nodes:
	disj = []
	for color in colors:
		disj.append(vs[(node,color)])
	cnf.append(disj)

# each vertex has at most one color
for node in nodes:
	amo = []
	for color in colors:
		amo.append(vs[(node,color)])
	amo_clauses = CardEnc.atmost(amo, encoding=EncType.pairwise)
	for clause in amo_clauses:
		cnf.append(clause)

s.append_formula(cnf.clauses)
# write to file; this is only useful to see the formula
cnf.to_file('color.cnf')

# print the output of the solver (SAT/UNSAT)
print("SAT result= ",s.solve())
print("SAT assignment= ",s.get_model())

# create a file with the graph and the respective colors
if s.get_model() != None:
	node_colors = print_model(nodes, colors, s.get_model())
	filename = 'graph_color.png'
	draw_colored_graph(nodes, edges, node_colors, filename)    


