diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index 037e7bea1d7..ad3ae0f4ced 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -447,7 +447,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) { } public enum EncoderType { - Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords + Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords, UDF } /* diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java index 756a1fdc5cb..79c21a650e4 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.transform.encode; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.List; import org.apache.sysds.api.DMLScript; @@ -46,7 +49,7 @@ public class ColumnEncoderUDF extends ColumnEncoder { //TODO pass execution context through encoder factory for arbitrary functions not just builtin //TODO integration into IPA to ensure existence of unoptimized functions - private final String _fName; + private String _fName; public int _domainSize = 1; protected ColumnEncoderUDF(int ptCols, String name) { @@ -165,4 +168,20 @@ protected double getCode(CacheBlock in, int row) { protected double[] getCodeCol(CacheBlock in, int startInd, int endInd, double[] tmp) { throw new DMLRuntimeException("UDF encoders only support full column access."); } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + out.writeUTF(_fName != null ? _fName : ""); + out.writeInt(_domainSize); + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + super.readExternal(in); + _fName = in.readUTF(); + if(_fName.isEmpty()) + _fName = null; + _domainSize = in.readInt(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index 1294d0e7e79..507cae36da4 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -267,6 +267,8 @@ else if(columnEncoder instanceof ColumnEncoderWordEmbedding) return EncoderType.WordEmbedding.ordinal(); else if(columnEncoder instanceof ColumnEncoderBagOfWords) return EncoderType.BagOfWords.ordinal(); + else if(columnEncoder instanceof ColumnEncoderUDF) + return EncoderType.UDF.ordinal(); throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName()); } @@ -276,19 +278,21 @@ public static ColumnEncoder createInstance(int type) { case Bin: return new ColumnEncoderBin(); case Dummycode: - return new ColumnEncoderDummycode(); - case FeatureHash: - return new ColumnEncoderFeatureHash(); - case PassThrough: - return new ColumnEncoderPassThrough(); - case Recode: - return new ColumnEncoderRecode(); - case WordEmbedding: - return new ColumnEncoderWordEmbedding(); - case BagOfWords: - return new ColumnEncoderBagOfWords(); - default: - throw new DMLRuntimeException("Unsupported encoder type: " + etype); + return new ColumnEncoderDummycode(); + case FeatureHash: + return new ColumnEncoderFeatureHash(); + case PassThrough: + return new ColumnEncoderPassThrough(); + case Recode: + return new ColumnEncoderRecode(); + case WordEmbedding: + return new ColumnEncoderWordEmbedding(); + case BagOfWords: + return new ColumnEncoderBagOfWords(); + case UDF: + return new ColumnEncoderUDF(); + default: + throw new DMLRuntimeException("Unsupported encoder type: " + etype); } } diff --git a/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java b/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java index afe15adf655..2391df6925c 100644 --- a/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java @@ -25,6 +25,8 @@ import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderUDF; import org.apache.sysds.runtime.util.LocalFileUtils; import org.junit.Assert; import org.junit.Test; @@ -46,6 +48,9 @@ import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.util.HashMap; +import java.util.Collections; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; public class SerializeTest extends AutomatedTestBase { @@ -113,6 +118,11 @@ public void testWEEncoderSerialization(){ runSerializeWEEncoder(); } + @Test + public void testUDFEncoderSerialization(){ + runSerializeUDFEncoder(); + } + private void runSerializeTest( int rows, int cols, double sparsity ) { try @@ -188,6 +198,63 @@ private void runSerializeWEEncoder(){ } } + private void runSerializeUDFEncoder(){ + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = new ObjectOutputStream(bos)) { + final String udfName = "dummyUdf"; + final int colId = 2; + final int domainSize = 5; + + ColumnEncoderUDF udf = createUdf(colId, udfName, domainSize); + ColumnEncoderComposite composite = new ColumnEncoderComposite(Collections.singletonList(udf)); + MultiColumnEncoder encoder = new MultiColumnEncoder(Collections.singletonList(composite)); + + encoder.writeExternal(out); + out.flush(); + + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + ObjectInput in = new ObjectInputStream(bis); + MultiColumnEncoder encoderSer = new MultiColumnEncoder(); + encoderSer.readExternal(in); + in.close(); + + ColumnEncoderComposite decodedComposite = encoderSer.getColumnEncoders().get(0); + ColumnEncoderUDF decodedUdf = decodedComposite.getEncoder(ColumnEncoderUDF.class); + + Assert.assertNotNull(decodedUdf); + Assert.assertEquals(colId, decodedUdf.getColID()); + Assert.assertEquals(domainSize, decodedUdf._domainSize); + Assert.assertEquals(udfName, getUdfName(decodedUdf)); + } + catch(IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private ColumnEncoderUDF createUdf(int colId, String name, int domainSize) { + try { + Constructor ctor = ColumnEncoderUDF.class.getDeclaredConstructor(int.class, String.class); + ctor.setAccessible(true); + ColumnEncoderUDF udf = ctor.newInstance(colId, name); + udf._domainSize = domainSize; + return udf; + } + catch(Exception e) { + throw new RuntimeException(e); + } + } + + private String getUdfName(ColumnEncoderUDF udf) { + try { + Field f = ColumnEncoderUDF.class.getDeclaredField("_fName"); + f.setAccessible(true); + return (String) f.get(udf); + } + catch(Exception e) { + throw new RuntimeException(e); + } + } + private void runSerializeDedupDenseTest( int rows, int cols ) { try