# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # based on https://github.com/subhadarship/kmeans_pytorch from __future__ import annotations import torch from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE from kornia.geometry.linalg import euclidean_distance class KMeans: """Implements the kmeans clustering algorithm with euclidean distance as similarity measure. Args: num_clusters: number of clusters the data has to be assigned to cluster_centers: tensor of starting cluster centres can be passed instead of num_clusters tolerance: float value. the algorithm terminates if the shift in centers is less than tolerance max_iterations: number of iterations to run the algorithm for seed: number to set torch manual seed for reproducibility Example: >>> kmeans = kornia.contrib.KMeans(3, None, 10e-4, 100, 0) >>> kmeans.fit(torch.rand((1000, 5))) >>> predictions = kmeans.predict(torch.rand((10, 5))) """ def __init__( self, num_clusters: int, cluster_centers: Tensor | None, tolerance: float = 10e-4, max_iterations: int = 0, seed: int | None = None, ) -> None: KORNIA_CHECK(num_clusters != 0, "num_clusters can't be 0") # cluster_centers should have only 2 dimensions if cluster_centers is not None: KORNIA_CHECK_SHAPE(cluster_centers, ["C", "D"]) self.num_clusters = num_clusters self._cluster_centers = cluster_centers self.tolerance = tolerance self.max_iterations = max_iterations self._final_cluster_assignments: None | Tensor = None self._final_cluster_centers: None | Tensor = None if seed is not None: torch.manual_seed(seed) @property def cluster_centers(self) -> Tensor: if isinstance(self._final_cluster_centers, Tensor): return self._final_cluster_centers if isinstance(self._cluster_centers, Tensor): return self._cluster_centers else: raise TypeError("Model has not been fit to a dataset") @property def cluster_assignments(self) -> Tensor: if isinstance(self._final_cluster_assignments, Tensor): return self._final_cluster_assignments else: raise TypeError("Model has not been fit to a dataset") def _initialise_cluster_centers(self, X: Tensor, num_clusters: int) -> Tensor: """Chooses num_cluster points from X as the initial cluster centers. Args: X: 2D input tensor to be clustered num_clusters: number of desired cluster centers Returns: 2D Tensor with num_cluster rows """ num_samples: int = len(X) perm = torch.randperm(num_samples, device=X.device) idx = perm[:num_clusters] initial_state = X[idx] return initial_state def _pairwise_euclidean_distance(self, data1: Tensor, data2: Tensor) -> Tensor: """Compute pairwise squared distance between 2 sets of vectors. Args: data1: 2D tensor of shape N, D data2: 2D tensor of shape C, D Returns: 2D tensor of shape N, C """ # N*1*D A = data1[:, None, ...] # 1*C*D B = data2[None, ...] distance = euclidean_distance(A, B) return distance def fit(self, X: Tensor) -> None: """Fit iterative KMeans clustering till a threshold for shift in cluster centers or a maximum no of iterations have reached. Args: X: 2D input tensor to be clustered """ # noqa: D205 # X should have only 2 dimensions KORNIA_CHECK_SHAPE(X, ["N", "D"]) if self._cluster_centers is None: self._cluster_centers = self._initialise_cluster_centers(X, self.num_clusters) else: # X and cluster_centers should have same number of columns KORNIA_CHECK( X.shape[1] == self._cluster_centers.shape[1], f"Dimensions at position 1 of X and cluster_centers do not match. \ {X.shape[1]} != {self._cluster_centers.shape[1]}", ) # X = X.to(self.device) current_centers = self._cluster_centers previous_centers: Tensor | None = None iteration: int = 0 while True: # find distance between X and current_centers distance: Tensor = self._pairwise_euclidean_distance(X, current_centers) cluster_assignment = distance.argmin(-1) previous_centers = current_centers.clone() for index in range(self.num_clusters): selected = torch.nonzero(cluster_assignment == index).squeeze() selected = torch.index_select(X, 0, selected) # edge case when a certain cluster centre has no points assigned to it # just choose a random point as it's update if selected.shape[0] == 0: selected = X[torch.randint(len(X), (1,), device=X.device)] current_centers[index] = selected.mean(dim=0) # sum of distance of how much the newly computed clusters have moved from their previous positions center_shift = torch.sum(torch.sqrt(torch.sum((current_centers - previous_centers) ** 2, dim=1))) iteration = iteration + 1 if self.tolerance is not None and center_shift**2 < self.tolerance: break if self.max_iterations != 0 and iteration >= self.max_iterations: break self._final_cluster_assignments = cluster_assignment self._final_cluster_centers = current_centers def predict(self, x: Tensor) -> Tensor: """Find the cluster center closest to each point in x. Args: x: 2D tensor Returns: 1D tensor containing cluster id assigned to each data point in x """ # x and cluster_centers should have same number of columns KORNIA_CHECK( x.shape[1] == self.cluster_centers.shape[1], f"Dimensions at position 1 of x and cluster_centers do not match. \ {x.shape[1]} != {self.cluster_centers.shape[1]}", ) distance = self._pairwise_euclidean_distance(x, self.cluster_centers) cluster_assignment = distance.argmin(-1) return cluster_assignment