Skip to content

Instantly share code, notes, and snippets.

@alidastgheib
Created February 2, 2019 07:07
Show Gist options
  • Select an option

  • Save alidastgheib/45a9e8be7255fda3898928474c964214 to your computer and use it in GitHub Desktop.

Select an option

Save alidastgheib/45a9e8be7255fda3898928474c964214 to your computer and use it in GitHub Desktop.
import keras.backend as K
def samplewise_tf_dct(x_sample):
print('=======> input of "samplewise_tf_dct"')
print(type(x_sample))
print((x_sample.shape))
print(x_sample)
for idx_channel in range(x_sample.shape[-1]):
x_sample[:, :, idx_channel] = K.spectral.dct(K.transpose(K.spectral.dct(x_sample[:, :, idx_channel]))) # a 2D DCT
return x_sample
def dct_layer_function(x_batch):
print('=======> input of "dct_layer_function"')
print(type(x_batch))
print((x_batch.shape))
print(x_batch)
return K.map_fn(samplewise_tf_dct, x_batch)
#############################################################
#############################################################
from keras.models import Model
from keras import layers
from keras.layers import Lambda
input_of_net = layers.Input(shape=(27, 27, 3), name='input_of_net')
x = layers.Conv2D(32, (3, 3), strides=(2, 2), kernel_initializer='glorot_normal', name='block1_conv1')(input_of_net)
x = layers.BatchNormalization(name='block1_conv1_bn')(x)
x = layers.Activation('relu', name='block1_conv1_act')(x)
dct_layer = Lambda(function = dct_layer_function)
x = dct_layer(x)
#############################################################
#############################################################
=======> input of "dct_layer_function"
<class 'tensorflow.python.framework.ops.Tensor'>
(?, 13, 13, 32)
Tensor("block1_conv1_act_5/Relu:0", shape=(?, 13, 13, 32), dtype=float32)
=======> input of "samplewise_tf_dct"
<class 'tensorflow.python.framework.ops.Tensor'>
(13, 13, 32)
Tensor("lambda_6/map/while/TensorArrayReadV3:0", shape=(13, 13, 32), dtype=float32)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-17-b2c158189414> in <module>()
11
12 dct_layer = Lambda(function = dct_layer_function)
---> 13 x = dct_layer(x)
<ipython-input-16-225a14073823> in samplewise_tf_dct(x_sample)
8
9 for idx_channel in range(x_sample.shape[-1]):
---> 10 x_sample[:, :, idx_channel] = K.spectral.dct(K.transpose(K.spectral.dct(x_sample[:, :, idx_channel]))) # a 2D DCT
11 return x_sample
12
AttributeError: module 'keras.backend' has no attribute 'spectral'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment