Home > Software design >  Stratified sample with design in pandas df
Stratified sample with design in pandas df

Time:02-02

I have a df with columns that represents a stratum (strat). I want to loop over those stratum and pull out rows to a new df, df_sample. I want to pull out all rows in a stratum if cases are few.

I've tried the below, and it works. But I wonder if there is a better solution to this problem. Perhaps pd.concat is slow when I later use the real much larger data for example.

df=pd.DataFrame({'ID': range(0,120),
             'strat': ['A', 'B', 'B', 'A', 'B', 'A', 'D', 'A', 'B', 'C', 
                       'A', 'D', 'A', 'A', 'A', 'D', 'F', 'D', 'F', 'C', 
                       'B', 'A', 'A', 'C', 'A', 'A', 'B', 'D', 'B', 'C', 
                       'C', 'A', 'C', 'A', 'C', 'A', 'D', 'C', 'C', 'A', 
                       'B', 'F', 'F', 'C', 'B', 'D', 'A', 'A', 'B', 'B', 
                       'A', 'C', 'A', 'A', 'F', 'A', 'A', 'B', 'A', 'D', 
                       'C', 'B', 'B', 'A', 'B', 'C', 'B', 'A', 'D', 'B', 
                       'B', 'A', 'A', 'C', 'D', 'F', 'F', 'A', 'B', 'C',
                      'F', 'B', 'D', 'A', 'A', 'F', 'B', 'D', 'B', 'A',
                      'F', 'D', 'A', 'A', 'C', 'B', 'B', 'C', 'C', 'B',
                      'F', 'A', 'A', 'B', 'B', 'B', 'F', 'A', 'B', 'C',
                      'A', 'A', 'A', 'B', 'B', 'A', 'A', 'A', 'B', 'B']})
df_sample=pd.DataFrame()

for i in df.strat.unique():
    temp=df[df['strat']==i]
    
    if len(temp) < 21:
        strat=temp.sample(len(temp))
        
    elif len(temp) > 20:
        strat = temp.sample(frac=0.5)
        
    df_sample=pd.concat([df_sample, strat])
    

CodePudding user response:

You can groupby "strat" and count the number of entries in each "strat", then identify the strats that have less than 21 entries and shuffle them. Then take the remaining strats (those with more than 20 entries) and sample 50% of them. Finally concatenate the two DataFrames:

msk1 = df.groupby('strat')['strat'].count() < 21
less_than_21 = msk1.index[msk1]
msk2 = df['strat'].isin(less_than_21)    
out = pd.concat((df[~msk2].groupby('strat').sample(frac=0.5), df[msk2].sample(msk2.sum())))

Output:

      ID strat
110  110     A
72    72     A
46    46     A
31    31     A
92    92     A
..   ...   ...
18    18     F
9      9     C
23    23     C
42    42     F
82    82     D

[82 rows x 2 columns]

CodePudding user response:

Other solutions may be faster. Here is another one in case readability/maintainability is more important.

def sample_stratum(stratum):
    nrows = stratum.shape[0]
    if nrows < 21:
        output = stratum.sample(nrows)
    else:
        output = stratum.sample(frac=0.5)
    return output


# Index may be retained if needed
sampled_df = df.groupby(by=['strat']).apply(sample_stratum).reset_index(drop=True)


#    ID strat
# 0   12     A
# 1    7     A
# 2   50     A
# 3   58     A
# 4    0     A
# ..  ..   ...
# 77  41     F
# 78  42     F
# 79  16     F
# 80  76     F
# 81  90     F
# [82 rows x 2 columns]

CodePudding user response:

Create mask for all groups with counts and then processing each group separately:

m = df.groupby('strat')['strat'].transform('size').lt(21)
df = pd.concat((df[~m].groupby('strat').sample(frac=0.5), 
                df[m].sample(frac=1)),
                ignore_index=True)
print (df)
    ID strat
0   71     A
1   31     A
2   72     A
3   39     A
4   83     A
..  ..   ...
77  37     C
78  85     F
79  19     C
80  34     C
81  73     C

[82 rows x 2 columns]

Alternative solution:

m = df['strat'].map(df['strat'].value_counts()).lt(21)
df = pd.concat((df[~m].groupby('strat').sample(frac=0.5), df[m].sample(frac=1)))
  •  Tags:  
  • Related