Home > Mobile >  How can I avoid dropping rows with NaNs when using Pandas `where` method?
How can I avoid dropping rows with NaNs when using Pandas `where` method?

Time:01-20

I'm running into a problem when using the Pandas where method. Specifically, I'm using the where to identify rows in a dataframe that meet specific conditions. If these conditions are met, the where method correctly assigns NaNs to these values. The problem I'm encountering is the case where some rows already contain NaN values prior to executing the where method. Instead of leaving these values intact, they are being deleted and my dataframe is altered in an undesirable and unexpected way. How can I rectify this?

import numpy as np
import pandas as pd

High = {'High': np.array([126.93000031, 126.98999786, 124.91999817, 127.72000122,
       128.        , 127.94000244, 128.32000732, 127.38999939,
       127.63999939, 125.80000305, 125.34999847, 125.23999786,
       124.84999847, 126.16000366, 126.31999969, 128.46000671,
       127.75      ])}
Low = {'Low': np.array([125.16999817, 124.77999878, 122.86000061, 125.09999847,
       125.20999908, 125.94000244, 126.31999969, 126.41999817,
       125.08000183, 124.55000305, 123.94000244, 124.05000305,
       123.12999725, 123.84999847, 124.83000183, 126.20999908,
       126.51999664])}       
Close = {'Close': np.array([126.26999664, 124.84999847, 124.69000244, 127.30999756,
       125.43000031, 127.09999847, 126.90000153, 126.84999847,
       125.27999878, 124.61000061, 124.27999878, 125.05999756,
       123.54000092, 125.88999939, 125.90000153, 126.73999786,
       127.12999725])}        
index = pd.date_range(start = '2021-05-17', periods = 17)
df = pd.DataFrame(dict(High, **Low, **Close), index = index)

pos = {'pos': np.array([np.nan, np.nan,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,  1.,  1.,
        1.,  1.,  1.,  1.])}
stop = {'stop': np.array([         np.nan,          np.nan, 122.86000061, 122.86000061,
       128.        , 128.        , 128.        , 128.        ,
       128.        , 125.80000305, 125.80000305, 124.05000305,
       124.05000305, 123.84999847, 123.84999847, 123.84999847,
       123.84999847])}
s = pd.DataFrame(dict(pos, **stop), index = index)

grouped = s.groupby(['pos','stop'])
grouped1 = grouped.apply(
    lambda g: g.where(
    (s['pos'] == 1) & (s['stop'] <= df['Low']) |
    (s['pos'] == -1) & (s['stop'] >= df['High']) 
    ))

s

            pos        stop
2021-05-17  NaN         NaN
2021-05-18  NaN         NaN
2021-05-19  1.0  122.860001
2021-05-20  1.0  122.860001
2021-05-21 -1.0  128.000000
2021-05-22 -1.0  128.000000
2021-05-23 -1.0  128.000000
2021-05-24 -1.0  128.000000
2021-05-25 -1.0  128.000000
2021-05-26 -1.0  125.800003
2021-05-27 -1.0  125.800003
2021-05-28  1.0  124.050003
2021-05-29  1.0  124.050003
2021-05-30  1.0  123.849998
2021-05-31  1.0  123.849998
2021-06-01  1.0  123.849998
2021-06-02  1.0  123.849998

grouped1

            pos        stop
2021-05-19  1.0  122.860001
2021-05-20  1.0  122.860001
2021-05-21 -1.0  128.000000
2021-05-22 -1.0  128.000000
2021-05-23  NaN         NaN
2021-05-24 -1.0  128.000000
2021-05-25 -1.0  128.000000
2021-05-26 -1.0  125.800003
2021-05-27 -1.0  125.800003
2021-05-28  1.0  124.050003
2021-05-29  NaN         NaN
2021-05-30  1.0  123.849998
2021-05-31  1.0  123.849998
2021-06-01  1.0  123.849998
2021-06-02  1.0  123.849998

The problem is that the grouped1 dataframe is now missing the first two rows from the s dataframe associated with the 2021-05-17 & 2021-05-18 indices. Am I misunderstanding something about the where method or is this a bug? What's the best alternative to produce the desired result below?

grouped1
            pos        stop
2021-05-17  NaN         NaN
2021-05-18  NaN         NaN
2021-05-19  1.0  122.860001
2021-05-20  1.0  122.860001
2021-05-21 -1.0  128.000000
2021-05-22 -1.0  128.000000
2021-05-23  NaN         NaN
2021-05-24 -1.0  128.000000
2021-05-25 -1.0  128.000000
2021-05-26 -1.0  125.800003
2021-05-27 -1.0  125.800003
2021-05-28  1.0  124.050003
2021-05-29  NaN         NaN
2021-05-30  1.0  123.849998
2021-05-31  1.0  123.849998
2021-06-01  1.0  123.849998
2021-06-02  1.0  123.849998

CodePudding user response:

One workaround is to fill your NaNs with some value that you would never otherwise get such as -999. Then these rows definitely won't meet your conditions in np.where and will be filled with NaN in your resulting grouped1 DataFrame:

grouped = s.fillna(-999).groupby(['pos','stop'])
grouped1 = grouped.apply(
    lambda g: g.where(
    (s['pos'] == 1) & (s['stop'] <= df['Low']) |
    (s['pos'] == -1) & (s['stop'] >= df['High']) 
    ))

Result:

>>> grouped1
            pos        stop
2021-05-17  NaN         NaN
2021-05-18  NaN         NaN
2021-05-19  1.0  122.860001
2021-05-20  1.0  122.860001
2021-05-21 -1.0  128.000000
2021-05-22 -1.0  128.000000
2021-05-23  NaN         NaN
2021-05-24 -1.0  128.000000
2021-05-25 -1.0  128.000000
2021-05-26 -1.0  125.800003
2021-05-27 -1.0  125.800003
2021-05-28  1.0  124.050003
2021-05-29  NaN         NaN
2021-05-30  1.0  123.849998
2021-05-31  1.0  123.849998
2021-06-01  1.0  123.849998
2021-06-02  1.0  123.849998

For those interested: initially I thought that s.groupby(['pos','stop'], dropna=False) should handle NaN because

for _,df_group in s.groupby(['pos','stop']): print(df_group)

displays all of the groups including NaN. However the moment you add .apply any rows with NaN are again dropped. For example, none of the NaN rows come through when you run:

s.groupby(['pos','stop'], dropna=False).apply(lambda g: g)

I would have expected this to return all rows including those with NaN. My guess is this is probably due to the fact that np.nan != np.nan so when we use .apply, NaN is somehow being dropped.

  •  Tags:  
  • Related