Skip to content

easy.ContextualizedBayesianNetworks.predict breaks #266

@cnellington

Description

@cnellington

To replicate

import numpy as np
from contextualized.easy import ContextualizedBayesianNetworks

n_samples = 100
n_contexts = 5
n_features = 10
n_bootstraps = 3
C = np.random.uniform(-1, 1, size=(n_samples, n_contexts))
X = np.random.normal(0, 1, size=(n_samples, n_features))

cbn = ContextualizedBayesianNetworks(n_bootstraps=n_bootstraps)
cbn.fit(C, X, max_epochs=1)

y_pred = cbn.predict(C, X)  # Problematic line

# Extra tests
assert y_pred.shape == (n_samples, n_features)
y_pred_avg = cbn.predict(C, X, individual_preds=False)
assert y_pred_avg.shape == (n_samples, n_features)
y_pred_individual = cbn.predict(C, X, individual_preds=True)
assert y_pred_individual.shape == (n_bootstraps, n_samples, n_features)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions