zl程序教程

您现在的位置是:首页 >  其它

当前栏目

Tensorflow weight pruning

Tensorflow weight
2023-09-14 09:10:02 时间
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras.layers import Conv2D,MaxPool2D,Flatten,Dense,Dropout

path = "/home/qjm/Downloads/mnist.npz"

f=np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']

x_train=x_train.astype("float32")
x_test=x_test.astype("float32")

x_train/=255.0
x_test/=255.0

x_train=x_train.reshape(-1,28,28,1)
x_test=x_test.reshape(-1,28,28,1)

yy_train=np.zeros((60000,10))
yy_test=np.zeros((10000,10))
for i in range(60000):
    yy_train[i,y_train[i]]=1
for i in range(10000):
    yy_test[i,y_test[i]]=1

y_train=yy_train
y_test=yy_test

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

model = tf.keras.models.Sequential([
        Conv2D(filters=6,kernel_size=5,strides=(1,1),padding='same',activation='relu',use_bias=False,input_shape=(28,28,1)),
        MaxPool2D(pool_size=(3,3),strides=2,padding="same"),
        Conv2D(filters=16,kernel_size=5,strides=(1,1),padding='same',activation='relu',use_bias=False),
        MaxPool2D(pool_size=(3,3),strides=2,padding="same"),
        Flatten(input_shape=(7, 7)),
        Dense(120, activation='relu'),
        Dense(84, activation='relu'),
        Dropout(0.2),
        Dense(10, activation='softmax')
    ])
  
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=200, end_step=4000)

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

model_for_pruning.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


log_dir = '/home/qjm/Desktop/model'
callbacks = [
 tfmot.sparsity.keras.UpdatePruningStep(),
 # Log sparsity and other metrics in Tensorboard.
 tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir),
 tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
]
model_for_pruning.fit(x_train, y_train, epochs=10,callbacks=callbacks)
model_for_pruning.evaluate(x_test,  y_test, verbose=2)
print(model_for_pruning.summary())
weight=model_for_pruning.get_layer(index=0).get_weights()
for w in weight:
    print(w.shape)
    print(1-1.0*np.count_nonzero(w)/w.size)
    
weight=model_for_pruning.get_layer(index=2).get_weights()
for w in weight:
    print(w.shape)
    print(1-1.0*np.count_nonzero(w)/w.size)