Home > database >  How to get rows where at least two distinct values are in a column?
How to get rows where at least two distinct values are in a column?

Time:01-05

I have a log file and I want to report the IP addresses that initiated more than one (at least two) type of protocol connections, while showing these protocols. I'm trying to get these results by using both the DataFrames API and SparkSQL.

Here is a sample of my data:

 ---------------- -------- -------- --------------- -------------- --------- ------------- ------ ----- 
|       Timestamp|Duration|Protocol|BytesOriginator|ResponderBytes|LocalHost|   RemoteHost| State|Flags|
 ---------------- -------- -------- --------------- -------------- --------- ------------- ------ ----- 
|748162802.427995| 1.24383|    smtp|              ?|             ?|        1| 128.97.154.3|   REJ|    L|
|748162802.803033| 3.96513|    smtp|           1173|           328|        3|  128.8.142.5|    SF| null|
|748162804.817224| 1.02839|    nntp|             58|           129|        2|   140.98.2.1|    SF|    L|
|748162812.254572| 138.168|    nntp|         363238|          1200|        4| 128.49.4.103|    SF|    L|
|748162817.478016| 10.0858|    nntp|            230|           100|        4| 128.32.133.1|    SF|    N|
|748162833.453963| 2.16477|    smtp|           2524|           306|        5|192.48.232.17|    SF| null|
|748162836.735788| 13.1779|    smtp|          16479|           174|       16| 128.233.1.12|RSTRS3|    L|
|748162839.930331| 6.69767|    smtp|           3104|           371|        8|   139.91.1.1|    SF|    L|
|748162841.854151| 2.07407|    smtp|           1172|           380|        6|  128.8.142.5|    SF| null|
|748162854.814153| 131.659|    nntp|         319292|          1220|        4| 128.110.4.25|    SF|    L|
|748162866.207165| 51.8406|    nntp|         135714|           280|        4| 128.110.4.25|    SF| null|
|748162866.600750|0.402045|    smtp|              ?|             ?|        1| 128.97.154.3|   REJ|    L|
|748162869.790751| 172.363|    smtp|              0|             0|       16|132.230.6.100|    SF|    L|
|748162873.491682|  102.88|    nntp|            346|           180|        4| 128.32.136.1|    SF|   LN|
|748162875.237378| 5.32943|    nntp|             90|            85|        4| 128.32.133.1|    SF|    N|
 ---------------- -------- -------- --------------- -------------- --------- ------------- ------ ----- 

I tried to filter my dataframe but I keep getting an error, I don't know if I should use the Window function or not. By using SparkSQL, so far I'm getting the IPs but without the protocols.

Here's what I did:

custom_schema = StructType([
    StructField('Timestamp', StringType(), True),
    StructField('Duration', FloatType(), True),
    StructField('Protocol', StringType(), True),
    StructField('BytesOriginator', StringType(), True),
    StructField('ResponderBytes', StringType(), True),
    StructField('LocalHost', StringType(), True),
    StructField('RemoteHost', StringType(), True),
    StructField('State', StringType(), True),
    StructField('Flags', StringType(), True) 
])

logs = spark.read.csv('lbl-conn-7.csv', header=False, sep=' ', schema=custom_schema)

# I get an error
logs.select('RemoteHost', 'Protocol').distinct().filter(F.countDistinct('Protocol') > 1).show()

logs.createOrReplaceTempView("mytable")
sqlContext = SQLContext(sc)
df = sqlContext.sql("select remotehost, protocol FROM mytable GROUP BY  HAVING COUNT(distinct protocol) > 1")
# It doesn't show the protocols
df.show()

CodePudding user response:

You can group by RemoteHost and collect list of distinct Protocol used. Then, filter the resulting dataframe using the size of protocols array:

import pyspark.sql.functions as F

logs.groupBy("RemoteHost").agg(
    F.collect_set("Protocol").alias("Protocols")
).filter(
    F.size("Protocols") >= 2
).show()

Spark SQL equivalent query:

SELECT  RemoteHost, 
        collect_set(Protocol) AS Protocols
FROM    mytable 
GROUP BY  RemoteHost
HAVING  size(Protocols) >= 2 -- or count(distinct Protocol)  >= 2

If you want to keep all the columns, then use Window with collect_set function:

logs.withColumn(
    "Protocols",
    F.collect_set("Protocol").over((Window.partitionBy("RemoteHost")))
).filter(
    F.size("Protocols") >= 2
).drop("Protocols").show()
  •  Tags:  
  • Related