1 Star 0 Fork 1

wb253 / onnxruntime

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
tensor_data.go 7.83 KB
一键复制 编辑 原始数据 按行查看 历史
wb253 提交于 2023-12-27 17:58 . fix bug
package onnxruntime
import (
"errors"
"fmt"
"reflect"
"unsafe"
)
// #cgo CFLAGS: -O2 -g
//
// #include "onnxruntime_wrapper.h"
import "C"
type FloatData interface {
~float32 | ~float64
}
type IntData interface {
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64
}
// This is used as a type constraint for the generic Tensor type.
type TensorData interface {
FloatData | IntData
}
// Returns the ONNX enum value used to indicate TensorData type T.
func GetTensorElementDataType[T TensorData]() C.ONNXTensorElementDataType {
// Sadly, we can't do type assertions to get underlying types, so we need
// to use reflect here instead.
var v T
kind := reflect.ValueOf(v).Kind()
switch kind {
case reflect.Float64:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
case reflect.Float32:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
case reflect.Int8:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
case reflect.Uint8:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
case reflect.Int16:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
case reflect.Uint16:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
case reflect.Int32:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
case reflect.Uint32:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
case reflect.Int64:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
case reflect.Uint64:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
}
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
}
// Wraps the ONNXTEnsorElementDataType enum in C.
type TensorElementDataType int
const (
TensorElementDataTypeUndefined = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
TensorElementDataTypeFloat = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
TensorElementDataTypeUint8 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
TensorElementDataTypeInt8 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
TensorElementDataTypeUint16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
TensorElementDataTypeInt16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
TensorElementDataTypeInt32 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
TensorElementDataTypeInt64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
TensorElementDataTypeString = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
TensorElementDataTypeBool = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
TensorElementDataTypeFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
TensorElementDataTypeDouble = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
TensorElementDataTypeUint32 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
TensorElementDataTypeUint64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
// Not supported by onnxruntime (as of onnxruntime version 1.16.1)
TensorElementDataTypeComplex64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
// Not supported by onnxruntime (as of onnxruntime version 1.16.1)
TensorElementDataTypeComplex128 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
// Non-IEEE floating-point format based on IEEE754 single-precision
TensorElementDataTypeBFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
// 8-bit float types, introduced in onnx 1.14. See
// https://onnx.ai/onnx/technical/float8.html
TensorElementDataTypeFloat8E4M3FN = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN
TensorElementDataTypeFloat8E4M3FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ
TensorElementDataTypeFloat8E5M2 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2
TensorElementDataTypeFloat8E5M2FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
)
func (t TensorElementDataType) String() string {
switch t {
case TensorElementDataTypeUndefined:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED"
case TensorElementDataTypeFloat:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT"
case TensorElementDataTypeUint8:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"
case TensorElementDataTypeInt8:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"
case TensorElementDataTypeUint16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"
case TensorElementDataTypeInt16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"
case TensorElementDataTypeInt32:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"
case TensorElementDataTypeInt64:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"
case TensorElementDataTypeString:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"
case TensorElementDataTypeBool:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"
case TensorElementDataTypeFloat16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"
case TensorElementDataTypeDouble:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"
case TensorElementDataTypeUint32:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"
case TensorElementDataTypeUint64:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"
case TensorElementDataTypeComplex64:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64"
case TensorElementDataTypeComplex128:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128"
case TensorElementDataTypeBFloat16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"
case TensorElementDataTypeFloat8E4M3FN:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN"
case TensorElementDataTypeFloat8E4M3FNUZ:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ"
case TensorElementDataTypeFloat8E5M2:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2"
case TensorElementDataTypeFloat8E5M2FNUZ:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ"
}
return fmt.Sprintf("Unknown tensor element data type: %d", int(t))
}
// This satisfies the ArbitraryTensor interface, but is intended to allow users
// to provide tensors of types that may not be supported by the generic
// typed Tensor[T] struct. Instead, CustomDataTensors are backed by a slice of
// bytes, using a user-provided shape and type from the
// ONNXTensorElementDataType enum.
type CustomDataTensor struct {
data []byte
dataType C.ONNXTensorElementDataType
shape Shape
ortValue *C.OrtValue
}
// Creates and returns a new CustomDataTensor using the given bytes as the
// underlying data slice. Apart from ensuring that the provided data slice is
// non-empty, this function mostly delegates validation of the provided data to
// the C onnxruntime library. For example, it is the caller's responsibility to
// ensure that the provided dataType and data slice are valid and correctly
// sized for the specified shape. If this returns successfully, the caller must
// call the returned tensor's Destroy() function to free it when no longer in
// use.
func NewCustomDataTensor(s Shape, data []byte,
dataType TensorElementDataType) (*CustomDataTensor, error) {
if !IsInitialized() {
return nil, ErrorNotInitialized
}
e := s.Validate()
if e != nil {
return nil, fmt.Errorf("invalid tensor shape: %w", e)
}
if len(data) == 0 {
return nil, errors.New("a CustomDataTensor requires at least one byte of data")
}
dt := C.ONNXTensorElementDataType(dataType)
var ortValue *C.OrtValue
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]),
C.size_t(len(data)), (*C.int64_t)(unsafe.Pointer(&s[0])),
C.int64_t(len(s)), ortMemoryInfo, dt, &ortValue)
if status != nil {
return nil, fmt.Errorf("ORT API error creating tensor: %s",
statusToError(status))
}
toReturn := CustomDataTensor{
data: data,
dataType: dt,
shape: s.Clone(),
ortValue: ortValue,
}
return &toReturn, nil
}
func (t *CustomDataTensor) Destroy() error {
C.ReleaseOrtValue(t.ortValue)
t.ortValue = nil
t.data = nil
t.shape = nil
t.dataType = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
return nil
}
func (t *CustomDataTensor) DataType() C.ONNXTensorElementDataType {
return t.dataType
}
func (t *CustomDataTensor) GetShape() Shape {
return t.shape.Clone()
}
func (t *CustomDataTensor) GetInternals() *TensorInternalData {
return &TensorInternalData{
OrtValue: t.ortValue,
}
}
// Sets all bytes in the data slice to 0.
func (t *CustomDataTensor) ZeroContents() {
C.memset(unsafe.Pointer(&t.data[0]), 0, C.size_t(len(t.data)))
}
// Returns the same slice that was passed to NewCustomDataTensor.
func (t *CustomDataTensor) GetData() []byte {
return t.data
}
1
https://gitee.com/wb253/onnxruntime.git
git@gitee.com:wb253/onnxruntime.git
wb253
onnxruntime
onnxruntime
main

搜索帮助

53164aa7 5694891 3bd8fe86 5694891