# -*- coding: utf-8 -*-
"""
Mixing matrices and assortativity coefficients

See the following article for definitions and general discussion::

  @article{newman-2003-mixing,
    title   = {Mixing patterns in networks},
    author  = {M. E. J. newman},
    journal = {Phys Rev. E},
    volume  = {67},
    number  = {2},
    pages   = {026126 -- 6126},
    month   = {Feb},
    year    = {2003},
  }

"""
#    Copyright (C) 2006 by 
#    Aric Hagberg <hagberg@lanl.gov>
#    Dan Schult <dschult@colgate.edu>
#    Pieter Swart <swart@lanl.gov>
#    Distributed under the terms of the GNU Lesser General Public License
#    http://www.gnu.org/copyleft/lesser.html
__author__ = """Aric Hagberg (hagberg@lanl.gov)"""


def assortativity(G):
    """Return the assortativity of the graph G.

    Assortativity is the Pearson correlation coefficient
    of degree-degree correlations.  
    """    
    pass


def categorical_mixing_matrix(G,pfunction,rows=None, **kwds):
    """ Mixing matrix for categorical properties.
        Returns mixing matrix and optionally the list of row and column labels.

        pfunction is a user supplied function that
        returns the property when given a vertex
        e.g.  property=pfunction(v)

        The matrix is normalized so the sum of the elements is one.
    """
    import Numeric as N
    labels=kwds.get("labels",False)
    data={}
    # call the user supplied function to get the property data
    # and put in a dictionary keyed by vertex
    for v in G.nodes():
        data.setdefault(v,pfunction(v))

    # find the number of unique categories and assign index
    categories={}
    count=0
    if rows is None:  # make the row and column indices from categories
        for d in data:
            if not categories.has_key(data[d]):
                categories[data[d]]=count
                count=count+1
    else:   # user has defined rows and columns
        for r in rows:
            categories[r]=count
            count=count+1

    # A mixing matrix of size n,n
    # where n is the number of categories
    n=len(categories)
    a = N.zeros((n,n)).astype(N.Float)
    # assign matrix data
    total=0


    # loop over all vertices with data in the categories we have chosen
    # i.e. skip some vertices if we haven't picked inclusive categories
    for v in data.keys():
        if categories.has_key(data[v]): 
            vi=categories[data[v]]
            for u in G.neighbors(v):
                if categories.has_key(data[u]):
                    ui=categories[data[u]]
                    a[ui,vi]+=1.0
                    total+=1
    # normalize by total number of edges
    if labels is True:
        # put rows/columns in order
        items = [(v, k) for (k, v) in categories.items()]
        items.sort()
        return a/total,[v for (k,v) in items]
    else:
        return a/total

def degree_mixing_matrix(G,**kwds):
    """ Degree mixing matrix (e.g. mixing by degree of vertices).
        An important special case of scalar_mixing_matrix.
    """
    # "remaining degree" (degree-1)
    def function(v):
        return G.degree(v)-1
    return scalar_mixing_matrix(G,function)

def scalar_mixing_matrix(G,pfunction,**kwds):
    """ Mixing matrix for scalar properties.

        pfunction must return the scalar data (integers) when given a
        vertex as an argument

        >>> deg=G.degree(v)
        >>> prop=G.vprop[v]

        or any user defined function.

        The rows and columns are assumed to be [0,..,max(data)].

        If you want to specify the values see categorical_mixing_matrix. 

        The matrix is normalized so the sum of the elements is one.
    """
    import Numeric as N
    data={}
    # make dictionary of of data keyed by vertex
    for v in G.nodes():
        data.setdefault(v,pfunction(v))
    # a mixing matrix of size n+1,n+1
    # where n is the max of the scalar property
    n=max(data.values())+1
    a = N.zeros((n,n)).astype(N.Float)

    # build scalar mixing matrix
    total=0
    for v in G.nodes():
        vdata=data[v]
        for u in G.neighbors(v):
            udata=data[u]
            a[udata,vdata]+=1.0
            total+=1
    return a/total        


def scalar_mixing_function(G,function,**kwds):
    """ Mixing function for scalar properties.
        function must return the scalar data when given a
        vertex as an argument
    """
    # this should be the continous version of the scalar mixing matrix above
    pass


def assortativity(M,**kwds):
    """ Returns assortativity coefficient r given mixing matrix M.
        Optionally returns rmin for given matrix.
    """ 
    import Numeric as N
    import sys
    rmin=kwds.get("rmin",False)

    trM=N.trace(M)
    a=N.add.reduce(M,0)
    b=N.add.reduce(M,1)
    ab=a*b
    sum=N.add.reduce(ab)
#    print >>sys.stderr,"a",a
#    print >>sys.stderr,"b",b
#    print >>sys.stderr,"ab",ab
#    print >>sys.stderr,"trM",trM
    # compute assortativity coef
    # this if statement is a hack to avoid blowing up for regular graphs
    if sum<>1.0:
        factor=1.0-sum
    else:
        factor=1.0
        
    r=(trM-sum)/factor
    # min assort.
    rminimum=-sum/factor
    # should add to one...or else this is broken
    one=N.add.reduce(N.add.reduce(M))
#    print >>sys.stderr,"one=",one

    if rmin is True:
        return r,rminimum
    else:
        return r

def Pearson_correlation(M):
    """ Mixing, Pearson correlation coeficient
    """ 
    import Numeric as N
    import sys
    # this should take as an arugment an array of integers.
    # but for now we'll assume that the values are [0,m-1]
    (m,n)=M.shape
    x=N.arange(m).astype(N.Float)+1
    y=N.arange(m).astype(N.Float)+1

    a=N.add.reduce(M,0)
    b=N.add.reduce(M,1)
    meana=N.add.reduce(a*x)
    meanb=N.add.reduce(b*y)
    vara=N.add.reduce(x*x*a)-meana*meana
    varb=N.add.reduce(y*y*b)-meanb*meanb
    sda=N.sqrt(vara)
    sdb=N.sqrt(varb)

#    print >>sys.stderr,"a (mean,std,var)\t",meana,sda,vara
#    print >>sys.stderr,"b (mean,std,var):\t",meanb,sdb,varb
#    print >>sys.stderr,"sda*sdb:\t",sda*sdb

    # compute correlation coefficient
    r=0.0
    for i in range(m):
        for j in range(n):
            r+=(x[i]*y[j])*(M[i,j]-a[i]*b[j])
    # normalize            
    # this if statement is a hack to avoid blowing up for regular graphs
#    print >>sys.stderr,"r:\t",r
    if sda*sdb==0:
        return r
    else:
        return r/(sda*sdb)

if __name__ == '__main__': 
    import networkx as NX
    import Numeric as N

    print "Scalar Data, by degree"
    G = NX.barbell_graph(5,2)
    print G.edges()
    M = degree_mixing_matrix(G)

    
    print M
    print "Pearson:",Pearson_correlation(M)
    print "assortativity:", assortativity(M,rmin=True)


     # surprising example
#    n=3
#    a = N.zeros((n,n)).astype(N.Float)
#    a[0,0]=10
#    a[1,1]=1
#    a[2,2]=10
#    a[0,2]=0
#    a[2,0]=0
#   sum=N.add.reduce(N.add.reduce(a))
#    a=a/sum
#
#
#     # appendix A example
#     n=3
#     a = N.zeros((n,n)).astype(N.Float)
#     a[0,0]=50
#     a[1,1]=50
#     a[2,2]=2
#     a[0,1]=50
#     a[1,0]=50
#     sum=N.add.reduce(N.add.reduce(a))
#     a=a/sum

#     # sexual partnership data
#     n=4
#     a = N.zeros((n,n)).astype(N.Float)
#     a[0,0]=0.259
#     a[1,0]=0.016
#     a[2,0]=0.035
#     a[3,0]=0.013
    
#     a[0,1]=0.012
#     a[1,1]=0.158
#     a[2,1]=0.058
#     a[3,1]=0.019

#     a[0,2]=0.013
#     a[1,2]=0.023
#     a[2,2]=0.306
#     a[3,2]=0.035

#     a[0,3]=0.005
#     a[1,3]=0.007
#     a[2,3]=0.024
#     a[3,3]=0.017

# #    sum=N.add.reduce(N.add.reduce(a))
# #    a=a/sum
    
    
#    print M
#    print "Pearson:",Pearson_correlation(M)
#    print "assortativity:", assortativity(M,rmin=True)

    
     # test categorical mixing
    print
    print "Categorical Data"
    G1 = NX.XGraph()

    G1.add_edge("A","B")
    G1.add_edge("A","C")    
    G1.add_edge("B","D")    
    G1.add_edge("D","E")
    G1.add_edge("A","E")
    G1.add_edge("E","F")
    G1.add_edge("F","G")

    s={}
    s["A"]="F"
    s["B"]="F"
    s["C"]="M"
    s["D"]="M"
    s["E"]="O"
    s["F"]="M"
    s["G"]="M"

    a,labels = categorical_mixing_matrix(G1,s.__getitem__,labels=True)
    print labels
    print a
    print assortativity(a,rmin=True)
