diff --git a/sklearn_extra/cluster/_k_medoids.py b/sklearn_extra/cluster/_k_medoids.py index f9d964df..f8e23a3e 100644 --- a/sklearn_extra/cluster/_k_medoids.py +++ b/sklearn_extra/cluster/_k_medoids.py @@ -155,6 +155,7 @@ def __init__( init="heuristic", max_iter=300, random_state=None, + **metrickw ): self.n_clusters = n_clusters self.metric = metric @@ -162,6 +163,7 @@ def __init__( self.init = init self.max_iter = max_iter self.random_state = random_state + self.metrickw=metrickw def _check_nonnegative_int(self, value, desc, strict=True): """Validates if value is a valid integer > 0""" @@ -235,7 +237,7 @@ def fit(self, X, y=None): % (self.n_clusters, X.shape[0]) ) - D = pairwise_distances(X, metric=self.metric) + D = pairwise_distances(X, metric=self.metric,**self.metrickwm) medoid_idxs = self._initialize_medoids( D, self.n_clusters, random_state_, X @@ -379,10 +381,10 @@ def transform(self, X): check_is_fitted(self, "cluster_centers_") Y = self.cluster_centers_ - kwargs = {} + if self.metric == "seuclidean": kwargs["V"] = np.var(np.vstack([X, Y]), axis=0, ddof=1) - DXY = pairwise_distances(X, Y=Y, metric=self.metric, **kwargs) + DXY = pairwise_distances(X, Y=Y, metric=self.metric, **self.metrickw) return DXY @@ -421,7 +423,7 @@ def predict(self, X): X, Y=self.cluster_centers_, metric=self.metric, - metric_kwargs=kwargs, + metric_kwargs=self.metrickw, ) return pd_argmin