Skip to content
On this page

泛型

Go 1.18 引入了泛型,允许编写适用于多种类型的代码。本章将详细介绍 Go 的泛型特性。

泛型基础

泛型函数

go
package main

import "fmt"

// 泛型函数
func Print[T any](value T) {
    fmt.Println(value)
}

func main() {
    Print(42)
    Print("hello")
    Print(3.14)
}

类型参数

go
// 比较两个值
func Equal[T comparable](a, b T) bool {
    return a == b
}

func main() {
    fmt.Println(Equal(1, 1))      // true
    fmt.Println(Equal("a", "a"))   // true
    fmt.Println(Equal(1, 2))      // false
}

多个类型参数

go
func Pair[T, U any](first T, second U) (T, U) {
    return first, second
}

func main() {
    a, b := Pair(1, "hello")
    fmt.Println(a, b)
}

类型约束

基本约束

go
// any 约束(任意类型)
func PrintAny[T any](value T) {
    fmt.Println(value)
}

// comparable 约束(可比较类型)
func Max[T comparable](a, b T) T {
    if a > b {
        return a
    }
    return b
}

自定义约束

go
type Number interface {
    int | int8 | int16 | int32 | int64 |
    uint | uint8 | uint16 | uint32 | uint64 |
    float32 | float64
}

func Sum[T Number](a, b T) T {
    return a + b
}

func main() {
    fmt.Println(Sum(1, 2))      // 3
    fmt.Println(Sum(1.5, 2.5))  // 4.0
}

接口约束

go
type Stringer interface {
    String() string
}

func PrintStringer[T Stringer](value T) {
    fmt.Println(value.String())
}

type Person struct {
    Name string
}

func (p Person) String() string {
    return p.Name
}

func main() {
    p := Person{Name: "张三"}
    PrintStringer(p)
}

泛型类型

泛型结构体

go
type Stack[T any] struct {
    items []T
}

func NewStack[T any]() *Stack[T] {
    return &Stack[T]{
        items: make([]T, 0),
    }
}

func (s *Stack[T]) Push(item T) {
    s.items = append(s.items, item)
}

func (s *Stack[T]) Pop() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    
    index := len(s.items) - 1
    item := s.items[index]
    s.items = s.items[:index]
    return item, true
}

func main() {
    stack := NewStack[int]()
    
    stack.Push(1)
    stack.Push(2)
    stack.Push(3)
    
    for {
        item, ok := stack.Pop()
        if !ok {
            break
        }
        fmt.Println(item)
    }
}

泛型接口

go
type Container[T any] interface {
    Add(item T)
    Remove() (T, bool)
    Size() int
}

type List[T any] struct {
    items []T
}

func (l *List[T]) Add(item T) {
    l.items = append(l.items, item)
}

func (l *List[T]) Remove() (T, bool) {
    if len(l.items) == 0 {
        var zero T
        return zero, false
    }
    item := l.items[0]
    l.items = l.items[1:]
    return item, true
}

func (l *List[T]) Size() int {
    return len(l.items)
}

func main() {
    var c Container[int] = &List[int]{}
    c.Add(1)
    c.Add(2)
    fmt.Println(c.Size())
}

泛型方法

go
type Pair[T any] struct {
    First  T
    Second T
}

func (p *Pair[T]) Swap() {
    p.First, p.Second = p.Second, p.First
}

func main() {
    p := Pair[int]{First: 1, Second: 2}
    fmt.Println(p)
    p.Swap()
    fmt.Println(p)
}

类型推断

自动推断

go
func Print[T any](value T) {
    fmt.Println(value)
}

func main() {
    // 自动推断类型
    Print(42)
    Print("hello")
}

显式指定类型

go
func main() {
    // 显式指定类型
    Print[int](42)
    Print[string]("hello")
}

泛型应用

通用排序

go
func Sort[T comparable](slice []T, less func(a, b T) bool) {
    for i := 0; i < len(slice); i++ {
        for j := i + 1; j < len(slice); j++ {
            if less(slice[j], slice[i]) {
                slice[i], slice[j] = slice[j], slice[i]
            }
        }
    }
}

func main() {
    nums := []int{3, 1, 4, 1, 5}
    Sort(nums, func(a, b int) bool {
        return a < b
    })
    fmt.Println(nums)
    
    names := []string{"banana", "apple", "cherry"}
    Sort(names, func(a, b string) bool {
        return a < b
    })
    fmt.Println(names)
}

通用映射

go
func Map[T, U any](slice []T, fn func(T) U) []U {
    result := make([]U, len(slice))
    for i, item := range slice {
        result[i] = fn(item)
    }
    return result
}

func main() {
    nums := []int{1, 2, 3, 4, 5}
    
    // 平方
    squares := Map(nums, func(n int) int {
        return n * n
    })
    fmt.Println(squares)
    
    // 转字符串
    strings := Map(nums, func(n int) string {
        return fmt.Sprintf("%d", n)
    })
    fmt.Println(strings)
}

通用过滤

go
func Filter[T any](slice []T, predicate func(T) bool) []T {
    result := []T{}
    for _, item := range slice {
        if predicate(item) {
            result = append(result, item)
        }
    }
    return result
}

func main() {
    nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    
    // 偶数
    evens := Filter(nums, func(n int) bool {
        return n%2 == 0
    })
    fmt.Println(evens)
    
    // 大于 5
    greater := Filter(nums, func(n int) bool {
        return n > 5
    })
    fmt.Println(greater)
}

通用归约

go
func Reduce[T, U any](slice []T, initial U, fn func(U, T) U) U {
    result := initial
    for _, item := range slice {
        result = fn(result, item)
    }
    return result
}

func main() {
    nums := []int{1, 2, 3, 4, 5}
    
    // 求和
    sum := Reduce(nums, 0, func(acc, n int) int {
        return acc + n
    })
    fmt.Println("求和:", sum)
    
    // 求积
    product := Reduce(nums, 1, func(acc, n int) int {
        return acc * n
    })
    fmt.Println("求积:", product)
}

泛型与接口

类型参数 vs 接口

go
// 使用泛型
func PrintGeneric[T any](value T) {
    fmt.Println(value)
}

// 使用接口
func PrintInterface(value interface{}) {
    fmt.Println(value)
}

func main() {
    PrintGeneric(42)
    PrintInterface(42)
}

何时使用泛型

go
// 使用泛型:类型安全
func Max[T comparable](a, b T) T {
    if a > b {
        return a
    }
    return b
}

// 使用接口:灵活性
func Print(value interface{}) {
    fmt.Println(value)
}

泛型最佳实践

1. 保持简单

go
// 推荐
func Max[T comparable](a, b T) T {
    if a > b {
        return a
    }
    return b
}

// 不推荐:过度复杂
func ComplexFunc[T any, U comparable, V interface{ String() string }](a T, b U, c V) (T, U, V) {
    return a, b, c
}

2. 合理命名

go
// 推荐
func Map[T, U any](slice []T, fn func(T) U) []U {}

// 不推荐
func Map[A, B any](slice []A, fn func(A) B) []B {}

3. 使用约束

go
// 推荐
func Sum[T Number](a, b T) T {
    return a + b
}

// 不推荐:过于宽松
func Sum[T any](a, b T) T {
    return a + b  // 编译错误
}

练习

练习 1:泛型栈

实现一个泛型栈,支持任意类型。

答案
go
package main

import "fmt"

type Stack[T any] struct {
    items []T
}

func NewStack[T any]() *Stack[T] {
    return &Stack[T]{items: []T{}}
}

func (s *Stack[T]) Push(item T) {
    s.items = append(s.items, item)
}

func (s *Stack[T]) Pop() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    index := len(s.items) - 1
    item := s.items[index]
    s.items = s.items[:index]
    return item, true
}

func (s *Stack[T]) Peek() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    return s.items[len(s.items)-1], true
}

func (s *Stack[T]) Size() int {
    return len(s.items)
}

func main() {
    stack := NewStack[int]()
    stack.Push(1)
    stack.Push(2)
    stack.Push(3)
    
    for stack.Size() > 0 {
        if item, ok := stack.Pop(); ok {
            fmt.Println(item)
        }
    }
}

总结

  • 泛型允许编写类型安全的通用代码
  • 使用类型参数和类型约束
  • 支持泛型函数、类型、方法
  • 类型推断简化使用
  • 合理使用,避免过度复杂
  • 保持简单和清晰

下一章:测试

基于 MIT 许可发布