import networkx as nx
import numpy as np

def get_graph_weights(graph):
    """
    get the weights of nodes and edges in graph
    Args:
        graph (nx.Graph): graph to get weight of nodes and edges
    Return:
        node weights form is dict{nid1: node_weight}, edges weights form is dict{(nid1, nid2): edge_weight}
    """
    nodew = nx.get_node_attributes(graph, 'weight')
    edw = nx.get_edge_attributes(graph, 'weight')
    edgew = edw.copy()
    for key, val in edw.items():
        edgew[(key[1], key[0])] = val

    return nodew, edgew


def generate_graph_data(node_num, edge_num, weight_range=10):
    """
    generate a simple graph‘s weights of nodes and edges with
    node number is node_num, edge number is edge_num
    Args:
        node_num (int): node number in graph
        edge_num (int): edge number in graph
        weight_range (int): weight range of nodes and edges
    Return:
        nodes(set of tuple(nid, node_weight)), edges(set of tuple(nid1, nid2, edge_weight))
    """
    if weight_range is None:
        weight_range = 10

    nodes = set()
    for i in range(node_num):
        ndw = np.random.choice(range(weight_range))
        nodes |= {(i, ndw)}

    edges = set()
    cnt = edge_num
    max_edges = node_num * (node_num - 1) / 2
    if cnt > max_edges:
        cnt = max_edges
    while cnt > 0:
        u = np.random.randint(node_num)
        v = np.random.randint(node_num)
        if u == v:  # without self loop
            continue
        flg = 0
        for e in edges:  # without duplicated edges
            if set(e[:2]) == set([v, u]):
                flg = 1
                break
        if flg == 1:
            continue
        edw = np.random.choice(range(weight_range))
        edges |= {(u, v, edw)}
        cnt -= 1
    return nodes, edges


def generate_weighted_graph(nodes, edges, weight_range=10):
    """
    generate graph from nodes list and edges list which identify the nodes and edges
    that should be add in graph, and the random weight range of every node and edge.
    Args:
        nodes (list/set): list of node idex / node-weight map, element form is tuple(nid, weight)
        edges (list/set): list of edge: (e_idex1, e_idex2) / edge-weight map, element form is tuple(nid1, nid2, edge_weight)
        weight_range (int): random weight range of every node and edge
    Returns:
        g (nx.Graph): graph generated by args
    """
    g = nx.Graph()
    if isinstance(nodes, list) and isinstance(edges, list):
        for v in nodes:
            w = np.random.choice(range(weight_range))
            g.add_node(v, weight=w)

        for e in edges:
            w = np.random.choice(range(weight_range))
            g.add_edge(e[0], e[1], weight=w)
    else:
        for item in nodes:
            g.add_node(item[0], weight=item[1])

        for item in edges:
            g.add_edge(item[0], item[1], weight=item[2])
    return g