Skip to content

Instantly share code, notes, and snippets.

@zachmu
Last active January 14, 2026 11:43
Show Gist options
  • Select an option

  • Save zachmu/99e1d4c7701265004cbb898ff395de1f to your computer and use it in GitHub Desktop.

Select an option

Save zachmu/99e1d4c7701265004cbb898ff395de1f to your computer and use it in GitHub Desktop.
package main
import (
"log"
"sort"
)
// Interface type used for vars
type Sortable[T comparable] interface {
Less(member T) bool
}
// Type set used only for constraints, not vars
type SortableConstraint[T comparable] interface {
comparable
Sortable[T]
}
type Name struct {
First string
Last string
}
var _ Sortable[Name] = Name{}
func (n Name) Less(member Name) bool {
return n.First < member.First || n.First == member.First && n.Last < member.Last
}
type SortableSet[T SortableConstraint[T]] interface {
Add(member T)
Size() int
Contains(member T) bool
Sorted() []T
}
type MapSet[T SortableConstraint[T]] struct {
members map[T]struct{}
}
func NewMapSet[T SortableConstraint[T]]() SortableSet[T] {
return MapSet[T]{
members: make(map[T]struct{}),
}
}
func (s MapSet[T]) Add(member T) {
s.members[member] = struct{}{}
}
func (s MapSet[T]) Size() int {
return len(s.members)
}
func (s MapSet[T]) Contains(member T) bool {
_, found := s.members[member]
return found
}
func (s MapSet[T]) Sorted() []T {
sorted := make([]T, 0, len(s.members))
for member := range s.members {
sorted = append(sorted, member)
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Less(sorted[j])
})
return sorted
}
type SliceSet[T SortableConstraint[T]] struct {
members []T
}
func NewSliceSet[T interface{
Sortable[T]
comparable}]() SortableSet[T] {
return &SliceSet[T]{
members: make([]T, 0),
}
}
func (s *SliceSet[T]) Add(member T) {
if !s.Contains(member) {
s.members = append(s.members, member)
}
}
func (s SliceSet[T]) Size() int {
return len(s.members)
}
func (s SliceSet[T]) Contains(member T) bool {
for _, m := range s.members {
if m == member {
return true
}
}
return false
}
func (s SliceSet[T]) Sorted() []T {
sorted := make([]T, len(s.members))
copy(sorted, s.members)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Less(sorted[j])
})
return sorted
}
func main() {
name1 := Name{"John", "Doe"}
name2 := Name{"Jane", "Doe"}
name3 := Name{"Frank", "Reynolds"}
ss := NewSliceSet[Name]()
ms := NewMapSet[Name]()
sets := []SortableSet[Name]{ss, ms}
for _, s := range sets {
s.Add(name1)
s.Add(name2)
s.Add(name2)
if s.Size() != 2 {
log.Fatal("set size is not 2")
}
if !s.Contains(name1) {
log.Fatal("set does not contain name1")
}
if !s.Contains(name2) {
log.Fatal("set does not contain name2")
}
if s.Contains(name3) {
log.Fatal("set contains name3")
}
sorted := s.Sorted()
expectedSorted := []Name{name2, name1}
if len(sorted) != len(expectedSorted) {
log.Fatal("sorted length does not match")
}
for i := range sorted {
if sorted[i] != expectedSorted[i] {
log.Fatal("sorted does not match")
}
}
log.Printf("%T passed", s)
}
}
@iamdlfl
Copy link

iamdlfl commented Jul 10, 2024

This is a very handy resource. Thanks for figuring this out! I think there's a typo on line 129, should that be s.Size() != 3?

@zachmu
Copy link
Author

zachmu commented Jul 10, 2024

I think there's a typo on line 129, should that be s.Size() != 3?

Nope, it's a Set, which means the duplicate element (name2) is stored only once.

@iamdlfl
Copy link

iamdlfl commented Jul 10, 2024

Ah, I see. name2 is added twice, and name3 is not added. Would have noticed that if I read line 140 (or 127 more carefully), whoops!

@jub0bs
Copy link

jub0bs commented Jan 14, 2026

@zachmu I'm currently working on a course about generics and iterators (whose material will be open source and available on GitHub) and reviewing posts on the topic; that's how I stumbled upon your post on DoltHub's blog.

I don't know whether this is still relevant to DoltHub, but you'd likely get better performance by eschewing the sort package and use the slices package instead (playground):

package main

import (
	"cmp"
	"log"
	"slices"
)

// Interface type used for vars
type Sortable[T comparable] interface {
	Compare(member T) int
}

// Type set used only for constraints, not vars
type SortableConstraint[T comparable] interface {
	comparable
	Sortable[T]
}

type Name struct {
	First string
	Last  string
}

var _ Sortable[Name] = Name{}

func (n Name) Compare(member Name) int {
	return cmp.Or(
		cmp.Compare(n.First, member.First),
		cmp.Compare(n.Last, member.Last),
	)
}

type SortableSet[T SortableConstraint[T]] interface {
	Add(member T)
	Size() int
	Contains(member T) bool
	Sorted() []T
}

type MapSet[T SortableConstraint[T]] struct {
	members map[T]struct{}
}

func NewMapSet[T SortableConstraint[T]]() SortableSet[T] {
	return MapSet[T]{
		members: make(map[T]struct{}),
	}
}

func (s MapSet[T]) Add(member T) {
	s.members[member] = struct{}{}
}

func (s MapSet[_]) Size() int {
	return len(s.members)
}

func (s MapSet[T]) Contains(member T) bool {
	_, found := s.members[member]
	return found
}

func (s MapSet[T]) Sorted() []T {
	sorted := make([]T, 0, len(s.members))
	for member := range s.members {
		sorted = append(sorted, member)
	}

	slices.SortFunc(sorted, T.Compare)

	return sorted
}

type SliceSet[T SortableConstraint[T]] struct {
	members []T
}

func NewSliceSet[T interface {
	Sortable[T]
	comparable
}]() SortableSet[T] {
	return &SliceSet[T]{
		members: make([]T, 0),
	}
}

func (s *SliceSet[T]) Add(member T) {
	if !s.Contains(member) {
		s.members = append(s.members, member)
	}
}

func (s SliceSet[_]) Size() int {
	return len(s.members)
}

func (s SliceSet[T]) Contains(member T) bool {
	for _, m := range s.members {
		if m == member {
			return true
		}
	}
	return false
}

func (s SliceSet[T]) Sorted() []T {
	sorted := make([]T, len(s.members))
	copy(sorted, s.members)

	slices.SortFunc(sorted, T.Compare)

	return sorted
}

func main() {
	name1 := Name{"John", "Doe"}
	name2 := Name{"Jane", "Doe"}
	name3 := Name{"Frank", "Reynolds"}

	ss := NewSliceSet[Name]()
	ms := NewMapSet[Name]()

	sets := []SortableSet[Name]{ss, ms}
	for _, s := range sets {
		s.Add(name1)
		s.Add(name2)
		s.Add(name2)

		if s.Size() != 2 {
			log.Fatal("set size is not 2")
		}

		if !s.Contains(name1) {
			log.Fatal("set does not contain name1")
		}
		if !s.Contains(name2) {
			log.Fatal("set does not contain name2")
		}
		if s.Contains(name3) {
			log.Fatal("set contains name3")
		}

		sorted := s.Sorted()
		expectedSorted := []Name{name2, name1}

		if len(sorted) != len(expectedSorted) {
			log.Fatal("sorted length does not match")
		}

		for i := range sorted {
			if sorted[i] != expectedSorted[i] {
				log.Fatal("sorted does not match")
			}
		}

		log.Printf("%T passed", s)
	}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment