V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
a421
V2EX  ›  机器学习

请教:在安卓 tf lite 使用训练好的模型,出现 DataType error

  •  
  •   a421 · 2020-06-11 21:57:43 +08:00 · 1539 次点击
    这是一个创建于 1626 天前的主题,其中的信息可能已经有所发展或是发生改变。

    原本的目的是移植一个模型到安卓,遇到问题后,重新做了个简单的模型验证,出现同样的问题。

    python 训练的代码

    
    model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
    model.compile(optimizer='sgd', loss='mean_squared_error')
    
    xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=np.float32)
    ys = np.array([-3.0, -1.0, 0.0, 3.0, 5.0, 7.0], dtype=np.float32)
    
    model.fit(xs, ys, epochs=500)
    keras_file = 'linear.h5'
    keras.models.save_model(model, keras_file)
    

    转换成 .tflite 后,在安卓使用

    Interpreter interpreter = new Interpreter(FileUtil.loadMappedFile(activity, "linear.tflite"));
    interpreter.allocateTensors();
    int probabilityTensorIndex = 0;
    int[] probabilityShape =
            interpreter.getOutputTensor(probabilityTensorIndex).shape(); //
    DataType probabilityDataType = interpreter.getOutputTensor(probabilityTensorIndex).dataType();
    TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
    
    int inputTensorIndex = 0;
    DataType inputDataType = interpreter.getInputTensor(inputTensorIndex).dataType();
    int[] inputShape = interpreter.getInputTensor(inputTensorIndex).shape();
    TensorBuffer inputBuffer = TensorBuffer.createFixedSize(inputShape, inputDataType);
    final float[] input = {10};
    inputBuffer.loadArray(input);
    
    interpreter.run(inputBuffer, outputProbabilityBuffer);
    

    报错是

    I/tflite: Initialized TensorFlow Lite runtime.
    E/AndroidRuntime: FATAL EXCEPTION: inference
        Process: com.example.my1application, PID: 26839
        java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat
            at org.tensorflow.lite.Tensor.dataTypeOf(Tensor.java:344)
            at org.tensorflow.lite.Tensor.throwIfTypeIsIncompatible(Tensor.java:397)
            at org.tensorflow.lite.Tensor.getInputShapeIfDifferent(Tensor.java:287)
            at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:137)
            at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:316)
            at org.tensorflow.lite.Interpreter.run(Interpreter.java:277)
            at com.example.my1application.DisplayMessageActivity$1.run(DisplayMessageActivity.java:114)
            at android.os.Handler.handleCallback(Handler.java:815)
            at android.os.Handler.dispatchMessage(Handler.java:104)
            at android.os.Looper.loop(Looper.java:207)
            at android.os.HandlerThread.run(HandlerThread.java:61)
    
    目前尚无回复
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   959 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 23ms · UTC 20:26 · PVG 04:26 · LAX 12:26 · JFK 15:26
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.