commit 9dd690c5f404e467ce646109b74a384f3c8a8365 Author: Marvin Martinson Date: Sun Dec 2 21:56:27 2018 +0200 Push diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..61efa50 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# ignore backup files +._* +# ignore the ESP32 MicroPython binary +esp32*.bin +# ignore the development config file +config-dev.json diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..66cefce --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 RoboKoding LTD + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5af2a5f --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +#!/bin/bash + +SERIAL_PORT=/dev/ttyUSB0 +#SERIAL_PORT=/dev/tty.SLAB_USBtoUART +#SERIAL_PORT=/dev/tty.wchusbserial1410 + +all: flash delay libs config update reset +upload: config update reset console + +delay: + sleep 3 + +reset: + esptool.py -p $(SERIAL_PORT) --after hard_reset read_mac + picocom -b 115200 $(SERIAL_PORT) + +libs: + ampy -p $(SERIAL_PORT) put uwebsockets.py + ampy -p $(SERIAL_PORT) put debounce.py + +update: + ampy -p $(SERIAL_PORT) put hal.py + ampy -p $(SERIAL_PORT) put main.py + ampy -p $(SERIAL_PORT) put boot.py + +config: + ampy -p $(SERIAL_PORT) put config.json + +flash: + esptool.py -p $(SERIAL_PORT) --chip esp32 -b 115200 erase_flash + esptool.py -p $(SERIAL_PORT) --chip esp32 -b 115200 write_flash --flash_mode dio 0x1000 esp32-*.bin + +console: + echo "Ctrl-A + Ctrl-Q to close Picocom" + picocom -b 115200 $(SERIAL_PORT) diff --git a/README.md b/README.md new file mode 100644 index 0000000..ea12f11 --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# sumorobot-firmware + +The software that is running on the SumoRobots + +Code + +# Instructions +* Change the SERIAL_PORT in the Makefile +* Add your WiFi networks to the config.json file +* Install [Python](https://www.python.org/downloads/) +* Install [esptool](https://github.com/espressif/esptool) (to flash MicroPython to the ESP32) +* Install [ampy](https://github.com/adafruit/ampy) (for uploading files) +* Download [the MicroPython binary](http://micropython.org/download#esp32) to this directory +* Upload the MicroPython binary and the SumoRobot firmware to your ESP32 (open a terminal and type: make all) + +# Support +If you find our work useful, please consider donating : ) +[![Donate using Liberapay](https://liberapay.com/assets/widgets/donate.svg)](https://liberapay.com/robokoding/donate) + + +# TODOS +* variable motor speed control, then more interesting for kids + + + +# Credits +* [K-SPACE MTÜ](https://k-space.ee/) diff --git a/boot.py b/boot.py new file mode 100644 index 0000000..c8f417b --- /dev/null +++ b/boot.py @@ -0,0 +1,61 @@ +import ujson +import network +from hal import * +from utime import sleep_ms +from machine import Timer, Pin +import ubinascii + +print("Press Ctrl-C to stop boot script...") +sleep_ms(200) + +#Pin(25, Pin.OUT).value(0) + + +# Open and parse the config file +with open("config.json", "r") as config_file: + config = ujson.load(config_file) + +# if not config["sumo_id"]: +# config["sumo_id"] = ubinascii.hexlify(network.WLAN().config('mac'),':').decode().replace(":","")[6:] +# with open("config.part", "w") as config_file: +# config_file.write(ujson.dumps(config)) +# os.rename("config.part", "config.json") + +config["sumo_id"] = ubinascii.hexlify(network.WLAN().config('mac'),':').decode().replace(":","")[6:] + +sleep_ms(500) + +robotName = "Sumo-"+config["sumo_id"] +# Initialize the SumoRobot object +sumorobot = Sumorobot(config) + +# Indiacte booting with blinking status LED +#timer = Timer(0) +#sumorobot.toggle_led() +#timer.init(period=2000, mode=Timer.PERIODIC, callback=sumorobot.toggle_led) + +# Connect to WiFi +wlan = network.WLAN(network.STA_IF) + +# Activate the WiFi interface +wlan.active(True) + +wlan.config(dhcp_hostname=robotName) + +# If not already connected +if not wlan.isconnected(): + # Scan for WiFi networks + networks = wlan.scan() + # Go trough all scanned WiFi networks + for network in networks: + # Extract the networks SSID + ssid = network[0].decode("utf-8") + # Check if the SSID is in the config file + if ssid in config["wifis"].keys(): + # Start to connect to the pre-configured network + wlan.connect(ssid, config["wifis"][ssid]) + break + +# Clean up +import gc +gc.collect() diff --git a/config.json b/config.json new file mode 100644 index 0000000..fc19260 --- /dev/null +++ b/config.json @@ -0,0 +1,25 @@ +{ + "paly/stop_button_pin": 22, + "charging_pin":4, + "left_motor_pin":13, + "right_motor_pin":15, + "led_power_pin": 25, + "ultrasonic_echo_pin":14, + "ultrasonic_trigger_pin":12, + "line_left_pin":32, + "line_middle_pin":35, + "line_right_pin":34, + "sumo_id": "", + "firmware_version": 0.3, + "left_servo_tuning": 33, + "right_servo_tuning": 33, + "ultrasonic_distance": 40, + "left_line_threshold": 4500, + "middle_line_threshold": 4500, + "right_line_threshold": 4500, + "motors_reverse": 1, + "sumo_server": "sumo.koodur.com:80", + "wifis": { + "SSID": "PAROOL" + } +} diff --git a/debounce.py b/debounce.py new file mode 100644 index 0000000..60d65c1 --- /dev/null +++ b/debounce.py @@ -0,0 +1,57 @@ +# +# inspired by: https://forum.micropython.org/viewtopic.php?t=1938#p10931 +# +import micropython + +try: + from machine import Timer + timer_init = lambda t, p, cb: t.init(period=p, callback=cb) +except ImportError: + from pyb import Timer + timer_init = lambda t, p, cb: t.init(freq=1000 // p, callback=cb) + +# uncomment when debugging callback problems +#micropython.alloc_emergency_exception_buf(100) + + +class DebouncedSwitch: + def __init__(self, sw, cb, arg=None, delay=50, tid=4): + self.sw = sw + # Create references to bound methods beforehand + # http://docs.micropython.org/en/latest/pyboard/library/micropython.html#micropython.schedule + self._sw_cb = self.sw_cb + self._tim_cb = self.tim_cb + self._set_cb = getattr(self.sw, 'callback', None) or self.sw.irq + self.delay = delay + self.tim = Timer(tid) + self.callback(cb, arg) + + def sw_cb(self, pin=None): + self._set_cb(None) + timer_init(self.tim, self.delay, self._tim_cb) + + def tim_cb(self, tim): + tim.deinit() + if self.sw(): + micropython.schedule(self.cb, self.arg) + self._set_cb(self._sw_cb if self.cb else None) + + def callback(self, cb, arg=None): + self.tim.deinit() + self.cb = cb + self.arg = arg + self._set_cb(self._sw_cb if cb else None) + + +def test_pyb(ledno=1): + import pyb + sw = pyb.Switch() + led = pyb.LED(ledno) + return DebouncedSwitch(sw, lambda l: l.toggle(), led) + + +def test_machine(swpin=2, ledpin=16): + from machine import Pin + sw = Pin(swpin, Pin.IN) + led = Pin(ledpin, Pin.OUT) + return DebouncedSwitch(sw, lambda l: l.value(not l.value()), led) diff --git a/hal.py b/hal.py new file mode 100644 index 0000000..7e4f36b --- /dev/null +++ b/hal.py @@ -0,0 +1,257 @@ +import os +import ujson +from utime import sleep_us, sleep_ms +from machine import Pin, PWM, ADC, time_pulse_us, deepsleep +import random + +# LEDs +STATUS = 0 +#OPPONENT = 1 +#LEFT_LINE = 2 +#RIGHT_LINE = 3 + +# Directions +STOP = 0 +LEFT = 1 +MIDDLE = 6 +RIGHT = 2 +SEARCH = 3 +FORWARD = 4 +BACKWARD = 5 + + +#states +MOVING = 0 +STANDBY = 1 + +class Sumorobot(object): + # Constructor + def __init__(self, config = None): + # Config file + self.config = config + + self.state = STANDBY + + self.name = "Sumo-"+self.config["sumo_id"] + + # Ultrasonic distance sensor + self.echo = Pin(self.config["ultrasonic_echo_pin"], Pin.IN) + self.trigger = Pin(self.config["ultrasonic_trigger_pin"], Pin.OUT) + + # Servo PWM-s + self.pwm_left = PWM(Pin(self.config["left_motor_pin"]), freq=50, duty=0) + self.pwm_right = PWM(Pin(self.config["right_motor_pin"]), freq=50, duty=0) + + # Bottom status LED + self.led_power = Pin(self.config["led_power_pin"], Pin.OUT) + self.charging = Pin(self.config["charging_pin"]); + + self.playStop = Pin(self.config["paly/stop_button_pin"], Pin.IN, Pin.PULL_UP) + + + self.led_power.value(0) + + self.adc_line_left = ADC(Pin(32)) + self.adc_line_middle = ADC(Pin(35)) + self.adc_line_right = ADC(Pin(34)) + + # Set reference voltage to 3.3V + self.adc_line_left.atten(ADC.ATTN_11DB) + self.adc_line_right.atten(ADC.ATTN_11DB) + self.adc_line_middle.atten(ADC.ATTN_11DB) + + # To smooth out ultrasonic sensor value + self.opponent_score = 0 + + # For terminating sleep + self.terminate = False + + # For search mode + self.search = 0 + self.search_counter = 0 + + # Memorise previous servo speeds + self.prev_speed = {LEFT: 0, RIGHT: 0} + + #saving line sensor valus, to read once in 50ms loop + self.line_left = 0 + self.line_right = 0; + self.line_middle = 0; + + self.speedForward = 100 if self.config["motors_reverse"] == 0 else -100 + self.speedReverse = -100 if self.config["motors_reverse"] == 0 else 100 + + + # Function to get distance (cm) from the object in front of the SumoRobot + def get_opponent_distance(self): + # Send a pulse + self.trigger.value(0) + sleep_us(5) + self.trigger.value(1) + sleep_us(10) + self.trigger.value(0) + # Wait for the pulse and calculate the distance + return (time_pulse_us(self.echo, 1, 30000) / 2) / 29.1 + + # Function to get boolean if there is something in front of the SumoRobot + def is_opponent(self): + # Get the opponent distance + self.opponent_distance = self.get_opponent_distance() + # When the opponent is close and the ping actually returned + if self.opponent_distance < self.config["ultrasonic_distance"] and self.opponent_distance > 0: + # When not maximum score + if self.opponent_score < 5: + # Increase the opponent score + self.opponent_score += 1 + # When no opponent was detected + else: + # When not lowest score + if self.opponent_score > 0: + # Decrease the opponent score + self.opponent_score -= 1 + + # When the sensor saw something more than 2 times + opponent = True if self.opponent_score > 2 else False + + # Trigger opponent LED + #self.set_led(OPPONENT, opponent) + + return opponent + + # Function to update line calibration and write it to the config file + def calibrate_line(self): + # Read the line sensor values + self.config["left_line_threshold"] = self.adc_line_left.read() + self.config["right_line_threshold"] = self.adc_line_right.read() + self.config["middle_line_threshold"] = self.adc_line_middle.read() + # Update the config file + with open("config.part", "w") as config_file: + config_file.write(ujson.dumps(config)) + os.rename("config.part", "config.json") + + # Function to get light inensity from the phototransistors + def get_line(self, dir): + # Check if the direction is valid + assert dir in (LEFT, RIGHT, MIDDLE) + + # Return the given line sensor value + if dir == LEFT: + return self.adc_line_left.read() + elif dir == RIGHT: + return self.adc_line_right.read() + elif dir == MIDDLE: + return self.adc_line_middle.read() + + def is_line(self, dir): + # Check if the direction is valid + assert dir in (LEFT, RIGHT, MIDDLE) + + # Return the given line sensor value, storing it to variable, not to ask double in 50ms time loop + if dir == LEFT: + self.line_left = self.get_line(LEFT) + line = abs(self.line_left - self.config["left_line_threshold"]) > 1000 + #self.set_led(LEFT_LINE, line) + return line + elif dir == RIGHT: + self.line_right = self.get_line(RIGHT) + line = abs(self.line_right - self.config["right_line_threshold"]) > 1000 + #self.set_led(RIGHT_LINE, line) + return line + elif dir == MIDDLE: + self.line_middle = self.get_line(MIDDLE) + line = abs(self.line_middle - self.config["middle_line_threshold"]) > 1000 + return line + + def set_servo(self, dir, speed): + # Check if the direction is valid + assert dir in (LEFT, RIGHT) + # Check if the speed is valid + assert speed <= 100 and speed >= -100 + + # When the speed didn't change + if speed == self.prev_speed[dir]: + return + + # Record the new speed + self.prev_speed[dir] = speed + + # Set the given servo speed + if dir == LEFT: + if speed == 0: + self.pwm_left.duty(0) + else: + # -100 ... 100 to 33 .. 102 + self.pwm_left.duty(int(33 + self.config["left_servo_tuning"] + speed * 33 / 100)) + elif dir == RIGHT: + if speed == 0: + self.pwm_right.duty(0) + else: + # -100 ... 100 to 33 .. 102 + self.pwm_right.duty(int(33 + self.config["right_servo_tuning"] + speed * 33 / 100)) + + def move(self, dir): + # Check if the direction is valid + assert dir in (SEARCH, STOP, RIGHT, LEFT, BACKWARD, FORWARD) + # Go to the given direction + + if dir == STOP: + self.set_state(STANDBY) + else: + self.set_state(MOVING) + + if dir == STOP: + self.set_servo(LEFT, 0) + self.set_servo(RIGHT, 0) + elif dir == LEFT: + self.set_servo(LEFT, self.speedReverse) + self.set_servo(RIGHT, self.speedReverse) + elif dir == RIGHT: + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedForward) + elif dir == SEARCH: + # Change search mode after X seconds + if self.search_counter == 50: + self.search = random.randrange(0,3) + self.search_counter = 0 + #self.search = 0 if self.search > 2 else self.search + # When in search mode + if self.search == 0: + # Go forward + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedReverse) + elif self.search == 1: + # Go left + self.set_servo(LEFT, self.speedReverse) + self.set_servo(RIGHT, self.speedReverse) + elif self.search == 2: + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedForward) + # Increase search counter + self.search_counter += 1 + elif dir == FORWARD: + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedReverse) + elif dir == BACKWARD: + self.set_servo(LEFT, self.speedReverse) + self.set_servo(RIGHT, self.speedForward) + + + def sleep(self, delay): + # Check for valid delay + assert delay > 0 + + # Split the delay into 50ms chunks + for j in range(0, delay, 50): + # Check for forceful termination + if self.terminate: + # Terminate the delay + return + else: + sleep_ms(50) + + def set_state(self,value): + assert value in (MOVING,STANDBY) + self.state = value + + def get_state(self): + return self.state diff --git a/main.py b/main.py new file mode 100644 index 0000000..a82e7e7 --- /dev/null +++ b/main.py @@ -0,0 +1,179 @@ +import _thread +import ubinascii +import ujson +import uwebsockets +import os +from debounce import DebouncedSwitch + +# Code to execute +ast = "" +executeCode = False + +# Scope, info to be sent to the client +scope = dict() + +def buttoncallback(p=True): + global executeCode + + sumorobot.terminate = executeCode + sumorobot.led_power.value(not executeCode) + sleep_ms(50) + executeCode = not executeCode + print(executeCode) + +def writeCodeTofile(data): + with open("code.part", "w") as code_file: + code_file.write(ujson.dumps(data)) + sleep_ms(50) + os.rename("code.part", "code") + sleep_ms(50) + + +def step(): + global scope + + while True: + + # Update scope + scope = dict( + line_left = sumorobot.line_left, + line_right = sumorobot.line_right, + line_middle = sumorobot.line_middle, + opponent = sumorobot.get_opponent_distance(), + battery_voltage = 0, + state = executeCode + ) + + # Execute code + if(executeCode): + exec(ast) + + #sumorobot.playStop.irq(trigger=Pin.IRQ_RISING, handler=buttoncallback) + sw = DebouncedSwitch(sumorobot.playStop, buttoncallback, "dummy") + + # When robot was stopped + if sumorobot.terminate: + # Disable forceful termination of delays in code + sumorobot.terminate = False + # Stop the robot + sumorobot.move(STOP) + # Leave time to process WebSocket commands + sleep_ms(50) + +def ws_handler(): + global executeCode + global ast + global has_wifi_connection + + while True: + # When WiFi has just been reconnected + if wlan.isconnected() and not has_wifi_connection: + #conn = uwebsockets.connect(url) + #sumorobot.set_led(STATUS, True) + has_wifi_connection = True + # When WiFi has just been disconnected + elif not wlan.isconnected() and has_wifi_connection: + #sumorobot.set_led(STATUS, False) + has_wifi_connection = False + elif not wlan.isconnected(): + # Continue to wait for a WiFi connection + continue + + try: # Try to read from the WebSocket + data = conn.recv() + except Exception as e: # Socket timeout, no data received + # Continue to try to read data + #print(e) + + #if not conn.open: + conn = uwebsockets.connect(url) + continue + + # When an empty frame was received + if not data: + # Continue to receive data + continue + elif b'forward' in data: + ast = "" + sumorobot.move(FORWARD) + elif b'backward' in data: + ast = "" + sumorobot.move(BACKWARD) + elif b'right' in data: + ast = "" + sumorobot.move(RIGHT) + elif b'left' in data: + ast = "" + sumorobot.move(LEFT) + elif b'ping' in data: + #conn.send(ujson.dumps({"cmd":"sensord", "data":ujson.dumps(scope)})) + conn.send(ujson.dumps({"cmd":"sensors", "data":ujson.loads(repr(scope).replace("'", '"').replace("False","false").replace("True","true"))})) + elif b'code' in data: + executeCode = False + try: + data = ujson.loads(data) + writeCodeTofile(data['val']) + conn.send(ujson.dumps({"cmd":"code_upload", "status":True})) + except Exception as e: + conn.send(ujson.dumps({"cmd":"code_upload", "status":False})) + continue + data['val'] = data['val'].replace(";;", "\n") + print(data['val']) + ast = compile(data['val'], "snippet", "exec") + elif b'start' in data: + #buttoncallback() + executeCode = True + sumorobot.led_power.value(1) + sleep_ms(50) + data = ujson.loads(data) + writeCodeTofile(data['val']) + data['val'] = data['val'].replace(";;", "\n") + ast = compile(data['val'], "snippet", "exec") + elif b'stop' in data: + #ast = "" + sumorobot.led_power.value(0) + executeCode = False + sumorobot.move(STOP) + # for terminating delays in code + sumorobot.terminate = True + elif b'calibrate_line' in data: + sumorobot.led_power.value(1) + sleep_ms(50) + sumorobot.calibrate_line() + sumorobot.led_power.value(0) + elif b'Gone' in data: + print("server said 410 Gone, attempting to reconnect...") + #conn = uwebsockets.connect(url) + else: + print("unknown cmd:", data) + +# Wait for WiFi to get connected +while not wlan.isconnected(): + sleep_ms(100) + +# Connect to the websocket +url = "ws://%s/p2p/sumo-%s/browser/" % (config['sumo_server'], config['sumo_id']) +conn = uwebsockets.connect(url) + +# Set X seconds timeout for socket reads +conn.settimeout(3) + +# Stop bootup blinking +#timer.deinit() + +# WiFi is connected +has_wifi_connection = True +# Indicate that the WebSocket is connected +#sumorobot.set_led(STATUS, True) + +if('code' in os.listdir()): + with open("code", "r") as code_file: + data = ujson.load(code_file).replace(";;", "\n") + ast = compile(data, "snippet", "exec") + #print(ast) + + +# Start the code processing thread +_thread.start_new_thread(step, ()) +# Start the Websocket processing thread +_thread.start_new_thread(ws_handler, ()) diff --git a/uwebsockets.py b/uwebsockets.py new file mode 100644 index 0000000..5ce1b3c --- /dev/null +++ b/uwebsockets.py @@ -0,0 +1,241 @@ +""" +Websockets client for micropython + +Based very heavily on +https://github.com/aaugustin/websockets/blob/master/websockets/client.py +""" + +#import usocket as socket +import os +import ure as re +import urandom as random +import ustruct as struct +import usocket as socket +import ubinascii as binascii +from ucollections import namedtuple + +# Opcodes +OP_CONT = const(0x0) +OP_TEXT = const(0x1) +OP_BYTES = const(0x2) +OP_CLOSE = const(0x8) +OP_PING = const(0x9) +OP_PONG = const(0xa) + +# Close codes +CLOSE_OK = const(1000) +CLOSE_GOING_AWAY = const(1001) +CLOSE_PROTOCOL_ERROR = const(1002) +CLOSE_DATA_NOT_SUPPORTED = const(1003) +CLOSE_BAD_DATA = const(1007) +CLOSE_POLICY_VIOLATION = const(1008) +CLOSE_TOO_BIG = const(1009) +CLOSE_MISSING_EXTN = const(1010) +CLOSE_BAD_CONDITION = const(1011) + +URL_RE = re.compile(r'ws://([A-Za-z0-9\-\.]+)(?:\:([0-9]+))?(/.+)?') +URI = namedtuple('URI', ('hostname', 'port', 'path')) + +def urlparse(uri): + match = URL_RE.match(uri) + if match: + return URI(match.group(1), int(match.group(2)), match.group(3)) + else: + raise ValueError("Invalid URL: %s" % uri) + +class Websocket: + is_client = False + + def __init__(self, sock): + self._sock = sock + self.open = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + def settimeout(self, timeout): + self._sock.settimeout(timeout) + + def read_frame(self, max_size=None): + # Frame header + byte1, byte2 = struct.unpack('!BB', self._sock.read(2)) + + # Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4) + fin = bool(byte1 & 0x80) + opcode = byte1 & 0x0f + + # Byte 2: MASK(1) LENGTH(7) + mask = bool(byte2 & (1 << 7)) + length = byte2 & 0x7f + + if length == 126: # Magic number, length header is 2 bytes + length, = struct.unpack('!H', self._sock.read(2)) + elif length == 127: # Magic number, length header is 8 bytes + length, = struct.unpack('!Q', self._sock.read(8)) + + if mask: # Mask is 4 bytes + mask_bits = self._sock.read(4) + + try: + data = self._sock.read(length) + except MemoryError: + # We can't receive this many bytes, close the socket + self.close(code=CLOSE_TOO_BIG) + return True, OP_CLOSE, None + + if mask: + data = bytes(b ^ mask_bits[i % 4] + for i, b in enumerate(data)) + + return fin, opcode, data + + def write_frame(self, opcode, data=b''): + fin = True + mask = self.is_client # messages sent by client are masked + + length = len(data) + + # Frame header + # Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4) + byte1 = 0x80 if fin else 0 + byte1 |= opcode + + # Byte 2: MASK(1) LENGTH(7) + byte2 = 0x80 if mask else 0 + + if length < 126: # 126 is magic value to use 2-byte length header + byte2 |= length + self._sock.write(struct.pack('!BB', byte1, byte2)) + + elif length < (1 << 16): # Length fits in 2-bytes + byte2 |= 126 # Magic code + self._sock.write(struct.pack('!BBH', byte1, byte2, length)) + + elif length < (1 << 64): + byte2 |= 127 # Magic code + self._sock.write(struct.pack('!BBQ', byte1, byte2, length)) + + else: + raise ValueError() + + if mask: # Mask is 4 bytes + mask_bits = struct.pack('!I', random.getrandbits(32)) + self._sock.write(mask_bits) + + data = bytes(b ^ mask_bits[i % 4] + for i, b in enumerate(data)) + + self._sock.write(data) + + def recv(self): + assert self.open + + while self.open: + try: + fin, opcode, data = self.read_frame() + except ValueError: + self._close() + return + + if not fin: + raise NotImplementedError() + + if opcode == OP_TEXT: + return data + elif opcode == OP_BYTES: + return data + elif opcode == OP_CLOSE: + self._close() + return + elif opcode == OP_PONG: + # Ignore this frame, keep waiting for a data frame + continue + elif opcode == OP_PING: + # We need to send a pong frame + self.write_frame(OP_PONG, data) + # And then wait to receive + continue + elif opcode == OP_CONT: + # This is a continuation of a previous frame + raise NotImplementedError(opcode) + else: + raise ValueError(opcode) + + def send(self, buf): + assert self.open + + if isinstance(buf, str): + opcode = OP_TEXT + buf = buf.encode('utf-8') + elif isinstance(buf, bytes): + opcode = OP_BYTES + else: + raise TypeError() + + self.write_frame(opcode, buf) + + def close(self, code=CLOSE_OK, reason=''): + if not self.open: + return + + buf = struct.pack('!H', code) + reason.encode('utf-8') + + self.write_frame(OP_CLOSE, buf) + self._close() + + def _close(self): + self.open = False + self._sock.close() + +class WebsocketClient(Websocket): + is_client = True + +def connect(uri): + """ + Connect a websocket. + """ + + # Parse the given WebSocket URI + uri = urlparse(uri) + assert uri + + # Connect the socket + sock = socket.socket() + addr = socket.getaddrinfo(uri.hostname, uri.port) + sock.connect(addr[0][4]) + + # Sec-WebSocket-Key is 16 bytes of random base64 encoded + key = binascii.b2a_base64(os.urandom(16))[:-1] + + # WebSocket initiation headers + headers = [ + b'GET %s HTTP/1.1' % uri.path or '/', + b'Upgrade: websocket', + b'Connection: Upgrade', + b'Host: %s:%s' % (uri.hostname, uri.port), + b'Origin: http://%s:%s' % (uri.hostname, uri.port), + b'Sec-WebSocket-Key: ' + key, + b'Sec-WebSocket-Version: 13', + b'', + b'' + ] + + # Concatenate the headers and add new lines + data = b'\r\n'.join(headers) + + # Send the WebSocket initiation packet + sock.send(data) + + # Check for the WebSocket response header + header = sock.readline()[:-2] + assert header == b'HTTP/1.1 101 Switching Protocols', header + + # We don't (currently) need these headers + # FIXME: should we check the return key? + while header: + header = sock.readline()[:-2] + + return WebsocketClient(sock)