# Gaussian Mixture Model
# Interface to scikit-learn implementation
from sklearn.mixture import GaussianMixture
import autopandas
[docs]class GMM():
[docs] def __init__(self, **kwargs):
""" Gaussian Mixture Model.
"""
self.model = GaussianMixture(**kwargs)
self.columns = None
self.indexes = None
[docs] def fit(self, data, **kwargs):
""" Train the generator with data.
:param data: The training data.
"""
self.columns = data.columns
self.indexes = data.indexes
self.model.fit(data, **kwargs)
[docs] def sample(self, n=1, **kwargs):
""" Sample from trained GMM.
:param n: Number of examples to sample.
"""
if self.indexes is None:
raise Exception('You firstly need to train the GMM before sampling. Please use fit method.')
else:
gen_data = self.model.sample(n, **kwargs)[0] # sklearn's gmm return a tuple
return autopandas.AutoData(gen_data, columns=self.columns, indexes=self.indexes)