diff --git a/internal/rtsp/rtsp.go b/internal/rtsp/rtsp.go index 31c2c5dbf..aa7ac1b9d 100644 --- a/internal/rtsp/rtsp.go +++ b/internal/rtsp/rtsp.go @@ -99,13 +99,7 @@ func rtspHandler(rawURL string) (core.Producer, error) { conn.Backchannel = true conn.UserAgent = app.UserAgent - if rawQuery != "" { - query := streams.ParseQuery(rawQuery) - conn.Backchannel = query.Get("backchannel") == "1" - conn.Media = query.Get("media") - conn.Timeout = core.Atoi(query.Get("timeout")) - conn.Transport = query.Get("transport") - } + applyClientQuery(conn, rawURL, rawQuery) if log.Trace().Enabled() { conn.Listen(func(msg any) { @@ -143,6 +137,27 @@ func rtspHandler(rawURL string) (core.Producer, error) { return conn, nil } +func applyClientQuery(conn *rtsp.Conn, rawURL, rawQuery string) { + query := url.Values{} + + if uri, err := url.Parse(rawURL); err == nil && uri != nil { + for key, values := range uri.Query() { + query[key] = append([]string(nil), values...) + } + } + + if extra := streams.ParseQuery(rawQuery); extra != nil { + for key, values := range extra { + query[key] = append([]string(nil), values...) + } + } + + conn.Backchannel = query.Get("backchannel") == "1" + conn.Media = query.Get("media") + conn.Timeout = core.Atoi(query.Get("timeout")) + conn.Transport = query.Get("transport") +} + func tcpHandler(conn *rtsp.Conn) { var name string var closer func() @@ -208,6 +223,8 @@ func tcpHandler(conn *rtsp.Conn) { conn.PacketSize = uint16(core.Atoi(s)) } + conn.Repack = defaultConsumerRepack(conn.Connection.RemoteAddr, query.Get("repack")) + // param name like ffmpeg style https://ffmpeg.org/ffmpeg-protocols.html if s := query.Get("log_level"); s != "" { if lvl, err := zerolog.ParseLevel(s); err == nil { @@ -287,6 +304,23 @@ func tcpHandler(conn *rtsp.Conn) { _ = conn.Close() } +func defaultConsumerRepack(remoteAddr, raw string) bool { + switch strings.ToLower(raw) { + case "": + return isLoopback(remoteAddr) + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return isLoopback(remoteAddr) + } +} + +func isLoopback(remoteAddr string) bool { + return strings.HasPrefix(remoteAddr, "127.") || strings.HasPrefix(remoteAddr, "[::1]") || strings.HasPrefix(remoteAddr, "localhost:") +} + func ParseQuery(query map[string][]string) []*core.Media { if v := query["mp4"]; v != nil { return []*core.Media{ diff --git a/internal/rtsp/rtsp_test.go b/internal/rtsp/rtsp_test.go new file mode 100644 index 000000000..6b8dd4174 --- /dev/null +++ b/internal/rtsp/rtsp_test.go @@ -0,0 +1,55 @@ +package rtsp + +import ( + "testing" + + pkgrtsp "github.com/AlexxIT/go2rtc/pkg/rtsp" + "github.com/stretchr/testify/require" +) + +func TestApplyClientQueryUsesURLQuery(t *testing.T) { + conn := &pkgrtsp.Conn{} + + applyClientQuery(conn, "rtsp://127.0.0.1:8554/test?timeout=20&transport=tcp&media=video", "") + + require.Equal(t, 20, conn.Timeout) + require.Equal(t, "tcp", conn.Transport) + require.Equal(t, "video", conn.Media) +} + +func TestApplyClientQueryRawQueryOverridesURLQuery(t *testing.T) { + conn := &pkgrtsp.Conn{} + + applyClientQuery(conn, "rtsp://127.0.0.1:8554/test?timeout=20&transport=tcp", "timeout=45#transport=udp#backchannel=1") + + require.Equal(t, 45, conn.Timeout) + require.Equal(t, "udp", conn.Transport) + require.True(t, conn.Backchannel) +} + +func TestApplyClientQueryAllowsEmptyURL(t *testing.T) { + conn := &pkgrtsp.Conn{} + + require.NotPanics(t, func() { + applyClientQuery(conn, "", "") + }) + + require.False(t, conn.Backchannel) + require.Zero(t, conn.Timeout) + require.Empty(t, conn.Transport) + require.Empty(t, conn.Media) +} + +func TestDefaultConsumerRepackLoopbackDefaultsOn(t *testing.T) { + require.True(t, defaultConsumerRepack("127.0.0.1:8554", "")) + require.True(t, defaultConsumerRepack("[::1]:8554", "")) +} + +func TestDefaultConsumerRepackRemoteDefaultsOff(t *testing.T) { + require.False(t, defaultConsumerRepack("192.168.2.3:46980", "")) +} + +func TestDefaultConsumerRepackAllowsOverride(t *testing.T) { + require.False(t, defaultConsumerRepack("127.0.0.1:8554", "off")) + require.True(t, defaultConsumerRepack("192.168.2.3:46980", "on")) +} diff --git a/internal/streams/stream.go b/internal/streams/stream.go index 984c73edd..99426a411 100644 --- a/internal/streams/stream.go +++ b/internal/streams/stream.go @@ -118,6 +118,21 @@ producers: s.mu.Unlock() } +func (s *Stream) stopAll() { + s.mu.Lock() + consumers := append([]core.Consumer(nil), s.consumers...) + producers := append([]*Producer(nil), s.producers...) + s.consumers = nil + s.mu.Unlock() + + for _, consumer := range consumers { + _ = consumer.Stop() + } + for _, producer := range producers { + producer.stop() + } +} + func (s *Stream) MarshalJSON() ([]byte, error) { var info = struct { Producers []*Producer `json:"producers"` diff --git a/internal/streams/stream_test.go b/internal/streams/stream_test.go index bc4c18bb1..729b21e8d 100644 --- a/internal/streams/stream_test.go +++ b/internal/streams/stream_test.go @@ -9,16 +9,20 @@ import ( ) func TestRecursion(t *testing.T) { + streams = map[string]*Stream{} + HandleFunc("rtsp", func(url string) (core.Producer, error) { return nil, nil }) + HandleFunc("test", func(url string) (core.Producer, error) { return nil, nil }) + // create stream with some source - stream1, err := New("from_yaml", "does_not_matter") - require.NoError(t, err) + stream1, err := New("from_yaml", "test://does_not_matter") + require.Nil(t, err) require.Len(t, streams, 1) // ask another unnamed stream that links go2rtc query, err := url.ParseQuery("src=rtsp://localhost:8554/from_yaml?video") - require.NoError(t, err) + require.Nil(t, err) stream2, err := GetOrPatch(query) - require.NoError(t, err) + require.Nil(t, err) // check stream is same require.Equal(t, stream1, stream2) @@ -28,14 +32,17 @@ func TestRecursion(t *testing.T) { } func TestTempate(t *testing.T) { + streams = map[string]*Stream{} + HandleFunc("rtsp", func(url string) (core.Producer, error) { return nil, nil }) // bypass HasProducer + HandleFunc("ffmpeg", func(url string) (core.Producer, error) { return nil, nil }) // config from yaml stream1, err := New("camera.from_hass", "ffmpeg:{input}#video=copy") - require.NoError(t, err) + require.Nil(t, err) // request from hass stream2, err := Patch("camera.from_hass", "rtsp://example.com") - require.NoError(t, err) + require.Nil(t, err) require.Equal(t, stream1, stream2) require.Equal(t, "ffmpeg:rtsp://example.com#video=copy", stream1.producers[0].url) diff --git a/internal/streams/streams.go b/internal/streams/streams.go index f3b8df03c..a84dda8a9 100644 --- a/internal/streams/streams.go +++ b/internal/streams/streams.go @@ -174,3 +174,18 @@ func GetAllSources() map[string][]string { streamsMu.Unlock() return sources } + +func StopAll() { + streamsMu.Lock() + unique := make(map[*Stream]struct{}, len(streams)) + for _, stream := range streams { + if stream != nil { + unique[stream] = struct{}{} + } + } + streamsMu.Unlock() + + for stream := range unique { + stream.stopAll() + } +} diff --git a/main.go b/main.go index 00c059e3e..5ff85ec96 100644 --- a/main.go +++ b/main.go @@ -121,5 +121,7 @@ func main() { } } - shell.RunUntilSignal() + sig := shell.WaitSignal() + println("exit with signal:", sig.String()) + streams.StopAll() } diff --git a/pkg/core/media.go b/pkg/core/media.go index 367d8cb82..58bdc708a 100644 --- a/pkg/core/media.go +++ b/pkg/core/media.go @@ -177,33 +177,30 @@ func UnmarshalMedia(md *sdp.MediaDescription) *Media { } func ParseQuery(query map[string][]string) (medias []*Media) { - // set media candidates from query list - for key, values := range query { - switch key { - case KindVideo, KindAudio: - for _, value := range values { - media := &Media{Kind: key, Direction: DirectionSendonly} - - for _, name := range strings.Split(value, ",") { - name = strings.ToUpper(name) - - // check aliases - switch name { - case "", "COPY": - name = CodecAny - case "MJPEG": - name = CodecJPEG - case "AAC": - name = CodecAAC - case "MP3": - name = CodecMP3 - } - - media.Codecs = append(media.Codecs, &Codec{Name: name}) + for _, key := range []string{KindVideo, KindAudio} { + values := query[key] + for _, value := range values { + media := &Media{Kind: key, Direction: DirectionSendonly} + + for _, name := range strings.Split(value, ",") { + name = strings.ToUpper(name) + + // check aliases + switch name { + case "", "COPY": + name = CodecAny + case "MJPEG": + name = CodecJPEG + case "AAC": + name = CodecAAC + case "MP3": + name = CodecMP3 } - medias = append(medias, media) + media.Codecs = append(media.Codecs, &Codec{Name: name}) } + + medias = append(medias, media) } } diff --git a/pkg/core/media_test.go b/pkg/core/media_test.go index f2f05e634..da19dfd45 100644 --- a/pkg/core/media_test.go +++ b/pkg/core/media_test.go @@ -44,6 +44,17 @@ func TestParseQuery(t *testing.T) { } } +func TestParseQueryMediaOrderIsStable(t *testing.T) { + query := url.Values{ + "audio": {""}, + "video": {""}, + } + medias := ParseQuery(query) + require.Len(t, medias, 2) + require.Equal(t, KindVideo, medias[0].Kind) + require.Equal(t, KindAudio, medias[1].Kind) +} + func TestClone(t *testing.T) { media1 := &Media{ Kind: KindVideo, diff --git a/pkg/rtsp/conn.go b/pkg/rtsp/conn.go index 2984c781c..b7817f0d3 100644 --- a/pkg/rtsp/conn.go +++ b/pkg/rtsp/conn.go @@ -27,6 +27,7 @@ type Conn struct { Media string OnClose func() error PacketSize uint16 + Repack bool SessionName string Timeout int Transport string // custom transport support, ex. RTSP over WebSocket @@ -107,7 +108,7 @@ func (c *Conn) Handle() (err error) { if c.Timeout == 0 { // polling frames from remote RTSP Server (ex Camera) - timeout = time.Second * 5 + timeout = defaultActiveProducerTimeout(c.URL) if len(c.Receivers) == 0 || c.Transport == "udp" { // if we only send audio to camera @@ -121,7 +122,7 @@ func (c *Conn) Handle() (err error) { case core.ModePassiveProducer: // polling frames from remote RTSP Client (ex FFmpeg) if c.Timeout == 0 { - timeout = time.Second * 15 + timeout = defaultPassiveProducerTimeout(c.RemoteAddr) } else { timeout = time.Second * time.Duration(c.Timeout) } @@ -151,6 +152,30 @@ func (c *Conn) Handle() (err error) { return } +func defaultActiveProducerTimeout(uri *url.URL) time.Duration { + if uri != nil { + host := uri.Hostname() + switch host { + case "127.0.0.1", "::1", "localhost": + return 20 * time.Second + } + } + + return 5 * time.Second +} + +func defaultPassiveProducerTimeout(remoteAddr string) time.Duration { + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + switch host { + case "127.0.0.1", "::1", "localhost": + return 60 * time.Second + } + } + + return 15 * time.Second +} + func (c *Conn) handleKeepalive(ctx context.Context, d time.Duration) { ticker := time.NewTicker(d) for { diff --git a/pkg/rtsp/conn_test.go b/pkg/rtsp/conn_test.go new file mode 100644 index 000000000..b255fd45c --- /dev/null +++ b/pkg/rtsp/conn_test.go @@ -0,0 +1,23 @@ +package rtsp + +import ( + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDefaultActiveProducerTimeout(t *testing.T) { + require.Equal(t, 20*time.Second, defaultActiveProducerTimeout(&url.URL{Host: "127.0.0.1:8554"})) + require.Equal(t, 20*time.Second, defaultActiveProducerTimeout(&url.URL{Host: "localhost:8554"})) + require.Equal(t, 5*time.Second, defaultActiveProducerTimeout(&url.URL{Host: "192.168.2.238:554"})) + require.Equal(t, 5*time.Second, defaultActiveProducerTimeout(nil)) +} + +func TestDefaultPassiveProducerTimeout(t *testing.T) { + require.Equal(t, 60*time.Second, defaultPassiveProducerTimeout("127.0.0.1:8554")) + require.Equal(t, 60*time.Second, defaultPassiveProducerTimeout("[::1]:8554")) + require.Equal(t, 15*time.Second, defaultPassiveProducerTimeout("192.168.2.238:554")) + require.Equal(t, 15*time.Second, defaultPassiveProducerTimeout("not-an-addr")) +} diff --git a/pkg/rtsp/consumer.go b/pkg/rtsp/consumer.go index e6525d963..7c5bc9f6b 100644 --- a/pkg/rtsp/consumer.go +++ b/pkg/rtsp/consumer.go @@ -148,6 +148,12 @@ func (c *Conn) packetWriter(codec *core.Codec, channel, payloadType uint8) core. flushBuf() } + handlerFunc = c.wrapPacketHandler(codec, handlerFunc) + + return handlerFunc +} + +func (c *Conn) wrapPacketHandler(codec *core.Codec, handlerFunc core.HandlerFunc) core.HandlerFunc { if !codec.IsRTP() { switch codec.Name { case core.CodecH264: @@ -159,22 +165,67 @@ func (c *Conn) packetWriter(codec *core.Codec, channel, payloadType uint8) core. case core.CodecJPEG: handlerFunc = mjpeg.RTPPay(handlerFunc) } - } else if codec.Name == core.CodecPCML { + return handlerFunc + } + + if codec.Name == core.CodecPCML { handlerFunc = pcm.LittleToBig(handlerFunc) - } else if c.PacketSize != 0 { + return handlerFunc + } + + if c.Repack || c.PacketSize != 0 { switch codec.Name { case core.CodecH264: handlerFunc = h264.RTPPay(c.PacketSize, handlerFunc) + if c.Repack { + handlerFunc = waitH264Keyframe(handlerFunc) + } handlerFunc = h264.RTPDepay(codec, handlerFunc) case core.CodecH265: handlerFunc = h265.RTPPay(c.PacketSize, handlerFunc) + if c.Repack { + handlerFunc = waitH265Keyframe(handlerFunc) + } handlerFunc = h265.RTPDepay(codec, handlerFunc) + case core.CodecAAC: + handlerFunc = aac.RTPPay(handlerFunc) + handlerFunc = aac.RTPDepay(handlerFunc) } } return handlerFunc } +func waitH264Keyframe(handlerFunc core.HandlerFunc) core.HandlerFunc { + var synced bool + + return func(packet *rtp.Packet) { + if !synced { + if !h264.IsKeyframe(packet.Payload) { + return + } + synced = true + } + + handlerFunc(packet) + } +} + +func waitH265Keyframe(handlerFunc core.HandlerFunc) core.HandlerFunc { + var synced bool + + return func(packet *rtp.Packet) { + if !synced { + if !h265.IsKeyframe(packet.Payload) { + return + } + synced = true + } + + handlerFunc(packet) + } +} + func (c *Conn) writeInterleavedData(data []byte) error { if c.Transport != "udp" { _ = c.conn.SetWriteDeadline(time.Now().Add(Timeout)) diff --git a/pkg/rtsp/consumer_test.go b/pkg/rtsp/consumer_test.go new file mode 100644 index 000000000..f6d550c3b --- /dev/null +++ b/pkg/rtsp/consumer_test.go @@ -0,0 +1,126 @@ +package rtsp + +import ( + "testing" + + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/pion/rtp" + "github.com/stretchr/testify/require" +) + +func TestConnWrapPacketHandlerAACRepack(t *testing.T) { + codec := &core.Codec{ + Name: core.CodecAAC, + ClockRate: 16000, + PayloadType: 97, + FmtpLine: "streamtype=5;profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3;config=1408", + } + + var packets []*rtp.Packet + conn := &Conn{Repack: true} + handler := conn.wrapPacketHandler(codec, func(packet *rtp.Packet) { + clone := *packet + clone.Payload = append([]byte(nil), packet.Payload...) + packets = append(packets, &clone) + }) + + handler(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 40000, + Timestamp: 123456, + SSRC: 77, + Marker: true, + }, + Payload: []byte{0x00, 0x10, 0x00, 0x40, 1, 2, 3, 4, 5, 6, 7, 8}, + }) + handler(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 17, + Timestamp: 42, + SSRC: 88, + Marker: true, + }, + Payload: []byte{0x00, 0x10, 0x00, 0x40, 9, 10, 11, 12, 13, 14, 15, 16}, + }) + + require.Len(t, packets, 2) + require.Equal(t, uint16(0), packets[0].SequenceNumber) + require.Equal(t, uint16(1), packets[1].SequenceNumber) + require.Equal(t, uint32(0), packets[0].Timestamp) + require.Greater(t, packets[1].Timestamp, packets[0].Timestamp) + require.Zero(t, packets[0].SSRC) + require.Zero(t, packets[1].SSRC) +} + +func TestConnWrapPacketHandlerAACPassthrough(t *testing.T) { + codec := &core.Codec{ + Name: core.CodecAAC, + ClockRate: 16000, + PayloadType: 97, + } + + var packets []*rtp.Packet + conn := &Conn{} + handler := conn.wrapPacketHandler(codec, func(packet *rtp.Packet) { + clone := *packet + packets = append(packets, &clone) + }) + + handler(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 1234, + Timestamp: 5678, + SSRC: 90, + Marker: true, + }, + Payload: []byte{1, 2, 3}, + }) + + require.Len(t, packets, 1) + require.Equal(t, uint16(1234), packets[0].SequenceNumber) + require.Equal(t, uint32(5678), packets[0].Timestamp) + require.Equal(t, uint32(90), packets[0].SSRC) +} + +func TestConnWrapPacketHandlerH264RepackWaitsForKeyframe(t *testing.T) { + codec := &core.Codec{ + Name: core.CodecH264, + ClockRate: 90000, + PayloadType: 96, + } + + var packets []*rtp.Packet + conn := &Conn{Repack: true} + handler := conn.wrapPacketHandler(codec, func(packet *rtp.Packet) { + clone := *packet + clone.Payload = append([]byte(nil), packet.Payload...) + packets = append(packets, &clone) + }) + + handler(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 10, + Timestamp: 1000, + Marker: true, + }, + Payload: []byte{0x41, 0x9a}, + }) + + require.Len(t, packets, 0) + + handler(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: 11, + Timestamp: 4000, + Marker: true, + }, + Payload: []byte{0x65, 0x88}, + }) + + require.NotEmpty(t, packets) +} diff --git a/pkg/shell/command.go b/pkg/shell/command.go index b7c818995..f6b7a7d1c 100644 --- a/pkg/shell/command.go +++ b/pkg/shell/command.go @@ -2,7 +2,10 @@ package shell import ( "context" + "errors" "os/exec" + "sync" + "time" ) // Command like exec.Cmd, but with support: @@ -13,6 +16,8 @@ type Command struct { *exec.Cmd ctx context.Context cancel context.CancelFunc + done chan struct{} + mu sync.Mutex err error } @@ -21,7 +26,15 @@ func NewCommand(s string) *Command { args := QuoteSplit(s) cmd := exec.CommandContext(ctx, args[0], args[1:]...) cmd.SysProcAttr = procAttr - return &Command{cmd, ctx, cancel, nil} + cmd.Cancel = func() error { + return terminateCommand(cmd) + } + return &Command{ + Cmd: cmd, + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + } } func (c *Command) Start() error { @@ -30,7 +43,11 @@ func (c *Command) Start() error { } go func() { - c.err = c.Cmd.Wait() + err := c.Cmd.Wait() + c.mu.Lock() + c.err = err + c.mu.Unlock() + close(c.done) c.cancel() // release context resources }() @@ -38,7 +55,9 @@ func (c *Command) Start() error { } func (c *Command) Wait() error { - <-c.ctx.Done() + <-c.done + c.mu.Lock() + defer c.mu.Unlock() return c.err } @@ -50,10 +69,31 @@ func (c *Command) Run() error { } func (c *Command) Done() <-chan struct{} { - return c.ctx.Done() + return c.done } func (c *Command) Close() error { c.cancel() - return nil + + select { + case <-c.done: + return c.Wait() + case <-time.After(5 * time.Second): + } + + _ = killCommand(c.Cmd) + + select { + case <-c.done: + case <-time.After(time.Second): + c.mu.Lock() + err := c.err + c.mu.Unlock() + if err != nil { + return err + } + return errors.New("shell: close timeout") + } + + return c.Wait() } diff --git a/pkg/shell/command_test.go b/pkg/shell/command_test.go new file mode 100644 index 000000000..2fbd07e98 --- /dev/null +++ b/pkg/shell/command_test.go @@ -0,0 +1,24 @@ +package shell + +import ( + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCommandCloseWaitsForExit(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("shell command test uses POSIX sh") + } + + cmd := NewCommand(`sh -c "sleep 30"`) + require.NoError(t, cmd.Start()) + + start := time.Now() + err := cmd.Close() + require.Less(t, time.Since(start), 7*time.Second) + require.Error(t, err) + require.NotNil(t, cmd.ProcessState) +} diff --git a/pkg/shell/procattr.go b/pkg/shell/procattr.go index fffdc2a40..d285b8ae8 100644 --- a/pkg/shell/procattr.go +++ b/pkg/shell/procattr.go @@ -2,6 +2,35 @@ package shell -import "syscall" +import ( + "errors" + "os" + "os/exec" + "syscall" +) var procAttr *syscall.SysProcAttr + +func terminateCommand(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + if err := cmd.Process.Signal(os.Interrupt); err != nil && !errors.Is(err, os.ErrProcessDone) { + return err + } + + return nil +} + +func killCommand(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) { + return err + } + + return nil +} diff --git a/pkg/shell/procattr_linux.go b/pkg/shell/procattr_linux.go index cef1d1529..28c1dcc2b 100644 --- a/pkg/shell/procattr_linux.go +++ b/pkg/shell/procattr_linux.go @@ -1,6 +1,49 @@ package shell -import "syscall" +import ( + "errors" + "os/exec" + "syscall" +) // will stop child if parent died (even with SIGKILL) -var procAttr = &syscall.SysProcAttr{Pdeathsig: syscall.SIGTERM} +var procAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGTERM, + Setpgid: true, +} + +func terminateCommand(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + if pgid, err := syscall.Getpgid(cmd.Process.Pid); err == nil { + if err = syscall.Kill(-pgid, syscall.SIGTERM); err == nil || errors.Is(err, syscall.ESRCH) { + return nil + } + } + + if err := cmd.Process.Signal(syscall.SIGTERM); err != nil && !errors.Is(err, syscall.ESRCH) { + return err + } + + return nil +} + +func killCommand(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + if pgid, err := syscall.Getpgid(cmd.Process.Pid); err == nil { + if err = syscall.Kill(-pgid, syscall.SIGKILL); err == nil || errors.Is(err, syscall.ESRCH) { + return nil + } + } + + if err := cmd.Process.Kill(); err != nil && !errors.Is(err, syscall.ESRCH) { + return err + } + + return nil +} diff --git a/pkg/shell/shell.go b/pkg/shell/shell.go index e04a58c49..df7d43b15 100644 --- a/pkg/shell/shell.go +++ b/pkg/shell/shell.go @@ -37,7 +37,11 @@ func QuoteSplit(s string) []string { } func RunUntilSignal() { + println("exit with signal:", WaitSignal().String()) +} + +func WaitSignal() os.Signal { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - println("exit with signal:", (<-sigs).String()) + return <-sigs }