Home > database >  Pytorch features and classes from .npy files
Pytorch features and classes from .npy files

Time:02-01

I am very rookie in moving from TensorFlow to Pytorch. In tensorflow, I can simply load features and labels from separate .npy files and train a CNN using them. It is simple as below:

def finetune_resnet(file_train_classes, file_train_features, name_model_to_save):

    #Lets load features and classes first
    print("Loading, organizing and pre-processing features")
    num_classes = 12
    x_train=np.load(file_train_features)
    y_train=np.load(file_train_classes)

    #Defining train as 70% and validation 30% of the data  
    #The partition is stratified with a fixed random state
    #Therefore, for all networks, the partition will be the same
    x_train, x_validation, y_train, y_validation = train_test_split(x_train, y_train, test_size=0.30, stratify=y_train, random_state=42)

    print("transforming to categorical")
    y_train = to_categorical(y_train, num_classes)
    y_validation = to_categorical(y_validation, num_classes)

    y_train= tf.constant(y_train, shape=[y_train.shape[0], num_classes])
    y_validation= tf.constant(y_validation, shape=[y_validation.shape[0], num_classes])

    print("preprocessing data")
    #Preprocessing data
    x_train = x_train.astype('float32')
    x_validation=x_validation.astype('float32')
    x_train /= 255.
    x_validation /= 255.
    
    print("Setting up the network")
    #Parameters for network training
    batch_size = 32
    epochs=300
    sgd = SGD(lr=0.01)
    trainAug = ImageDataGenerator(rotation_range=30,zoom_range=0.15,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.15,horizontal_flip=True,fill_mode="nearest")
    print("Compiling the network")
    #Load model and prepare it for fine tuning
    baseModel = ResNet50(weights="imagenet", include_top=False, 
    input_tensor=Input(shape=(224, 224, 3)))
    # construct the head of the model that will be placed on top of the
    # the base model
    headModel = baseModel.output
    headModel = Flatten(name="flatten")(headModel)
    headModel = Dense(512, activation="relu")(headModel)
    headModel = Dropout(0.5)(headModel)
    headModel = Dense(num_classes, activation="softmax")(headModel)

    # place the head FC model on top of the base model (this will become
    # the actual model we will train)
    model = Model(inputs=baseModel.input, outputs=headModel)
    model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"])

    trainAug.fit(x_train)
    # Fit the model on the batches generated by datagen.flow().
    print("[INFO] training head...")
    H=model.fit(trainAug.flow(x_train, y_train, batch_size=batch_size), steps_per_epoch=x_train.shape[0] // batch_size, epochs=epochs, validation_data=(x_validation, y_validation), callbacks=callbacks)

However, I have no idea how to load, train and evaluate training and testing data if loading these data from .npy files. I checked a tutorial that loads training data from folders, which is not what I want.

How can I train and test a RESNET-50 model starting with imagenet weights loading train and test data from .npy files with Pytorch?

P.s: most of Pytorch training loops require <class 'torch.utils.data.dataloader.DataLoader'> inputs to train. Is that possible to transform my training data in numpy arrays to such a format?

P.s= you can try with my data here

CodePudding user response:

It seems like you need to create a custom Dataset.

class MyDataSet(torch.utils.data.Dataset):
  def __init__(self, x, y):
    super(MyDataSet, self).__init__()
    # store the raw tensors
    self._x = np.load(file_train_features)
    self._y = np.load(file_train_classes)

  def __len__(self):
    # a DataSet must know it size
    return self._x.shape[0]

  def __getitem__(self, index):
    x = self._x[index, :]
    y = self._y[index, :]
    return x, y

You can further use Dataset methods to split MyDataSet into train and validation (e.g., using torch.utils.data.random_split).

You might also find TensorDataset useful.

  •  Tags:  
  • Related