diff --git a/pom.xml b/pom.xml index aafcbcd..d59c1d7 100644 --- a/pom.xml +++ b/pom.xml @@ -3,7 +3,7 @@ 4.0.0 com.medallia.word2vec medallia-word2vec - 0.10.2 + 0.10.3-SNAPSHOT MIT License diff --git a/src/main/java/com/medallia/word2vec/Word2VecModel.java b/src/main/java/com/medallia/word2vec/Word2VecModel.java index 5fa6b25..58c8669 100644 --- a/src/main/java/com/medallia/word2vec/Word2VecModel.java +++ b/src/main/java/com/medallia/word2vec/Word2VecModel.java @@ -37,259 +37,269 @@ * @see {@link #forSearch()} */ public class Word2VecModel { - final List vocab; - final int layerSize; - final DoubleBuffer vectors; - private final static long ONE_GB = 1024 * 1024 * 1024; - - Word2VecModel(Iterable vocab, int layerSize, DoubleBuffer vectors) { - this.vocab = ImmutableList.copyOf(vocab); - this.layerSize = layerSize; - this.vectors = vectors; - } - - Word2VecModel(Iterable vocab, int layerSize, double[] vectors) { - this(vocab, layerSize, DoubleBuffer.wrap(vectors)); - } - - /** @return Vocabulary */ - public Iterable getVocab() { - return vocab; - } - - /** @return {@link Searcher} for searching */ - public Searcher forSearch() { - return new SearcherImpl(this); - } - - /** @return Serializable thrift representation */ - public Word2VecModelThrift toThrift() { - double[] vectorsArray; - if(vectors.hasArray()) { - vectorsArray = vectors.array(); - } else { - vectorsArray = new double[vectors.limit()]; - vectors.position(0); - vectors.get(vectorsArray); - } - - return new Word2VecModelThrift() - .setVocab(vocab) - .setLayerSize(layerSize) - .setVectors(Doubles.asList(vectorsArray)); - } - - /** @return {@link Word2VecModel} created from a thrift representation */ - public static Word2VecModel fromThrift(Word2VecModelThrift thrift) { - return new Word2VecModel( - thrift.getVocab(), - thrift.getLayerSize(), - Doubles.toArray(thrift.getVectors())); - } - - /** - * @return {@link Word2VecModel} read from a file in the text output format of the Word2Vec C - * open source project. - */ - public static Word2VecModel fromTextFile(File file) throws IOException { - List lines = Common.readToList(file); - return fromTextFile(file.getAbsolutePath(), lines); - } - - /** - * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with the default - * ByteOrder.LITTLE_ENDIAN and no ProfilingTimer - */ - public static Word2VecModel fromBinFile(File file) - throws IOException { - return fromBinFile(file, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE); - } - - /** - * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with no ProfilingTimer - */ - public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder) - throws IOException { - return fromBinFile(file, byteOrder, ProfilingTimer.NONE); - } - - /** - * @return {@link Word2VecModel} created from the binary representation output - * by the open source C version of word2vec using the given byte order. - */ - public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer) - throws IOException { - - try ( - final FileInputStream fis = new FileInputStream(file); - final AC ac = timer.start("Loading vectors from bin file") - ) { - final FileChannel channel = fis.getChannel(); - timer.start("Reading gigabyte #1"); - MappedByteBuffer buffer = - channel.map( - FileChannel.MapMode.READ_ONLY, - 0, - Math.min(channel.size(), Integer.MAX_VALUE)); - buffer.order(byteOrder); - int bufferCount = 1; - // Java's NIO only allows memory-mapping up to 2GB. To work around this problem, we re-map - // every gigabyte. To calculate offsets correctly, we have to keep track how many gigabytes - // we've already skipped. That's what this is for. - - StringBuilder sb = new StringBuilder(); - char c = (char) buffer.get(); - while (c != '\n') { - sb.append(c); - c = (char) buffer.get(); - } - String firstLine = sb.toString(); - int index = firstLine.indexOf(' '); - Preconditions.checkState(index != -1, - "Expected a space in the first line of file '%s': '%s'", - file.getAbsolutePath(), firstLine); - - final int vocabSize = Integer.parseInt(firstLine.substring(0, index)); - final int layerSize = Integer.parseInt(firstLine.substring(index + 1)); - timer.appendToLog(String.format( - "Loading %d vectors with dimensionality %d", - vocabSize, - layerSize)); - - List vocabs = new ArrayList(vocabSize); - DoubleBuffer vectors = ByteBuffer.allocateDirect(vocabSize * layerSize * 8).asDoubleBuffer(); - - long lastLogMessage = System.currentTimeMillis(); - final float[] floats = new float[layerSize]; - for (int lineno = 0; lineno < vocabSize; lineno++) { - // read vocab - sb.setLength(0); - c = (char) buffer.get(); - while (c != ' ') { - // ignore newlines in front of words (some binary files have newline, - // some don't) - if (c != '\n') { - sb.append(c); - } - c = (char) buffer.get(); - } - vocabs.add(sb.toString()); - - // read vector - final FloatBuffer floatBuffer = buffer.asFloatBuffer(); - floatBuffer.get(floats); - for (int i = 0; i < floats.length; ++i) { - vectors.put(lineno * layerSize + i, floats[i]); - } - buffer.position(buffer.position() + 4 * layerSize); - - // print log - final long now = System.currentTimeMillis(); - if (now - lastLogMessage > 1000) { - final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0; - timer.appendToLog( - String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage)); - lastLogMessage = now; - } - - // remap file - if (buffer.position() > ONE_GB) { - final int newPosition = (int) (buffer.position() - ONE_GB); - final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE); - timer.endAndStart( - "Reading gigabyte #%d. Start: %d, size: %d", - bufferCount, - ONE_GB * bufferCount, - size); - buffer = channel.map( - FileChannel.MapMode.READ_ONLY, - ONE_GB * bufferCount, - size); - buffer.order(byteOrder); - buffer.position(newPosition); - bufferCount += 1; - } - } - timer.end(); - - return new Word2VecModel(vocabs, layerSize, vectors); - } - } - - /** - * Saves the model as a bin file that's compatible with the C version of Word2Vec - */ - public void toBinFile(final OutputStream out) throws IOException { - final Charset cs = Charset.forName("UTF-8"); - final String header = String.format("%d %d\n", vocab.size(), layerSize); - out.write(header.getBytes(cs)); - - final double[] vector = new double[layerSize]; - final ByteBuffer buffer = ByteBuffer.allocate(4 * layerSize); - buffer.order(ByteOrder.LITTLE_ENDIAN); // The C version uses this byte order. - for(int i = 0; i < vocab.size(); ++i) { - out.write(String.format("%s ", vocab.get(i)).getBytes(cs)); - - vectors.position(i * layerSize); - vectors.get(vector); - buffer.clear(); - for(int j = 0; j < layerSize; ++j) - buffer.putFloat((float)vector[j]); - out.write(buffer.array()); - - out.write('\n'); - } - - out.flush(); - } - - /** - * @return {@link Word2VecModel} from the lines of the file in the text output format of the - * Word2Vec C open source project. - */ - @VisibleForTesting - static Word2VecModel fromTextFile(String filename, List lines) throws IOException { - List vocab = Lists.newArrayList(); - List vectors = Lists.newArrayList(); - int vocabSize = Integer.parseInt(lines.get(0).split(" ")[0]); - int layerSize = Integer.parseInt(lines.get(0).split(" ")[1]); - - Preconditions.checkArgument( - vocabSize == lines.size() - 1, - "For file '%s', vocab size is %s, but there are %s word vectors in the file", - filename, - vocabSize, - lines.size() - 1 - ); - - for (int n = 1; n < lines.size(); n++) { - String[] values = lines.get(n).split(" "); - vocab.add(values[0]); - - // Sanity check - Preconditions.checkArgument( - layerSize == values.length - 1, - "For file '%s', on line %s, layer size is %s, but found %s values in the word vector", - filename, - n, - layerSize, - values.length - 1 - ); - - for (int d = 1; d < values.length; d++) { - vectors.add(Double.parseDouble(values[d])); - } - } - - Word2VecModelThrift thrift = new Word2VecModelThrift() - .setLayerSize(layerSize) - .setVocab(vocab) - .setVectors(vectors); - return fromThrift(thrift); - } - - /** @return {@link Word2VecTrainerBuilder} for training a model */ - public static Word2VecTrainerBuilder trainer() { - return new Word2VecTrainerBuilder(); - } + final List vocab; + final int layerSize; + final DoubleBuffer vectors; + private final static long ONE_GB = 1024 * 1024 * 1024; + + Word2VecModel(Iterable vocab, int layerSize, DoubleBuffer vectors) { + this.vocab = ImmutableList.copyOf(vocab); + this.layerSize = layerSize; + this.vectors = vectors; + } + + Word2VecModel(Iterable vocab, int layerSize, double[] vectors) { + this(vocab, layerSize, DoubleBuffer.wrap(vectors)); + } + + /** + * @return Vocabulary + */ + public Iterable getVocab() { + return vocab; + } + + /** + * @return {@link Searcher} for searching + */ + public Searcher forSearch() { + return new SearcherImpl(this); + } + + /** + * @return Serializable thrift representation + */ + public Word2VecModelThrift toThrift() { + double[] vectorsArray; + if (vectors.hasArray()) { + vectorsArray = vectors.array(); + } else { + vectorsArray = new double[vectors.limit()]; + vectors.position(0); + vectors.get(vectorsArray); + } + + return new Word2VecModelThrift() + .setVocab(vocab) + .setLayerSize(layerSize) + .setVectors(Doubles.asList(vectorsArray)); + } + + /** + * @return {@link Word2VecModel} created from a thrift representation + */ + public static Word2VecModel fromThrift(Word2VecModelThrift thrift) { + return new Word2VecModel( + thrift.getVocab(), + thrift.getLayerSize(), + Doubles.toArray(thrift.getVectors())); + } + + /** + * @return {@link Word2VecModel} read from a file in the text output format of the Word2Vec C + * open source project. + */ + public static Word2VecModel fromTextFile(File file) throws IOException { + List lines = Common.readToList(file); + return fromTextFile(file.getAbsolutePath(), lines); + } + + /** + * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with the default + * ByteOrder.LITTLE_ENDIAN and no ProfilingTimer + */ + public static Word2VecModel fromBinFile(File file) + throws IOException { + return fromBinFile(file, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE); + } + + /** + * Forwards to {@link #fromBinFile(File, ByteOrder, ProfilingTimer)} with no ProfilingTimer + */ + public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder) + throws IOException { + return fromBinFile(file, byteOrder, ProfilingTimer.NONE); + } + + /** + * @return {@link Word2VecModel} created from the binary representation output + * by the open source C version of word2vec using the given byte order. + */ + public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer) + throws IOException { + + try ( + final FileInputStream fis = new FileInputStream(file); + final AC ac = timer.start("Loading vectors from bin file") + ) { + final FileChannel channel = fis.getChannel(); + timer.start("Reading gigabyte #1"); + MappedByteBuffer buffer = + channel.map( + FileChannel.MapMode.READ_ONLY, + 0, + Math.min(channel.size(), Integer.MAX_VALUE)); + buffer.order(byteOrder); + int bufferCount = 1; + // Java's NIO only allows memory-mapping up to 2GB. To work around this problem, we re-map + // every gigabyte. To calculate offsets correctly, we have to keep track how many gigabytes + // we've already skipped. That's what this is for. + + StringBuilder sb = new StringBuilder(); + char c = (char) buffer.get(); + while (c != '\n') { + sb.append(c); + c = (char) buffer.get(); + } + String firstLine = sb.toString(); + int index = firstLine.indexOf(' '); + Preconditions.checkState(index != -1, + "Expected a space in the first line of file '%s': '%s'", + file.getAbsolutePath(), firstLine); + + final int vocabSize = Integer.parseInt(firstLine.substring(0, index)); + final int layerSize = Integer.parseInt(firstLine.substring(index + 1)); + timer.appendToLog(String.format( + "Loading %d vectors with dimensionality %d", + vocabSize, + layerSize)); + + List vocabs = new ArrayList(vocabSize); + DoubleBuffer vectors = DoubleBuffer.allocate(vocabSize * layerSize); + + long lastLogMessage = System.currentTimeMillis(); + final float[] floats = new float[layerSize]; + for (int lineno = 0; lineno < vocabSize; lineno++) { + // read vocab + sb.setLength(0); + c = (char) buffer.get(); + while (c != ' ') { + // ignore newlines in front of words (some binary files have newline, + // some don't) + if (c != '\n') { + sb.append(c); + } + c = (char) buffer.get(); + } + vocabs.add(sb.toString()); + + // read vector + final FloatBuffer floatBuffer = buffer.asFloatBuffer(); + floatBuffer.get(floats); + for (int i = 0; i < floats.length; ++i) { + vectors.put(lineno * layerSize + i, floats[i]); + } + buffer.position(buffer.position() + 4 * layerSize); + + // print log + final long now = System.currentTimeMillis(); + if (now - lastLogMessage > 1000) { + final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0; + timer.appendToLog( + String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage)); + lastLogMessage = now; + } + + // remap file + if (buffer.position() > ONE_GB) { + final int newPosition = (int) (buffer.position() - ONE_GB); + final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE); + timer.endAndStart( + "Reading gigabyte #%d. Start: %d, size: %d", + bufferCount, + ONE_GB * bufferCount, + size); + buffer = channel.map( + FileChannel.MapMode.READ_ONLY, + ONE_GB * bufferCount, + size); + buffer.order(byteOrder); + buffer.position(newPosition); + bufferCount += 1; + } + } + timer.end(); + + return new Word2VecModel(vocabs, layerSize, vectors); + } + } + + /** + * Saves the model as a bin file that's compatible with the C version of Word2Vec + */ + public void toBinFile(final OutputStream out) throws IOException { + final Charset cs = Charset.forName("UTF-8"); + final String header = String.format("%d %d\n", vocab.size(), layerSize); + out.write(header.getBytes(cs)); + + final double[] vector = new double[layerSize]; + final ByteBuffer buffer = ByteBuffer.allocate(4 * layerSize); + buffer.order(ByteOrder.LITTLE_ENDIAN); // The C version uses this byte order. + for (int i = 0; i < vocab.size(); ++i) { + out.write(String.format("%s ", vocab.get(i)).getBytes(cs)); + + vectors.position(i * layerSize); + vectors.get(vector); + buffer.clear(); + for (int j = 0; j < layerSize; ++j) + buffer.putFloat((float) vector[j]); + out.write(buffer.array()); + + out.write('\n'); + } + + out.flush(); + } + + /** + * @return {@link Word2VecModel} from the lines of the file in the text output format of the + * Word2Vec C open source project. + */ + @VisibleForTesting + static Word2VecModel fromTextFile(String filename, List lines) throws IOException { + List vocab = Lists.newArrayList(); + List vectors = Lists.newArrayList(); + int vocabSize = Integer.parseInt(lines.get(0).split(" ")[0]); + int layerSize = Integer.parseInt(lines.get(0).split(" ")[1]); + + Preconditions.checkArgument( + vocabSize == lines.size() - 1, + "For file '%s', vocab size is %s, but there are %s word vectors in the file", + filename, + vocabSize, + lines.size() - 1 + ); + + for (int n = 1; n < lines.size(); n++) { + String[] values = lines.get(n).split(" "); + vocab.add(values[0]); + + // Sanity check + Preconditions.checkArgument( + layerSize == values.length - 1, + "For file '%s', on line %s, layer size is %s, but found %s values in the word vector", + filename, + n, + layerSize, + values.length - 1 + ); + + for (int d = 1; d < values.length; d++) { + vectors.add(Double.parseDouble(values[d])); + } + } + + Word2VecModelThrift thrift = new Word2VecModelThrift() + .setLayerSize(layerSize) + .setVocab(vocab) + .setVectors(vectors); + return fromThrift(thrift); + } + + /** + * @return {@link Word2VecTrainerBuilder} for training a model + */ + public static Word2VecTrainerBuilder trainer() { + return new Word2VecTrainerBuilder(); + } }