If I have below table
----- ----- --- ----- ----- ----- ----- ----- ----- -----
| a| b| id|m2000|m2001|m2002|m2003|m2004|m2005
----- ----- --- ----- ----- ----- ----- ----- ----- -----
|a |world| 1| 0| 0| 1| 0| 0| 1|
----- ----- --- ----- ----- ----- ----- ----- ----- -----
How do I create a new dataframe like below that checks cols m2000 to m2014 and sees if any these fields are 1. It then creates the below table where 10/10 is static. 2002 and 2005 is used as it is only 2 columns between m2000 and m2014 where 1 is in above table.
|id | year | yearend |
|1 | 10/10/2002| 12/12/2005|
|1 | 10/10/2002| 12/12/2005|
code to create first dataframe
from pyspark.shell import spark
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
data2 = [("a", "world", "1", 0, 0, 1,0,0,1),
]
schema = StructType([ \
StructField("a", StringType(), True), \
StructField("b", StringType(), True), \
StructField("id", StringType(), True), \
StructField("m2000", IntegerType(), True), \
StructField("m2001", IntegerType(), True), \
StructField("m2002", IntegerType(), True), \
StructField("m2003", IntegerType(), True), \
StructField("m2004", IntegerType(), True), \
StructField("m2005", IntegerType(), True), \
])
df = spark.createDataFrame(data=data2, schema=schema)
df.printSchema()
df.show(truncate=False)
CodePudding user response:
Assuming a dataframe with a more complete scenario, where there are rows without years to '1' and rows with more '1's:
from pyspark.shell import spark
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
data2 = [("a", "world", "1", 0, 0, 1,0,0,1),
("b", "world", "2", 0, 1, 0,1,0,1),
("c", "world", "3", 0, 0, 0,0,0,0)
]
schema = StructType([ \
StructField("a", StringType(), True), \
StructField("b", StringType(), True), \
StructField("id", StringType(), True), \
StructField("m2000", IntegerType(), True), \
StructField("m2001", IntegerType(), True), \
StructField("m2002", IntegerType(), True), \
StructField("m2003", IntegerType(), True), \
StructField("m2004", IntegerType(), True), \
StructField("m2005", IntegerType(), True), \
])
df = spark.createDataFrame(data=data2, schema=schema)
df.printSchema()
df.show(truncate=False)
| a | b | id | m2000 | m2001 | m2002 | m2003 | m2004 | m2005 | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | a | world | 1 | 0 | 0 | 1 | 0 | 0 | 1 |
| 1 | b | world | 2 | 0 | 1 | 0 | 1 | 0 | 1 |
| 2 | c | world | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
For convenience, I pass your dataframe to pandas, but I will use simple iterative structures that you can integrate into spark.
pandas_df = df.toPandas()
We retrieve the list of years excluding the first 3 columns:
years = list(pandas_df.columns)[3:]
Finally, the code needed to generate the required dataframe is as follows (inline comments):
tmp_df_data_list = []
# iterate over rows of df
for _, row in pandas_df.iterrows():
flagged_years=[]
# for each year check if col value is 1
for y in years:
if row[y]: # if is 1, append col name
flagged_years.append(y)
if len(flagged_years) >= 2:
# get first occurence as 'year' and last as 'yearend' by removing the first letter
min_year = flagged_years[0][1:]
max_year = flagged_years[-1][1:]
tmp_df_data_list.append([row.id, '10/10/' min_year, '12/12/' max_year])
res_df = pd.DataFrame(tmp_df_data_list, columns=['id', 'year', 'yearend'])
Output will be:
| id | year | yearend | |
|---|---|---|---|
| 0 | 1 | 10/10/2002 | 12/12/2005 |
| 1 | 2 | 10/10/2001 | 12/12/2005 |
CodePudding user response:
we can use pyspark native functions to create an array of the column names that have the value 1. the array can then be used to get the min and max of years and concat with "10/10/".
here's an example
data_ls = [
("a", "world", "1", 0, 0, 1,0,0,1),
("b", "world", "2", 0, 1, 0,1,0,1),
("c", "world", "3", 0, 0, 0,0,0,0)
]
data_sdf = spark.sparkContext.parallelize(data_ls). \
toDF(['a', 'b', 'id', 'm2000', 'm2001', 'm2002', 'm2003', 'm2004', 'm2005'])
# --- ----- --- ----- ----- ----- ----- ----- -----
# | a| b| id|m2000|m2001|m2002|m2003|m2004|m2005|
# --- ----- --- ----- ----- ----- ----- ----- -----
# | a|world| 1| 0| 0| 1| 0| 0| 1|
# | b|world| 2| 0| 1| 0| 1| 0| 1|
# | c|world| 3| 0| 0| 0| 0| 0| 0|
# --- ----- --- ----- ----- ----- ----- ----- -----
yearcols = [k for k in data_sdf.columns if k.startswith('m20')]
data_sdf. \
withColumn('yearcol_structs',
func.array(*[func.struct(func.lit(int(c[-4:])).alias('year'), func.col(c).alias('value'))
for c in yearcols]
)
). \
withColumn('yearcol_1s',
func.expr('transform(filter(yearcol_structs, x -> x.value = 1), f -> f.year)')
). \
filter(func.size('yearcol_1s') >= 1). \
withColumn('year_start', func.concat(func.lit('10/10/'), func.array_min('yearcol_1s'))). \
withColumn('year_end', func.concat(func.lit('10/10/'), func.array_max('yearcol_1s'))). \
show(truncate=False)
# --- ----- --- ----- ----- ----- ----- ----- ----- ------------------------------------------------------------------ ------------------ ---------- ----------
# |a |b |id |m2000|m2001|m2002|m2003|m2004|m2005|yearcol_structs |yearcol_1s |year_start|year_end |
# --- ----- --- ----- ----- ----- ----- ----- ----- ------------------------------------------------------------------ ------------------ ---------- ----------
# |a |world|1 |0 |0 |1 |0 |0 |1 |[{2000, 0}, {2001, 0}, {2002, 1}, {2003, 0}, {2004, 0}, {2005, 1}]|[2002, 2005] |10/10/2002|10/10/2005|
# |b |world|2 |0 |1 |0 |1 |0 |1 |[{2000, 0}, {2001, 1}, {2002, 0}, {2003, 1}, {2004, 0}, {2005, 1}]|[2001, 2003, 2005]|10/10/2001|10/10/2005|
# --- ----- --- ----- ----- ----- ----- ----- ----- ------------------------------------------------------------------ ------------------ ---------- ----------
