import numpy as np
import matplotlib.pyplot as plt
import networkx as nx


def plot_clusters(X, title='data', show=False, save=False):
    """
    Function to plot clusters.

    :param X: (num_samples, 2) matrix of 2-dimensional samples
    :param Y:  (num_samples, ) vector of cluster assignment
    :param title: figure name
    :param show: If True, call plt.show()
    :return:
    """
    plt.figure(title)
    plt.plot(X[..., 0], X[..., 1], 'o')
    if show:
        plt.show()
    if save:
        plt.savefig(title+'.png')


def plot_edges_and_points(X, W,title='', show=False, save=False):
    n=len(X)
    G=nx.from_numpy_matrix(W)
    print(G)
    nx.draw_networkx_edges(G, X)
    for i in range(n):
        plt.scatter(X[i, 0], X[i, 1], color='b')
    plt.title(title)
    plt.axis('equal')
    if show:
        plt.show()
    if save:
        plt.savefig(title+'.png')


def blobs(num_samples, n_blobs=1, blob_var=0.15):
    """
    Creates N gaussian blobs evenly spaced across a circle.

    :param num_samples: number of samples to create in the dataset
    :param n_blobs:      how many separate blobs to create
    :param blob_var:    gaussian variance of each blob
    :return: X,  (num_samples, 2) matrix of 2-dimensional samples
             Y,  (num_samples, ) vector of "true" cluster assignment
    """
        # data array
    X = np.zeros((num_samples, 2))
    # array containing the indices of the true clusters
    Y = np.zeros(num_samples, dtype=np.int32)

    # generate data
    block_size = (num_samples)//n_blobs

    for i in range(n_blobs):
        start_index = i* block_size
        end_index = (i+1) * block_size
        if i == n_blobs-1:
            end_index = num_samples
        nn = end_index - start_index

        X[start_index:end_index, 0] = np.cos(2*np.pi*i/n_blobs) + blob_var*np.random.randn(nn)
        X[start_index:end_index, 1] = np.sin(2*np.pi*i/n_blobs) + blob_var*np.random.randn(nn)
    return X

def uniform(num_samples):
	X = np.zeros((num_samples, 2))
	X[:,0] = np.random.random(num_samples)
	X[:,1] = np.random.random(num_samples)
	return X

if __name__ == '__main__':
	X = uniform(300)
	W = np.random.binomial(1, .001, size=(300,300))
	plot_clusters(X, show=True)
	plot_edges_and_points(X, W, title="bonjour", show=True)

