#!/usr/bin/env python3 import asyncio import socket import urllib.parse import urllib.request import collections import time import datetime import hmac import base64 import hashlib import random import binascii import sys import re import runpy import signal import os import stat import traceback TG_DATACENTER_PORT = 443 TG_DATACENTERS_V4 = [ "149.154.175.50", "149.154.167.51", "149.154.175.100", "149.154.167.91", "149.154.171.5", ] TG_DATACENTERS_V6 = [ "2001:b28:f23d:f001::a", "2001:67c:04e8:f002::a", "2001:b28:f23d:f003::a", "2001:67c:04e8:f004::a", "2001:b28:f23f:f005::a", ] # This list will be updated in the runtime TG_MIDDLE_PROXIES_V4 = { 1: [("149.154.175.50", 8888)], -1: [("149.154.175.50", 8888)], 2: [("149.154.161.144", 8888)], -2: [("149.154.161.144", 8888)], 3: [("149.154.175.100", 8888)], -3: [("149.154.175.100", 8888)], 4: [("91.108.4.136", 8888)], -4: [("149.154.165.109", 8888)], 5: [("91.108.56.168", 8888)], -5: [("91.108.56.182", 8888)], } TG_MIDDLE_PROXIES_V6 = { 1: [("2001:b28:f23d:f001::d", 8888)], -1: [("2001:b28:f23d:f001::d", 8888)], 2: [("2001:67c:04e8:f002::d", 80)], -2: [("2001:67c:04e8:f002::d", 80)], 3: [("2001:b28:f23d:f003::d", 8888)], -3: [("2001:b28:f23d:f003::d", 8888)], 4: [("2001:67c:04e8:f004::d", 8888)], -4: [("2001:67c:04e8:f004::d", 8888)], 5: [("2001:b28:f23f:f005::d", 8888)], -5: [("2001:67c:04e8:f004::d", 8888)], } PROXY_SECRET = bytes.fromhex( "c4f9faca9678e6bb48ad6c7e2ce5c0d24430645d554addeb55419e034da62721" + "d046eaab6e52ab14a95a443ecfb3463e79a05a66612adf9caeda8be9a80da698" + "6fb0a6ff387af84d88ef3a6413713e5c3377f6e1a3d47d99f5e0c56eece8f05c" + "54c490b079e31bef82ff0ee8f2b0a32756d249c5f21269816cb7061b265db212" ) SKIP_LEN = 8 PREKEY_LEN = 32 KEY_LEN = 32 IV_LEN = 16 HANDSHAKE_LEN = 64 PROTO_TAG_POS = 56 DC_IDX_POS = 60 MIN_CERT_LEN = 1024 PROTO_TAG_ABRIDGED = b"\xef\xef\xef\xef" PROTO_TAG_INTERMEDIATE = b"\xee\xee\xee\xee" PROTO_TAG_SECURE = b"\xdd\xdd\xdd\xdd" CBC_PADDING = 16 PADDING_FILLER = b"\x04\x00\x00\x00" MIN_MSG_LEN = 12 MAX_MSG_LEN = 2**24 STAT_DURATION_BUCKETS = [0.1, 0.5, 1, 2, 5, 15, 60, 300, 600, 1800, 2**31 - 1] my_ip_info = {"ipv4": None, "ipv6": None} used_handshakes = collections.OrderedDict() client_ips = collections.OrderedDict() last_client_ips = {} disable_middle_proxy = False is_time_skewed = False fake_cert_len = random.randrange(1024, 4096) mask_host_cached_ip = None last_clients_with_time_skew = {} last_clients_with_same_handshake = collections.Counter() proxy_start_time = 0 proxy_links = [] stats = collections.Counter() user_stats = collections.defaultdict(collections.Counter) config = {} def init_config(): global config # we use conf_dict to protect the original config from exceptions when reloading if len(sys.argv) < 2: conf_dict = runpy.run_module("config") elif len(sys.argv) == 2: # launch with own config conf_dict = runpy.run_path(sys.argv[1]) else: # undocumented way of launching conf_dict = {} conf_dict["PORT"] = int(sys.argv[1]) secrets = sys.argv[2].split(",") conf_dict["USERS"] = { "user%d" % i: secrets[i].zfill(32) for i in range(len(secrets)) } conf_dict["MODES"] = {"classic": False, "secure": True, "tls": True} if len(sys.argv) > 3: conf_dict["AD_TAG"] = sys.argv[3] if len(sys.argv) > 4: conf_dict["TLS_DOMAIN"] = sys.argv[4] conf_dict["MODES"] = {"classic": False, "secure": False, "tls": True} conf_dict = {k: v for k, v in conf_dict.items() if k.isupper()} conf_dict.setdefault("PORT", 3256) conf_dict.setdefault("USERS", {"tg": "00000000000000000000000000000000"}) conf_dict["AD_TAG"] = bytes.fromhex(conf_dict.get("AD_TAG", "")) for user, secret in conf_dict["USERS"].items(): if not re.fullmatch("[0-9a-fA-F]{32}", secret): fixed_secret = re.sub(r"[^0-9a-fA-F]", "", secret).zfill(32)[:32] print_err( "Bad secret for user %s, should be 32 hex chars, got %s. " % (user, secret) ) print_err("Changing it to %s" % fixed_secret) conf_dict["USERS"][user] = fixed_secret # load advanced settings # use middle proxy, necessary to show ad conf_dict.setdefault("USE_MIDDLE_PROXY", len(conf_dict["AD_TAG"]) == 16) # if IPv6 avaliable, use it by default conf_dict.setdefault("PREFER_IPV6", socket.has_ipv6) # disables tg->client trafic reencryption, faster but less secure conf_dict.setdefault("FAST_MODE", True) # enables some working modes modes = conf_dict.get("MODES", {}) if "MODES" not in conf_dict: modes.setdefault("classic", True) modes.setdefault("secure", True) modes.setdefault("tls", True) else: modes.setdefault("classic", False) modes.setdefault("secure", False) modes.setdefault("tls", False) legacy_warning = False if "SECURE_ONLY" in conf_dict: legacy_warning = True modes["classic"] = not bool(conf_dict["SECURE_ONLY"]) if "TLS_ONLY" in conf_dict: legacy_warning = True if conf_dict["TLS_ONLY"]: modes["classic"] = False modes["secure"] = False if not modes["classic"] and not modes["secure"] and not modes["tls"]: print_err("No known modes enabled, enabling tls-only mode") modes["tls"] = True if legacy_warning: print_err("Legacy options SECURE_ONLY or TLS_ONLY detected") print_err("Please use MODES in your config instead:") print_err("MODES = {") print_err(' "classic": %s,' % modes["classic"]) print_err(' "secure": %s,' % modes["secure"]) print_err(' "tls": %s' % modes["tls"]) print_err("}") conf_dict["MODES"] = modes # accept incoming connections only with proxy protocol v1/v2, useful for nginx and haproxy conf_dict.setdefault("PROXY_PROTOCOL", False) # set the tls domain for the proxy, has an influence only on starting message conf_dict.setdefault("TLS_DOMAIN", "www.google.com") # enable proxying bad clients to some host conf_dict.setdefault("MASK", True) # the next host to forward bad clients conf_dict.setdefault("MASK_HOST", conf_dict["TLS_DOMAIN"]) # set the home domain for the proxy, has an influence only on the log message conf_dict.setdefault("MY_DOMAIN", False) # the next host's port to forward bad clients conf_dict.setdefault("MASK_PORT", 443) # use upstream SOCKS5 proxy conf_dict.setdefault("SOCKS5_HOST", None) conf_dict.setdefault("SOCKS5_PORT", None) conf_dict.setdefault("SOCKS5_USER", None) conf_dict.setdefault("SOCKS5_PASS", None) if conf_dict["SOCKS5_HOST"] and conf_dict["SOCKS5_PORT"]: # Disable the middle proxy if using socks, they are not compatible conf_dict["USE_MIDDLE_PROXY"] = False # user tcp connection limits, the mapping from name to the integer limit # one client can create many tcp connections, up to 8 conf_dict.setdefault("USER_MAX_TCP_CONNS", {}) # expiration date for users in format of day/month/year conf_dict.setdefault("USER_EXPIRATIONS", {}) for user in conf_dict["USER_EXPIRATIONS"]: expiration = datetime.datetime.strptime( conf_dict["USER_EXPIRATIONS"][user], "%d/%m/%Y" ) conf_dict["USER_EXPIRATIONS"][user] = expiration # the data quota for user conf_dict.setdefault("USER_DATA_QUOTA", {}) # length of used handshake randoms for active fingerprinting protection, zero to disable conf_dict.setdefault("REPLAY_CHECK_LEN", 65536) # accept clients with bad clocks. This reduces the protection against replay attacks conf_dict.setdefault("IGNORE_TIME_SKEW", False) # length of last client ip addresses for logging conf_dict.setdefault("CLIENT_IPS_LEN", 131072) # delay in seconds between stats printing conf_dict.setdefault("STATS_PRINT_PERIOD", 600) # delay in seconds between middle proxy info updates conf_dict.setdefault("PROXY_INFO_UPDATE_PERIOD", 24 * 60 * 60) # delay in seconds between time getting, zero means disabled conf_dict.setdefault("GET_TIME_PERIOD", 10 * 60) # delay in seconds between getting the length of certificate on the mask host conf_dict.setdefault( "GET_CERT_LEN_PERIOD", random.randrange(4 * 60 * 60, 6 * 60 * 60) ) # max socket buffer size to the client direction, the more the faster, but more RAM hungry # can be the tuple (low, users_margin, high) for the adaptive case. If no much users, use high conf_dict.setdefault("TO_CLT_BUFSIZE", (16384, 100, 131072)) # max socket buffer size to the telegram servers direction, also can be the tuple conf_dict.setdefault("TO_TG_BUFSIZE", 65536) # keepalive period for clients in secs conf_dict.setdefault("CLIENT_KEEPALIVE", 10 * 60) # drop client after this timeout if the handshake fail conf_dict.setdefault("CLIENT_HANDSHAKE_TIMEOUT", random.randrange(5, 15)) # if client doesn't confirm data for this number of seconds, it is dropped conf_dict.setdefault("CLIENT_ACK_TIMEOUT", 5 * 60) # telegram servers connect timeout in seconds conf_dict.setdefault("TG_CONNECT_TIMEOUT", 10) # listen address for IPv4 conf_dict.setdefault("LISTEN_ADDR_IPV4", "0.0.0.0") # listen address for IPv6 conf_dict.setdefault("LISTEN_ADDR_IPV6", "::") # listen unix socket conf_dict.setdefault("LISTEN_UNIX_SOCK", "") # prometheus exporter listen port, use some random port here conf_dict.setdefault("METRICS_PORT", None) # prometheus listen addr ipv4 conf_dict.setdefault("METRICS_LISTEN_ADDR_IPV4", "0.0.0.0") # prometheus listen addr ipv6 conf_dict.setdefault("METRICS_LISTEN_ADDR_IPV6", None) # prometheus scrapers whitelist conf_dict.setdefault("METRICS_WHITELIST", ["127.0.0.1", "::1"]) # export proxy link to prometheus conf_dict.setdefault("METRICS_EXPORT_LINKS", False) # default prefix for metrics conf_dict.setdefault("METRICS_PREFIX", "mtprotoproxy_") # allow access to config by attributes config = type("config", (dict,), conf_dict)(conf_dict) def apply_upstream_proxy_settings(): # apply socks settings in place if config.SOCKS5_HOST and config.SOCKS5_PORT: import socks print_err( "Socket-proxy mode activated, it is incompatible with advertising and uvloop" ) socks.set_default_proxy( socks.PROXY_TYPE_SOCKS5, config.SOCKS5_HOST, config.SOCKS5_PORT, username=config.SOCKS5_USER, password=config.SOCKS5_PASS, ) if not hasattr(socket, "origsocket"): socket.origsocket = socket.socket socket.socket = socks.socksocket elif hasattr(socket, "origsocket"): socket.socket = socket.origsocket del socket.origsocket def try_use_cryptography_module(): from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend class CryptographyEncryptorAdapter: __slots__ = ("encryptor", "decryptor") def __init__(self, cipher): self.encryptor = cipher.encryptor() self.decryptor = cipher.decryptor() def encrypt(self, data): return self.encryptor.update(data) def decrypt(self, data): return self.decryptor.update(data) def create_aes_ctr(key, iv): iv_bytes = int.to_bytes(iv, 16, "big") cipher = Cipher(algorithms.AES(key), modes.CTR(iv_bytes), default_backend()) return CryptographyEncryptorAdapter(cipher) def create_aes_cbc(key, iv): cipher = Cipher(algorithms.AES(key), modes.CBC(iv), default_backend()) return CryptographyEncryptorAdapter(cipher) return create_aes_ctr, create_aes_cbc def try_use_pycrypto_or_pycryptodome_module(): from Crypto.Cipher import AES from Crypto.Util import Counter def create_aes_ctr(key, iv): ctr = Counter.new(128, initial_value=iv) return AES.new(key, AES.MODE_CTR, counter=ctr) def create_aes_cbc(key, iv): return AES.new(key, AES.MODE_CBC, iv) return create_aes_ctr, create_aes_cbc def use_slow_bundled_cryptography_module(): import pyaes msg = "To make the program a *lot* faster, please install cryptography module: " msg += "pip install cryptography\n" print(msg, flush=True, file=sys.stderr) class BundledEncryptorAdapter: __slots__ = ("mode",) def __init__(self, mode): self.mode = mode def encrypt(self, data): encrypter = pyaes.Encrypter(self.mode, pyaes.PADDING_NONE) return encrypter.feed(data) + encrypter.feed() def decrypt(self, data): decrypter = pyaes.Decrypter(self.mode, pyaes.PADDING_NONE) return decrypter.feed(data) + decrypter.feed() def create_aes_ctr(key, iv): ctr = pyaes.Counter(iv) return pyaes.AESModeOfOperationCTR(key, ctr) def create_aes_cbc(key, iv): mode = pyaes.AESModeOfOperationCBC(key, iv) return BundledEncryptorAdapter(mode) return create_aes_ctr, create_aes_cbc try: create_aes_ctr, create_aes_cbc = try_use_cryptography_module() except ImportError: try: create_aes_ctr, create_aes_cbc = try_use_pycrypto_or_pycryptodome_module() except ImportError: create_aes_ctr, create_aes_cbc = use_slow_bundled_cryptography_module() def print_err(*params): print(*params, file=sys.stderr, flush=True) def ensure_users_in_user_stats(): global user_stats for user in config.USERS: user_stats[user].update() def init_proxy_start_time(): global proxy_start_time proxy_start_time = time.time() def update_stats(**kw_stats): global stats stats.update(**kw_stats) def update_user_stats(user, **kw_stats): global user_stats user_stats[user].update(**kw_stats) def update_durations(duration): global stats for bucket in STAT_DURATION_BUCKETS: if duration <= bucket: break update_stats(**{"connects_with_duration_le_%s" % str(bucket): 1}) def get_curr_connects_count(): global user_stats all_connects = 0 for user, stat in user_stats.items(): all_connects += stat["curr_connects"] return all_connects def get_to_tg_bufsize(): if isinstance(config.TO_TG_BUFSIZE, int): return config.TO_TG_BUFSIZE low, margin, high = config.TO_TG_BUFSIZE return high if get_curr_connects_count() < margin else low def get_to_clt_bufsize(): if isinstance(config.TO_CLT_BUFSIZE, int): return config.TO_CLT_BUFSIZE low, margin, high = config.TO_CLT_BUFSIZE return high if get_curr_connects_count() < margin else low class MyRandom(random.Random): def __init__(self): super().__init__() key = bytes([random.randrange(256) for i in range(32)]) iv = random.randrange(256**16) self.encryptor = create_aes_ctr(key, iv) self.buffer = bytearray() def getrandbits(self, k): numbytes = (k + 7) // 8 return int.from_bytes(self.getrandbytes(numbytes), "big") >> (numbytes * 8 - k) def getrandbytes(self, n): CHUNK_SIZE = 512 while n > len(self.buffer): data = int.to_bytes(super().getrandbits(CHUNK_SIZE * 8), CHUNK_SIZE, "big") self.buffer += self.encryptor.encrypt(data) result = self.buffer[:n] self.buffer = self.buffer[n:] return bytes(result) myrandom = MyRandom() class TgConnectionPool: MAX_CONNS_IN_POOL = 0 def __init__(self): self.pools = {} async def open_tg_connection(self, host, port, init_func=None): task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize()) reader_tgt, writer_tgt = await asyncio.wait_for( task, timeout=config.TG_CONNECT_TIMEOUT ) set_keepalive(writer_tgt.get_extra_info("socket")) set_bufsizes( writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize(), ) if init_func: return await asyncio.wait_for( init_func(host, port, reader_tgt, writer_tgt), timeout=config.TG_CONNECT_TIMEOUT, ) return reader_tgt, writer_tgt def register_host_port(self, host, port, init_func): if (host, port, init_func) not in self.pools: self.pools[(host, port, init_func)] = [] while ( len(self.pools[(host, port, init_func)]) < TgConnectionPool.MAX_CONNS_IN_POOL ): connect_task = asyncio.ensure_future( self.open_tg_connection(host, port, init_func) ) self.pools[(host, port, init_func)].append(connect_task) async def get_connection(self, host, port, init_func=None): # self.register_host_port(host, port, init_func) # ret = None # for task in self.pools[(host, port, init_func)][::]: # if task.done(): # if task.exception(): # self.pools[(host, port, init_func)].remove(task) # continue # reader, writer, *other = task.result() # if writer.transport.is_closing(): # self.pools[(host, port, init_func)].remove(task) # continue # if not ret: # self.pools[(host, port, init_func)].remove(task) # ret = (reader, writer, *other) # self.register_host_port(host, port, init_func) # if ret: # return ret return await self.open_tg_connection(host, port, init_func) tg_connection_pool = TgConnectionPool() class LayeredStreamReaderBase: __slots__ = ("upstream",) def __init__(self, upstream): self.upstream = upstream async def read(self, n): return await self.upstream.read(n) async def readexactly(self, n): return await self.upstream.readexactly(n) class LayeredStreamWriterBase: __slots__ = ("upstream",) def __init__(self, upstream): self.upstream = upstream def write(self, data, extra={}): return self.upstream.write(data) def write_eof(self): return self.upstream.write_eof() async def drain(self): return await self.upstream.drain() def close(self): return self.upstream.close() def abort(self): return self.upstream.transport.abort() def get_extra_info(self, name): return self.upstream.get_extra_info(name) @property def transport(self): return self.upstream.transport class FakeTLSStreamReader(LayeredStreamReaderBase): __slots__ = ("buf",) def __init__(self, upstream): self.upstream = upstream self.buf = bytearray() async def read(self, n, ignore_buf=False): if self.buf and not ignore_buf: data = self.buf self.buf = bytearray() return bytes(data) while True: tls_rec_type = await self.upstream.readexactly(1) if not tls_rec_type: return b"" if tls_rec_type not in [b"\x14", b"\x17"]: print_err("BUG: bad tls type %s in FakeTLSStreamReader" % tls_rec_type) return b"" version = await self.upstream.readexactly(2) if version != b"\x03\x03": print_err("BUG: unknown version %s in FakeTLSStreamReader" % version) return b"" data_len = int.from_bytes(await self.upstream.readexactly(2), "big") data = await self.upstream.readexactly(data_len) if tls_rec_type == b"\x14": continue return data async def readexactly(self, n): while len(self.buf) < n: tls_data = await self.read(1, ignore_buf=True) if not tls_data: return b"" self.buf += tls_data data, self.buf = self.buf[:n], self.buf[n:] return bytes(data) class FakeTLSStreamWriter(LayeredStreamWriterBase): __slots__ = () def __init__(self, upstream): self.upstream = upstream def write(self, data, extra={}): MAX_CHUNK_SIZE = 16384 + 24 for start in range(0, len(data), MAX_CHUNK_SIZE): end = min(start + MAX_CHUNK_SIZE, len(data)) self.upstream.write(b"\x17\x03\x03" + int.to_bytes(end - start, 2, "big")) self.upstream.write(data[start:end]) return len(data) class CryptoWrappedStreamReader(LayeredStreamReaderBase): __slots__ = ("decryptor", "block_size", "buf") def __init__(self, upstream, decryptor, block_size=1): self.upstream = upstream self.decryptor = decryptor self.block_size = block_size self.buf = bytearray() async def read(self, n): if self.buf: ret = bytes(self.buf) self.buf.clear() return ret else: data = await self.upstream.read(n) if not data: return b"" needed_till_full_block = -len(data) % self.block_size if needed_till_full_block > 0: data += self.upstream.readexactly(needed_till_full_block) return self.decryptor.decrypt(data) async def readexactly(self, n): if n > len(self.buf): to_read = n - len(self.buf) needed_till_full_block = -to_read % self.block_size to_read_block_aligned = to_read + needed_till_full_block data = await self.upstream.readexactly(to_read_block_aligned) self.buf += self.decryptor.decrypt(data) ret = bytes(self.buf[:n]) self.buf = self.buf[n:] return ret class CryptoWrappedStreamWriter(LayeredStreamWriterBase): __slots__ = ("encryptor", "block_size") def __init__(self, upstream, encryptor, block_size=1): self.upstream = upstream self.encryptor = encryptor self.block_size = block_size def write(self, data, extra={}): if len(data) % self.block_size != 0: print_err( "BUG: writing %d bytes not aligned to block size %d" % (len(data), self.block_size) ) return 0 q = self.encryptor.encrypt(data) return self.upstream.write(q) class MTProtoFrameStreamReader(LayeredStreamReaderBase): __slots__ = ("seq_no",) def __init__(self, upstream, seq_no=0): self.upstream = upstream self.seq_no = seq_no async def read(self, buf_size): msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") # skip paddings while msg_len == 4: msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") len_is_bad = msg_len % len(PADDING_FILLER) != 0 if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_bad: print_err("msg_len is bad, closing connection", msg_len) return b"" msg_seq_bytes = await self.upstream.readexactly(4) msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True) if msg_seq != self.seq_no: print_err("unexpected seq_no") return b"" self.seq_no += 1 data = await self.upstream.readexactly(msg_len - 4 - 4 - 4) checksum_bytes = await self.upstream.readexactly(4) checksum = int.from_bytes(checksum_bytes, "little") computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data) if computed_checksum != checksum: return b"" return data class MTProtoFrameStreamWriter(LayeredStreamWriterBase): __slots__ = ("seq_no",) def __init__(self, upstream, seq_no=0): self.upstream = upstream self.seq_no = seq_no def write(self, msg, extra={}): len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little") seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True) self.seq_no += 1 msg_without_checksum = len_bytes + seq_bytes + msg checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little") full_msg = msg_without_checksum + checksum padding = PADDING_FILLER * ( (-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER) ) return self.upstream.write(full_msg + padding) class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase): __slots__ = () async def read(self, buf_size): msg_len_bytes = await self.upstream.readexactly(1) msg_len = int.from_bytes(msg_len_bytes, "little") extra = {"QUICKACK_FLAG": False} if msg_len >= 0x80: extra["QUICKACK_FLAG"] = True msg_len -= 0x80 if msg_len == 0x7F: msg_len_bytes = await self.upstream.readexactly(3) msg_len = int.from_bytes(msg_len_bytes, "little") msg_len *= 4 data = await self.upstream.readexactly(msg_len) return data, extra class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase): __slots__ = () def write(self, data, extra={}): SMALL_PKT_BORDER = 0x7F LARGE_PKT_BORGER = 256**3 if len(data) % 4 != 0: print_err( "BUG: MTProtoFrameStreamWriter attempted to send msg with len", len(data), ) return 0 if extra.get("SIMPLE_ACK"): return self.upstream.write(data[::-1]) len_div_four = len(data) // 4 if len_div_four < SMALL_PKT_BORDER: return self.upstream.write(bytes([len_div_four]) + data) elif len_div_four < LARGE_PKT_BORGER: return self.upstream.write( b"\x7f" + int.to_bytes(len_div_four, 3, "little") + data ) else: print_err("Attempted to send too large pkt len =", len(data)) return 0 class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase): __slots__ = () async def read(self, buf_size): msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") extra = {} if msg_len > 0x80000000: extra["QUICKACK_FLAG"] = True msg_len -= 0x80000000 data = await self.upstream.readexactly(msg_len) return data, extra class MTProtoIntermediateFrameStreamWriter(LayeredStreamWriterBase): __slots__ = () def write(self, data, extra={}): if extra.get("SIMPLE_ACK"): return self.upstream.write(data) else: return self.upstream.write(int.to_bytes(len(data), 4, "little") + data) class MTProtoSecureIntermediateFrameStreamReader(LayeredStreamReaderBase): __slots__ = () async def read(self, buf_size): msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") extra = {} if msg_len > 0x80000000: extra["QUICKACK_FLAG"] = True msg_len -= 0x80000000 data = await self.upstream.readexactly(msg_len) if msg_len % 4 != 0: cut_border = msg_len - (msg_len % 4) data = data[:cut_border] return data, extra class MTProtoSecureIntermediateFrameStreamWriter(LayeredStreamWriterBase): __slots__ = () def write(self, data, extra={}): MAX_PADDING_LEN = 4 if extra.get("SIMPLE_ACK"): # TODO: make this unpredictable return self.upstream.write(data) else: padding_len = myrandom.randrange(MAX_PADDING_LEN) padding = myrandom.getrandbytes(padding_len) padded_data_len_bytes = int.to_bytes(len(data) + padding_len, 4, "little") return self.upstream.write(padded_data_len_bytes + data + padding) class ProxyReqStreamReader(LayeredStreamReaderBase): __slots__ = () async def read(self, msg): RPC_PROXY_ANS = b"\x0d\xda\x03\x44" RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e" RPC_SIMPLE_ACK = b"\x9b\x40\xac\x3b" RPC_UNKNOWN = b"\xdf\xa2\x30\x57" data = await self.upstream.read(1) if len(data) < 4: return b"" ans_type = data[:4] if ans_type == RPC_CLOSE_EXT: return b"" if ans_type == RPC_PROXY_ANS: ans_flags, conn_id, conn_data = data[4:8], data[8:16], data[16:] return conn_data if ans_type == RPC_SIMPLE_ACK: conn_id, confirm = data[4:12], data[12:16] return confirm, {"SIMPLE_ACK": True} if ans_type == RPC_UNKNOWN: return b"", {"SKIP_SEND": True} print_err("unknown rpc ans type:", ans_type) return b"", {"SKIP_SEND": True} class ProxyReqStreamWriter(LayeredStreamWriterBase): __slots__ = ("remote_ip_port", "our_ip_port", "out_conn_id", "proto_tag") def __init__(self, upstream, cl_ip, cl_port, my_ip, my_port, proto_tag): self.upstream = upstream if ":" not in cl_ip: self.remote_ip_port = b"\x00" * 10 + b"\xff\xff" self.remote_ip_port += socket.inet_pton(socket.AF_INET, cl_ip) else: self.remote_ip_port = socket.inet_pton(socket.AF_INET6, cl_ip) self.remote_ip_port += int.to_bytes(cl_port, 4, "little") if ":" not in my_ip: self.our_ip_port = b"\x00" * 10 + b"\xff\xff" self.our_ip_port += socket.inet_pton(socket.AF_INET, my_ip) else: self.our_ip_port = socket.inet_pton(socket.AF_INET6, my_ip) self.our_ip_port += int.to_bytes(my_port, 4, "little") self.out_conn_id = myrandom.getrandbytes(8) self.proto_tag = proto_tag def write(self, msg, extra={}): RPC_PROXY_REQ = b"\xee\xf1\xce\x36" EXTRA_SIZE = b"\x18\x00\x00\x00" PROXY_TAG = b"\xae\x26\x1e\xdb" FOUR_BYTES_ALIGNER = b"\x00\x00\x00" FLAG_NOT_ENCRYPTED = 0x2 FLAG_HAS_AD_TAG = 0x8 FLAG_MAGIC = 0x1000 FLAG_EXTMODE2 = 0x20000 FLAG_PAD = 0x8000000 FLAG_INTERMEDIATE = 0x20000000 FLAG_ABRIDGED = 0x40000000 FLAG_QUICKACK = 0x80000000 if len(msg) % 4 != 0: print_err("BUG: attempted to send msg with len %d" % len(msg)) return 0 flags = FLAG_HAS_AD_TAG | FLAG_MAGIC | FLAG_EXTMODE2 if self.proto_tag == PROTO_TAG_ABRIDGED: flags |= FLAG_ABRIDGED elif self.proto_tag == PROTO_TAG_INTERMEDIATE: flags |= FLAG_INTERMEDIATE elif self.proto_tag == PROTO_TAG_SECURE: flags |= FLAG_INTERMEDIATE | FLAG_PAD if extra.get("QUICKACK_FLAG"): flags |= FLAG_QUICKACK if msg.startswith(b"\x00" * 8): flags |= FLAG_NOT_ENCRYPTED full_msg = bytearray() full_msg += RPC_PROXY_REQ + int.to_bytes(flags, 4, "little") + self.out_conn_id full_msg += self.remote_ip_port + self.our_ip_port + EXTRA_SIZE + PROXY_TAG full_msg += bytes([len(config.AD_TAG)]) + config.AD_TAG + FOUR_BYTES_ALIGNER full_msg += msg return self.upstream.write(full_msg) def try_setsockopt(sock, level, option, value): try: sock.setsockopt(level, option, value) except OSError as E: pass def set_keepalive(sock, interval=40, attempts=5): sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if hasattr(socket, "TCP_KEEPIDLE"): try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval) if hasattr(socket, "TCP_KEEPINTVL"): try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval) if hasattr(socket, "TCP_KEEPCNT"): try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts) def set_ack_timeout(sock, timeout): if hasattr(socket, "TCP_USER_TIMEOUT"): try_setsockopt( sock, socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout * 1000 ) def set_bufsizes(sock, recv_buf, send_buf): try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf) try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf) def set_instant_rst(sock): INSTANT_RST = b"\x01\x00\x00\x00\x00\x00\x00\x00" if hasattr(socket, "SO_LINGER"): try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_LINGER, INSTANT_RST) def gen_x25519_public_key(): # generates some number which has square root by modulo P P = 2**255 - 19 n = myrandom.randrange(P) return int.to_bytes((n * n) % P, length=32, byteorder="little") async def connect_reader_to_writer(reader, writer): BUF_SIZE = 8192 try: while True: data = await reader.read(BUF_SIZE) if not data: if not writer.transport.is_closing(): writer.write_eof() await writer.drain() return writer.write(data) await writer.drain() except (OSError, asyncio.IncompleteReadError) as e: pass async def handle_bad_client(reader_clt, writer_clt, handshake): BUF_SIZE = 8192 CONNECT_TIMEOUT = 5 global mask_host_cached_ip update_stats(connects_bad=1) if writer_clt.transport.is_closing(): return set_bufsizes(writer_clt.get_extra_info("socket"), BUF_SIZE, BUF_SIZE) if not config.MASK or handshake is None: while await reader_clt.read(BUF_SIZE): # just consume all the data pass return writer_srv = None try: host = mask_host_cached_ip or config.MASK_HOST task = asyncio.open_connection(host, config.MASK_PORT, limit=BUF_SIZE) reader_srv, writer_srv = await asyncio.wait_for(task, timeout=CONNECT_TIMEOUT) if not mask_host_cached_ip: mask_host_cached_ip = writer_srv.get_extra_info("peername")[0] writer_srv.write(handshake) await writer_srv.drain() srv_to_clt = connect_reader_to_writer(reader_srv, writer_clt) clt_to_srv = connect_reader_to_writer(reader_clt, writer_srv) task_srv_to_clt = asyncio.ensure_future(srv_to_clt) task_clt_to_srv = asyncio.ensure_future(clt_to_srv) await asyncio.wait( [task_srv_to_clt, task_clt_to_srv], return_when=asyncio.FIRST_COMPLETED ) task_srv_to_clt.cancel() task_clt_to_srv.cancel() if writer_clt.transport.is_closing(): return # if the server closed the connection with RST or FIN-RST, copy them to the client if not writer_srv.transport.is_closing(): # workaround for uvloop, it doesn't fire exceptions on write_eof sock = writer_srv.get_extra_info("socket") raw_sock = socket.socket(sock.family, sock.type, sock.proto, sock.fileno()) try: raw_sock.shutdown(socket.SHUT_WR) except OSError as E: set_instant_rst(writer_clt.get_extra_info("socket")) finally: raw_sock.detach() else: set_instant_rst(writer_clt.get_extra_info("socket")) except ConnectionRefusedError as E: return except (OSError, asyncio.TimeoutError) as E: return finally: if writer_srv is not None: writer_srv.transport.abort() async def handle_fake_tls_handshake(handshake, reader, writer, peer): global used_handshakes global client_ips global last_client_ips global last_clients_with_time_skew global last_clients_with_same_handshake global fake_cert_len TIME_SKEW_MIN = -20 * 60 TIME_SKEW_MAX = 10 * 60 TLS_VERS = b"\x03\x03" TLS_CIPHERSUITE = b"\x13\x01" TLS_CHANGE_CIPHER = b"\x14" + TLS_VERS + b"\x00\x01\x01" TLS_APP_HTTP2_HDR = b"\x17" + TLS_VERS DIGEST_LEN = 32 DIGEST_HALFLEN = 16 DIGEST_POS = 11 SESSION_ID_LEN_POS = DIGEST_POS + DIGEST_LEN SESSION_ID_POS = SESSION_ID_LEN_POS + 1 tls_extensions = b"\x00\x2e" + b"\x00\x33\x00\x24" + b"\x00\x1d\x00\x20" tls_extensions += gen_x25519_public_key() + b"\x00\x2b\x00\x02\x03\x04" digest = handshake[DIGEST_POS : DIGEST_POS + DIGEST_LEN] if digest[:DIGEST_HALFLEN] in used_handshakes: last_clients_with_same_handshake[peer[0]] += 1 return False sess_id_len = handshake[SESSION_ID_LEN_POS] sess_id = handshake[SESSION_ID_POS : SESSION_ID_POS + sess_id_len] for user in config.USERS: secret = bytes.fromhex(config.USERS[user]) msg = ( handshake[:DIGEST_POS] + b"\x00" * DIGEST_LEN + handshake[DIGEST_POS + DIGEST_LEN :] ) computed_digest = hmac.new(secret, msg, digestmod=hashlib.sha256).digest() xored_digest = bytes(digest[i] ^ computed_digest[i] for i in range(DIGEST_LEN)) digest_good = xored_digest.startswith(b"\x00" * (DIGEST_LEN - 4)) if not digest_good: continue timestamp = int.from_bytes(xored_digest[-4:], "little") client_time_is_ok = TIME_SKEW_MIN < time.time() - timestamp < TIME_SKEW_MAX # some clients fail to read unix time and send the time since boot instead client_time_is_small = timestamp < 60 * 60 * 24 * 1000 accept_bad_time = ( config.IGNORE_TIME_SKEW or is_time_skewed or client_time_is_small ) if not client_time_is_ok and not accept_bad_time: last_clients_with_time_skew[peer[0]] = (time.time() - timestamp) // 60 continue http_data = myrandom.getrandbytes(fake_cert_len) srv_hello = TLS_VERS + b"\x00" * DIGEST_LEN + bytes([sess_id_len]) + sess_id srv_hello += TLS_CIPHERSUITE + b"\x00" + tls_extensions hello_pkt = b"\x16" + TLS_VERS + int.to_bytes(len(srv_hello) + 4, 2, "big") hello_pkt += b"\x02" + int.to_bytes(len(srv_hello), 3, "big") + srv_hello hello_pkt += TLS_CHANGE_CIPHER + TLS_APP_HTTP2_HDR hello_pkt += int.to_bytes(len(http_data), 2, "big") + http_data computed_digest = hmac.new( secret, msg=digest + hello_pkt, digestmod=hashlib.sha256 ).digest() hello_pkt = ( hello_pkt[:DIGEST_POS] + computed_digest + hello_pkt[DIGEST_POS + DIGEST_LEN :] ) writer.write(hello_pkt) await writer.drain() if config.REPLAY_CHECK_LEN > 0: while len(used_handshakes) >= config.REPLAY_CHECK_LEN: used_handshakes.popitem(last=False) used_handshakes[digest[:DIGEST_HALFLEN]] = True if config.CLIENT_IPS_LEN > 0: while len(client_ips) >= config.CLIENT_IPS_LEN: client_ips.popitem(last=False) if peer[0] not in client_ips: client_ips[peer[0]] = True last_client_ips[peer[0]] = True reader = FakeTLSStreamReader(reader) writer = FakeTLSStreamWriter(writer) return reader, writer return False async def handle_proxy_protocol(reader, peer=None): PROXY_SIGNATURE = b"PROXY " PROXY_MIN_LEN = 6 PROXY_TCP4 = b"TCP4" PROXY_TCP6 = b"TCP6" PROXY_UNKNOWN = b"UNKNOWN" PROXY2_SIGNATURE = b"\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" PROXY2_MIN_LEN = 16 PROXY2_AF_UNSPEC = 0x0 PROXY2_AF_INET = 0x1 PROXY2_AF_INET6 = 0x2 header = await reader.readexactly(PROXY_MIN_LEN) if header.startswith(PROXY_SIGNATURE): # proxy header v1 header += await reader.readuntil(b"\r\n") _, proxy_fam, *proxy_addr = header[:-2].split(b" ") if proxy_fam in (PROXY_TCP4, PROXY_TCP6): if len(proxy_addr) == 4: src_addr = proxy_addr[0].decode("ascii") src_port = int(proxy_addr[2].decode("ascii")) return (src_addr, src_port) elif proxy_fam == PROXY_UNKNOWN: return peer return False header += await reader.readexactly(PROXY2_MIN_LEN - PROXY_MIN_LEN) if header.startswith(PROXY2_SIGNATURE): # proxy header v2 proxy_ver = header[12] if proxy_ver & 0xF0 != 0x20: return False proxy_len = int.from_bytes(header[14:16], "big") proxy_addr = await reader.readexactly(proxy_len) if proxy_ver == 0x21: proxy_fam = header[13] >> 4 if proxy_fam == PROXY2_AF_INET: if proxy_len >= (4 + 2) * 2: src_addr = socket.inet_ntop(socket.AF_INET, proxy_addr[:4]) src_port = int.from_bytes(proxy_addr[8:10], "big") return (src_addr, src_port) elif proxy_fam == PROXY2_AF_INET6: if proxy_len >= (16 + 2) * 2: src_addr = socket.inet_ntop(socket.AF_INET6, proxy_addr[:16]) src_port = int.from_bytes(proxy_addr[32:34], "big") return (src_addr, src_port) elif proxy_fam == PROXY2_AF_UNSPEC: return peer elif proxy_ver == 0x20: return peer return False async def handle_handshake(reader, writer): global used_handshakes global client_ips global last_client_ips global last_clients_with_same_handshake TLS_START_BYTES = b"\x16\x03\x01" if writer.transport.is_closing() or writer.get_extra_info("peername") is None: return False peer = writer.get_extra_info("peername")[:2] if not peer: peer = ("unknown ip", 0) if config.PROXY_PROTOCOL: ip = peer[0] if peer else "unknown ip" peer = await handle_proxy_protocol(reader, peer) if not peer: print_err("Client from %s sent bad proxy protocol headers" % ip) await handle_bad_client(reader, writer, None) return False is_tls_handshake = True handshake = b"" for expected_byte in TLS_START_BYTES: handshake += await reader.readexactly(1) if handshake[-1] != expected_byte: is_tls_handshake = False break if is_tls_handshake: handshake += await reader.readexactly(2) tls_handshake_len = int.from_bytes(handshake[-2:], "big") if tls_handshake_len < 512: is_tls_handshake = False if is_tls_handshake: handshake += await reader.readexactly(tls_handshake_len) tls_handshake_result = await handle_fake_tls_handshake( handshake, reader, writer, peer ) if not tls_handshake_result: await handle_bad_client(reader, writer, handshake) return False reader, writer = tls_handshake_result handshake = await reader.readexactly(HANDSHAKE_LEN) else: if not config.MODES["classic"] and not config.MODES["secure"]: await handle_bad_client(reader, writer, handshake) return False handshake += await reader.readexactly(HANDSHAKE_LEN - len(handshake)) dec_prekey_and_iv = handshake[SKIP_LEN : SKIP_LEN + PREKEY_LEN + IV_LEN] dec_prekey, dec_iv = dec_prekey_and_iv[:PREKEY_LEN], dec_prekey_and_iv[PREKEY_LEN:] enc_prekey_and_iv = handshake[SKIP_LEN : SKIP_LEN + PREKEY_LEN + IV_LEN][::-1] enc_prekey, enc_iv = enc_prekey_and_iv[:PREKEY_LEN], enc_prekey_and_iv[PREKEY_LEN:] if dec_prekey_and_iv in used_handshakes: last_clients_with_same_handshake[peer[0]] += 1 await handle_bad_client(reader, writer, handshake) return False for user in config.USERS: secret = bytes.fromhex(config.USERS[user]) dec_key = hashlib.sha256(dec_prekey + secret).digest() decryptor = create_aes_ctr(key=dec_key, iv=int.from_bytes(dec_iv, "big")) enc_key = hashlib.sha256(enc_prekey + secret).digest() encryptor = create_aes_ctr(key=enc_key, iv=int.from_bytes(enc_iv, "big")) decrypted = decryptor.decrypt(handshake) proto_tag = decrypted[PROTO_TAG_POS : PROTO_TAG_POS + 4] if proto_tag not in ( PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE, ): continue if proto_tag == PROTO_TAG_SECURE: if is_tls_handshake and not config.MODES["tls"]: continue if not is_tls_handshake and not config.MODES["secure"]: continue else: if not config.MODES["classic"]: continue dc_idx = int.from_bytes( decrypted[DC_IDX_POS : DC_IDX_POS + 2], "little", signed=True ) if config.REPLAY_CHECK_LEN > 0: while len(used_handshakes) >= config.REPLAY_CHECK_LEN: used_handshakes.popitem(last=False) used_handshakes[dec_prekey_and_iv] = True if config.CLIENT_IPS_LEN > 0: while len(client_ips) >= config.CLIENT_IPS_LEN: client_ips.popitem(last=False) if peer[0] not in client_ips: client_ips[peer[0]] = True last_client_ips[peer[0]] = True reader = CryptoWrappedStreamReader(reader, decryptor) writer = CryptoWrappedStreamWriter(writer, encryptor) return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv, peer await handle_bad_client(reader, writer, handshake) return False async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): RESERVED_NONCE_FIRST_CHARS = [b"\xef"] RESERVED_NONCE_BEGININGS = [ b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54", b"\x47\x45\x54\x20", b"\xee\xee\xee\xee", b"\xdd\xdd\xdd\xdd", b"\x16\x03\x01\x02", ] RESERVED_NONCE_CONTINUES = [b"\x00\x00\x00\x00"] global my_ip_info global tg_connection_pool dc_idx = abs(dc_idx) - 1 if my_ip_info["ipv6"] and (config.PREFER_IPV6 or not my_ip_info["ipv4"]): if not 0 <= dc_idx < len(TG_DATACENTERS_V6): return False dc = TG_DATACENTERS_V6[dc_idx] else: if not 0 <= dc_idx < len(TG_DATACENTERS_V4): return False dc = TG_DATACENTERS_V4[dc_idx] try: reader_tgt, writer_tgt = await tg_connection_pool.get_connection( dc, TG_DATACENTER_PORT ) except ConnectionRefusedError as E: print_err( "Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT ) return False except ConnectionAbortedError as E: print_err( "The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, E) ) return False except (OSError, asyncio.TimeoutError) as E: print_err("Unable to connect to", dc, TG_DATACENTER_PORT) return False while True: rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN)) if rnd[:1] in RESERVED_NONCE_FIRST_CHARS: continue if rnd[:4] in RESERVED_NONCE_BEGININGS: continue if rnd[4:8] in RESERVED_NONCE_CONTINUES: continue break rnd[PROTO_TAG_POS : PROTO_TAG_POS + 4] = proto_tag if dec_key_and_iv: rnd[SKIP_LEN : SKIP_LEN + KEY_LEN + IV_LEN] = dec_key_and_iv[::-1] rnd = bytes(rnd) dec_key_and_iv = rnd[SKIP_LEN : SKIP_LEN + KEY_LEN + IV_LEN][::-1] dec_key, dec_iv = dec_key_and_iv[:KEY_LEN], dec_key_and_iv[KEY_LEN:] decryptor = create_aes_ctr(key=dec_key, iv=int.from_bytes(dec_iv, "big")) enc_key_and_iv = rnd[SKIP_LEN : SKIP_LEN + KEY_LEN + IV_LEN] enc_key, enc_iv = enc_key_and_iv[:KEY_LEN], enc_key_and_iv[KEY_LEN:] encryptor = create_aes_ctr(key=enc_key, iv=int.from_bytes(enc_iv, "big")) rnd_enc = rnd[:PROTO_TAG_POS] + encryptor.encrypt(rnd)[PROTO_TAG_POS:] writer_tgt.write(rnd_enc) await writer_tgt.drain() reader_tgt = CryptoWrappedStreamReader(reader_tgt, decryptor) writer_tgt = CryptoWrappedStreamWriter(writer_tgt, encryptor) return reader_tgt, writer_tgt def get_middleproxy_aes_key_and_iv( nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, clt_ip, srv_port, middleproxy_secret, clt_ipv6=None, srv_ipv6=None, ): EMPTY_IP = b"\x00\x00\x00\x00" if not clt_ip or not srv_ip: clt_ip = EMPTY_IP srv_ip = EMPTY_IP s = bytearray() s += ( nonce_srv + nonce_clt + clt_ts + srv_ip + clt_port + purpose + clt_ip + srv_port ) s += middleproxy_secret + nonce_srv if clt_ipv6 and srv_ipv6: s += clt_ipv6 + srv_ipv6 s += nonce_clt md5_sum = hashlib.md5(s[1:]).digest() sha1_sum = hashlib.sha1(s).digest() key = md5_sum[:12] + sha1_sum iv = hashlib.md5(s[2:]).digest() return key, iv async def middleproxy_handshake(host, port, reader_tgt, writer_tgt): """The most logic of middleproxy handshake, launched in pool""" START_SEQ_NO = -2 NONCE_LEN = 16 RPC_HANDSHAKE = b"\xf5\xee\x82\x76" RPC_NONCE = b"\xaa\x87\xcb\x7a" # pass as consts to simplify code RPC_FLAGS = b"\x00\x00\x00\x00" CRYPTO_AES = b"\x01\x00\x00\x00" RPC_NONCE_ANS_LEN = 32 RPC_HANDSHAKE_ANS_LEN = 32 writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO) key_selector = PROXY_SECRET[:4] crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little") nonce = myrandom.getrandbytes(NONCE_LEN) msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce writer_tgt.write(msg) await writer_tgt.drain() reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO) ans = await reader_tgt.read(get_to_clt_bufsize()) if len(ans) != RPC_NONCE_ANS_LEN: raise ConnectionAbortedError("bad rpc answer length") rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = ( ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32], ) if ( rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES ): raise ConnectionAbortedError("bad rpc answer") # get keys tg_ip, tg_port = writer_tgt.upstream.get_extra_info("peername")[:2] my_ip, my_port = writer_tgt.upstream.get_extra_info("sockname")[:2] use_ipv6_tg = ":" in tg_ip if not use_ipv6_tg: if my_ip_info["ipv4"]: # prefer global ip settings to work behind NAT my_ip = my_ip_info["ipv4"] tg_ip_bytes = socket.inet_pton(socket.AF_INET, tg_ip)[::-1] my_ip_bytes = socket.inet_pton(socket.AF_INET, my_ip)[::-1] tg_ipv6_bytes = None my_ipv6_bytes = None else: if my_ip_info["ipv6"]: my_ip = my_ip_info["ipv6"] tg_ip_bytes = None my_ip_bytes = None tg_ipv6_bytes = socket.inet_pton(socket.AF_INET6, tg_ip) my_ipv6_bytes = socket.inet_pton(socket.AF_INET6, my_ip) tg_port_bytes = int.to_bytes(tg_port, 2, "little") my_port_bytes = int.to_bytes(my_port, 2, "little") enc_key, enc_iv = get_middleproxy_aes_key_and_iv( nonce_srv=rpc_nonce, nonce_clt=nonce, clt_ts=crypto_ts, srv_ip=tg_ip_bytes, clt_port=my_port_bytes, purpose=b"CLIENT", clt_ip=my_ip_bytes, srv_port=tg_port_bytes, middleproxy_secret=PROXY_SECRET, clt_ipv6=my_ipv6_bytes, srv_ipv6=tg_ipv6_bytes, ) dec_key, dec_iv = get_middleproxy_aes_key_and_iv( nonce_srv=rpc_nonce, nonce_clt=nonce, clt_ts=crypto_ts, srv_ip=tg_ip_bytes, clt_port=my_port_bytes, purpose=b"SERVER", clt_ip=my_ip_bytes, srv_port=tg_port_bytes, middleproxy_secret=PROXY_SECRET, clt_ipv6=my_ipv6_bytes, srv_ipv6=tg_ipv6_bytes, ) encryptor = create_aes_cbc(key=enc_key, iv=enc_iv) decryptor = create_aes_cbc(key=dec_key, iv=dec_iv) SENDER_PID = b"IPIPPRPDTIME" PEER_PID = b"IPIPPRPDTIME" # TODO: pass client ip and port here for statistics handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID writer_tgt.upstream = CryptoWrappedStreamWriter( writer_tgt.upstream, encryptor, block_size=16 ) writer_tgt.write(handshake) await writer_tgt.drain() reader_tgt.upstream = CryptoWrappedStreamReader( reader_tgt.upstream, decryptor, block_size=16 ) handshake_ans = await reader_tgt.read(1) if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN: raise ConnectionAbortedError("bad rpc handshake answer length") handshake_type, handshake_flags, handshake_sender_pid, handshake_peer_pid = ( handshake_ans[:4], handshake_ans[4:8], handshake_ans[8:20], handshake_ans[20:32], ) if handshake_type != RPC_HANDSHAKE or handshake_peer_pid != SENDER_PID: raise ConnectionAbortedError("bad rpc handshake answer") return reader_tgt, writer_tgt, my_ip, my_port async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): global my_ip_info global tg_connection_pool use_ipv6_tg = my_ip_info["ipv6"] and (config.PREFER_IPV6 or not my_ip_info["ipv4"]) if use_ipv6_tg: if dc_idx not in TG_MIDDLE_PROXIES_V6: return False addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V6[dc_idx]) else: if dc_idx not in TG_MIDDLE_PROXIES_V4: return False addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) try: ret = await tg_connection_pool.get_connection(addr, port, middleproxy_handshake) reader_tgt, writer_tgt, my_ip, my_port = ret except ConnectionRefusedError as E: print_err( "The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port) ) return False except ConnectionAbortedError as E: print_err( "The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E) ) return False except (OSError, asyncio.TimeoutError) as E: print_err( "Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port) ) return False writer_tgt = ProxyReqStreamWriter( writer_tgt, cl_ip, cl_port, my_ip, my_port, proto_tag ) reader_tgt = ProxyReqStreamReader(reader_tgt) return reader_tgt, writer_tgt async def tg_connect_reader_to_writer(rd, wr, user, rd_buf_size, is_upstream): try: while True: data = await rd.read(rd_buf_size) if isinstance(data, tuple): data, extra = data else: extra = {} if extra.get("SKIP_SEND"): continue if not data: wr.write_eof() await wr.drain() return else: if is_upstream: update_user_stats( user, octets_from_client=len(data), msgs_from_client=1 ) else: update_user_stats( user, octets_to_client=len(data), msgs_to_client=1 ) wr.write(data, extra) await wr.drain() except (OSError, asyncio.IncompleteReadError) as e: # print_err(e) pass async def handle_client(reader_clt, writer_clt): set_keepalive( writer_clt.get_extra_info("socket"), config.CLIENT_KEEPALIVE, attempts=3 ) set_ack_timeout(writer_clt.get_extra_info("socket"), config.CLIENT_ACK_TIMEOUT) set_bufsizes( writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize() ) update_stats(connects_all=1) try: clt_data = await asyncio.wait_for( handle_handshake(reader_clt, writer_clt), timeout=config.CLIENT_HANDSHAKE_TIMEOUT, ) except asyncio.TimeoutError: update_stats(handshake_timeouts=1) return if not clt_data: return reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv, peer = clt_data cl_ip, cl_port = peer update_user_stats(user, connects=1) connect_directly = not config.USE_MIDDLE_PROXY or disable_middle_proxy if connect_directly: if config.FAST_MODE: tg_data = await do_direct_handshake( proto_tag, dc_idx, dec_key_and_iv=enc_key_and_iv ) else: tg_data = await do_direct_handshake(proto_tag, dc_idx) else: tg_data = await do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port) if not tg_data: return reader_tg, writer_tg = tg_data if connect_directly and config.FAST_MODE: class FakeEncryptor: def encrypt(self, data): return data class FakeDecryptor: def decrypt(self, data): return data reader_tg.decryptor = FakeDecryptor() writer_clt.encryptor = FakeEncryptor() if not connect_directly: if proto_tag == PROTO_TAG_ABRIDGED: reader_clt = MTProtoCompactFrameStreamReader(reader_clt) writer_clt = MTProtoCompactFrameStreamWriter(writer_clt) elif proto_tag == PROTO_TAG_INTERMEDIATE: reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt) writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt) elif proto_tag == PROTO_TAG_SECURE: reader_clt = MTProtoSecureIntermediateFrameStreamReader(reader_clt) writer_clt = MTProtoSecureIntermediateFrameStreamWriter(writer_clt) else: return tg_to_clt = tg_connect_reader_to_writer( reader_tg, writer_clt, user, get_to_clt_bufsize(), False ) clt_to_tg = tg_connect_reader_to_writer( reader_clt, writer_tg, user, get_to_tg_bufsize(), True ) task_tg_to_clt = asyncio.ensure_future(tg_to_clt) task_clt_to_tg = asyncio.ensure_future(clt_to_tg) update_user_stats(user, curr_connects=1) tcp_limit_hit = ( user in config.USER_MAX_TCP_CONNS and user_stats[user]["curr_connects"] > config.USER_MAX_TCP_CONNS[user] ) user_expired = ( user in config.USER_EXPIRATIONS and datetime.datetime.now() > config.USER_EXPIRATIONS[user] ) user_data_quota_hit = user in config.USER_DATA_QUOTA and ( user_stats[user]["octets_to_client"] + user_stats[user]["octets_from_client"] > config.USER_DATA_QUOTA[user] ) if (not tcp_limit_hit) and (not user_expired) and (not user_data_quota_hit): start = time.time() await asyncio.wait( [task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED ) update_durations(time.time() - start) update_user_stats(user, curr_connects=-1) task_tg_to_clt.cancel() task_clt_to_tg.cancel() writer_tg.transport.abort() async def handle_client_wrapper(reader, writer): try: await handle_client(reader, writer) except (asyncio.IncompleteReadError, asyncio.CancelledError): pass except (ConnectionResetError, TimeoutError, BrokenPipeError): pass except Exception: traceback.print_exc() finally: writer.transport.abort() def make_metrics_pkt(metrics): pkt_body_list = [] used_names = set() for name, m_type, desc, val in metrics: name = config.METRICS_PREFIX + name if name not in used_names: pkt_body_list.append("# HELP %s %s" % (name, desc)) pkt_body_list.append("# TYPE %s %s" % (name, m_type)) used_names.add(name) if isinstance(val, dict): tags = [] for tag, tag_val in val.items(): if tag == "val": continue tag_val = tag_val.replace('"', r"\"") tags.append('%s="%s"' % (tag, tag_val)) pkt_body_list.append("%s{%s} %s" % (name, ",".join(tags), val["val"])) else: pkt_body_list.append("%s %s" % (name, val)) pkt_body = "\n".join(pkt_body_list) + "\n" pkt_header_list = [] pkt_header_list.append("HTTP/1.1 200 OK") pkt_header_list.append("Connection: close") pkt_header_list.append("Content-Length: %d" % len(pkt_body)) pkt_header_list.append("Content-Type: text/plain; version=0.0.4; charset=utf-8") pkt_header_list.append( "Date: %s" % time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) ) pkt_header = "\r\n".join(pkt_header_list) pkt = pkt_header + "\r\n\r\n" + pkt_body return pkt async def handle_metrics(reader, writer): global stats global user_stats global my_ip_info global proxy_start_time global proxy_links global last_clients_with_time_skew global last_clients_with_same_handshake client_ip = writer.get_extra_info("peername")[0] if client_ip not in config.METRICS_WHITELIST: writer.close() return try: metrics = [] metrics.append( ["uptime", "counter", "proxy uptime", time.time() - proxy_start_time] ) metrics.append( [ "connects_bad", "counter", "connects with bad secret", stats["connects_bad"], ] ) metrics.append( ["connects_all", "counter", "incoming connects", stats["connects_all"]] ) metrics.append( [ "handshake_timeouts", "counter", "number of timed out handshakes", stats["handshake_timeouts"], ] ) if config.METRICS_EXPORT_LINKS: for link in proxy_links: link_as_metric = link.copy() link_as_metric["val"] = 1 metrics.append( [ "proxy_link_info", "counter", "the proxy link info", link_as_metric, ] ) bucket_start = 0 for bucket in STAT_DURATION_BUCKETS: bucket_end = bucket if bucket != STAT_DURATION_BUCKETS[-1] else "+Inf" metric = { "bucket": "%s-%s" % (bucket_start, bucket_end), "val": stats["connects_with_duration_le_%s" % str(bucket)], } metrics.append( ["connects_by_duration", "counter", "connects by duration", metric] ) bucket_start = bucket_end user_metrics_desc = [ ["user_connects", "counter", "user connects", "connects"], ["user_connects_curr", "gauge", "current user connects", "curr_connects"], [ "user_octets", "counter", "octets proxied for user", "octets_from_client+octets_to_client", ], [ "user_msgs", "counter", "msgs proxied for user", "msgs_from_client+msgs_to_client", ], [ "user_octets_from", "counter", "octets proxied from user", "octets_from_client", ], ["user_octets_to", "counter", "octets proxied to user", "octets_to_client"], ["user_msgs_from", "counter", "msgs proxied from user", "msgs_from_client"], ["user_msgs_to", "counter", "msgs proxied to user", "msgs_to_client"], ] for m_name, m_type, m_desc, stat_key in user_metrics_desc: for user, stat in user_stats.items(): if "+" in stat_key: val = 0 for key_part in stat_key.split("+"): val += stat[key_part] else: val = stat[stat_key] metric = {"user": user, "val": val} metrics.append([m_name, m_type, m_desc, metric]) pkt = make_metrics_pkt(metrics) writer.write(pkt.encode()) await writer.drain() except Exception: traceback.print_exc() finally: writer.close() async def stats_printer(): global user_stats global last_client_ips global last_clients_with_time_skew global last_clients_with_same_handshake while True: await asyncio.sleep(config.STATS_PRINT_PERIOD) print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) for user, stat in user_stats.items(): print( "%s: %d connects (%d current), %.2f MB, %d msgs" % ( user, stat["connects"], stat["curr_connects"], (stat["octets_from_client"] + stat["octets_to_client"]) / 1000000, stat["msgs_from_client"] + stat["msgs_to_client"], ) ) print(flush=True) if last_client_ips: print("New IPs:") for ip in last_client_ips: print(ip) print(flush=True) last_client_ips.clear() if last_clients_with_time_skew: print("Clients with time skew (possible replay-attackers):") for ip, skew_minutes in last_clients_with_time_skew.items(): print("%s, clocks were %d minutes behind" % (ip, skew_minutes)) print(flush=True) last_clients_with_time_skew.clear() if last_clients_with_same_handshake: print("Clients with duplicate handshake (likely replay-attackers):") for ip, times in last_clients_with_same_handshake.items(): print("%s, %d times" % (ip, times)) print(flush=True) last_clients_with_same_handshake.clear() async def make_https_req(url, host="core.telegram.org"): """Make request, return resp body and headers.""" SSL_PORT = 443 url_data = urllib.parse.urlparse(url) HTTP_REQ_TEMPLATE = ( "\r\n".join(["GET %s HTTP/1.1", "Host: %s", "Connection: close"]) + "\r\n\r\n" ) reader, writer = await asyncio.open_connection(url_data.netloc, SSL_PORT, ssl=True) req = HTTP_REQ_TEMPLATE % (urllib.parse.quote(url_data.path), host) writer.write(req.encode("utf8")) data = await reader.read() writer.close() headers, body = data.split(b"\r\n\r\n", 1) return headers, body def gen_tls_client_hello_msg(server_name): msg = bytearray() msg += b"\x16\x03\x01\x02\x00\x01\x00\x01\xfc\x03\x03" + myrandom.getrandbytes(32) msg += b"\x20" + myrandom.getrandbytes(32) msg += b"\x00\x22\x4a\x4a\x13\x01\x13\x02\x13\x03\xc0\x2b\xc0\x2f\xc0\x2c\xc0\x30\xcc\xa9" msg += b"\xcc\xa8\xc0\x13\xc0\x14\x00\x9c\x00\x9d\x00\x2f\x00\x35\x00\x0a\x01\x00\x01\x91" msg += b"\xda\xda\x00\x00\x00\x00" msg += int.to_bytes(len(server_name) + 5, 2, "big") msg += int.to_bytes(len(server_name) + 3, 2, "big") + b"\x00" msg += int.to_bytes(len(server_name), 2, "big") + server_name.encode("ascii") msg += b"\x00\x17\x00\x00\xff\x01\x00\x01\x00\x00\x0a\x00\x0a\x00\x08\xaa\xaa\x00\x1d\x00" msg += b"\x17\x00\x18\x00\x0b\x00\x02\x01\x00\x00\x23\x00\x00\x00\x10\x00\x0e\x00\x0c\x02" msg += b"\x68\x32\x08\x68\x74\x74\x70\x2f\x31\x2e\x31\x00\x05\x00\x05\x01\x00\x00\x00\x00" msg += b"\x00\x0d\x00\x14\x00\x12\x04\x03\x08\x04\x04\x01\x05\x03\x08\x05\x05\x01\x08\x06" msg += b"\x06\x01\x02\x01\x00\x12\x00\x00\x00\x33\x00\x2b\x00\x29\xaa\xaa\x00\x01\x00\x00" msg += b"\x1d\x00\x20" + gen_x25519_public_key() msg += b"\x00\x2d\x00\x02\x01\x01\x00\x2b\x00\x0b\x0a\xba\xba\x03\x04\x03\x03\x03\x02\x03" msg += b"\x01\x00\x1b\x00\x03\x02\x00\x02\x3a\x3a\x00\x01\x00\x00\x15" msg += int.to_bytes(517 - len(msg) - 2, 2, "big") msg += b"\x00" * (517 - len(msg)) return bytes(msg) async def get_encrypted_cert(host, port, server_name): async def get_tls_record(reader): try: record_type = (await reader.readexactly(1))[0] tls_version = await reader.readexactly(2) if tls_version != b"\x03\x03": return 0, b"" record_len = int.from_bytes(await reader.readexactly(2), "big") record = await reader.readexactly(record_len) return record_type, record except asyncio.IncompleteReadError: return 0, b"" reader, writer = await asyncio.open_connection(host, port) writer.write(gen_tls_client_hello_msg(server_name)) await writer.drain() record1_type, record1 = await get_tls_record(reader) if record1_type != 22: return b"" record2_type, record2 = await get_tls_record(reader) if record2_type != 20: return b"" record3_type, record3 = await get_tls_record(reader) if record3_type != 23: return b"" if len(record3) < MIN_CERT_LEN: record4_type, record4 = await get_tls_record(reader) if record4_type != 23: return b"" msg = ( "The MASK_HOST %s sent some TLS record before certificate record, this makes the " + "proxy more detectable" ) % config.MASK_HOST print_err(msg) return record4 return record3 async def get_mask_host_cert_len(): global fake_cert_len GET_CERT_TIMEOUT = 10 MASK_ENABLING_CHECK_PERIOD = 60 while True: try: if not config.MASK: # do nothing await asyncio.sleep(MASK_ENABLING_CHECK_PERIOD) continue task = get_encrypted_cert( config.MASK_HOST, config.MASK_PORT, config.TLS_DOMAIN ) cert = await asyncio.wait_for(task, timeout=GET_CERT_TIMEOUT) if cert: if len(cert) < MIN_CERT_LEN: msg = ( "The MASK_HOST %s returned several TLS records, this is not supported" % config.MASK_HOST ) print_err(msg) elif len(cert) != fake_cert_len: fake_cert_len = len(cert) print_err( "Got cert from the MASK_HOST %s, its length is %d" % (config.MASK_HOST, fake_cert_len) ) else: print_err( "The MASK_HOST %s is not TLS 1.3 host, this is not recommended" % config.MASK_HOST ) except ConnectionRefusedError: print_err( "The MASK_HOST %s is refusing connections, this is not recommended" % config.MASK_HOST ) except (TimeoutError, asyncio.TimeoutError): print_err( "Got timeout while getting TLS handshake from MASK_HOST %s" % config.MASK_HOST ) except Exception as E: print_err("Failed to connect to MASK_HOST %s: %s" % (config.MASK_HOST, E)) await asyncio.sleep(config.GET_CERT_LEN_PERIOD) async def get_srv_time(): TIME_SYNC_ADDR = "https://core.telegram.org/getProxySecret" MAX_TIME_SKEW = 30 global disable_middle_proxy global is_time_skewed want_to_reenable_advertising = False while True: try: headers, secret = await make_https_req(TIME_SYNC_ADDR) for line in headers.split(b"\r\n"): if not line.startswith(b"Date: "): continue line = line[len("Date: ") :].decode() srv_time = datetime.datetime.strptime(line, "%a, %d %b %Y %H:%M:%S %Z") now_time = datetime.datetime.utcnow() is_time_skewed = (now_time - srv_time).total_seconds() > MAX_TIME_SKEW if ( is_time_skewed and config.USE_MIDDLE_PROXY and not disable_middle_proxy ): print_err("Time skew detected, please set the clock") print_err("Server time:", srv_time, "your time:", now_time) print_err("Disabling advertising to continue serving") print_err("Putting down the shields against replay attacks") disable_middle_proxy = True want_to_reenable_advertising = True elif not is_time_skewed and want_to_reenable_advertising: print_err("Time is ok, reenabling advertising") disable_middle_proxy = False want_to_reenable_advertising = False except Exception as E: print_err("Error getting server time", E) await asyncio.sleep(config.GET_TIME_PERIOD) async def clear_ip_resolving_cache(): global mask_host_cached_ip min_sleep = myrandom.randrange(60 - 10, 60 + 10) max_sleep = myrandom.randrange(120 - 10, 120 + 10) while True: mask_host_cached_ip = None await asyncio.sleep(myrandom.randrange(min_sleep, max_sleep)) async def update_middle_proxy_info(): async def get_new_proxies(url): PROXY_REGEXP = re.compile(r"proxy_for\s+(-?\d+)\s+(.+):(\d+)\s*;") ans = {} headers, body = await make_https_req(url) fields = PROXY_REGEXP.findall(body.decode("utf8")) if fields: for dc_idx, host, port in fields: if host.startswith("[") and host.endswith("]"): host = host[1:-1] dc_idx, port = int(dc_idx), int(port) if dc_idx not in ans: ans[dc_idx] = [(host, port)] else: ans[dc_idx].append((host, port)) return ans PROXY_INFO_ADDR = "https://core.telegram.org/getProxyConfig" PROXY_INFO_ADDR_V6 = "https://core.telegram.org/getProxyConfigV6" PROXY_SECRET_ADDR = "https://core.telegram.org/getProxySecret" global TG_MIDDLE_PROXIES_V4 global TG_MIDDLE_PROXIES_V6 global PROXY_SECRET while True: try: v4_proxies = await get_new_proxies(PROXY_INFO_ADDR) if not v4_proxies: raise Exception("no proxy data") TG_MIDDLE_PROXIES_V4 = v4_proxies except Exception as E: print_err("Error updating middle proxy list:", E) try: v6_proxies = await get_new_proxies(PROXY_INFO_ADDR_V6) if not v6_proxies: raise Exception("no proxy data (ipv6)") TG_MIDDLE_PROXIES_V6 = v6_proxies except Exception as E: print_err("Error updating middle proxy list for IPv6:", E) try: headers, secret = await make_https_req(PROXY_SECRET_ADDR) if not secret: raise Exception("no secret") if secret != PROXY_SECRET: PROXY_SECRET = secret print_err("Middle proxy secret updated") except Exception as E: print_err("Error updating middle proxy secret, using old", E) await asyncio.sleep(config.PROXY_INFO_UPDATE_PERIOD) def init_ip_info(): global my_ip_info global disable_middle_proxy def get_ip_from_url(url): TIMEOUT = 5 try: with urllib.request.urlopen(url, timeout=TIMEOUT) as f: if f.status != 200: raise Exception("Invalid status code") return f.read().decode().strip() except Exception: return None IPV4_URL1 = "http://v4.ident.me/" IPV4_URL2 = "http://ipv4.icanhazip.com/" IPV6_URL1 = "http://v6.ident.me/" IPV6_URL2 = "http://ipv6.icanhazip.com/" my_ip_info["ipv4"] = get_ip_from_url(IPV4_URL1) or get_ip_from_url(IPV4_URL2) my_ip_info["ipv6"] = get_ip_from_url(IPV6_URL1) or get_ip_from_url(IPV6_URL2) # the server can return ipv4 address instead of ipv6 if my_ip_info["ipv6"] and ":" not in my_ip_info["ipv6"]: my_ip_info["ipv6"] = None if my_ip_info["ipv6"] and (config.PREFER_IPV6 or not my_ip_info["ipv4"]): print_err("IPv6 found, using it for external communication") if config.USE_MIDDLE_PROXY: if not my_ip_info["ipv4"] and not my_ip_info["ipv6"]: print_err("Failed to determine your ip, advertising disabled") disable_middle_proxy = True def print_tg_info(): global my_ip_info global proxy_links print_default_warning = False if config.PORT == 3256: print("The default port 3256 is used, this is not recommended", flush=True) if not config.MODES["classic"] and not config.MODES["secure"]: print( "Since you have TLS only mode enabled the best port is 443", flush=True ) print_default_warning = True if not config.MY_DOMAIN: ip_addrs = [ip for ip in my_ip_info.values() if ip] if not ip_addrs: ip_addrs = ["YOUR_IP"] else: ip_addrs = [config.MY_DOMAIN] proxy_links = [] for user, secret in sorted(config.USERS.items(), key=lambda x: x[0]): for ip in ip_addrs: if config.MODES["classic"]: params = {"server": ip, "port": config.PORT, "secret": secret} params_encodeded = urllib.parse.urlencode(params, safe=":") classic_link = "tg://proxy?{}".format(params_encodeded) proxy_links.append({"user": user, "link": classic_link}) print("{}: {}".format(user, classic_link), flush=True) if config.MODES["secure"]: params = {"server": ip, "port": config.PORT, "secret": "dd" + secret} params_encodeded = urllib.parse.urlencode(params, safe=":") dd_link = "tg://proxy?{}".format(params_encodeded) proxy_links.append({"user": user, "link": dd_link}) print("{}: {}".format(user, dd_link), flush=True) if config.MODES["tls"]: tls_secret = "ee" + secret + config.TLS_DOMAIN.encode().hex() # the base64 links is buggy on ios # tls_secret = bytes.fromhex("ee" + secret) + config.TLS_DOMAIN.encode() # tls_secret_base64 = base64.urlsafe_b64encode(tls_secret) params = {"server": ip, "port": config.PORT, "secret": tls_secret} params_encodeded = urllib.parse.urlencode(params, safe=":") tls_link = "tg://proxy?{}".format(params_encodeded) proxy_links.append({"user": user, "link": tls_link}) print("{}: {}".format(user, tls_link), flush=True) if secret in [ "00000000000000000000000000000000", "0123456789abcdef0123456789abcdef", "00000000000000000000000000000001", ]: msg = "The default secret {} is used, this is not recommended".format( secret ) print(msg, flush=True) random_secret = "".join( myrandom.choice("0123456789abcdef") for i in range(32) ) print("You can change it to this random secret:", random_secret, flush=True) print_default_warning = True if config.TLS_DOMAIN == "www.google.com": print( "The default TLS_DOMAIN www.google.com is used, this is not recommended", flush=True, ) msg = "You should use random existing domain instead, bad clients are proxied there" print(msg, flush=True) print_default_warning = True if print_default_warning: print_err("Warning: one or more default settings detected") def setup_files_limit(): try: import resource soft_fd_limit, hard_fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (hard_fd_limit, hard_fd_limit)) except (ValueError, OSError): print( "Failed to increase the limit of opened files", flush=True, file=sys.stderr ) except ImportError: pass def setup_asyncio(): # get rid of annoying "socket.send() raised exception" log messages asyncio.constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES = 100 def setup_signals(): if hasattr(signal, "SIGUSR1"): def debug_signal(signum, frame): import pdb pdb.set_trace() signal.signal(signal.SIGUSR1, debug_signal) if hasattr(signal, "SIGUSR2"): def reload_signal(signum, frame): init_config() ensure_users_in_user_stats() apply_upstream_proxy_settings() print("Config reloaded", flush=True, file=sys.stderr) print_tg_info() signal.signal(signal.SIGUSR2, reload_signal) def try_setup_uvloop(): if config.SOCKS5_HOST and config.SOCKS5_PORT: # socks mode is not compatible with uvloop return try: import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) print_err("Found uvloop, using it for optimal performance") except ImportError: pass def remove_unix_socket(path): try: if stat.S_ISSOCK(os.stat(path).st_mode): os.unlink(path) except (FileNotFoundError, NotADirectoryError): pass def loop_exception_handler(loop, context): exception = context.get("exception") transport = context.get("transport") if exception: if isinstance(exception, TimeoutError): if transport: transport.abort() return if isinstance(exception, OSError): IGNORE_ERRNO = { 10038, # operation on non-socket on Windows, likely because fd == -1 121, # the semaphore timeout period has expired on Windows } FORCE_CLOSE_ERRNO = { 113, # no route to host } if exception.errno in IGNORE_ERRNO: return elif exception.errno in FORCE_CLOSE_ERRNO: if transport: transport.abort() return loop.default_exception_handler(context) def create_servers(loop): servers = [] reuse_port = hasattr(socket, "SO_REUSEPORT") has_unix = hasattr(socket, "AF_UNIX") if config.LISTEN_ADDR_IPV4: task = asyncio.start_server( handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, limit=get_to_tg_bufsize(), reuse_port=reuse_port, ) servers.append(loop.run_until_complete(task)) if config.LISTEN_ADDR_IPV6 and socket.has_ipv6: task = asyncio.start_server( handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, limit=get_to_tg_bufsize(), reuse_port=reuse_port, ) servers.append(loop.run_until_complete(task)) if config.LISTEN_UNIX_SOCK and has_unix: remove_unix_socket(config.LISTEN_UNIX_SOCK) task = asyncio.start_unix_server( handle_client_wrapper, config.LISTEN_UNIX_SOCK, limit=get_to_tg_bufsize() ) servers.append(loop.run_until_complete(task)) os.chmod(config.LISTEN_UNIX_SOCK, 0o666) if config.METRICS_PORT is not None: if config.METRICS_LISTEN_ADDR_IPV4: task = asyncio.start_server( handle_metrics, config.METRICS_LISTEN_ADDR_IPV4, config.METRICS_PORT ) servers.append(loop.run_until_complete(task)) if config.METRICS_LISTEN_ADDR_IPV6 and socket.has_ipv6: task = asyncio.start_server( handle_metrics, config.METRICS_LISTEN_ADDR_IPV6, config.METRICS_PORT ) servers.append(loop.run_until_complete(task)) return servers def create_utilitary_tasks(loop): tasks = [] stats_printer_task = asyncio.Task(stats_printer(), loop=loop) tasks.append(stats_printer_task) if config.USE_MIDDLE_PROXY: middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info(), loop=loop) tasks.append(middle_proxy_updater_task) if config.GET_TIME_PERIOD: time_get_task = asyncio.Task(get_srv_time(), loop=loop) tasks.append(time_get_task) get_cert_len_task = asyncio.Task(get_mask_host_cert_len(), loop=loop) tasks.append(get_cert_len_task) clear_resolving_cache_task = asyncio.Task(clear_ip_resolving_cache(), loop=loop) tasks.append(clear_resolving_cache_task) return tasks def main(): init_config() ensure_users_in_user_stats() apply_upstream_proxy_settings() init_ip_info() print_tg_info() setup_asyncio() setup_files_limit() setup_signals() try_setup_uvloop() init_proxy_start_time() if sys.platform == "win32": loop = asyncio.ProactorEventLoop() else: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.set_exception_handler(loop_exception_handler) utilitary_tasks = create_utilitary_tasks(loop) for task in utilitary_tasks: asyncio.ensure_future(task) servers = create_servers(loop) try: loop.run_forever() except KeyboardInterrupt: pass if hasattr(asyncio, "all_tasks"): tasks = asyncio.all_tasks(loop) else: # for compatibility with Python 3.6 tasks = asyncio.Task.all_tasks(loop) for task in tasks: task.cancel() for server in servers: server.close() loop.run_until_complete(server.wait_closed()) has_unix = hasattr(socket, "AF_UNIX") if config.LISTEN_UNIX_SOCK and has_unix: remove_unix_socket(config.LISTEN_UNIX_SOCK) loop.close() if __name__ == "__main__": main()