Copyright (c) 2019, Matt Pewsey

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.patches import Rectangle

__all__ = ['SpatialHash']

[docs]class SpatialHash(object): """ A class representing a spatial hash structure for efficient distance queries. Parameters ---------- points : list A list of points of shape (N, D). grid : float The width of each spatial hash grid element. For 2D spaces, this value represents the width and height of each square spatial hash partition. Examples -------- The below example uses the :meth:`.query_point` and :meth:`.query_range` methods to search for points within the specified offset of a point and within the specified offset of a range, respectively. The found points are shown in green. .. plot:: ../examples/survey/ :include-source: """ def __init__(self, points, grid): self._grid = grid self._dict = {} self._add_points(points) def __repr__(self): return '{}({}, grid={!r})'.format(type(self).__name__, self.points.shape, self._grid) def _check_shape(self, point): """ Checks that the input point conforms to the hash dimensionality. Parameters ---------- point : array An array of shape (D,). """ if point.shape[0] != self._dim: raise ValueError('Point is {}D but should be {}D.' .format(point.shape[0], self._dim)) def _add_points(self, points): """ Adds the input list of points to the spatial hash. Parameters ---------- points : list A list of points of shape (N, D). """ points = np.asarray(points) self._dim = points.shape[1] self.points = points hashes = self._multi_hash(points, norm=True) odict = self._dict for i, h in enumerate(hashes): if h not in odict: odict[h] = [] odict[h].append(i) def _multi_hash(self, points, norm): """ Returns a list of dictionary hash keys corresponding to the input points. Parameters ---------- points : list A list of points of shape (N, D). norm : bool If True, normalizes the points to their grid index. Otherwise, assumes that the input points are grid indices. """ if norm: points = np.asarray(points) // self._grid return (hash(tuple(x)) for x in points) def _hash(self, point, norm): """ Returns the hash key corresponding to the input point. Parameters ---------- point : list A list of shape (D,). norm : bool If True, normalizes the points to their grid index. Otherwise, assumes that the input points are grid indices. """ if norm: point = np.asarray(point) // self._grid return hash(tuple(point))
[docs] def multi_get(self, points, norm=True): """ Returns the point indices corresponding to the input array of points. Parameters ---------- points : list A list of points of shape (N, D). norm : bool If True, normalizes the points to their grid index. Otherwise, assumes that the input points are grid indices. """ result = [] odict = self._dict for x in self._multi_hash(points, norm): result.extend(odict.get(x, [])) return np.asarray(np.unique(result), dtype='int')
[docs] def get(self, point, norm=True): """ Returns the point indices correesponding to the same hash as the input point. Parameters ---------- point : list A list of shape (D,). norm : bool If True, normalizes the points to their grid index. Otherwise, assumes that the input points are grid indices. """ point = self._hash(point, norm) return self._dict.get(point, [])
def _query_point_hash(self, point, ro, ri): # Calculate worst case offsets diag = self._grid * self._dim**0.5 ri = max(ri - diag, 0) ro = ro + diag # Create meshgrid of hash indices p = np.column_stack([point - ro, point + ro]) // self._grid p = [np.arange(a, b+1) for a, b in p] p = np.array(np.meshgrid(*p), dtype='int').T.reshape(-1, self._dim) # Filter hashes by distance dist = np.linalg.norm(point - self._grid * p, axis=1) if ri == 0: p = p[dist <= ro] else: p = p[(dist <= ro) & (dist >= ri)] return self.multi_get(p, norm=False)
[docs] def query_point(self, point, ro, ri=0): """ Returns an array of point indices for all points contained within the specified inner and outer radii from the input point. Parameters ---------- point : list A list of shape (D,). ro : float The outer radius beyond which points will be excluded. ri : float The inner radius before which points will be excluded. """ point = np.asarray(point) self._check_shape(point) # Get hash filtered points result = self._query_point_hash(point, ro, ri) p = self.points[result] # Filter points by distance dist = np.linalg.norm(p - point, axis=1) if ri == 0: f = (dist <= ro) else: f = (dist <= ro) & (dist >= ri) return result[f][dist[f].argsort()]
def _query_range_hash(self, a, b, ro, ri, u, l): # Calculate worst case offsets diag = self._grid * self._dim**0.5 ri = max(ri - diag, 0) ro = ro + diag # Create meshgrid of hash indices x = np.column_stack([a - ro, a - ro]).min(axis=1) y = np.column_stack([b + ro, b + ro]).max(axis=1) p = np.column_stack([x, y]) // self._grid p = [np.arange(x, y+1) for x, y in p] p = np.array(np.meshgrid(*p), dtype='int').T.reshape(-1, self._dim) # Filter hashes by projection and offset v = self._grid * p - b proj =, u) dist = np.linalg.norm(v - proj.reshape(-1, 1) * u, axis=1) del v if ri == 0: p = p[(proj >= -diag) & (proj <= l+diag) & (dist <= ro)] else: p = p[(proj >= -diag) & (proj <= l+diag) & (dist <= ro) & (dist >= ri)] return self.multi_get(p, norm=False)
[docs] def query_range(self, a, b, ro, ri=0): """ Returns an array of point indices for all points along the specified range within the inner and outer offsets. Parameters ---------- a : list The starting point for the range. The point should be of shape (D,). b : list The ending point for the range. The point should be of shape (D,). ro : float The outer offset beyond which points will be excluded. ri : float The inner offset before which points will be excluded. """ a = np.asarray(a) b = np.asarray(b) self._check_shape(a) self._check_shape(b) # Create unit vector for range u = a - b l = np.linalg.norm(u) if l == 0: return self.query_point(a, ro, ri) u = u / l # Get hash filtered points result = self._query_range_hash(a, b, ro, ri, u, l) p = self.points[result] # Filter points by projection and offset v = p - b proj =, u) dist = np.linalg.norm(v - proj.reshape(-1, 1) * u, axis=1) if ri == 0: f = (proj >= 0) & (proj <= l) & (dist <= ro) else: f = (proj >= 0) & (proj <= l) & (dist <= ro) & (dist >= ri) return result[f][dist[f].argsort()]
def _plot_1d(self, ax, sym): """ Creates a 1D plot. Parameters ---------- ax : :class:`matplotlib.axes.Axes` The axes to which the plot will be added. If None, a new figure and axes will be created. sym : dict A dictionary of plot symbols with any of the following keys: * points: Point symbols, default is 'r.' * hash: Hash region color, default is 'b' """ # Create plot if ax is None: lim = np.array([self.points.min(), self.points.max()]) lim = self._grid * (lim // self._grid + [-1, 2]) ticks = np.arange(lim[0], lim[1] + self._grid, self._grid) fig = plt.figure() ax = fig.add_subplot(111, xlim=lim, ylim=self._grid * np.array([-0.5, 0.5]), xticks=ticks, yticks=[0] ) ax.grid('major', alpha=0.4) # Plot hash regions if sym['hash'] is not None: y = -0.5 * self._grid xs = self._grid * (self.points // self._grid) xs = set(map(tuple, xs)) for x in xs: rect = Rectangle((x[0], y), self._grid, self._grid, color=sym['hash'], alpha=0.2, zorder=1 ) ax.add_artist(rect) # Plot points if sym['points'] is not None: x = self.points ax.plot(x[:,0], np.zeros(x.shape[0]), sym['points']) return ax def _plot_2d(self, ax, sym): """ Creates a 2D plot. Parameters ---------- ax : :class:`matplotlib.axes.Axes` The axes to which the plot will be added. If None, a new figure and axes will be created. sym : dict A dictionary of plot symbols with any of the following keys: * points: Point symbols, default is 'r.' * hash: Hash region color, default is 'b' """ # Create plot if ax is None: lim = np.array([self.points.min(), self.points.max()]) lim = self._grid * (lim // self._grid + [-1, 2]) ticks = np.arange(lim[0], lim[1] + self._grid, self._grid) fig = plt.figure() ax = fig.add_subplot(111, xlim=lim, ylim=lim, xticks=ticks, yticks=ticks, aspect='equal' ) ax.grid('major', alpha=0.4) # Plot hash squares if sym['hash'] is not None: xs = self._grid * (self.points // self._grid) xs = set(map(tuple, xs)) for x in xs: rect = Rectangle(x, self._grid, self._grid, color=sym['hash'], fill=False, zorder=1 ) ax.add_artist(rect) # Plot points if sym['points'] is not None: x = self.points ax.plot(x[:,0], x[:,1], sym['points']) return ax def _plot_3d(self, ax, sym): """ Creats a 3D plot. Parameters ---------- ax : :class:`matplotlib.axes.Axes` The axes to which the plot will be added. If None, a new figure and axes will be created. sym : dict A dictionary of plot symbols with any of the following keys: * points: Point symbols, default is 'r.' * hash: Hash region color, default is 'b' """ # Create plot if ax is None: lim = np.array([self.points.min(), self.points.max()]) lim = self._grid * (lim // self._grid + [-1, 2]) ticks = np.arange(lim[0], lim[1] + self._grid, self._grid) fig = plt.figure() ax = fig.add_subplot(111, projection='3d', xlim=lim, ylim=lim, zlim=lim, xticks=ticks, yticks=ticks, zticks=ticks, aspect='equal' ) # Plot hash cubes if sym['hash'] is not None: xs = self._grid * (self.points // self._grid) xs = set(map(tuple, xs)) cube = [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0], [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]] simplices = [[0, 1, 2, 3], [0, 1, 5, 4], [4, 5, 6, 7], [5, 1, 2, 6], [6, 7, 3, 2], [4, 7, 3, 0]] cube = self._grid * np.array(cube) for x in xs: x = cube + x x = [[x[i] for i in s] for s in simplices] poly = Poly3DCollection(x, alpha=0.05) poly.set_facecolor(sym['hash']) poly.set_edgecolor(sym['hash']) ax.add_collection(poly) # Plot points if sym['points'] is not None: x = self.points ax.plot(x[:,0], x[:,1], x[:,2], sym['points']) return ax
[docs] def plot(self, ax=None, symbols={}): """ Creates a plot of the spatial hash. Cannot create plots for hashes greater than 3 dimensions. Parameters ---------- ax : :class:`matplotlib.axes.Axes` The axes to which the plot will be added. If None, a new figure and axes will be created. symbols : dict A dictionary of plot symbols with any of the following keys: * points: Point symbols, default is 'r.' * hash: Hash region color, default is 'b' Examples -------- .. plot:: ../examples/survey/ :include-source: """ # Plot symbols sym = dict( points='r.', hash='b' ) sym.update(symbols) if self._dim == 1: return self._plot_1d(ax, sym) elif self._dim == 2: return self._plot_2d(ax, sym) elif self._dim == 3: return self._plot_3d(ax, sym) else: raise ValueError('Hash is {}D but plot only supports 1D, 2D, or 3D.' .format(self._dim))