Skip to content

Updating Dimod Sampler#73

Open
anahitamansouri wants to merge 3 commits into
dwavesystems:mainfrom
anahitamansouri:feature/dimod-conditional-sampling
Open

Updating Dimod Sampler#73
anahitamansouri wants to merge 3 commits into
dwavesystems:mainfrom
anahitamansouri:feature/dimod-conditional-sampling

Conversation

@anahitamansouri
Copy link
Copy Markdown
Collaborator

This PR adds:

  • Conditional sampling feature for Dimod sampler.
  • Tests for this feature.

@anahitamansouri anahitamansouri self-assigned this Apr 7, 2026
@anahitamansouri anahitamansouri added the enhancement New feature or request label Apr 7, 2026
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a quick first pass over the main function.

super().__init__()

def sample(self, x: torch.Tensor | None = None) -> torch.Tensor:
"""Sample from the dimod sampler and return the corresponding tensor.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstrings to reflect behaviour of conditional sampling

Comment on lines +118 to +124
for i in range(x.shape[0]):
# Fresh BQM
bqm = dimod.BinaryQuadraticModel.from_ising(h, J)

# Build conditioning dict
conditioned = {node: int(x[i, j].item())
for j, node in enumerate(nodes) if mask[i, j]}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we iterate over x instead of range(x.shape[0])? Seems more natural

Comment thread dwave/plugins/torch/samplers/dimod_sampler.py Outdated
results.append(full)
continue

# Sample one configuration per input
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be documented in the docstring. It should also raise a warning if num_reads is supplied and overwritten.

# Sample one configuration per input
sample_kwargs = dict(self._sampler_params)
sample_kwargs["num_reads"] = 1
sample_set = AggregatedSamples.spread(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for this line?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you mean calling the spread? I was following the earlier use case. You're right, it's redundant. I'll remove it then.

Comment on lines +163 to +169
sample = sample_set.first.sample

# Reconstruct full sample
full = torch.empty(n_nodes, device=device)
for j, node in enumerate(nodes):
full[j] = conditioned[node] if node in conditioned else float(sample[node])
results.append(full)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample's variable ordering may not be identical to that of grbm.nodes.
I'd add a test to verify correct ordering first.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anahitamansouri missed this one?

Copy link
Copy Markdown
Collaborator Author

@anahitamansouri anahitamansouri Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I didn't. The order is the same as that. As you suggested I have written the test to show that. You can see your sampleset_to_tensor also uses this order.

@@ -22,7 +22,7 @@
from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set`
which is overwritten by subsequent calls.
which is overwritten by subsequent calls. When ``x`` is provided (conditional sampling), exactly
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there's a better way to document and handle this. The sample size parameter of a sampler is not guaranteed to use num_reads. @anahitamansouri any thoughts?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was also my challenge.

"(one sample per input row).",
UserWarning,
)
sample_kwargs["num_reads"] = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anahitamansouri do you think this is sufficient? i.e., can we safely ignore the num_reads parameters altogether and change the documentation to ~= "only the first sample is retained ..."

Alternatively, and perhaps the right way to do this, is to return tensors of shape $(b, N, d)$ where $b$ is batch size, $N$ is the sample size (i.e., what we enforce to be $1$ here), and $d$ is model dimension / number of qubits.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. We discussed to keep it (b, d) to be consistent with other samplers as you remember. I will update the documentation.

Comment on lines +190 to +194
self._sample_set = dimod.SampleSet.from_samples(
(samples.cpu().numpy(), self._grbm.nodes),
vartype=dimod.SPIN,
energy=energies,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why explicitly reconstruct the sample set?

Copy link
Copy Markdown
Collaborator Author

@anahitamansouri anahitamansouri Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting self._sample_set part is only necessary to be consistent with your code as you have this field in the class and sample_set function needs that. It is not necessary for the logic of sample function. If you remember, I didn't have this in the earlier version and I told you I have not set this. This is one way I though of addressing it. Do you mean this way of creating samples is wrong? Do you have any suggestions? :)

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I think there are at least a couple simplifications, 1) self._sample_set = sample_set and note in documentation that it's the latest sample set, 2) concatenate all samplesets

Comment on lines +163 to +169
sample = sample_set.first.sample

# Reconstruct full sample
full = torch.empty(n_nodes, device=device)
for j, node in enumerate(nodes):
full[j] = conditioned[node] if node in conditioned else float(sample[node])
results.append(full)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anahitamansouri missed this one?

])

out = sampler.sample(x)
ss = sampler.sample_set
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: A test that fails this

        # Unconditional sampling
        if x is None:
            self._sample_set = AggregatedSamples.spread(
                self._sampler.sample_ising(h, J, **self._sampler_params)
            )
            return sampleset_to_tensor(nodes, self._sample_set, device)


sample_set = self._sampler.sample(bqm, **sample_kwargs)

# Extract sampled values
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise an error here if sample set has sample size > 1


# Reconstruct full sample
full = torch.empty(n_nodes, device=device)
for j, node in enumerate(nodes):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for j, node in enumerate(nodes):
for node, idx in grbm._node_to_index.items():

self._sample_set = AggregatedSamples.spread(
self._sampler.sample_ising(h, J, **self._sampler_params)
)
return sampleset_to_tensor(nodes, self._sample_set, device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use sampleset_to_tensor because that function is only safe when contained in GRBM class.
Here, we should rely on grbm._index_to_node and grbm._node_to_index etc. to guarantee correct mapping between indices and nodes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants