Transport Analytics Training Series - Last Revision: October 2022

The Nearest Neighbour Heuristic¶

Never one to give up easily, Santa decided to roll up his sleeves and put together an algorithm that could find a rough plan for the journey without waiting so long

What do you think about the results?

This notebook provides a simple implementation of the Nearest Neighbour Heuristic. The algorithm is performs much faster than our implementation of the MTZ model. The solution provided is a valid one, since it does not contain any subtours, and respects all the constraints of our TSP model.

However, as a greedy heuristic, the results leave much to be desired... Scroll down and you will see what we mean!

In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import collections
import seaborn as sns; sns.set()
from sklearn.datasets import make_blobs
import pandas as pd

%matplotlib inline

plot_size   = 16
plot_width  = 10
plot_height = 5

params = {'legend.fontsize': 'large',
          'figure.figsize': (plot_width,plot_height),
          'axes.labelsize': plot_size,
          'axes.titlesize': plot_size,
          'xtick.labelsize': plot_size*0.75,
          'ytick.labelsize': plot_size*0.75,
          'axes.titlepad': 25}
plt.rcParams.update(params)
plt.rcParams.update(params)
In [2]:
num_cities = 14

cities_names = [f'Node {i}' for i in range(num_cities)]
In [3]:
center_box = (100, 200) 

cities_coord, _ = make_blobs(n_samples=num_cities, 
                           centers=1, 
                           cluster_std=20, 
                           center_box=center_box, 
                           random_state = 10)


cities_coord_dict = {name: coord for name,coord in zip(cities_names, cities_coord)}
In [4]:
cities_coord[:, 0]
plt.scatter(cities_coord[:, 0], 
            cities_coord[:, 1], 
            s=plot_size*2, 
            cmap='viridis');

for i in range(num_cities):
    plt.annotate(i+1, (cities_coord[i, 0]+1, cities_coord[i, 1]+0.5))
In [5]:
start_node = cities_names[6]
start_node
Out[5]:
'Node 6'
In [6]:
from scipy.spatial import distance

dist_matrix = distance.cdist(cities_coord, cities_coord, 'euclidean')

dist_matrix = np.where(dist_matrix==0, np.Inf, dist_matrix)

dist_matrix.shape
Out[6]:
(14, 14)
In [7]:
df = pd.DataFrame(dist_matrix, index=cities_names, columns=cities_names)

df.to_csv('nearest_neighbour.csv', index=False)

df
Out[7]:
Node 0 Node 1 Node 2 Node 3 Node 4 Node 5 Node 6 Node 7 Node 8 Node 9 Node 10 Node 11 Node 12 Node 13
Node 0 inf 28.856187 15.699552 7.732602 46.397038 22.144679 23.873272 28.179447 20.448168 38.646409 50.267460 37.969171 76.139522 16.703630
Node 1 28.856187 inf 13.181848 31.509097 31.171604 7.704782 51.458387 30.896988 31.959957 16.471773 34.701714 23.643818 50.554715 43.198549
Node 2 15.699552 13.181848 inf 19.232318 36.621664 6.772074 38.803820 26.570343 22.711135 24.592380 39.510767 27.999713 62.153336 30.383238
Node 3 7.732602 31.509097 19.232318 inf 43.277932 26.002709 20.019733 22.532181 28.178039 43.630036 56.776046 35.337774 75.316976 21.214200
Node 4 46.397038 31.171604 36.621664 43.277932 inf 36.293663 60.031842 23.760875 59.332085 45.612566 64.275853 8.665196 35.760830 63.086706
Node 5 22.144679 7.704782 6.772074 26.002709 36.293663 inf 45.528377 30.725986 24.532734 18.036024 34.048241 28.047317 58.178126 35.680801
Node 6 23.873272 51.458387 38.803820 20.019733 60.031842 45.528377 inf 36.640480 39.614060 62.506215 73.455586 53.025378 93.974081 22.297529
Node 7 28.179447 30.896988 26.570343 22.532181 23.760875 30.725986 36.640480 inf 46.357641 47.216233 64.636059 18.189215 59.016978 43.717685
Node 8 20.448168 31.959957 22.711135 28.178039 59.332085 24.532734 39.614060 46.357641 inf 32.029531 36.031763 50.709054 82.453697 19.593126
Node 9 38.646409 16.471773 24.592380 43.630036 45.612566 18.036024 62.506215 47.216233 32.029531 inf 18.703091 39.098604 55.865019 48.897490
Node 10 50.267460 34.701714 39.510767 56.776046 64.275853 34.048241 73.455586 64.636059 36.031763 18.703091 inf 57.780461 70.505610 55.519599
Node 11 37.969171 23.643818 27.999713 35.337774 8.665196 28.047317 53.025378 18.189215 50.709054 39.098604 57.780461 inf 41.164498 54.672800
Node 12 76.139522 50.554715 62.153336 75.316976 35.760830 58.178126 93.974081 59.016978 82.453697 55.865019 70.505610 41.164498 inf 92.277455
Node 13 16.703630 43.198549 30.383238 21.214200 63.086706 35.680801 22.297529 43.717685 19.593126 48.897490 55.519599 54.672800 92.277455 inf
In [8]:
def plot_solution(_N, _path, _cities_coord, _cities_coord_dict):
    
    plt.scatter(_cities_coord[:, 0], 
                _cities_coord[:, 1], 
                s=plot_size*2, 
                cmap='viridis',
                zorder = 10000);
    lines = []
    for p in range(len(_path) - 1):
        i = _path[p]
        j = _path[p+1]
        
        colour = 'black'
        if p + 1 == (len(_path) - 1):
            colour = 'red'
        
        plt.arrow(_cities_coord_dict[i][0], _cities_coord_dict[i][1],_cities_coord_dict[j][0] - _cities_coord_dict[i][0], _cities_coord_dict[j][1] - _cities_coord_dict[i][1], color=colour)
        
    for i in range(num_cities):
        plt.annotate(i+1, (cities_coord[i, 0]+1, cities_coord[i, 1]+0.5))
    plt.show()
In [9]:
name_to_index = {i:idx for idx,i in enumerate(cities_names)}
visited = [False if start_node is not _ else True for _ in cities_names]
route = [start_node]
curr = start_node

while False in visited:

    row = dist_matrix[name_to_index[curr]]

    for idx in range(len(row)):
        if visited[idx]:
            row[idx] = np.Inf

    closest = np.argmin(row)

    visited[closest] = True
    curr = cities_names[closest]
    route.append(curr)

    plot_solution(cities_names, route, cities_coord, cities_coord_dict)
    
route.append(start_node)
plot_solution(cities_names, route, cities_coord, cities_coord_dict)