Updating Dimod Sampler#73
Conversation
kevinchern
left a comment
There was a problem hiding this comment.
Did a quick first pass over the main function.
| super().__init__() | ||
|
|
||
| def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: | ||
| """Sample from the dimod sampler and return the corresponding tensor. |
There was a problem hiding this comment.
Update docstrings to reflect behaviour of conditional sampling
| for i in range(x.shape[0]): | ||
| # Fresh BQM | ||
| bqm = dimod.BinaryQuadraticModel.from_ising(h, J) | ||
|
|
||
| # Build conditioning dict | ||
| conditioned = {node: int(x[i, j].item()) | ||
| for j, node in enumerate(nodes) if mask[i, j]} |
There was a problem hiding this comment.
Can we iterate over x instead of range(x.shape[0])? Seems more natural
| results.append(full) | ||
| continue | ||
|
|
||
| # Sample one configuration per input |
There was a problem hiding this comment.
This needs to be documented in the docstring. It should also raise a warning if num_reads is supplied and overwritten.
| # Sample one configuration per input | ||
| sample_kwargs = dict(self._sampler_params) | ||
| sample_kwargs["num_reads"] = 1 | ||
| sample_set = AggregatedSamples.spread( |
There was a problem hiding this comment.
What's the motivation for this line?
There was a problem hiding this comment.
Oh you mean calling the spread? I was following the earlier use case. You're right, it's redundant. I'll remove it then.
| sample = sample_set.first.sample | ||
|
|
||
| # Reconstruct full sample | ||
| full = torch.empty(n_nodes, device=device) | ||
| for j, node in enumerate(nodes): | ||
| full[j] = conditioned[node] if node in conditioned else float(sample[node]) | ||
| results.append(full) |
There was a problem hiding this comment.
sample's variable ordering may not be identical to that of grbm.nodes.
I'd add a test to verify correct ordering first.
There was a problem hiding this comment.
No I didn't. The order is the same as that. As you suggested I have written the test to show that. You can see your sampleset_to_tensor also uses this order.
| @@ -22,7 +22,7 @@ | |||
| from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler | |||
There was a problem hiding this comment.
|
|
||
| The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set` | ||
| which is overwritten by subsequent calls. | ||
| which is overwritten by subsequent calls. When ``x`` is provided (conditional sampling), exactly |
There was a problem hiding this comment.
I'm wondering if there's a better way to document and handle this. The sample size parameter of a sampler is not guaranteed to use num_reads. @anahitamansouri any thoughts?
There was a problem hiding this comment.
Yeah, this was also my challenge.
| "(one sample per input row).", | ||
| UserWarning, | ||
| ) | ||
| sample_kwargs["num_reads"] = 1 |
There was a problem hiding this comment.
@anahitamansouri do you think this is sufficient? i.e., can we safely ignore the num_reads parameters altogether and change the documentation to ~= "only the first sample is retained ..."
Alternatively, and perhaps the right way to do this, is to return tensors of shape
There was a problem hiding this comment.
Agreed. We discussed to keep it (b, d) to be consistent with other samplers as you remember. I will update the documentation.
| self._sample_set = dimod.SampleSet.from_samples( | ||
| (samples.cpu().numpy(), self._grbm.nodes), | ||
| vartype=dimod.SPIN, | ||
| energy=energies, | ||
| ) |
There was a problem hiding this comment.
Why explicitly reconstruct the sample set?
There was a problem hiding this comment.
Setting self._sample_set part is only necessary to be consistent with your code as you have this field in the class and sample_set function needs that. It is not necessary for the logic of sample function. If you remember, I didn't have this in the earlier version and I told you I have not set this. This is one way I though of addressing it. Do you mean this way of creating samples is wrong? Do you have any suggestions? :)
There was a problem hiding this comment.
Hmm... I think there are at least a couple simplifications, 1) self._sample_set = sample_set and note in documentation that it's the latest sample set, 2) concatenate all samplesets
| sample = sample_set.first.sample | ||
|
|
||
| # Reconstruct full sample | ||
| full = torch.empty(n_nodes, device=device) | ||
| for j, node in enumerate(nodes): | ||
| full[j] = conditioned[node] if node in conditioned else float(sample[node]) | ||
| results.append(full) |
| ]) | ||
|
|
||
| out = sampler.sample(x) | ||
| ss = sampler.sample_set |
There was a problem hiding this comment.
edit: A test that fails this
# Unconditional sampling
if x is None:
self._sample_set = AggregatedSamples.spread(
self._sampler.sample_ising(h, J, **self._sampler_params)
)
return sampleset_to_tensor(nodes, self._sample_set, device)
|
|
||
| sample_set = self._sampler.sample(bqm, **sample_kwargs) | ||
|
|
||
| # Extract sampled values |
There was a problem hiding this comment.
Raise an error here if sample set has sample size > 1
|
|
||
| # Reconstruct full sample | ||
| full = torch.empty(n_nodes, device=device) | ||
| for j, node in enumerate(nodes): |
There was a problem hiding this comment.
| for j, node in enumerate(nodes): | |
| for node, idx in grbm._node_to_index.items(): |
| self._sample_set = AggregatedSamples.spread( | ||
| self._sampler.sample_ising(h, J, **self._sampler_params) | ||
| ) | ||
| return sampleset_to_tensor(nodes, self._sample_set, device) |
There was a problem hiding this comment.
We should not use sampleset_to_tensor because that function is only safe when contained in GRBM class.
Here, we should rely on grbm._index_to_node and grbm._node_to_index etc. to guarantee correct mapping between indices and nodes
This PR adds: