JavaAPI调用TensorFlow2.0模型GPU或CPU

jvm设置

1
export CUDA_VISIBLE_DEVICES=1

指定GPU显卡索引

1
export CUDA_VISIBLE_DEVICES=-1

不使用GPU

POM

maven根据环境进行GPU或CPU的包引入进行切换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
<profiles>
<profile>
<id>dev</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<properties>
<profileActive>dev</profileActive>
</properties>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency>
</dependencies>
</profile>
<profile>
<id>prod</id>
<properties>
<profileActive>prod</profileActive>
</properties>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
<version>1.15.0</version>
</dependency>

<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.15.0</version>
</dependency>
</dependencies>
</profile>
</profiles>

加载模型

CPU加载

1
2
3
session = SavedModelBundle
.load(modelConfig.getPath(), modelConfig.getTags())
.session();

GPU加载,指定GPU卡

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
GPUOptions gpuOptions = GPUOptions.newBuilder()
.setVisibleDeviceList(modelConfig.getGpuIds())
.setPerProcessGpuMemoryFraction(0.85f)
.setAllowGrowth(true)
.build();
ConfigProto configProto = ConfigProto.newBuilder()
.setAllowSoftPlacement(true)
.setLogDevicePlacement(true)
.setGpuOptions(gpuOptions)
.build();

session = SavedModelBundle
.loader(modelConfig.getPath())
.withTags(modelConfig.getTags())
.withConfigProto(configProto.toByteArray())
.load()
.session();

模型预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
private List<Double> innerPredict(Map<String, List<List<Integer>>> featureMap, PredictModel model) {
Session.Runner runner = getRealSession().runner();
List<Tensor> inputTensors = new ArrayList<>();
List<Tensor<?>> outputTensors = new ArrayList<>();
try {
int batchSize = 0;
for (String inputName : model.getInputs()) {
List<List<Integer>> value = featureMap.get(inputName);

if (value == null) {
return Collections.emptyList();
}

if (value.size() == 0) {
log.warn("Input '{}' size == 0", inputName);
return Collections.emptyList();
}

batchSize = value.size();
Tensor inputTensor = toTensor(value);
inputTensors.add(inputTensor);
runner.feed(modelConfig.getDefaultInOpPrefix() + inputName, inputTensor);
}

outputTensors = runner.fetch(modelConfig.getDefaultOutOp()).run();
Tensor<?> output = outputTensors.get(0);
List<Double> ret = new ArrayList<>();
float[][] resultArray = new float[batchSize][1];
output.copyTo(resultArray);
for (float[] v : resultArray) {
ret.add((double) v[0]);
}
return ret;
} finally {
inputTensors.forEach(Tensor::close);
outputTensors.forEach(Tensor::close);
}
}

private Tensor<Integer> toTensor(List<List<Integer>> featureValues) {
int featureLen = featureValues.get(0).size();
if (featureLen == 1) {
return singleFeatureToArray(featureValues);
} else {
return multiFeatureToArray(featureValues, featureLen);
}
}

private Tensor<Integer> singleFeatureToArray(List<List<Integer>> featureValues) {
int[] ret = new int[featureValues.size()];
int i = 0;
for (List<Integer> featureValue : featureValues) {
ret[i++] = featureValue.get(0);
}
return Tensors.create(ret);
}

private Tensor<Integer> multiFeatureToArray(List<List<Integer>> featureValues, int featureLen) {
int[][] ret = new int[featureValues.size()][featureLen];
int i = 0;
for (List<Integer> featureValue : featureValues) {
int j = 0;
for (int value : featureValue) {
ret[i][j++] = value;
}
i++;
}
return Tensors.create(ret);
}

模型配置

1
2
3
4
5
6
7
8
@Data
public class ModelConfig {
private String path;
private String tags = "serve";
private String defaultInOpPrefix = "serving_default_";
private String defaultOutOp = "StatefulPartitionedCall";
private String gpuIds="-1";
}

  • 注意单输入和多输入的处理
  • 注意内存释放
  • 注意Op、前缀等默认参数
------ 本文结束------

本文标题:JavaAPI调用TensorFlow2.0模型GPU或CPU

文章作者:Perkins

发布时间:2020年06月08日

原始链接:https://perkins4j2.github.io/posts/41363/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。