泛型
Go 1.18 引入了泛型,允许编写适用于多种类型的代码。本章将详细介绍 Go 的泛型特性。
泛型基础
泛型函数
go
package main
import "fmt"
// 泛型函数
func Print[T any](value T) {
fmt.Println(value)
}
func main() {
Print(42)
Print("hello")
Print(3.14)
}
类型参数
go
// 比较两个值
func Equal[T comparable](a, b T) bool {
return a == b
}
func main() {
fmt.Println(Equal(1, 1)) // true
fmt.Println(Equal("a", "a")) // true
fmt.Println(Equal(1, 2)) // false
}
多个类型参数
go
func Pair[T, U any](first T, second U) (T, U) {
return first, second
}
func main() {
a, b := Pair(1, "hello")
fmt.Println(a, b)
}
类型约束
基本约束
go
// any 约束(任意类型)
func PrintAny[T any](value T) {
fmt.Println(value)
}
// comparable 约束(可比较类型)
func Max[T comparable](a, b T) T {
if a > b {
return a
}
return b
}
自定义约束
go
type Number interface {
int | int8 | int16 | int32 | int64 |
uint | uint8 | uint16 | uint32 | uint64 |
float32 | float64
}
func Sum[T Number](a, b T) T {
return a + b
}
func main() {
fmt.Println(Sum(1, 2)) // 3
fmt.Println(Sum(1.5, 2.5)) // 4.0
}
接口约束
go
type Stringer interface {
String() string
}
func PrintStringer[T Stringer](value T) {
fmt.Println(value.String())
}
type Person struct {
Name string
}
func (p Person) String() string {
return p.Name
}
func main() {
p := Person{Name: "张三"}
PrintStringer(p)
}
泛型类型
泛型结构体
go
type Stack[T any] struct {
items []T
}
func NewStack[T any]() *Stack[T] {
return &Stack[T]{
items: make([]T, 0),
}
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
index := len(s.items) - 1
item := s.items[index]
s.items = s.items[:index]
return item, true
}
func main() {
stack := NewStack[int]()
stack.Push(1)
stack.Push(2)
stack.Push(3)
for {
item, ok := stack.Pop()
if !ok {
break
}
fmt.Println(item)
}
}
泛型接口
go
type Container[T any] interface {
Add(item T)
Remove() (T, bool)
Size() int
}
type List[T any] struct {
items []T
}
func (l *List[T]) Add(item T) {
l.items = append(l.items, item)
}
func (l *List[T]) Remove() (T, bool) {
if len(l.items) == 0 {
var zero T
return zero, false
}
item := l.items[0]
l.items = l.items[1:]
return item, true
}
func (l *List[T]) Size() int {
return len(l.items)
}
func main() {
var c Container[int] = &List[int]{}
c.Add(1)
c.Add(2)
fmt.Println(c.Size())
}
泛型方法
go
type Pair[T any] struct {
First T
Second T
}
func (p *Pair[T]) Swap() {
p.First, p.Second = p.Second, p.First
}
func main() {
p := Pair[int]{First: 1, Second: 2}
fmt.Println(p)
p.Swap()
fmt.Println(p)
}
类型推断
自动推断
go
func Print[T any](value T) {
fmt.Println(value)
}
func main() {
// 自动推断类型
Print(42)
Print("hello")
}
显式指定类型
go
func main() {
// 显式指定类型
Print[int](42)
Print[string]("hello")
}
泛型应用
通用排序
go
func Sort[T comparable](slice []T, less func(a, b T) bool) {
for i := 0; i < len(slice); i++ {
for j := i + 1; j < len(slice); j++ {
if less(slice[j], slice[i]) {
slice[i], slice[j] = slice[j], slice[i]
}
}
}
}
func main() {
nums := []int{3, 1, 4, 1, 5}
Sort(nums, func(a, b int) bool {
return a < b
})
fmt.Println(nums)
names := []string{"banana", "apple", "cherry"}
Sort(names, func(a, b string) bool {
return a < b
})
fmt.Println(names)
}
通用映射
go
func Map[T, U any](slice []T, fn func(T) U) []U {
result := make([]U, len(slice))
for i, item := range slice {
result[i] = fn(item)
}
return result
}
func main() {
nums := []int{1, 2, 3, 4, 5}
// 平方
squares := Map(nums, func(n int) int {
return n * n
})
fmt.Println(squares)
// 转字符串
strings := Map(nums, func(n int) string {
return fmt.Sprintf("%d", n)
})
fmt.Println(strings)
}
通用过滤
go
func Filter[T any](slice []T, predicate func(T) bool) []T {
result := []T{}
for _, item := range slice {
if predicate(item) {
result = append(result, item)
}
}
return result
}
func main() {
nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
// 偶数
evens := Filter(nums, func(n int) bool {
return n%2 == 0
})
fmt.Println(evens)
// 大于 5
greater := Filter(nums, func(n int) bool {
return n > 5
})
fmt.Println(greater)
}
通用归约
go
func Reduce[T, U any](slice []T, initial U, fn func(U, T) U) U {
result := initial
for _, item := range slice {
result = fn(result, item)
}
return result
}
func main() {
nums := []int{1, 2, 3, 4, 5}
// 求和
sum := Reduce(nums, 0, func(acc, n int) int {
return acc + n
})
fmt.Println("求和:", sum)
// 求积
product := Reduce(nums, 1, func(acc, n int) int {
return acc * n
})
fmt.Println("求积:", product)
}
泛型与接口
类型参数 vs 接口
go
// 使用泛型
func PrintGeneric[T any](value T) {
fmt.Println(value)
}
// 使用接口
func PrintInterface(value interface{}) {
fmt.Println(value)
}
func main() {
PrintGeneric(42)
PrintInterface(42)
}
何时使用泛型
go
// 使用泛型:类型安全
func Max[T comparable](a, b T) T {
if a > b {
return a
}
return b
}
// 使用接口:灵活性
func Print(value interface{}) {
fmt.Println(value)
}
泛型最佳实践
1. 保持简单
go
// 推荐
func Max[T comparable](a, b T) T {
if a > b {
return a
}
return b
}
// 不推荐:过度复杂
func ComplexFunc[T any, U comparable, V interface{ String() string }](a T, b U, c V) (T, U, V) {
return a, b, c
}
2. 合理命名
go
// 推荐
func Map[T, U any](slice []T, fn func(T) U) []U {}
// 不推荐
func Map[A, B any](slice []A, fn func(A) B) []B {}
3. 使用约束
go
// 推荐
func Sum[T Number](a, b T) T {
return a + b
}
// 不推荐:过于宽松
func Sum[T any](a, b T) T {
return a + b // 编译错误
}
练习
练习 1:泛型栈
实现一个泛型栈,支持任意类型。
答案
go
package main
import "fmt"
type Stack[T any] struct {
items []T
}
func NewStack[T any]() *Stack[T] {
return &Stack[T]{items: []T{}}
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
index := len(s.items) - 1
item := s.items[index]
s.items = s.items[:index]
return item, true
}
func (s *Stack[T]) Peek() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
return s.items[len(s.items)-1], true
}
func (s *Stack[T]) Size() int {
return len(s.items)
}
func main() {
stack := NewStack[int]()
stack.Push(1)
stack.Push(2)
stack.Push(3)
for stack.Size() > 0 {
if item, ok := stack.Pop(); ok {
fmt.Println(item)
}
}
}
总结
- 泛型允许编写类型安全的通用代码
- 使用类型参数和类型约束
- 支持泛型函数、类型、方法
- 类型推断简化使用
- 合理使用,避免过度复杂
- 保持简单和清晰
下一章:测试