1 Star 0 Fork 1

wb253 / onnxruntime

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
shape.go 1.89 KB
一键复制 编辑 原始数据 按行查看 历史
wb253 提交于 2023-12-25 22:52 . fix bug
package onnxruntime
import "fmt"
// The Shape type holds the shape of the tensors used by the network input and
// outputs.
type Shape []int64
// Returns a Shape, with the given dimensions.
func NewShape(dimensions ...int64) Shape {
return Shape(dimensions)
}
// Returns the total number of elements in a tensor with the given shape. Note
// that this may be an invalid value due to overflow or negative dimensions. If
// a shape comes from an untrusted source, it may be a good practice to call
// Validate() prior to trusting the FlattenedSize.
func (s Shape) FlattenedSize() int64 {
if len(s) == 0 {
return 0
}
toReturn := int64(s[0])
for i := 1; i < len(s); i++ {
toReturn *= s[i]
}
return toReturn
}
// Returns a non-nil error if the shape has bad or zero dimensions. May return
// a ZeroShapeLengthError, a ShapeOverflowError, or a BadShapeDimensionError.
// In the future, this may return other types of errors if it others become
// necessary.
func (s Shape) Validate() error {
if len(s) == 0 {
return ErrorZeroShapeLength
}
if s[0] <= 0 {
return &BadShapeDimensionError{
DimensionIndex: 0,
DimensionSize: s[0],
}
}
flattenedSize := int64(s[0])
for i := 1; i < len(s); i++ {
d := s[i]
if d <= 0 {
return &BadShapeDimensionError{
DimensionIndex: i,
DimensionSize: d,
}
}
tmp := flattenedSize * d
if tmp < flattenedSize {
return ErrorShapeOverflow
}
flattenedSize = tmp
}
return nil
}
// Makes and returns a deep copy of the Shape.
func (s Shape) Clone() Shape {
toReturn := make([]int64, len(s))
copy(toReturn, []int64(s))
return Shape(toReturn)
}
func (s Shape) String() string {
return fmt.Sprintf("%v", []int64(s))
}
// Returns true if both shapes match in every dimension.
func (s Shape) Equals(other Shape) bool {
if len(s) != len(other) {
return false
}
for i := 0; i < len(s); i++ {
if s[i] != other[i] {
return false
}
}
return true
}
1
https://gitee.com/wb253/onnxruntime.git
git@gitee.com:wb253/onnxruntime.git
wb253
onnxruntime
onnxruntime
main

搜索帮助

53164aa7 5694891 3bd8fe86 5694891