aboutsummaryrefslogtreecommitdiff
path: root/platform/linux-generic/test/validation/api/ml/batch_add_gen.py
blob: 33515bd2f425049b32136c6a4f4755c43f4ad93a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Nokia
#

import onnx
from onnx import helper
from onnx import TensorProto

graph = helper.make_graph(
    [  # nodes
        helper.make_node("Add", ["x1", "x2"], ["y"], "Batch Add"),
    ],
    "Batch Add",  # name
    [  # inputs
        helper.make_tensor_value_info('x1', TensorProto.DOUBLE, ["c", 3]),
        helper.make_tensor_value_info('x2', TensorProto.DOUBLE, ["c", 3]),
    ],
    [  # outputs
        helper.make_tensor_value_info('y', TensorProto.DOUBLE, ["c", 3]),
    ]
)

model = helper.make_model(
    graph,
    opset_imports=[helper.make_opsetid("", 14)],
    producer_name='ODP validation tests',
    model_version=1,
    doc_string="y = x1 + x2",
    ir_version = 8
)

onnx.save(model, 'batch_add.onnx')