xiaoxiaokeke 2019-11-11
logdir = './callbacks' if not os.path.exists(logdir): os.mkdir(logdir) output_model_file = os.path.join(logdir, "xxxx.h5") callbacks = [ tf.keras.callbacks.ModelCheckpoint(output_model_file, save_best_file = True) ] hist = model.fit_generator(xxxxx, callbacks = callbacks)
NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
ValueError: Unknown loss function:loss
ValueError: Unknown layer: xxxlayer
在自定义网络层时重写get_config函数
我们主要看传入__init__接口时有哪些配置参数,然后在get_config内一一的将它们转为字典键值并且返回使用,以Mylayer为例:
class MyLayer(tf.keras.layers.Layer): def __init__(self, num_outputs, name="MyLayer", **kwargs): super(MyLayer, self).__init__(name=name, **kwargs) self.num_outputs = num_outputs def build(self, input_shape): self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs]) super().build(input_shape) def call(self, input): output = tf.matmul(input, self.kernel) return output def get_config(self): config = {"num_outputs":self.num_outputs} base_config = super(Mylayer, self).get_config() return dict(list(base_config.items()) + list(config.items()))
一般来说,父类的config也是需要一并保存的,其中base_config即是父类网络层实现的配置参数,最后把父类及继承类的config组装为字典形式即可解决该问题
然后 在加载模型的时候,建立一个字典,该字典的键是自定义网络层时设定该层的名字,其值为该自定义网络层的类名,该字典将用于加载模型时使用
如果还使用了自定义的loss,则把loss也加到_custom_objects中
_custom_objects = { "Mylayer" : Mylayer, "loss" : Myloss }
最后在load模型的时候把_custom_objects传入
model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)