Open In App

ML | K-means++ Algorithm

Last Updated : 11 Mar, 2025
Summarize
Comments
Improve
Suggest changes
Like Article
Like
Share
Report
News Follow

Clustering is one of the most common tasks in machine learning where we group similar data points together. K-Means Clustering is one of the simplest and most popular clustering algorithms but it has one major drawback — the random initialization of cluster centers often leads to poor clustering results. Some clusters may have no points or multiple centroids may end up in the same cluster.

For example consider the images shown below done using K-means clustering.

Here we can see clusters are not forming properly. To solve this problem KMeans++ was introduced which improves the way initial cluster centers are selected and make the clustering results more accurate and faster.

This is how the clustering should have been:  

KMeans++ is an improved version of the KMeans algorithm that automatically chooses better starting points instead of selecting them randomly. The key idea behind KMeans++ is that it chooses the initial cluster centers in a smart way to ensure they are spread out which helps the algorithm converge faster and gives better clustering results.

How K-mean++ Algorithm Works

The KMeans++ algorithm works in two steps:

1. Initialization Step:

  • Choose the first cluster center randomly from the data points.
  • For each remaining cluster center select the next center based on the probability that is proportional to the square of the distance between the data point and the closest selected center.

2. Clustering Step:

  • After selecting the initial centers KMeans++ performs clustering the same way as KMeans:
  • Assign each data point to the nearest cluster center.
  • Recalculate cluster centers by finding the average of all points in each cluster.
  • Repeat the steps until the cluster centers do not change or a fixed number of iterations is reached.

Implementation in Python

Let’s understand how KMeans++ initializes centroids step by step using the following implementation:

  1. Dataset Creation: Four separate Gaussian clusters are generated with different means and covariances to simulate different groupings in the data.
  2. Plot Function: Visualizes the dataset and selected centroids.
  3. Distance Function: Calculates the Euclidean distance between two points.
  4. Initialize Function:
    • The first centroid is selected randomly.
    • The next centroid is the farthest point from the previously selected centroid.
    • This process continues until all k centroids are selected
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys

mean_01 = np.array([0.0, 0.0])
cov_01 = np.array([[1, 0.3], [0.3, 1]])
dist_01 = np.random.multivariate_normal(mean_01, cov_01, 100)

mean_02 = np.array([6.0, 7.0])
cov_02 = np.array([[1.5, 0.3], [0.3, 1]])
dist_02 = np.random.multivariate_normal(mean_02, cov_02, 100)

mean_03 = np.array([7.0, -5.0])
cov_03 = np.array([[1.2, 0.5], 
                   [0.5, 1]]) 
dist_03 = np.random.multivariate_normal(mean_03, cov_01, 100)

mean_04 = np.array([2.0, -7.0])
cov_04 = np.array([[1.2, 0.5], [0.5, 1.3]])
dist_04 = np.random.multivariate_normal(mean_04, cov_01, 100)

data = np.vstack((dist_01, dist_02, dist_03, dist_04))
np.random.shuffle(data)

def plot(data, centroids):
    plt.scatter(data[:, 0], data[:, 1], marker='.',
                color='gray', label='data points')
    plt.scatter(centroids[:-1, 0], centroids[:-1, 1],
                color='black', label='previously selected centroids')
    plt.scatter(centroids[-1, 0], centroids[-1, 1],
                color='red', label='next centroid')
    plt.title('Select % d th centroid' % (centroids.shape[0]))

    plt.legend()
    plt.xlim(-5, 12)
    plt.ylim(-10, 15)
    plt.show()

def distance(p1, p2):
    return np.sqrt(np.sum((p1 - p2)**2))

# initialization algorithm
def initialize(data, k):
    '''
    initialized the centroids for K-means++
    inputs:
        data - numpy array of data points having shape (200, 2)
        k - number of clusters 
    '''
    centroids = []
    centroids.append(data[np.random.randint(
        data.shape[0]), :])
    plot(data, np.array(centroids))

    for c_id in range(k - 1):

        dist = []
        for i in range(data.shape[0]):
            point = data[i, :]
            d = sys.maxsize

            for j in range(len(centroids)):
                temp_dist = distance(point, centroids[j])
                d = min(d, temp_dist)
            dist.append(d)

        dist = np.array(dist)
        next_centroid = data[np.argmax(dist), :]
        centroids.append(next_centroid)
        dist = []
        plot(data, np.array(centroids))
    return centroids

centroids = initialize(data, k=4)

Output: 

It shows the dataset with the first randomly selected centroid (in red). No black points are visible since only one centroid is selected.

The second centroid is selected which is the farthest point from the first centroid. The first centroid becomes black and the new centroid is marked in red

The third centroid is selected. The two previously selected centroids are shown in black while the newly selected centroid is in red.

The final centroid is selected completing the initialization. Three previously selected centroids are in black and the last selected centroid is in red.

Applications of k-means++ algorithm

  • Image segmentation: It can be used to segment images into different regions based on their color or texture features. This is useful in computer vision applications, such as object recognition or tracking.
  • Customer segmentation: These are used to group customers into different segments based on their purchasing habits, demographic data, or other characteristics. This is useful in marketing and advertising applications, as it can help businesses target their marketing efforts more effectively.
  • Recommender systems: K-means++ can be used to recommend products or services to users based on their past purchases or preferences. This is useful in e-commerce and online advertising applications.        

Note: Although the initialization in K-means++ is computationally more expensive than the standard K-means algorithm, the run-time for convergence to optimum is drastically reduced for K-means++. This is because the centroids that are initially chosen are likely to lie in different clusters already.



Next Article

Similar Reads

three90RightbarBannerImg