diff --git a/mnist.go b/mnist.go index 5418363..37b6bf3 100644 --- a/mnist.go +++ b/mnist.go @@ -49,6 +49,11 @@ func (img RawImage) At(x, y int) color.Color { return color.Gray{img[y*Width+x]} } +func (img RawImage) AtGray(x, y int) color.Gray { + return color.Gray{img[y*Width+x]} +} + + // ReadImageFile opens the named image file (training or test), parses it and // returns all images in order. func ReadImageFile(name string) (rows, cols int, imgs []RawImage, err error) { diff --git a/util.go b/util.go index 1a015ed..95b77f6 100644 --- a/util.go +++ b/util.go @@ -57,6 +57,7 @@ type Sweeper struct { // Next returns the next image and its label in the data set. // If the end is reached, present is set to false. func (sw *Sweeper) Next() (image RawImage, label Label, present bool) { + sw.i++ if sw.i >= len(sw.set.Images) { return nil, 0, false } @@ -65,7 +66,7 @@ func (sw *Sweeper) Next() (image RawImage, label Label, present bool) { // Sweep creates a new sweep iterator over the data set func (s *Set) Sweep() *Sweeper { - return &Sweeper{set: s} + return &Sweeper{set: s, i: -1} } // Load reads both the training and the testing MNIST data sets, given