1 Star 0 Fork 0

eecjimmy / go-location

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
db.go 4.84 KB
一键复制 编辑 原始数据 按行查看 历史
eecjimmy 提交于 2021-07-22 14:08 . first commit
package ip
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"golang.org/x/text/encoding/simplifiedchinese"
"io"
"io/ioutil"
"log"
"net"
"os"
"strconv"
"strings"
)
const (
recordLength = uint32(7)
redirectMode1 = uint8(0x1)
redirectMode2 = uint8(0x2)
)
type database struct {
fp *os.File
beginIpOffset uint32
lastIpOffset uint32
totalIpNum uint32
}
func getDbFile() string {
p := os.TempDir() + string(os.PathSeparator) + "ip.db"
if _, e := os.Stat(p); e != nil {
b, _ := Asset("data/qqwry.dat")
e := ioutil.WriteFile(p, b, 0644)
if e != nil {
log.Fatal(e)
}
}
return p
}
// getDB 获取数据库, 单例模式
func getDB() (*database, error) {
db := &database{}
if e := db.initialize(); e != nil {
return nil, e
} else {
return db, nil
}
}
// initialize 初始化数据库
func (db *database) initialize() (e error) {
db.fp, e = os.Open(getDbFile())
if e != nil {
return errors.New("Open database file failed: " + e.Error())
}
db.beginIpOffset = db.readBytesAsUint32(4) // begin ip pos, 4bytes
db.lastIpOffset = db.readBytesAsUint32(4) // end ip pos, 4bytes
db.totalIpNum = (db.lastIpOffset - db.beginIpOffset) / recordLength
return nil
}
// query 查询IP地址
func (db *database) query(ipStr string) (Location, error) {
var e error
ip := ip2long(ipStr)
// 二分查找IP地址
var findIpPos, l, u uint32
u = db.totalIpNum
for l <= u {
i := (l + u) / 2
db.forward(db.beginIpOffset + i*recordLength)
if ip < db.readBytesAsUint32(4) {
u = i - 1
} else {
db.forward(db.readBytesAsUint32(3))
if ip > db.readBytesAsUint32(4) {
l = i + 1
} else {
findIpPos = db.beginIpOffset + i*recordLength
break
}
}
}
db.forward(findIpPos)
beginIP := long2ip(db.readBytesAsUint32(4)) // 开始IP区间
offset := db.readBytesAsUint32(3)
db.forward(offset)
endIP := long2ip(db.readBytesAsUint32(4)) // 结束IP区间
// 获取国家和区域
var country, area []byte
b := db.readByte()
switch b {
case redirectMode1:
countryOffset := db.readBytesAsUint32(3)
db.forward(countryOffset)
b2 := db.readByte()
switch b2 {
case redirectMode2:
db.forward(db.readBytesAsUint32(3))
country = db.getString()
db.forward(countryOffset + 4)
area = db.getArea()
default:
country = append([]byte{b2}, db.getString()...)
area = db.getArea()
}
case redirectMode2:
db.forward(db.readBytesAsUint32(3))
country = db.getString()
db.forward(offset + 8)
area = db.getArea()
default:
country = append([]byte{b}, db.getString()...)
area = db.getArea()
}
location := Location{
IP: ipStr,
BeginIP: beginIP,
EndIP: endIP,
Country: toUTF8(country),
Area: toUTF8(area),
}
return location, e
}
// getString 获取字符串直到碰到\0
func (db *database) getString() []byte {
var b []byte
c := db.readByte()
for c > 0 {
b = append(b, c)
c = db.readByte()
}
return b
}
// getArea 获取区域
func (db *database) getArea() []byte {
f := db.readByte()
switch f {
case 0:
return []byte{}
case 1:
fallthrough
case 2:
db.forward(db.readBytesAsUint32(3))
return db.getString()
default:
bs := db.getString()
return append([]byte{f}, bs...)
}
}
// readByte 读取一个字节
func (db *database) readByte() byte {
b := make([]byte, 1)
_, e := db.fp.Read(b)
if e != nil {
log.Println("raed byte failed:", e)
}
return b[0]
}
// readBytesAsUint32 读取指定字节的数据并转换为uint32
func (db *database) readBytesAsUint32(n int) uint32 {
b := make([]byte, n)
_, e := db.fp.Read(b)
if e != nil {
log.Println("raed bytes failed:", e)
}
if len(b) == 3 {
b = append(b, 0)
}
var t uint32
buf := bytes.NewBuffer(b)
e = binary.Read(buf, binary.LittleEndian, &t)
if e != nil {
log.Println("convert into uint32 failed:", e)
}
return t
}
// forward 前进到指定offset位置
func (db *database) forward(offset uint32) {
_, e := db.fp.Seek(int64(offset), io.SeekStart)
if e != nil {
log.Println("forward failed")
}
}
// ip2long IP转换为uint32
func ip2long(ip string) uint32 {
var ip1, ip2, ip3, ip4 uint32
to32 := func(s string) uint32 {
i, _ := strconv.Atoi(s)
return uint32(i)
}
arr := strings.Split(ip, ".")
ip1 = to32(arr[0])
ip2 = to32(arr[1])
ip3 = to32(arr[2])
ip4 = to32(arr[3])
return (ip1 << 24) | (ip2 << 16) | (ip3 << 8) | ip4
}
// long2ip uint32转换为ip地址
func long2ip(ip uint32) string {
return fmt.Sprintf("%d.%d.%d.%d", ip>>24, ip<<8>>24, ip<<16>>24, ip<<24>>24)
}
// toUTF8 转换为utf8格式
func toUTF8(byte []byte) string {
var decodeBytes, _ = simplifiedchinese.GBK.NewDecoder().Bytes(byte)
return string(decodeBytes)
}
// FindByStr 根据IP地址查询
func FindByStr(ip string) (loc Location, e error) {
d, e := getDB()
if e != nil {
return loc, e
}
return d.query(ip)
}
// Find 根据net.IP查询
func Find(ip net.IP) (loc Location, e error) {
return FindByStr(ip.String())
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/eecjimmy/go-location.git
git@gitee.com:eecjimmy/go-location.git
eecjimmy
go-location
go-location
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891