代码拉取完成,页面将自动刷新
package onnxruntime
import (
"errors"
"fmt"
"reflect"
"unsafe"
)
// #cgo CFLAGS: -O2 -g
//
// #include "onnxruntime_wrapper.h"
import "C"
type FloatData interface {
~float32 | ~float64
}
type IntData interface {
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64
}
// This is used as a type constraint for the generic Tensor type.
type TensorData interface {
FloatData | IntData
}
// Returns the ONNX enum value used to indicate TensorData type T.
func GetTensorElementDataType[T TensorData]() C.ONNXTensorElementDataType {
// Sadly, we can't do type assertions to get underlying types, so we need
// to use reflect here instead.
var v T
kind := reflect.ValueOf(v).Kind()
switch kind {
case reflect.Float64:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
case reflect.Float32:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
case reflect.Int8:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
case reflect.Uint8:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
case reflect.Int16:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
case reflect.Uint16:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
case reflect.Int32:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
case reflect.Uint32:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
case reflect.Int64:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
case reflect.Uint64:
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
}
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
}
// Wraps the ONNXTEnsorElementDataType enum in C.
type TensorElementDataType int
const (
TensorElementDataTypeUndefined = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
TensorElementDataTypeFloat = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
TensorElementDataTypeUint8 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
TensorElementDataTypeInt8 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
TensorElementDataTypeUint16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
TensorElementDataTypeInt16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
TensorElementDataTypeInt32 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
TensorElementDataTypeInt64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
TensorElementDataTypeString = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
TensorElementDataTypeBool = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
TensorElementDataTypeFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
TensorElementDataTypeDouble = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
TensorElementDataTypeUint32 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
TensorElementDataTypeUint64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
// Not supported by onnxruntime (as of onnxruntime version 1.16.1)
TensorElementDataTypeComplex64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
// Not supported by onnxruntime (as of onnxruntime version 1.16.1)
TensorElementDataTypeComplex128 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
// Non-IEEE floating-point format based on IEEE754 single-precision
TensorElementDataTypeBFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
// 8-bit float types, introduced in onnx 1.14. See
// https://onnx.ai/onnx/technical/float8.html
TensorElementDataTypeFloat8E4M3FN = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN
TensorElementDataTypeFloat8E4M3FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ
TensorElementDataTypeFloat8E5M2 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2
TensorElementDataTypeFloat8E5M2FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
)
func (t TensorElementDataType) String() string {
switch t {
case TensorElementDataTypeUndefined:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED"
case TensorElementDataTypeFloat:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT"
case TensorElementDataTypeUint8:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"
case TensorElementDataTypeInt8:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"
case TensorElementDataTypeUint16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"
case TensorElementDataTypeInt16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"
case TensorElementDataTypeInt32:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"
case TensorElementDataTypeInt64:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"
case TensorElementDataTypeString:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"
case TensorElementDataTypeBool:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL"
case TensorElementDataTypeFloat16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16"
case TensorElementDataTypeDouble:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"
case TensorElementDataTypeUint32:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"
case TensorElementDataTypeUint64:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"
case TensorElementDataTypeComplex64:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64"
case TensorElementDataTypeComplex128:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128"
case TensorElementDataTypeBFloat16:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16"
case TensorElementDataTypeFloat8E4M3FN:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN"
case TensorElementDataTypeFloat8E4M3FNUZ:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ"
case TensorElementDataTypeFloat8E5M2:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2"
case TensorElementDataTypeFloat8E5M2FNUZ:
return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ"
}
return fmt.Sprintf("Unknown tensor element data type: %d", int(t))
}
// This satisfies the ArbitraryTensor interface, but is intended to allow users
// to provide tensors of types that may not be supported by the generic
// typed Tensor[T] struct. Instead, CustomDataTensors are backed by a slice of
// bytes, using a user-provided shape and type from the
// ONNXTensorElementDataType enum.
type CustomDataTensor struct {
data []byte
dataType C.ONNXTensorElementDataType
shape Shape
ortValue *C.OrtValue
}
// Creates and returns a new CustomDataTensor using the given bytes as the
// underlying data slice. Apart from ensuring that the provided data slice is
// non-empty, this function mostly delegates validation of the provided data to
// the C onnxruntime library. For example, it is the caller's responsibility to
// ensure that the provided dataType and data slice are valid and correctly
// sized for the specified shape. If this returns successfully, the caller must
// call the returned tensor's Destroy() function to free it when no longer in
// use.
func NewCustomDataTensor(s Shape, data []byte,
dataType TensorElementDataType) (*CustomDataTensor, error) {
if !IsInitialized() {
return nil, ErrorNotInitialized
}
e := s.Validate()
if e != nil {
return nil, fmt.Errorf("invalid tensor shape: %w", e)
}
if len(data) == 0 {
return nil, errors.New("a CustomDataTensor requires at least one byte of data")
}
dt := C.ONNXTensorElementDataType(dataType)
var ortValue *C.OrtValue
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]),
C.size_t(len(data)), (*C.int64_t)(unsafe.Pointer(&s[0])),
C.int64_t(len(s)), ortMemoryInfo, dt, &ortValue)
if status != nil {
return nil, fmt.Errorf("ORT API error creating tensor: %s",
statusToError(status))
}
toReturn := CustomDataTensor{
data: data,
dataType: dt,
shape: s.Clone(),
ortValue: ortValue,
}
return &toReturn, nil
}
func (t *CustomDataTensor) Destroy() error {
C.ReleaseOrtValue(t.ortValue)
t.ortValue = nil
t.data = nil
t.shape = nil
t.dataType = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
return nil
}
func (t *CustomDataTensor) DataType() C.ONNXTensorElementDataType {
return t.dataType
}
func (t *CustomDataTensor) GetShape() Shape {
return t.shape.Clone()
}
func (t *CustomDataTensor) GetInternals() *TensorInternalData {
return &TensorInternalData{
OrtValue: t.ortValue,
}
}
// Sets all bytes in the data slice to 0.
func (t *CustomDataTensor) ZeroContents() {
C.memset(unsafe.Pointer(&t.data[0]), 0, C.size_t(len(t.data)))
}
// Returns the same slice that was passed to NewCustomDataTensor.
func (t *CustomDataTensor) GetData() []byte {
return t.data
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。