#!/usr/bin/env python

import argparse
from collections import namedtuple
from datetime import datetime
import logging
import os
import subprocess
import sys
import time


parser = argparse.ArgumentParser(
    "tcpmonitor",
    description="""
Displays the TCP connection speed in Bps, and factors that may affect performance.
If an output field is red, it means that field may impact the tcp performance.

e.g.
show all tcp connections that the source port is 22:
tcpmonitor -s 22
show the tcp connection that the soruce port is 22 and the dst port is 55257
and show the reuslt in every 2 seconds:
tcpmonitor -s 22 -d 55257 -i 2
output:
```
           datetime       src_ip:src_port       dst_ip:dst_port      outB/s       inB/s    SQ    cwnd_Bps snd_wnd_Bps
2026-02-06 10:46:40  11.164.240.247:55590     100.67.155.244:80        0.00        0.00 False     5213.76 None
2026-02-06 10:46:41  11.164.240.247:55590     100.67.155.244:80      677.00        0.00 False     5224.66 None
2026-02-06 10:46:42  11.164.240.247:55590     100.67.155.244:80      374.00        0.00 False     5274.26 None
```

datetime
the timestamp

src_ip:src_port 
the source ip address and the source port

dst_ip:dst_port
the dst ip address and dst port

outB/s
outbound bytes per second

inB/s
inbound bytes per second

SQ
True indicates that traffic is being throttled by the TCP small queue check. For
virtio_net NIC, you may set `virtio_net.napi_tx=0` to workaround this limit.

cwnd_Bps
The maximum bandwidth that the current congestion window can support. It is obtained
by dividing cwnd by rtt. If this value is too small, it indicates that the performance 
of the TCP connection is limited by the size of the congestion window. This situation 
may be due to a bottleneck in the intermediate link's bandwidth, or it may be that the 
TCP congestion control algorithm has not adjusted the window to its maximum size. Consider 
testing the bandwidth of the intermediate link or adjusting the relevant parameters of 
the congestion control window.

snd_wnd_Bps
The maximum bandwidth that the current TCP receiver window can support. It is calculated 
by dividing the send window (snd_wnd) by the round-trip time (RTT). If this value is too 
small, it indicates that the other end of the TCP connection does not have enough space to 
receive data, and you need to check the settings on the other end of the TCP connection to 
resolve the issue.
""",
    formatter_class=argparse.RawTextHelpFormatter,
)

parser.add_argument(
    "-s",
    "--sport",
    type=int,
    help="tcp source port number",
)

parser.add_argument(
    "-d",
    "--dport",
    type=int,
    help="tcp dest port number",
)

parser.add_argument(
    "-i",
    "--interval",
    type=int,
    help="delay in seconds between screen updates",
    default=1,
)

parser.add_argument(
    "-c",
    "--count",
    type=int,
    help="Specifies the maximum number of iterations, <=0 means run forever",
    default=0,
)

parser.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    help="show verbose messages",
    default=False,
)

parser.add_argument

# Get from the `enum sk_pacing` in include/net/sock.h
SK_PACING_NONE = 0
SK_PACING_NEEDED = 1
SK_PACING_FQ = 2


class GlobalInfo:

    def __init__(
        self,
        interval,
        sysctl_tcp_limit_output_bytes,
    ):
        self.interval = interval
        self.sysctl_tcp_limit_output_bytes = sysctl_tcp_limit_output_bytes
        self.tcp_tx_delay_enabled = False

        # Below fields are not really global infomration, they are
        # either per coneciton or per package information. We can't
        # get them easily, so just give them reasonable default values.
        self.sk_pacing_status = SK_PACING_NONE
        self.sk_pacing_shift = 10 # 10 is the kernel default value
        self.tcp_tx_delay = 0
        self.skb_truesize = 0
        self.small_queue_check_factor = 1


class TcpInfo:

    def __init__(self):
        self.bytes_sent = None
        self.bytes_received = None
        self.wmem_alloc = None
        self.pacing_rate = None
        self.snd_wscale = None
        self.rcv_wscale = None
        self.rtt = None
        self.rtt_var = None
        self.cwnd = None
        self.snd_wnd = None

    def __str__(self):
        output = ""
        output = "%s bytes_sent=%s" % (output, self.bytes_sent)
        output = "%s bytes_received=%s" % (output, self.bytes_received)
        output = "%s wmem_alloc=%s" % (output, self.wmem_alloc)
        output = "%s pacing_rate=%s" % (output, self.pacing_rate)
        return output

    def __repr__(self):
        return self.__str__()


TCP_LIMIT_OUTPUT_BYTES_PATH = "/proc/sys/net/ipv4/tcp_limit_output_bytes"


def get_tcp_limit_output_bytes():
    if os.path.exists(TCP_LIMIT_OUTPUT_BYTES_PATH):
        with open(TCP_LIMIT_OUTPUT_BYTES_PATH) as f:
            data = f.read().strip()
            logging.debug("tcp_limit_output_bytes: %s", data)
            return int(data)

    # use a huge value so it won't limit anything
    return 2**63 - 1



def run_ss(ss_cmd):
    try:
        result = subprocess.check_output(ss_cmd).decode()
    except Exception as e:
        logging.error("Call ss command failed: %s", e)
        if "No such file or directory" in str(e):
            logging.error("Please install the ss command")
            sys.exit(1)
        else:
            raise e

    lines = result.split("\n")
    four_tuple_to_info = {}
    for line in lines:
        line = line.strip()
        if len(line) == 0:
            continue
        items = line.split()
        if len(items) < 4:
            logging.warning(
                "Invalid line: [%s], items count: %d",
                line,
                len(items),
            )
            continue
        src = items[2]
        if ":" not in src:
            logging.warning("Invalid line: [%s] src: [%s]", line, src)
            continue
        dst = items[3]
        if ":" not in dst:
            logging.warning("Invalid line: [%s] dst: [%s]", line, dst)
            continue
        four_tuple = "{src}-{dst}".format(src=src, dst=dst)

        tcp_info = TcpInfo()
        next_is_pacing_rate = False
        for item in items[4:]:
            if item.startswith("bytes_sent:"):
                tcp_info.bytes_sent = int(item[len("bytes_sent:"):])
            elif item.startswith("bytes_received:"):
                tcp_info.bytes_received = int(item[len("bytes_received:")])
            elif item.startswith("skmem"):
                for sub_item in item[len("skmem:(") : -1].split(","):
                    if sub_item[0] == "t" and sub_item[1].isdigit():
                        tcp_info.wmem_alloc = int(sub_item[1:])
            elif item.startswith("pacing_rate"):
                next_is_pacing_rate = True
            elif next_is_pacing_rate is True:
                next_is_pacing_rate = False
                if not item.endswith("bps"):
                    logging.warning("Unknown pacing rate format: %s", item)
                    continue
                tcp_info.pacing_rate = int(item[:-3])
            elif item.startswith("wscale:"):
                wscales = item[len("wscale:"):].split(",")
                if len(wscales) != 2:
                    logging.warning("Unknown wscale format: %s", item)
                    continue
                tcp_info.snd_wscale = int(wscales[0])
                tcp_info.rcv_wscale = int(wscales[1])
            elif item.startswith("cwnd:"):
                tcp_info.cwnd = int(item[len("cwnd:"):])
            elif item.startswith("rtt:"):
                rtts = item[len("rtt:"):].split("/")
                if len(rtts) != 2:
                    logging.warning("Unknown rtt format: %s", item)
                    continue
                tcp_info.rtt = float(rtts[0])
                tcp_info.rtt_var = float(rtts[1])
            elif item.startswith("snd_wnd:"):
                tcp_info.snd_wnd = int(item[len("snd_wnd:"):])

        four_tuple_to_info[four_tuple] = tcp_info

    return four_tuple_to_info


def calc_bytes_sent_per_second(
    curr_bytes_sent,
    prev_bytes_sent,
    interval,
):
    return (curr_bytes_sent - prev_bytes_sent) / float(interval)


def calc_bytes_received_per_second(
    curr_bytes_recevied,
    prev_bytes_received,
    interval,
):
    return (curr_bytes_recevied - prev_bytes_received) / float(interval)


def calc_cwnd_Bps(
        cwnd,
        rtt_ms,
):
    # rtt_ms is a float number, to check if it is zero, we compare
    # it with a small number.
    if rtt_ms < 1e-10:
        return 0
    return 1000 * cwnd / rtt_ms


def calc_snd_wnd_Bps(
    snd_wnd,
    rtt_ms,
):
    # rtt_ms is a float number, to check if it is zero, we compare
    # it with a small number.
    if rtt_ms < 1e-10:
        return 0
    return 1000 * snd_wnd / rtt_ms


def green_str(raw_str):
    return "\033[32m%s\033[0m" % raw_str


def red_str(raw_str):
    return "\033[31m%s\033[0m" % raw_str


# Simulate the tcp_small_queue_check function
# in net/ipv4/tcp_output.c
def calc_small_queue_check_throttled(
        wmem_alloc,
        pacing_rate,
        sk_pacing_shift,
        sk_pacing_status,
        sysctl_tcp_limit_output_bytes,
        small_queue_check_factor,
        tcp_tx_delay_enabled,
        tcp_tx_delay,
        skb_truesize,
):
    # Convert bps to bytes per second
    pacing_rate = pacing_rate // 8

    limit = max(2 * skb_truesize, pacing_rate >> sk_pacing_shift)
    if sk_pacing_status == SK_PACING_NONE:
        limit = min(limit, sysctl_tcp_limit_output_bytes)
    limit <<= small_queue_check_factor

    if tcp_tx_delay_enabled and tcp_tx_delay:
        extra_bytes = pacing_rate * tcp_tx_delay
        # Copied below comment from the kerenl tcp_small_queue_check funciton:
        # TSQ is based on skb truesize sum (sk_wmem_alloc), so we
        # approximate our needs assuming an ~100% skb->truesize overhead.
	# USEC_PER_SEC is approximated by 2^20.
	# do_div(extra_bytes, USEC_PER_SEC/2) is replaced by a right shift.
        extra_bytes >>= 20 - 1
        limit += extra_bytes

    return wmem_alloc > limit


Bps_window_factor = 0.5


def show_tcp_connection(
    four_tuple,
    curr_info,
    prev_info,
    global_info,
    dt_str,
):
    src_ip_port, dst_ip_port = four_tuple.split("-")
    output_line = dt_str
    output_line = "%s %21s" % (output_line, src_ip_port)
    output_line = "%s %21s" % (output_line, dst_ip_port)

    Bps_threshold = -1

    if curr_info.bytes_sent is None or prev_info.bytes_sent is None:
        output_line = "%s %11s" % (output_line, "None")
    else:
        bytes_sent_per_second = calc_bytes_sent_per_second(
            curr_info.bytes_sent,
            prev_info.bytes_sent,
            global_info.interval,
        )
        output_line = "%s %11.2f" % (
            output_line, bytes_sent_per_second)
        Bps_threshold = bytes_sent_per_second  * Bps_window_factor

    if curr_info.bytes_received is None or prev_info.bytes_received is None:
        output_line = "%s %11s" % (output_line, "None")
    else:
        bytes_received_per_second = calc_bytes_received_per_second(
            curr_info.bytes_received,
            prev_info.bytes_received,
            global_info.interval,
        )
        output_line = "%s %11.2f" % (
            output_line, bytes_received_per_second)

    if curr_info.wmem_alloc is None or curr_info.pacing_rate is None:
        output_line = "%s %5s" % (output_line, "None")
    else:
        sq = calc_small_queue_check_throttled(
            curr_info.wmem_alloc,
            curr_info.pacing_rate,
            global_info.sk_pacing_shift,
            global_info.sk_pacing_status,
            global_info.sysctl_tcp_limit_output_bytes,
            global_info.small_queue_check_factor,
            global_info.tcp_tx_delay_enabled,
            global_info.tcp_tx_delay,
            global_info.skb_truesize,
        )
        sq_raw_str = "%5s" % sq
        sq_str = red_str(sq_raw_str) if sq else green_str(sq_raw_str)
        output_line = "%s %s" % (output_line, sq_str)

    if curr_info.rtt is None or curr_info.cwnd is None:
        output_line = "%s %11s" % (output_line, "None")
    else:
        cwnd_Bps = calc_cwnd_Bps(curr_info.cwnd, curr_info.rtt)
        cwnd_Bps_raw_str = "%11.2f" % cwnd_Bps
        if cwnd_Bps >= Bps_threshold:
            cwnd_Bps_str = green_str(cwnd_Bps_raw_str)
        else:
            cwnd_Bps_str = red_str(cwnd_Bps_raw_str)
        output_line = "%s %s" % (output_line, cwnd_Bps_str)

    if curr_info.rtt is None or curr_info.snd_wnd is None:
        output_line = "%s %11s" %(output_line, "None")
    else:
        snd_wnd = curr_info.snd_wnd
        if curr_info.snd_wscale is not None:
            snd_wnd <<= curr_info.snd_wscale
        snd_wnd_Bps = calc_snd_wnd_Bps(snd_wnd, curr_info.rtt)
        snd_wnd_Bps_raw_str = "%11.2f" % snd_wnd_Bps
        if snd_wnd_Bps > Bps_threshold:
            snd_wnd_Bps_str = green_str(snd_wnd_Bps_raw_str)
        else:
            snd_wnd_Bps_str = red_str(snd_wnd_Bps_raw_str)
        output_line = "%s %s" % (output_line, snd_wnd_Bps_str)

    print(output_line)


def show_tcp_connections(curr, prev, global_info, dt_str):
    for four_tuple in curr:
        curr_info = curr[four_tuple]
        prev_info = prev.get(four_tuple)
        if prev_info is None:
            continue
        show_tcp_connection(four_tuple, curr_info, prev_info, global_info, dt_str)


def show_header():
    header = "           datetime"
    header = "%s       src_ip:src_port" % header
    header = "%s       dst_ip:dst_port" % header
    header = "%s      outB/s" % header
    header = "%s       inB/s" % header
    header = "%s    SQ" % header
    header = "%s    cwnd_Bps" % header
    header = "%s snd_wnd_Bps" % header
    print(header)


def main(args):

    if args.interval <= 0:
        raise Exception("interval should be larger than 0")

    if args.count <= 0:
        args.count = 1 << 64

    port_filter = None
    if args.sport is not None and args.dport is not None:
        port_filter = "\'( sport = :{sport} and dport = :{dport} )\'".format(
            sport=args.sport,
            dport=args.dport,
        )
    elif args.sport is not None:
        port_filter = "'sport = :{sport}'".format(
            sport=args.sport,
        )
    elif args.dport is not None:
        port_filter = "'dport = :{dport}'".format(
            dport=args.dport,
        )

    cmd = [
        "ss",
        "--no-header",
        "--oneline",
        "-tnmi",
        "-o",
        "state",
        "established",
    ]
    if args.sport is not None and args.dport is not None:
        cmd.append("sport")
        cmd.append("=")
        cmd.append(":{sport}".format(sport=args.sport))
        cmd.append("and")
        cmd.append("dport")
        cmd.append("=")
        cmd.append(":{dport}".format(dport=args.dport))
    elif args.sport is not None:
        cmd.append("sport")
        cmd.append("=")
        cmd.append(":{sport}".format(sport=args.sport))
    elif args.dport is not None:
        cmd.append("dport")
        cmd.append("=")
        cmd.append(":{dport}".format(dport=args.dport))
    
    logging.debug("cmd: %s", cmd)

    global_info = GlobalInfo(
        interval=args.interval,
        sysctl_tcp_limit_output_bytes=get_tcp_limit_output_bytes(),
    )

    prev = {}
    show_header()
    count = 0
    while count <= args.count:
        curr = run_ss(cmd)
        dt_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        show_tcp_connections(curr, prev, global_info, dt_str)
        prev = curr
        time.sleep(args.interval)
        count += 1


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)
