diff --git a/databricksx12/edi.py b/databricksx12/edi.py index 3b3645f..b9cb5a7 100644 --- a/databricksx12/edi.py +++ b/databricksx12/edi.py @@ -58,9 +58,23 @@ def num_transactions(self): # def transaction_segments(self): from databricksx12.transaction import Transaction - return [Transaction(self.segments_by_position(i - int(x.element(1)),i+1), self.format_cls, self.fields, self.funcs) for i,x in self.segments_by_name_index("SE")] - - + return [Transaction(self.segments_by_position(i - int(x.element(1))+1,i+1), self.format_cls, self.fields, self.funcs) for i,x in self.segments_by_name_index("SE")] + """ + Convert entire dataset into consumable row/column format + Preserves the following information: + *Segment names + *Row numbers for hierarchy + *Row length for easy query access + *Row data with split functionality to easily access row members + """ + def toRows(self): + return [{"segment_name": x.segment_name() + ,"segment_length": x.segment_len() + ,"row_number": i + ,"row_data": x.data + ,"segment_element_delim_char": x.format_cls.ELEMENT_DELIM + ,"segment_subelement_delim_char": x.format_cls.SUB_DELIM} for i,x in enumerate(self.data)] + """ spark dataframe can be built from json """ @@ -116,7 +130,7 @@ def element(self, element, sub_element=-1, dne="na/dne"): # # @returns number of elements in a segment # - def element_len(self): + def segment_len(self): return len(self.data.split(self.format_cls.ELEMENT_DELIM)) # diff --git a/databricksx12/transaction.py b/databricksx12/transaction.py index 29399fe..845effc 100644 --- a/databricksx12/transaction.py +++ b/databricksx12/transaction.py @@ -2,7 +2,7 @@ from databricksx12.format import * """ - Base class for all types of transactions. Each transaction must be a subclass of this + Base class for all transactions (ST/SE Segments) Building a Spark DataFrame using toJson() - @param "funcs" define which functions to use to flatten a transaction. Default is to use all "fx_*" definitions @@ -21,14 +21,23 @@ def __init__(self, segments, delim_cls = AnsiX12Delim, fields = None, funcs = No self.funcs = [x for x in dir(self) if x.startswith("fx_") and x not in funcs] self.fields = {**fields, **{x[3:]:getattr(self,x)() for x in self.funcs}} + # + # Returns number of claims in the transaction + # def claim_count(self): - pass #TODO + return len(self.segments_by_name("CLM")) -""" - Parsing 837 for relevant information about a transaction. + # + # Returns claim level detail + # + def to_claims(self): + from databricksx12.hls.claim import Claim + [(i, x.data) for i,x in x.segments_by_name_index("CLM")] + """ + [(24, 'CLM*ABC11111*1800***22:B:1*Y*A*Y*Y'), + (60, 'CLM*ABC111112*984***22:B:1*Y*A*Y*Y'), + (94, 'CLM*ABC111113*1353***22:B:1*Y*A*Y*Y'), + (118, 'CLM*ABC111114*1968***22:B:1*Y*A*Y*Y')] + self.segments_by_name_index("CLM") + """ -""" -class Claim(Transaction): - - def claim_line_count(self): - pass #TODO diff --git a/tests/test_pyspark.py b/tests/test_pyspark.py index 409e15e..9c4483f 100644 --- a/tests/test_pyspark.py +++ b/tests/test_pyspark.py @@ -3,8 +3,20 @@ class TestPyspark(PysparkBaseTest): - df = spark.read.text("sampledata/837/CC_837I_EDI.txt", wholetext=True) + df = spark.read.text("sampledata/837/*txt", wholetext=True) - def test_spark_df(self): - df.withColumn( + def test_transaction_count(self): + data = (df.rdd + .map(lambda x: x.asDict().get("value")) + .map(lambda x: EDI(x)) + .map(lambda x: {"transaction_count": x.num_transactions()}) + ).toDF() + assert ( data.count() == 4) #4 rows + assert ( data.select(sum(data.transaction_count)) == 8) #8 ST/SE transactions + + def test_tbd(self): + data = data = (df.rdd + .map(lambda x: x.asDict().get("value")) + .map(lambda x: EDI(x)) + diff --git a/tests/test_segment.py b/tests/test_segment.py index 219c38b..51ad0bd 100644 --- a/tests/test_segment.py +++ b/tests/test_segment.py @@ -12,7 +12,7 @@ class TestSegment(PysparkBaseTest): # def test_segment_length(self): assert(len(TestSegment.segments) == 66) - assert( set([s.element_len() == len(s.data.split("*")) for s in TestSegment.segments]) == {True} ) + assert( set([s.segment_len() == len(s.data.split("*")) for s in TestSegment.segments]) == {True} ) def test_sub_element_length(self): assert(len(TestSegment.segments) == 66)