创建旅游景点图数据库Neo4J技术验证

写在前面

本章主要实践内容:
(1)neo4j知识图谱库建库。使用导航poi中的公园、景点两类csv直接建库。
(2)pg建库。携程poi入库tripdata的poibaseinfo表,之后,导航poi中的公园、景点也导入该表。

基础数据建库

python3源代码

以下,实现了csv数据初始导入KG。如果是增量更新,代码需要调整。
另外,星级、旅游时间 是随机生成,不具备任何真实性。

import csv
from py2neo import *
import random
import geohash

def importCSV2NeoKG( graph,csvPath,csvType ):
    
    #单纯的查询方法
    node_Match = NodeMatcher(graph)

    seasons = ["春季","夏季","秋季","冬季"]
    stars = ["A","AA","AAA","AAAA","AAAAA"]

    with open(csvPath,"r",encoding="utf-8") as f:
        reader = csv.reader(f)
        datas = list(reader)
        
    print("csv连接成功",len(datas))

    newDatas = []
    #for data in datas:
    for k in range(0,len(datas)):
        data = datas[k]
        if k==0:
            newDatas.append(data)
        else:
            if datas[k][0]==datas[k-1][0] and datas[k][1]==datas[k-1][1]:#通过 名称+区县 组合判断是否唯一
                continue
            else:
                newDatas.append(data)
    print("去除csv中重复记录")

    nodeCity_new = Node("chengshi",name="北京")
    cityMatch = node_Match.match("chengshi",name="北京")
    if cityMatch==None :
        graph.merge(nodeCity_new,"chengshi","name")

    for i in range(0,len(newDatas)):
        nodeQu_new = Node("quxian",name=newDatas[i][1])        
        rel1 = Relationship(nodeQu_new,"属于",nodeCity_new)
        graph.merge(rel1,"quxian","name")
            
        geoxy_encode = geohash.encode( newDatas[i][4],newDatas[i][3],6 )
        nodeJingdian = Node(csvType,name=newDatas[i][0],quyu=newDatas[i][1],
                            jianjie=newDatas[i][0],
                            dizhi=newDatas[i][2],
                            zuobiao=geoxy_encode)
        
        jingdianMatch = node_Match.match(csvType,name=newDatas[i][0]).where(quyu=newDatas[i][1]).first()
        
        if jingdianMatch==None :
            graph.create(nodeJingdian)
                    
            rel2 = Relationship(nodeJingdian,"位于",nodeQu_new)
            graph.create(rel2)
            
            nodeTime = Node("traveltime",time=random.choice(seasons))
            #graph.create(nodeTime)
            
            rel3 = Relationship(nodeJingdian,"旅游时间",nodeTime)
            graph.merge(rel3,"traveltime","time")
            
            nodeAAA = Node("Stars",star=random.choice(stars))
            #graph.create(nodeAAA)
            
            rel4 = Relationship(nodeJingdian,"星级",nodeAAA)
            graph.merge(rel4,"Stars","star")

if __name__ == '__main__':
    
    graph = Graph("bolt://localhost:7687",auth=("neo4j","neo4j?"))
    print("neo4j连接成功")
    
    importCSV2NeoKG(graph,"公园2050101Attr.csv","gongyuan")
    print("gongyuan ok")
    
    importCSV2NeoKG(graph,"景点2050201and20600102Attr.csv","jingdian")
    print("jingdian ok")

坐标用到了geohash,尝试安装过几种geohash库,均有错误。最后,直接复制源代码生成.py文件。
geohash.py代码如下:

from __future__ import division
from collections import namedtuple
from builtins import range
import decimal
import math

base32 = '0123456789bcdefghjkmnpqrstuvwxyz'


def _indexes(geohash):
    if not geohash:
        raise ValueError('Invalid geohash')

    for char in geohash:
        try:
            yield base32.index(char)
        except ValueError:
            raise ValueError('Invalid geohash')


def _fixedpoint(num, bound_max, bound_min):
    """
    Return given num with precision of 2 - log10(range)

    Params
    ------
    num: A number
    bound_max: max bound, e.g max latitude of a geohash cell(NE)
    bound_min: min bound, e.g min latitude of a geohash cell(SW)

    Returns
    -------
    A decimal
    """
    try:
        decimal.getcontext().prec = math.floor(2-math.log10(bound_max
                                                            - bound_min))
    except ValueError:
        decimal.getcontext().prec = 12
    return decimal.Decimal(num)


def bounds(geohash):
    """
    Returns SW/NE latitude/longitude bounds of a specified geohash::

            |      .| NE
            |    .  |
            |  .    |
         SW |.      |

    :param geohash: string, cell that bounds are required of

    :returns: a named tuple of namedtuples Bounds(sw(lat, lon), ne(lat, lon)). 
    
    >>> bounds = geohash.bounds('ezs42')
    >>> bounds
    >>> ((42.583, -5.625), (42.627, -5.58)))
    >>> bounds.sw.lat
    >>> 42.583

    """
    geohash = geohash.lower()

    even_bit = True
    lat_min = -90
    lat_max = 90
    lon_min = -180
    lon_max = 180

    # 5 bits for a char. So divide the decimal by power of 2, then AND 1
    # to get the binary bit - fast modulo operation.
    for index in _indexes(geohash):
        for n in range(4, -1, -1):
            bit = (index >> n) & 1
            if even_bit:
                # longitude
                lon_mid = (lon_min + lon_max) / 2
                if bit == 1:
                    lon_min = lon_mid
                else:
                    lon_max = lon_mid
            else:
                # latitude
                lat_mid = (lat_min + lat_max) / 2
                if bit == 1:
                    lat_min = lat_mid
                else:
                    lat_max = lat_mid
            even_bit = not even_bit

    SouthWest = namedtuple('SouthWest', ['lat', 'lon'])
    NorthEast = namedtuple('NorthEast', ['lat', 'lon'])
    sw = SouthWest(lat_min, lon_min)
    ne = NorthEast(lat_max, lon_max)
    Bounds = namedtuple('Bounds', ['sw', 'ne'])
    return Bounds(sw, ne)


def decode(geohash):
    """
    Decode geohash to latitude/longitude. Location is approximate centre of the
    cell to reasonable precision.

    :param geohash: string, cell that bounds are required of

    :returns: Namedtuple with decimal lat and lon as properties.

    >>> geohash.decode('gkkpfve')
    >>> (70.2995, -27.9993)
    """
    (lat_min, lon_min), (lat_max, lon_max) = bounds(geohash)

    lat = (lat_min + lat_max) / 2
    lon = (lon_min + lon_max) / 2

    lat = _fixedpoint(lat, lat_max, lat_min)
    lon = _fixedpoint(lon, lon_max, lon_min)
    Point = namedtuple('Point', ['lat', 'lon'])
    return Point(lat, lon)


def encode(lat, lon, precision):
    """
    Encode latitude, longitude to a geohash.

    :param lat: latitude, a number or string that can be converted to decimal.
         Ideally pass a string to avoid floating point uncertainties.
         It will be converted to decimal.
    :param lon: longitude, a number or string that can be converted to decimal.
         Ideally pass a string to avoid floating point uncertainties.
         It will be converted to decimal.
    :param precision: integer, 1 to 12 represeting geohash levels upto 12.

    :returns: geohash as string.

    >>> geohash.encode('70.2995', '-27.9993', 7)
    >>> gkkpfve
    """
    lat = decimal.Decimal(lat)
    lon = decimal.Decimal(lon)

    index = 0  # index into base32 map
    bit = 0   # each char holds 5 bits
    even_bit = True
    lat_min = -90
    lat_max = 90
    lon_min = -180
    lon_max = 180
    ghash = []

    while(len(ghash) < precision):
        if even_bit:
            # bisect E-W longitude
            lon_mid = (lon_min + lon_max) / 2
            if lon >= lon_mid:
                index = index * 2 + 1
                lon_min = lon_mid
            else:
                index = index * 2
                lon_max = lon_mid
        else:
            # bisect N-S latitude
            lat_mid = (lat_min + lat_max) / 2
            if lat >= lat_mid:
                index = index * 2 + 1
                lat_min = lat_mid
            else:
                index = index * 2
                lat_max = lat_mid
        even_bit = not even_bit

        bit += 1
        if bit == 5:
            # 5 bits gives a char in geohash. Start over
            ghash.append(base32[index])
            bit = 0
            index = 0

    return ''.join(ghash)


def adjacent(geohash, direction):
    """
    Determines adjacent cell in given direction.

    :param geohash: cell to which adjacent cell is required
    :param direction: direction from geohash, string, one of n, s, e, w

    :returns: geohash of adjacent cell

    >>> geohash.adjacent('gcpuyph', 'n')
    >>> gcpuypk
    """
    if not geohash:
        raise ValueError('Invalid geohash')
    if direction not in ('nsew'):
        raise ValueError('Invalid direction')

    neighbour = {
        'n': ['p0r21436x8zb9dcf5h7kjnmqesgutwvy',
              'bc01fg45238967deuvhjyznpkmstqrwx'],
        's': ['14365h7k9dcfesgujnmqp0r2twvyx8zb',
              '238967debc01fg45kmstqrwxuvhjyznp'],
        'e': ['bc01fg45238967deuvhjyznpkmstqrwx',
              'p0r21436x8zb9dcf5h7kjnmqesgutwvy'],
        'w': ['238967debc01fg45kmstqrwxuvhjyznp',
              '14365h7k9dcfesgujnmqp0r2twvyx8zb'],
    }

    border = {
        'n': ['prxz',     'bcfguvyz'],
        's': ['028b',     '0145hjnp'],
        'e': ['bcfguvyz', 'prxz'],
        'w': ['0145hjnp', '028b'],
    }

    last_char = geohash[-1]
    parent = geohash[:-1]  # parent is hash without last char

    typ = len(geohash) % 2

    # check for edge-cases which don't share common prefix
    if last_char in border[direction][typ] and parent:
        parent = adjacent(parent, direction)

    index = neighbour[direction][typ].index(last_char)
    return parent + base32[index]


def neighbours(geohash):
    """
    Returns all 8 adjacent cells to specified geohash::

        | nw | n | ne |
        |  w | * | e  |
        | sw | s | se |

    :param geohash: string, geohash neighbours are required of

    :returns: neighbours as namedtuple of geohashes with properties n,ne,e,se,s,sw,w,nw

    >>> neighbours = geohash.neighbours('gcpuyph')
    >>> neighbours
    >>> ('gcpuypk', 'gcpuypm', 'gcpuypj', 'gcpuynv', 'gcpuynu', 'gcpuyng', 'gcpuyp5', 'gcpuyp7')
    >>> neighbours.ne
    >>> gcpuypm
    """
    n = adjacent(geohash, 'n')
    ne = adjacent(n, 'e')
    e = adjacent(geohash, 'e')
    s = adjacent(geohash, 's')
    se = adjacent(s, 'e')
    w = adjacent(geohash, 'w')
    sw = adjacent(s, 'w')
    nw = adjacent(n, 'w')
    Neighbours = namedtuple('Neighbours',
                            ['n', 'ne', 'e', 'se', 's', 'sw', 'w', 'nw'])
    return Neighbours(n, ne, e, se, s, sw, w, nw)

KG效果

命令行里启动neo4j:
neo4j.bat console

KG入库效率优化方案

上文的python方法是py2neo的基本方法,经过本人亲测,当节点量到3~5w的时候,入库开始变慢,以小时计。

百度后,有大神提供了另外一种方法:
采用这种方法,建立50w个节点和50w个关系,流程包括node、rel的建立、append到list、入库,全过程4分钟以内搞定。测试环境在VM虚拟机实现。
代码如下:

from py2neo import Graph, Subgraph, Node, Relationship
from progressbar import *
import datetime

def batch_create(graph, nodes_list, relations_list):

    subgraph = Subgraph(nodes_list, relations_list)
    tx_ = graph.begin()
    tx_.create(subgraph)
    graph.commit(tx_)


if __name__ == '__main__':
    # 连接neo4j
    graph = Graph("bolt://localhost:7687",auth=("neo4j","neo4j?"))

    # 批量创建节点
    nodes_list = []  # 一批节点数据
    relations_list = []  # 一批关系数据

    nodeCity_new = Node("chengshi",name="北京")
    nodes_list.append(nodeCity_new)

    widgets = ['CSV导入KG进度: ', Percentage(), ' ', Bar('#'), ' ', Timer(), ' ', ETA(), ' ']
    bar = ProgressBar(widgets=widgets, maxval=500000)
    bar.start()#

    for i in range(0,500000):
        
        bar.update(i+1)
        
        nodeQu_new = Node("quxian",name="Test{0}".format(i))
        nodes_list.append(nodeQu_new)
        
        rel1 = Relationship(nodeQu_new,"属于",nodeCity_new)
        relations_list.append(rel1)
  
    bar.finish()

    current_time = datetime.datetime.now()
    print("current_time:    " + str(current_time))

    # 批量创建节点/关系
    batch_create(graph, nodes_list, relations_list)
    
    current_time = datetime.datetime.now()
    print("current_time:    " + str(current_time))

    print("batch ok")

PostGreSQL建库

pg建库。携程poi入库tripdata的poibaseinfo表,之后,导航poi中的公园、景点也导入该表。

携程poi导入代码:psycopg2_004.py

import psycopg2
import csv
import random
import geohash
from progressbar import *

#
#携程爬虫csv数据入库
# 
def importCtripCSV2PG(cur,csvpath,csvcity,csvprovice):

#     csvPath = "pois_bj_ctrip.csv"
    with open(csvpath,"r",encoding="utf-8") as f:
        reader = csv.reader(f)
        datas = list(reader)

    print("csv datas number = {}".format(len(datas)))
    print("")

    widgets = ['爬虫数据导入PG进度: ', Percentage(), ' ', Bar('#'), ' ', Timer(), ' ', ETA(), ' ']
    bar = ProgressBar(widgets=widgets, maxval=len(datas))
    bar.start()#

    #sCol = "namec,namec2,namee,tags,brief,ticket,ticketmin,ticketadult,ticketchild,ticketold,ticketstudent,scores,scorenumber,opentime,spendtime,introduceinfo,salesinfo,guid,quyu,city,province,contry"
    #sCol = "namec,namee,tags,brief,ticket,ticketmin,ticketadult,ticketchild,ticketold,ticketstudent,scores,scorenumber,opentime,spendtime,introduceinfo,salesinfo,city,province,contry"

    sCol = "namec,namee,tags,brief,ticket,ticketmin,ticketadult,ticketchild,ticketold,ticketstudent,scores,scorenumber,opentime,spendtime,introduceinfo,salesinfo,x,y,geos,photourl,city,province,contry"


#     print("sCol number = {}".format(len(sCol)))

    for i in range(0,len(datas)):
        
        bar.update(i+1)
        
        data = datas[i]
            
        if data==None or len(data)==0:
            #print("{}行None值".format(i))
            continue 
        if data[0]=="名称" or data[0]==None :
            continue 

        geoxy_encode = geohash.encode( data[5],data[4],7 )
        
        values = ",".join("\'{0}\'".format(w) for w in [
            data[0].replace("\'","''"),
            data[1].replace("\'","''"),
            data[6],
            data[7],
            data[8],
            data[9],
            data[13],
            data[16],
            data[14],
            data[15],
            data[10],
            data[11],
            data[18].replace("\'","''"),
            data[17].replace("\'","''"),
            data[19].replace("\'","''"),
            data[20].replace("\'","''"),
            data[5],
            data[4],
            geoxy_encode,
            data[12],
            csvcity,
            csvprovice,
            "中国"
            ]
            )
    #     print(values)
        
        sqlpre = "insert into poibaseinfo({})".format(sCol)
        sql = sqlpre+" values ({})".format(values)
    #     print(sql)
        try:
            cur.execute(sql)
        except psycopg2.Error as e:
            print(e)
    
    bar.finish()
    
    
    
if __name__ == '__main__':
        
    user = "postgres"
    pwd = "你的密码"
    port = "5432"
    hostname = "127.0.0.1"

    conn = psycopg2.connect(database = "tripdata", user = user, password = pwd, host = "127.0.0.1", port = port)
    print(conn)

    sql = "select * from poibaseinfo"
    cur = conn.cursor()
    cur.execute(sql)

    cols = cur.description
    print("PG cols number = {}".format(len(cols)))
    
    #CSV文件导入PG
    csvPath = "pois_bj_ctrip.csv"    
    importCtripCSV2PG(cur,csvPath,"北京","北京")

    #其他CSV文件导入PG
    #TODO...
    

    conn.commit()

    cur.close()
    conn.close()

    print("ok")
03-11 09:24