What I have:
country | sources | infer_from_source
-----------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE",
| "CHN", "FRA"] | "FALSE", "TRUE"]
"DEU" | ["DEU"] | ["FALSE"]
What I want after a function:
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | ["CZE", "FRA"]
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
I need to create a function that
if country column is null, extracts the countries from the sources array based on the boolean values in the infer_from_source column array, otherwise it should give back the country value.
I created this function
from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
idx = infer_from_source.index("TRUE")
return sources[idx]
return None
But this yields - basically the .index("TRUE") method returns the index of the first element that matches its argument only.
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | "CZE"
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
CodePudding user response:
Fixed it! Was simply a list comprehension matter
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
max_ix = len(infer_from_source)
true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
return [sources[ix] for ix in true_index_array]
return None
CodePudding user response:
You should avoid using UDFs whenever you could achieve the same only with Spark builtin functions especially when it comes to Pyspark UDFs.
Here's another way using higher order functions transform filter on arrays:
import pyspark.sql.functions as F
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("""filter(
transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
x -> x is not null
)""")
)
)
df1.show()
# ------- -------------------- -------------------- ----------------
#|country| sources| infer_from_source|inferred_country|
# ------- -------------------- -------------------- ----------------
#| null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...| [CZE, FRA]|
#| DEU| [DEU]| [FALSE]| [DEU]|
# ------- -------------------- -------------------- ----------------
And starting from Spark 3 , you can use index in filter lambda function :
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
)
)
