diff --git a/src/clojure/zero_one/geni/core.clj b/src/clojure/zero_one/geni/core.clj index a3b6b1d9..61a08854 100644 --- a/src/clojure/zero_one/geni/core.clj +++ b/src/clojure/zero_one/geni/core.clj @@ -354,8 +354,10 @@ (import-vars [zero-one.geni.core.dataset-creation + array-type create-dataframe map->dataset + map-type records->dataset struct-field struct-type @@ -467,6 +469,7 @@ (import-vars [zero-one.geni.core.data-sources + ->kebab-columns read-avro! read-csv! read-edn! diff --git a/src/clojure/zero_one/geni/core/data_sources.clj b/src/clojure/zero_one/geni/core/data_sources.clj index 0352ed2c..c4ad3db7 100644 --- a/src/clojure/zero_one/geni/core/data_sources.clj +++ b/src/clojure/zero_one/geni/core/data_sources.clj @@ -33,7 +33,7 @@ (defn remove-punctuations [string] (string/replace string #"[.,\/#!$%\^&\*;:{}=\`~()°]" "")) -(defn normalise-column-names [dataset] +(defn ->kebab-columns [dataset] (let [new-columns (->> dataset .columns (map remove-punctuations) @@ -50,7 +50,7 @@ (cond-> (not (nil? schema)) (.schema (dataset-creation/->schema schema))))] (-> (.load reader path) - (cond-> (:kebab-columns options) normalise-column-names)))) + (cond-> (:kebab-columns options) ->kebab-columns)))) (defmulti read-avro! (fn [head & _] (class head))) (defmethod read-avro! :default @@ -190,4 +190,4 @@ edn/read-string (dataset-creation/records->dataset spark))] (-> dataset - (cond-> (:kebab-columns options) normalise-column-names))))) + (cond-> (:kebab-columns options) ->kebab-columns))))) diff --git a/src/clojure/zero_one/geni/core/dataset_creation.clj b/src/clojure/zero_one/geni/core/dataset_creation.clj index 033a16f2..5155d1f7 100644 --- a/src/clojure/zero_one/geni/core/dataset_creation.clj +++ b/src/clojure/zero_one/geni/core/dataset_creation.clj @@ -3,7 +3,7 @@ [zero-one.geni.defaults] [zero-one.geni.interop :as interop]) (:import - (org.apache.spark.sql.types ArrayType DataTypes) + (org.apache.spark.sql.types ArrayType DataType DataTypes) (org.apache.spark.ml.linalg VectorUDT DenseVector SparseVector))) @@ -28,18 +28,39 @@ nil DataTypes/NullType}) (defn struct-field [col-name data-type nullable] - (let [spark-type (data-type->spark-type data-type)] + (let [spark-type (if (instance? DataType data-type) + data-type + (data-type->spark-type data-type))] (DataTypes/createStructField (name col-name) spark-type nullable))) (defn struct-type [& fields] (DataTypes/createStructType fields)) +(defn array-type [val-type nullable] + (DataTypes/createArrayType + (data-type->spark-type val-type) + nullable)) + +(defn map-type [key-type val-type] + (DataTypes/createMapType + (data-type->spark-type key-type) + (data-type->spark-type val-type))) + (defn ->schema [value] (cond - (map? value) (->> value - (map (fn [[k v]] (struct-field k v true))) - (apply struct-type)) - :else value)) + (and (vector? value) (= 1 (count value))) + (array-type (->schema (first value)) true) + + (and (vector? value) (= 2 (count value))) + (map-type (->schema (first value)) (->schema (second value))) + + (map? value) + (->> value + (map (fn [[k v]] (struct-field k (->schema v) true))) + (apply struct-type)) + + :else + value)) (defn create-dataframe ([rows schema] (create-dataframe @default-spark rows (->schema schema))) diff --git a/test/zero_one/geni/data_sources_test.clj b/test/zero_one/geni/data_sources_test.clj index b5f876bf..25e8bb99 100644 --- a/test/zero_one/geni/data_sources_test.clj +++ b/test/zero_one/geni/data_sources_test.clj @@ -12,6 +12,50 @@ (def write-df (-> melbourne-df (g/select :Method :Type) (g/limit 5))) +(facts "On data-oriented schema" :schema + (let [dummy-df (-> melbourne-df + (g/limit 2) + g/->kebab-columns + (g/select {:rooms (g/struct :rooms :bathroom) + :coord (g/array :longtitude :lattitude) + :prop (g/map (g/lit "seller") :seller-g + (g/lit "price") :price)})) + temp-file (.toString (create-temp-file! "-complex.parquet"))] + (g/write-parquet! dummy-df temp-file {:mode "overwrite"}) + (fact "correct dataframe baseline" + (g/dtypes dummy-df) => {:coord "ArrayType(DoubleType,true)" + :prop "MapType(StringType,StringType,true)" + :rooms (str "StructType(" + "StructField(rooms,LongType,true), " + "StructField(bathroom,DoubleType,true))")}) + (fact "correct direct schema option" + (-> (g/read-parquet! + temp-file + {:schema (g/struct-type + (g/struct-field :rooms + (g/struct-type + (g/struct-field :rooms :int true) + (g/struct-field :bathroom :float true)) + true) + (g/struct-field :coord (g/array-type :long true) true) + (g/struct-field :prop (g/map-type :string :string) true))}) + g/dtypes) => {:coord "ArrayType(LongType,true)" + :prop "MapType(StringType,StringType,true)" + :rooms (str "StructType(" + "StructField(rooms,IntegerType,true), " + "StructField(bathroom,FloatType,true))")}) + (fact "correct data-oriented schema option" + (-> (g/read-parquet! + temp-file + {:schema {:coord [:short] + :prop [:string :string] + :rooms {:rooms :float :bathroom :long}}}) + g/dtypes) => {:coord "ArrayType(ShortType,true)" + :prop "MapType(StringType,StringType,true)" + :rooms (str "StructType(" + "StructField(rooms,FloatType,true), " + "StructField(bathroom,LongType,true))")}))) + (facts "On schema option" :schema (let [csv-path "test/resources/sample_csv_data.csv" selected [:InvoiceDate :Price]]