From bf71185e979a07711cdaba59d97b955e7a35c22a Mon Sep 17 00:00:00 2001 From: Eric Torreborre Date: Mon, 1 Sep 2014 17:46:34 +1000 Subject: [PATCH] added a wire format for Thrift. fixes #341 --- project/dependencies.scala | 7 +- .../nicta/scoobi/io/thrift/ThriftSchema.scala | 47 +++ .../scoobi/io/thrift/ThriftSerialiser.scala | 27 ++ .../com/nicta/scoobi/io/thrift/package.scala | 12 + .../com/nicta/scoobi/io/thrift/MyThrift.java | 387 ++++++++++++++++++ .../scoobi/io/thrift/ThriftSchemaSpec.scala | 24 ++ src/test/thrift/build | 7 + src/test/thrift/test.thrift | 5 + 8 files changed, 515 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/com/nicta/scoobi/io/thrift/ThriftSchema.scala create mode 100644 src/main/scala/com/nicta/scoobi/io/thrift/ThriftSerialiser.scala create mode 100644 src/main/scala/com/nicta/scoobi/io/thrift/package.scala create mode 100644 src/test/java/com/nicta/scoobi/io/thrift/MyThrift.java create mode 100644 src/test/scala/com/nicta/scoobi/io/thrift/ThriftSchemaSpec.scala create mode 100644 src/test/thrift/build create mode 100644 src/test/thrift/test.thrift diff --git a/project/dependencies.scala b/project/dependencies.scala index 10c51612d..97ff15e36 100644 --- a/project/dependencies.scala +++ b/project/dependencies.scala @@ -22,7 +22,8 @@ object dependencies { lazy val dependencies = libraryDependencies ++= scoobi(scalaVersion.value) ++ hadoop(version.value) ++ - scalaz() ++ + thrift ++ + scalaz() ++ specs2() // Libraries @@ -54,6 +55,10 @@ object dependencies { "org.scalaz" %% "scalaz-typelevel" % scalazVersion intransitive(), "org.scalaz" %% "scalaz-xml" % scalazVersion intransitive()) + val thrift = Seq( + "org.apache.thrift" % "libthrift" % "0.9.1" + ) + def specs2(specs2Version: String = "2.4") = Seq( "org.specs2" %% "specs2-core" % specs2Version % "optional") ++ Seq( "org.specs2" %% "specs2-mock" % specs2Version , diff --git a/src/main/scala/com/nicta/scoobi/io/thrift/ThriftSchema.scala b/src/main/scala/com/nicta/scoobi/io/thrift/ThriftSchema.scala new file mode 100644 index 000000000..5333bad59 --- /dev/null +++ b/src/main/scala/com/nicta/scoobi/io/thrift/ThriftSchema.scala @@ -0,0 +1,47 @@ +package com.nicta.scoobi.io.thrift + +import java.io.{DataInput, DataOutput} + +import com.nicta.scoobi.Scoobi._ +import org.apache.hadoop.io.BytesWritable + +/** + * Schema for creating Thrift WireFormat and SeqSchema instances. + */ +object ThriftSchema { + + /* WARNING THIS MUST BE A DEF OR OR IT CAN TRIGGER CONCURRENCY ISSUES WITH SHARED THRIFT SERIALIZERS */ + def mkThriftFmt[A](implicit m: Manifest[A], ev: A <:< ThriftLike): WireFormat[A] = new WireFormat[A] { + // Call once when the implicit is created to avoid further reflection + val empty = m.runtimeClass.newInstance().asInstanceOf[A] + + def toWire(x: A, out: DataOutput) = { + val bytes = ThriftSerialiser().toBytes(x) + out.writeInt(bytes.length) + out.write(bytes) + } + + def fromWire(in: DataInput): A = { + val size = in.readInt() + val bytes = new Array[Byte](size) + in.readFully(bytes) + ThriftSerialiser().fromBytes(empty, bytes) + } + + override def toString = "ThriftObject" + } + + /* WARNING THIS MUST BE A DEF OR OR IT CAN TRIGGER CONCURRENCY ISSUES WITH SHARED THRIFT SERIALIZERS*/ + def mkThriftSchema[A](implicit m: Manifest[A], ev: A <:< ThriftLike) = new SeqSchema[A] { + type SeqType = BytesWritable + + // Call once when the implicit is created to avoid further reflection + val empty = m.runtimeClass.newInstance().asInstanceOf[A] + + def toWritable(x: A) = new BytesWritable(ThriftSerialiser().toBytes(x)) + + def fromWritable(x: BytesWritable): A = ThriftSerialiser().fromBytes(empty, x.getBytes) + + val mf: Manifest[SeqType] = implicitly + } +} \ No newline at end of file diff --git a/src/main/scala/com/nicta/scoobi/io/thrift/ThriftSerialiser.scala b/src/main/scala/com/nicta/scoobi/io/thrift/ThriftSerialiser.scala new file mode 100644 index 000000000..942393b85 --- /dev/null +++ b/src/main/scala/com/nicta/scoobi/io/thrift/ThriftSerialiser.scala @@ -0,0 +1,27 @@ +package com.nicta.scoobi.io.thrift + +import org.apache.thrift.{TDeserializer, TSerializer} +import org.apache.thrift.protocol.TCompactProtocol + +/** + * Util for converting a `ThriftLike` object to and from bytes. + * + * WARNING: This class is _not_ threadsafe and should be used with extreme caution! + * + * https://issues.apache.org/jira/browse/THRIFT-2218 + */ +case class ThriftSerialiser() { + + val serialiser = new TSerializer(new TCompactProtocol.Factory) + val deserialiser = new TDeserializer(new TCompactProtocol.Factory) + + def toBytes[A](a: A)(implicit ev: A <:< ThriftLike): Array[Byte] = + serialiser.serialize(ev(a)) + + def fromBytes[A](empty: A, bytes: Array[Byte])(implicit ev: A <:< ThriftLike): A = { + val e = ev(empty).deepCopy + e.clear() + deserialiser.deserialize(e, bytes) + e.asInstanceOf[A] + } +} \ No newline at end of file diff --git a/src/main/scala/com/nicta/scoobi/io/thrift/package.scala b/src/main/scala/com/nicta/scoobi/io/thrift/package.scala new file mode 100644 index 000000000..654b68970 --- /dev/null +++ b/src/main/scala/com/nicta/scoobi/io/thrift/package.scala @@ -0,0 +1,12 @@ +package com.nicta.scoobi.io + +import com.nicta.scoobi.Scoobi._ + +package object thrift { + + type ThriftLike = org.apache.thrift.TBase[_ <: org.apache.thrift.TBase[_, _], _ <: org.apache.thrift.TFieldIdEnum] + + implicit def ThriftWireFormat[A](implicit m: Manifest[A], ev: A <:< ThriftLike): WireFormat[A] =ThriftSchema.mkThriftFmt[A] + + implicit def ThriftSeqSchema[A](implicit m: Manifest[A], ev: A <:< ThriftLike): SeqSchema[A] = ThriftSchema.mkThriftSchema[A] +} \ No newline at end of file diff --git a/src/test/java/com/nicta/scoobi/io/thrift/MyThrift.java b/src/test/java/com/nicta/scoobi/io/thrift/MyThrift.java new file mode 100644 index 000000000..1dcda5489 --- /dev/null +++ b/src/test/java/com/nicta/scoobi/io/thrift/MyThrift.java @@ -0,0 +1,387 @@ +/** + * Autogenerated by Thrift Compiler (0.9.1) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package com.nicta.scoobi.io.thrift; + +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.protocol.TProtocolException; +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.server.AbstractNonblockingServer.*; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.Set; +import java.util.HashSet; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MyThrift implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("MyThrift"); + + private static final org.apache.thrift.protocol.TField ENTITY_FIELD_DESC = new org.apache.thrift.protocol.TField("entity", org.apache.thrift.protocol.TType.STRING, (short)1); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new MyThriftStandardSchemeFactory()); + schemes.put(TupleScheme.class, new MyThriftTupleSchemeFactory()); + } + + public String entity; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + ENTITY((short)1, "entity"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // ENTITY + return ENTITY; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.ENTITY, new org.apache.thrift.meta_data.FieldMetaData("entity", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(MyThrift.class, metaDataMap); + } + + public MyThrift() { + } + + public MyThrift( + String entity) + { + this(); + this.entity = entity; + } + + /** + * Performs a deep copy on other. + */ + public MyThrift(MyThrift other) { + if (other.isSetEntity()) { + this.entity = other.entity; + } + } + + public MyThrift deepCopy() { + return new MyThrift(this); + } + + @Override + public void clear() { + this.entity = null; + } + + public String getEntity() { + return this.entity; + } + + public MyThrift setEntity(String entity) { + this.entity = entity; + return this; + } + + public void unsetEntity() { + this.entity = null; + } + + /** Returns true if field entity is set (has been assigned a value) and false otherwise */ + public boolean isSetEntity() { + return this.entity != null; + } + + public void setEntityIsSet(boolean value) { + if (!value) { + this.entity = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case ENTITY: + if (value == null) { + unsetEntity(); + } else { + setEntity((String)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case ENTITY: + return getEntity(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case ENTITY: + return isSetEntity(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof MyThrift) + return this.equals((MyThrift)that); + return false; + } + + public boolean equals(MyThrift that) { + if (that == null) + return false; + + boolean this_present_entity = true && this.isSetEntity(); + boolean that_present_entity = true && that.isSetEntity(); + if (this_present_entity || that_present_entity) { + if (!(this_present_entity && that_present_entity)) + return false; + if (!this.entity.equals(that.entity)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public int compareTo(MyThrift other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + + lastComparison = Boolean.valueOf(isSetEntity()).compareTo(other.isSetEntity()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetEntity()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.entity, other.entity); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("MyThrift("); + boolean first = true; + + sb.append("entity:"); + if (this.entity == null) { + sb.append("null"); + } else { + sb.append(this.entity); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class MyThriftStandardSchemeFactory implements SchemeFactory { + public MyThriftStandardScheme getScheme() { + return new MyThriftStandardScheme(); + } + } + + private static class MyThriftStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, MyThrift struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // ENTITY + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.entity = iprot.readString(); + struct.setEntityIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + // check for required fields of primitive type, which can't be checked in the validate method + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, MyThrift struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + if (struct.entity != null) { + oprot.writeFieldBegin(ENTITY_FIELD_DESC); + oprot.writeString(struct.entity); + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class MyThriftTupleSchemeFactory implements SchemeFactory { + public MyThriftTupleScheme getScheme() { + return new MyThriftTupleScheme(); + } + } + + private static class MyThriftTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, MyThrift struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + BitSet optionals = new BitSet(); + if (struct.isSetEntity()) { + optionals.set(0); + } + oprot.writeBitSet(optionals, 1); + if (struct.isSetEntity()) { + oprot.writeString(struct.entity); + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, MyThrift struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + BitSet incoming = iprot.readBitSet(1); + if (incoming.get(0)) { + struct.entity = iprot.readString(); + struct.setEntityIsSet(true); + } + } + } + +} diff --git a/src/test/scala/com/nicta/scoobi/io/thrift/ThriftSchemaSpec.scala b/src/test/scala/com/nicta/scoobi/io/thrift/ThriftSchemaSpec.scala new file mode 100644 index 000000000..766a5a355 --- /dev/null +++ b/src/test/scala/com/nicta/scoobi/io/thrift/ThriftSchemaSpec.scala @@ -0,0 +1,24 @@ +package com.nicta.scoobi.io.thrift + +import java.io._ + +import com.nicta.scoobi.Scoobi._ +import org.specs2.ScalaCheck +import org.specs2.mutable.Specification + +class ThriftSchemaSpec extends Specification with ScalaCheck { + + "WireFormat bidirectional" >> prop((s: String) => { + implicit val wf = implicitly[WireFormat[MyThrift]] + val a = new MyThrift(s) + val out = new ByteArrayOutputStream() + wf.toWire(a, new DataOutputStream(out)) + wf.fromWire(new DataInputStream(new ByteArrayInputStream(out.toByteArray))) ==== a + }) + + "SeqSchema bidirectional" >> prop((s: String) => { + implicit val ss = implicitly[SeqSchema[MyThrift]] + val a = new MyThrift(s) + ss.fromWritable(ss.toWritable(a)) ==== a + }) +} \ No newline at end of file diff --git a/src/test/thrift/build b/src/test/thrift/build new file mode 100644 index 000000000..fae57b41d --- /dev/null +++ b/src/test/thrift/build @@ -0,0 +1,7 @@ +#!/bin/sh -eu + +# This requires a copy of thrift to be installed because it can be run +# The expected files have been checked in for convenience + +DIR=$(dirname $0)/../../.. +thrift -r -out ${DIR}/src/test/java/ --gen java ${DIR}/src/test/thrift/test.thrift \ No newline at end of file diff --git a/src/test/thrift/test.thrift b/src/test/thrift/test.thrift new file mode 100644 index 000000000..8fc343734 --- /dev/null +++ b/src/test/thrift/test.thrift @@ -0,0 +1,5 @@ +namespace java com.nicta.scoobi.io.thrift + +struct MyThrift { + 1: string entity; +} \ No newline at end of file