问题描述
我修改了 BernoulliRBM类 scikit-learn 使用一组 softmax 可见单元.在这个过程中,我添加了一个额外的 Numpy 数组 visible_config
作为类属性,它在构造函数中初始化如下:
I modified the BernoulliRBM class of scikit-learn to use groups of softmax visible units. In the process, I added an extra Numpy array visible_config
as a class attribute which is initialized in the constructor as follows using:
self.visible_config = np.cumsum(np.concatenate((np.asarray([0]),
visible_config), axis=0))
其中 visible_config
是作为输入传递给构造函数的 Numpy 数组.当我直接使用 fit()
函数训练模型时,代码运行没有错误.但是,当我使用 GridSearchCV
结构时,出现以下错误
where visible_config
is a Numpy array passed as an input to the constructor. The code runs without errors when I directly use the fit()
function to train the model. However, when I use the GridSearchCV
structure, I get the following error
Cannot clone object SoftmaxRBM(batch_size=100, learning_rate=0.01, n_components=100, n_iter=100,
random_state=0, verbose=True, visible_config=[ 0 21 42 63]), as the constructor does not seem to set parameter visible_config
这似乎是类的实例与其由 sklearn.base.clone 因为 visible_config
没有被正确复制.我不知道如何解决这个问题.它在文档中说 sklearn.base.clone
使用 deepcopy()
,所以不应该复制 visible_config
吗?有人可以解释一下我可以在这里尝试什么吗?谢谢!
This seems to be a problem in the equality check between the instance of the class and its copy created by sklearn.base.clone because visible_config
does not get copied correctly. I'm not sure how to fix this. It says in the documentation that sklearn.base.clone
uses a deepcopy()
, so shouldn't visible_config
also get copied? Can someone please explain what I can try here? Thanks!
推荐答案
如果没有看到您的代码,就很难确切地知道出了什么问题,但是您在这里违反了 scikit-learn API 约定.估算器中的构造函数应该只将属性设置为用户作为参数传递的值.所有计算都应该在 fit
中进行,如果 fit
需要存储计算结果,它应该在带有尾随下划线(_代码>).这种约定使
clone
和诸如 GridSearchCV
之类的元估计器起作用.
Without seeing your code, it's hard to tell exactly what goes wrong, but you are violating a scikit-learn API convention here. The constructor in an estimator should only set attributes to the values the user passes as arguments. All computation should occur in fit
, and if fit
needs to store the result of a computation, it should do so in an attribute with a trailing underscore (_
). This convention is what makes clone
and meta-estimators such as GridSearchCV
work.
(*) 如果您在主代码库中看到违反此规则的估算器:那将是一个错误,欢迎提供补丁.
(*) If you ever see an estimator in the main codebase that violates this rule: that would be a bug, and patches are welcome.
这篇关于Python scikit-learn:无法克隆对象...因为构造函数似乎没有设置参数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!