Skip to content

_split_train_data error #255

@alexanderchang1

Description

@alexanderchang1

Hi,

Is ContextualizedBayesianNetworks able to handle a Y target? The documentation for the fit function implies yes.

from contextualized.easy import ContextualizedBayesianNetworks

cbn = ContextualizedBayesianNetworks(
    encoder_type='mlp', num_archetypes=16,
    n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
    sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
    learning_rate=1e-3)
cbn.fit(C, X, Y, max_epochs=10, es_verbose=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[29], line 8
      1 from contextualized.easy import ContextualizedBayesianNetworks
      3 cbn = ContextualizedBayesianNetworks(
      4     encoder_type='mlp', num_archetypes=16,
      5     n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
      6     sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
      7     learning_rate=1e-3)
----> 8 cbn.fit(C, X, Y, max_epochs=10, es_verbose=True)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/easy/wrappers/SKLearnWrapper.py:516, in SKLearnWrapper.fit(self, *args, **kwargs)
    514 for bootstrap in range(self.n_bootstraps):
    515     model = self.base_constructor(**organized_kwargs["model"])
--> 516     train_data, val_data = self._split_train_data(
    517         *args, **organized_kwargs["data"]
    518     )
    519     train_dataloader, val_dataloader = self._build_dataloaders(
    520         model,
    521         train_data,
    522         val_data,
    523         **organized_kwargs["data"],
    524     )
    525     # Makes a new trainer for each bootstrap fit - bad practice, but necessary here.

TypeError: _split_train_data() takes 3 positional arguments but 4 were given

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions