From fe6d32987492d06ce43e9c3ba49fe1bf494e070a Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 10 Jun 2015 16:09:38 -0700 Subject: [PATCH 1/6] Makes sure we don't pull the whole corpus into memory when training --- .../neuralnetwork/NeuralNetworkTrainer.java | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index a415c53..305266f 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -1,6 +1,8 @@ package com.medallia.word2vec.neuralnetwork; import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; import com.google.common.collect.Multiset; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -11,12 +13,8 @@ import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; import com.medallia.word2vec.util.CallableVoid; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; +import java.util.*; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; /** Parent class for training word2vec's neural network */ @@ -51,7 +49,7 @@ public abstract class NeuralNetworkTrainer { /** * In the C version, this includes the token that replaces a newline character */ - int numTrainedTokens; + long numTrainedTokens; /* The following includes shared state that is updated per worker thread */ @@ -151,28 +149,49 @@ public interface NeuralNetworkModel { } /** @return Trained NN model */ - public NeuralNetworkModel train(Iterable> sentences) throws InterruptedException { - ListeningExecutorService ex = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(config.numThreads)); - + public NeuralNetworkModel train(final Iterable> sentences) throws InterruptedException { + final ListeningExecutorService ex = + MoreExecutors.listeningDecorator( + new ThreadPoolExecutor(config.numThreads, config.numThreads, + 0L, TimeUnit.MILLISECONDS, + new ArrayBlockingQueue(8), + new ThreadPoolExecutor.CallerRunsPolicy())); + int numSentences = Iterables.size(sentences); numTrainedTokens += numSentences; - - // Partition the sentences evenly amongst the threads - Iterable>> partitioned = Iterables.partition(sentences, numSentences / config.numThreads + 1); - + + // Partition the sentences into batches + final Iterable>> batched = new Iterable>>() { + @Override public Iterator>> iterator() { + return new Iterator>>() { + private final Iterator> inner = sentences.iterator(); + + @Override + public boolean hasNext() { + return inner.hasNext(); + } + + @Override + public List> next() { + if(!hasNext()) + throw new NoSuchElementException(); + + return Lists.newArrayList(Iterators.limit(inner, 1024)); + } + }; + } + }; + try { listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0); for (int iter = config.iterations; iter > 0; iter--) { - List tasks = new ArrayList<>(); + List> futures = new ArrayList<>(64); int i = 0; - for (final List> batch : partitioned) { - tasks.add(createWorker(i, iter, batch)); + for (final List> batch : batched) { + futures.add(ex.submit(createWorker(i, iter, batch))); i++; } - List> futures = new ArrayList<>(tasks.size()); - for (CallableVoid task : tasks) - futures.add(ex.submit(task)); try { Futures.allAsList(futures).get(); } catch (ExecutionException e) { From 8e81df1f1c78714ce33d7723a539f74aa342e36f Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 22 Jun 2015 14:29:36 -0700 Subject: [PATCH 2/6] Made it build in Java 7. --- .../word2vec/neuralnetwork/NeuralNetworkTrainer.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index 305266f..2350520 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -178,6 +178,11 @@ public List> next() { return Lists.newArrayList(Iterators.limit(inner, 1024)); } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } }; } }; From e5f7d44f41d76650c9b3ede6de8e53d66383335c Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 22 Jun 2015 14:31:49 -0700 Subject: [PATCH 3/6] Made * imports explicit --- .../word2vec/neuralnetwork/NeuralNetworkTrainer.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index 2350520..5f8742e 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -13,8 +13,15 @@ import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode; import com.medallia.word2vec.util.CallableVoid; -import java.util.*; -import java.util.concurrent.*; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** Parent class for training word2vec's neural network */ From 94db87da64e8d11ba98d41cecfe08f5b03389ebb Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 22 Jun 2015 14:34:52 -0700 Subject: [PATCH 4/6] Formatting --- .../medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index 5f8742e..9c61a71 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -169,7 +169,7 @@ public NeuralNetworkModel train(final Iterable> sentences) throws I // Partition the sentences into batches final Iterable>> batched = new Iterable>>() { - @Override public Iterator>> iterator() { + @Override public Iterator>> iterator() { return new Iterator>>() { private final Iterator> inner = sentences.iterator(); From 61e7fc7e24caaee179c65f690c2c3aa5d014fe78 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 24 Jun 2015 15:58:40 -0700 Subject: [PATCH 5/6] Google's Iterable thing works just fine --- .../neuralnetwork/NeuralNetworkTrainer.java | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index 9c61a71..d1ac884 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -168,31 +168,7 @@ public NeuralNetworkModel train(final Iterable> sentences) throws I numTrainedTokens += numSentences; // Partition the sentences into batches - final Iterable>> batched = new Iterable>>() { - @Override public Iterator>> iterator() { - return new Iterator>>() { - private final Iterator> inner = sentences.iterator(); - - @Override - public boolean hasNext() { - return inner.hasNext(); - } - - @Override - public List> next() { - if(!hasNext()) - throw new NoSuchElementException(); - - return Lists.newArrayList(Iterators.limit(inner, 1024)); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - }; + final Iterable>> batched = Iterables.partition(sentences, 1024); try { listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0); From 464abcc909897feecf2b9bf4a5ba7a1ed9bb27a8 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Tue, 7 Jul 2015 14:50:15 -0700 Subject: [PATCH 6/6] Changed queue size and added comment --- .../word2vec/neuralnetwork/NeuralNetworkTrainer.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index d1ac884..99e72e0 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -157,11 +157,14 @@ public interface NeuralNetworkModel { /** @return Trained NN model */ public NeuralNetworkModel train(final Iterable> sentences) throws InterruptedException { + // Create an executor that runs as many threads as are defined in the config, and blocks if + // you're trying to run more. This is to make sure we don't read the entire corpus into + // memory. final ListeningExecutorService ex = MoreExecutors.listeningDecorator( new ThreadPoolExecutor(config.numThreads, config.numThreads, 0L, TimeUnit.MILLISECONDS, - new ArrayBlockingQueue(8), + new ArrayBlockingQueue(config.numThreads), new ThreadPoolExecutor.CallerRunsPolicy())); int numSentences = Iterables.size(sentences);