178 lines
4.0 KiB
Go
178 lines
4.0 KiB
Go
package interrupt_handler
|
|
|
|
import (
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/onsi/ginkgo/v2/internal/parallel_support"
|
|
)
|
|
|
|
var ABORT_POLLING_INTERVAL = 500 * time.Millisecond
|
|
|
|
type InterruptCause uint
|
|
|
|
const (
|
|
InterruptCauseInvalid InterruptCause = iota
|
|
InterruptCauseSignal
|
|
InterruptCauseAbortByOtherProcess
|
|
)
|
|
|
|
type InterruptLevel uint
|
|
|
|
const (
|
|
InterruptLevelUninterrupted InterruptLevel = iota
|
|
InterruptLevelCleanupAndReport
|
|
InterruptLevelReportOnly
|
|
InterruptLevelBailOut
|
|
)
|
|
|
|
func (ic InterruptCause) String() string {
|
|
switch ic {
|
|
case InterruptCauseSignal:
|
|
return "Interrupted by User"
|
|
case InterruptCauseAbortByOtherProcess:
|
|
return "Interrupted by Other Ginkgo Process"
|
|
}
|
|
return "INVALID_INTERRUPT_CAUSE"
|
|
}
|
|
|
|
type InterruptStatus struct {
|
|
Channel chan interface{}
|
|
Level InterruptLevel
|
|
Cause InterruptCause
|
|
}
|
|
|
|
func (s InterruptStatus) Interrupted() bool {
|
|
return s.Level != InterruptLevelUninterrupted
|
|
}
|
|
|
|
func (s InterruptStatus) Message() string {
|
|
return s.Cause.String()
|
|
}
|
|
|
|
func (s InterruptStatus) ShouldIncludeProgressReport() bool {
|
|
return s.Cause != InterruptCauseAbortByOtherProcess
|
|
}
|
|
|
|
type InterruptHandlerInterface interface {
|
|
Status() InterruptStatus
|
|
}
|
|
|
|
type InterruptHandler struct {
|
|
c chan interface{}
|
|
lock *sync.Mutex
|
|
level InterruptLevel
|
|
cause InterruptCause
|
|
client parallel_support.Client
|
|
stop chan interface{}
|
|
signals []os.Signal
|
|
requestAbortCheck chan interface{}
|
|
}
|
|
|
|
func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler {
|
|
if len(signals) == 0 {
|
|
signals = []os.Signal{os.Interrupt, syscall.SIGTERM}
|
|
}
|
|
handler := &InterruptHandler{
|
|
c: make(chan interface{}),
|
|
lock: &sync.Mutex{},
|
|
stop: make(chan interface{}),
|
|
requestAbortCheck: make(chan interface{}),
|
|
client: client,
|
|
signals: signals,
|
|
}
|
|
handler.registerForInterrupts()
|
|
return handler
|
|
}
|
|
|
|
func (handler *InterruptHandler) Stop() {
|
|
close(handler.stop)
|
|
}
|
|
|
|
func (handler *InterruptHandler) registerForInterrupts() {
|
|
// os signal handling
|
|
signalChannel := make(chan os.Signal, 1)
|
|
signal.Notify(signalChannel, handler.signals...)
|
|
|
|
// cross-process abort handling
|
|
var abortChannel chan interface{}
|
|
if handler.client != nil {
|
|
abortChannel = make(chan interface{})
|
|
go func() {
|
|
pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL)
|
|
for {
|
|
select {
|
|
case <-pollTicker.C:
|
|
if handler.client.ShouldAbort() {
|
|
close(abortChannel)
|
|
pollTicker.Stop()
|
|
return
|
|
}
|
|
case <-handler.requestAbortCheck:
|
|
if handler.client.ShouldAbort() {
|
|
close(abortChannel)
|
|
pollTicker.Stop()
|
|
return
|
|
}
|
|
case <-handler.stop:
|
|
pollTicker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
go func(abortChannel chan interface{}) {
|
|
var interruptCause InterruptCause
|
|
for {
|
|
select {
|
|
case <-signalChannel:
|
|
interruptCause = InterruptCauseSignal
|
|
case <-abortChannel:
|
|
interruptCause = InterruptCauseAbortByOtherProcess
|
|
case <-handler.stop:
|
|
signal.Stop(signalChannel)
|
|
return
|
|
}
|
|
abortChannel = nil
|
|
|
|
handler.lock.Lock()
|
|
oldLevel := handler.level
|
|
handler.cause = interruptCause
|
|
if handler.level == InterruptLevelUninterrupted {
|
|
handler.level = InterruptLevelCleanupAndReport
|
|
} else if handler.level == InterruptLevelCleanupAndReport {
|
|
handler.level = InterruptLevelReportOnly
|
|
} else if handler.level == InterruptLevelReportOnly {
|
|
handler.level = InterruptLevelBailOut
|
|
}
|
|
if handler.level != oldLevel {
|
|
close(handler.c)
|
|
handler.c = make(chan interface{})
|
|
}
|
|
handler.lock.Unlock()
|
|
}
|
|
}(abortChannel)
|
|
}
|
|
|
|
func (handler *InterruptHandler) Status() InterruptStatus {
|
|
handler.lock.Lock()
|
|
status := InterruptStatus{
|
|
Level: handler.level,
|
|
Channel: handler.c,
|
|
Cause: handler.cause,
|
|
}
|
|
handler.lock.Unlock()
|
|
|
|
if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() {
|
|
close(handler.requestAbortCheck)
|
|
<-status.Channel
|
|
return handler.Status()
|
|
}
|
|
|
|
return status
|
|
}
|