From f787ef79cd5c944ad913975a43abf3ba3102c961 Mon Sep 17 00:00:00 2001 From: Mike Walters Date: Wed, 19 Jun 2024 18:16:47 +0100 Subject: [PATCH] gateware.usb: Implement ClearFeature(ENDPOINT_HALT). --- applets/clear_endpoint_halt_test.py | 223 +++++++++++++++++++++ luna/gateware/usb/request/standard.py | 28 ++- luna/gateware/usb/usb2/control.py | 2 + luna/gateware/usb/usb2/endpoint.py | 10 + luna/gateware/usb/usb2/endpoints/stream.py | 18 ++ luna/gateware/usb/usb2/request.py | 24 ++- luna/gateware/usb/usb2/transfer.py | 6 +- pyproject.toml | 2 +- 8 files changed, 303 insertions(+), 10 deletions(-) create mode 100755 applets/clear_endpoint_halt_test.py diff --git a/applets/clear_endpoint_halt_test.py b/applets/clear_endpoint_halt_test.py new file mode 100755 index 000000000..0e7637c97 --- /dev/null +++ b/applets/clear_endpoint_halt_test.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# +# This file is part of LUNA. +# +# Copyright (c) 2024 Great Scott Gadgets +# SPDX-License-Identifier: BSD-3-Clause + +import logging +import os +import time +import usb1 + +from amaranth import Elaboratable, Module, Signal + +from luna import top_level_cli, configure_default_logging +from luna.usb2 import USBDevice, USBStreamInEndpoint, USBStreamOutEndpoint +from luna.gateware.stream.generator import StreamSerializer +from luna.gateware.usb.request.control import ControlRequestHandler +from luna.gateware.usb.stream import USBInStreamInterface + +from usb_protocol.types import USBRequestRecipient, USBRequestType +from usb_protocol.emitters import DeviceDescriptorCollection + +# use pid.codes Test PID +VID = 0x1209 +PID = 0x0001 + +BULK_ENDPOINT_NUMBER = 1 +MAX_BULK_PACKET_SIZE = 512 + +COUNTER_MAX = 251 +GET_OUT_COUNTER_VALID = 0 + +out_counter_valid = Signal(reset=1) + +class VendorRequestHandler(ControlRequestHandler): + + REQUEST_SET_LEDS = 0 + + def elaborate(self, platform): + m = Module() + + interface = self.interface + setup = self.interface.setup + + # Transmitter for small-constant-response requests + m.submodules.transmitter = transmitter = \ + StreamSerializer(data_length=1, domain="usb", stream_type=USBInStreamInterface, max_length_width=1) + # + # Vendor request handlers. + with m.FSM(domain="usb"): + with m.State('IDLE'): + vendor = setup.type == USBRequestType.VENDOR + with m.If( + setup.received & \ + (setup.type == USBRequestType.VENDOR) & \ + (setup.recipient == USBRequestRecipient.INTERFACE) & \ + (setup.index == 0)): + with m.Switch(setup.request): + with m.Case(GET_OUT_COUNTER_VALID): + m.d.comb += interface.claim.eq(1) + m.next = 'GET_OUT_COUNTER_VALID' + pass + + with m.State('GET_OUT_COUNTER_VALID'): + m.d.comb += interface.claim.eq(1) + self.handle_simple_data_request(m, transmitter, out_counter_valid, length=1) + + return m + + +class ClearHaltTestDevice(Elaboratable): + + + def create_descriptors(self): + + descriptors = DeviceDescriptorCollection() + + with descriptors.DeviceDescriptor() as d: + d.idVendor = VID + d.idProduct = PID + + d.iManufacturer = "LUNA" + d.iProduct = "Clear Endpoint Halt Test" + d.iSerialNumber = "no serial" + + d.bNumConfigurations = 1 + + + with descriptors.ConfigurationDescriptor() as c: + + with c.InterfaceDescriptor() as i: + i.bInterfaceNumber = 0 + + with i.EndpointDescriptor() as e: + e.bEndpointAddress = 0x80 | BULK_ENDPOINT_NUMBER + e.wMaxPacketSize = MAX_BULK_PACKET_SIZE + + with i.EndpointDescriptor() as e: + e.bEndpointAddress = BULK_ENDPOINT_NUMBER + e.wMaxPacketSize = MAX_BULK_PACKET_SIZE + + + return descriptors + + + def elaborate(self, platform): + m = Module() + + m.submodules.car = platform.clock_domain_generator() + + ulpi = platform.request(platform.default_usb_connection) + m.submodules.usb = usb = USBDevice(bus=ulpi) + + descriptors = self.create_descriptors() + control_ep = usb.add_standard_control_endpoint(descriptors) + + control_ep.add_request_handler(VendorRequestHandler()) + + stream_in_ep = USBStreamInEndpoint( + endpoint_number=BULK_ENDPOINT_NUMBER, + max_packet_size=MAX_BULK_PACKET_SIZE + ) + usb.add_endpoint(stream_in_ep) + + stream_out_ep = USBStreamOutEndpoint( + endpoint_number=BULK_ENDPOINT_NUMBER, + max_packet_size=MAX_BULK_PACKET_SIZE + ) + usb.add_endpoint(stream_out_ep) + + # Generate a counter on the IN endpoint. + in_counter = Signal(8) + with m.If(stream_in_ep.stream.ready): + m.d.usb += in_counter.eq(in_counter + 1) + with m.If(in_counter == COUNTER_MAX): + m.d.usb += in_counter.eq(0) + + # Expect a counter on the OUT endpoint, and verify that it is contiguous. + prev_out_counter = Signal(8, reset=COUNTER_MAX) + with m.If(stream_out_ep.stream.valid): + out_counter = stream_out_ep.stream.payload + counter_increase = out_counter == (prev_out_counter + 1) + counter_wrap = (out_counter == 0) & (prev_out_counter == COUNTER_MAX) + with m.If(~counter_increase & ~counter_wrap): + m.d.usb += out_counter_valid.eq(0) + + m.d.usb += prev_out_counter.eq(out_counter) + + m.d.comb += [ + stream_in_ep.stream.valid .eq(1), + stream_in_ep.stream.payload .eq(in_counter), + + stream_out_ep.stream.ready .eq(1), + ] + + # Connect our device as a high speed device by default. + m.d.comb += [ + usb.connect .eq(1), + usb.full_speed_only .eq(1 if os.getenv('LUNA_FULL_ONLY') else 0), + ] + + return m + +def test_clear_halt(): + with usb1.USBContext() as context: + device = context.openByVendorIDAndProductID(VID, PID) + + # Read the first packet which should have a DATA0 PID, next we expect DATA1. + packet = device.bulkRead(BULK_ENDPOINT_NUMBER, MAX_BULK_PACKET_SIZE) + # Send clear halt, this resets both sides to DATA0. + device.clearHalt(usb1.ENDPOINT_IN | BULK_ENDPOINT_NUMBER) + # Read another packet. If the PID doesn't match what we epxect, + # then the host will assume it was a retransmission of the last one and drop it. + packet += device.bulkRead(BULK_ENDPOINT_NUMBER, MAX_BULK_PACKET_SIZE) + + # Check that the counter is contiguous across all received data, making sure we didn't drop a packet. + for i in range(1, len(packet)): + if packet[i] == packet[i-1] + 1: + pass + elif packet[i] == 0 and packet[i-1] == COUNTER_MAX: + pass + else: + print(f"IN test fail {i} {packet[i]} {packet[i-1]}") + return + + print("IN OK") + + # Generate three packets worth of counter data, the gateware will verify that it is contiguous. + data = bytes(i % (COUNTER_MAX+1) for i in range(MAX_BULK_PACKET_SIZE*3)) + # Send DATA0, device should expect DATA1 next. + device.bulkWrite(BULK_ENDPOINT_NUMBER, data[:MAX_BULK_PACKET_SIZE]) + # Reset both sides to DATA0. + device.clearHalt(usb1.ENDPOINT_OUT | BULK_ENDPOINT_NUMBER) + # Send two packets. If the first packet doesn't match, + # it'll be dropped and another is required to let the gateware check the counter. + device.bulkWrite(BULK_ENDPOINT_NUMBER, data[MAX_BULK_PACKET_SIZE:]) + + # Read back the out_counter_valid register to check for success. + request_type = usb1.REQUEST_TYPE_VENDOR | usb1.RECIPIENT_INTERFACE | usb1.ENDPOINT_IN + if device.controlRead(request_type, GET_OUT_COUNTER_VALID, 0, 0, 1)[0] == 1: + print("OUT OK") + else: + print("OUT FAIL") + + +if __name__ == "__main__": + configure_default_logging() + + # If our environment is suggesting we rerun tests without rebuilding, do so. + if os.getenv('LUNA_RERUN_TEST'): + logging.info("Running speed test without rebuilding...") + + # Otherwise, rebuild. + else: + device = top_level_cli(ClearHaltTestDevice) + + # Give the device a moment to connect. + if device is not None: + logging.info("Giving the device time to connect...") + time.sleep(5) + + test_clear_halt() \ No newline at end of file diff --git a/luna/gateware/usb/request/standard.py b/luna/gateware/usb/request/standard.py index aaf62f3f6..8f844bd67 100644 --- a/luna/gateware/usb/request/standard.py +++ b/luna/gateware/usb/request/standard.py @@ -14,7 +14,7 @@ from amaranth import * from amaranth.hdl.ast import Value, Const -from usb_protocol.types import USBStandardRequests, USBRequestType +from usb_protocol.types import USBStandardFeatures, USBStandardRequests, USBRequestRecipient, USBRequestType from usb_protocol.emitters import DeviceDescriptorCollection from ..usb2.request import RequestHandlerInterface, USBRequestHandler @@ -139,6 +139,8 @@ def elaborate(self, platform): with m.Case(USBStandardRequests.GET_STATUS): m.next = 'GET_STATUS' + with m.Case(USBStandardRequests.CLEAR_FEATURE): + m.next = 'CLEAR_FEATURE' with m.Case(USBStandardRequests.SET_ADDRESS): m.next = 'SET_ADDRESS' with m.Case(USBStandardRequests.SET_CONFIGURATION): @@ -158,6 +160,30 @@ def elaborate(self, platform): # TODO: copy the remote wakeup and bus-powered attributes from bmAttributes of the relevant descriptor? self.handle_simple_data_request(m, transmitter, 0, length=2) + with m.State('CLEAR_FEATURE'): + # Provide an response to the STATUS stage. + with m.If(self.interface.status_requested): + + # If our stall condition is met, stall; otherwise, send a ZLP [USB 8.5.3]. + # For now, we only implement clearing ENDPOINT_HALT. + stall_condition = \ + (self.interface.setup.recipient != USBRequestRecipient.ENDPOINT) | \ + (self.interface.setup.value != USBStandardFeatures.ENDPOINT_HALT) + with m.If(stall_condition): + m.d.comb += self.interface.handshakes_out.stall.eq(1) + with m.Else(): + m.d.comb += self.send_zlp() + + # Accept the relevant value after the packet is ACK'd... + with m.If(self.interface.handshakes_in.ack): + m.d.comb += [ + self.interface.clear_endpoint_halt.enable .eq(1), + self.interface.clear_endpoint_halt.direction.eq(self.interface.setup.index[7]), + self.interface.clear_endpoint_halt.number .eq(self.interface.setup.index[0:4]), + ] + + # ... and then return to idle. + m.next = 'IDLE' # SET_ADDRESS -- The host is trying to assign us an address. with m.State('SET_ADDRESS'): diff --git a/luna/gateware/usb/usb2/control.py b/luna/gateware/usb/usb2/control.py index 8b23dfe88..312fb5614 100644 --- a/luna/gateware/usb/usb2/control.py +++ b/luna/gateware/usb/usb2/control.py @@ -178,6 +178,8 @@ def elaborate(self, platform): interface.address_changed .eq(request_handler.address_changed), interface.new_address .eq(request_handler.new_address), + interface.clear_endpoint_halt_out .eq(request_handler.clear_endpoint_halt), + request_handler.active_config .eq(interface.active_config), interface.config_changed .eq(request_handler.config_changed), interface.new_config .eq(request_handler.new_config), diff --git a/luna/gateware/usb/usb2/endpoint.py b/luna/gateware/usb/usb2/endpoint.py index 63a4c7356..eadacded8 100644 --- a/luna/gateware/usb/usb2/endpoint.py +++ b/luna/gateware/usb/usb2/endpoint.py @@ -13,6 +13,7 @@ from .packet import DataCRCInterface, InterpacketTimerInterface, TokenDetectorInterface from .packet import HandshakeExchangeInterface +from .request import ClearEndpointHaltInterface from ..stream import USBInStreamInterface, USBOutStreamInterface from ...utils.bus import OneHotMultiplexer @@ -90,6 +91,9 @@ def __init__(self): self.config_changed = Signal() self.new_config = Signal(8) + self.clear_endpoint_halt_out = Signal(ClearEndpointHaltInterface) + self.clear_endpoint_halt_in = Signal(ClearEndpointHaltInterface) + self.rx = USBOutStreamInterface() self.rx_complete = Signal() self.rx_ready_for_response = Signal() @@ -213,6 +217,8 @@ def elaborate(self, platform): shared.handshakes_in .connect(interface.handshakes_in), shared.tokenizer .connect(interface.tokenizer), + interface.clear_endpoint_halt_in .eq(shared.clear_endpoint_halt_out), + # Rx interface. shared.rx .connect(interface.rx), interface.rx_complete .eq(shared.rx_complete), @@ -259,6 +265,10 @@ def elaborate(self, platform): # ... and our timer start signals. self.or_join_interface_signals(m, lambda interface : interface.timer.start) + self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.enable) + self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.direction) + self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.number) + # Finally, connect up our transmit PID select. conditional = m.If diff --git a/luna/gateware/usb/usb2/endpoints/stream.py b/luna/gateware/usb/usb2/endpoints/stream.py index 399eb65e6..33f9d667d 100644 --- a/luna/gateware/usb/usb2/endpoints/stream.py +++ b/luna/gateware/usb/usb2/endpoints/stream.py @@ -81,6 +81,11 @@ def elaborate(self, platform): # Create our transfer manager, which will be used to sequence packet transfers for our stream. m.submodules.tx_manager = tx_manager = USBInTransferManager(self._max_packet_size) + # Check there has been a ClearFeature(ENDPOINT_HALT) request address to this endpoint. + clear_endpoint_halt = \ + self.interface.clear_endpoint_halt_in.enable & \ + self.interface.clear_endpoint_halt_in.direction & \ + (self.interface.clear_endpoint_halt_in.number == self._endpoint_number) m.d.comb += [ # Always generate ZLPs; in order to pass along when stream packets terminate. @@ -94,6 +99,9 @@ def elaborate(self, platform): tx_manager.flush .eq(self.flush), tx_manager.discard .eq(self.discard), + # ... and data-toggle reset on clear endpoint halt... + tx_manager.reset_sequence .eq(clear_endpoint_halt), + # ... and our output stream... interface.tx .stream_eq(tx_manager.packet_stream), interface.tx_pid_toggle .eq(tx_manager.data_pid), @@ -414,6 +422,16 @@ def elaborate(self, platform): with m.If(data_response_requested & data_accepted): m.d.usb += expected_data_toggle.eq(~expected_data_toggle) + # If there has been a ClearFeature(ENDPOINT_HALT) request address to this endpoint... + clear_endpoint_halt = \ + self.interface.clear_endpoint_halt_in.enable & \ + ~self.interface.clear_endpoint_halt_in.direction & \ + (self.interface.clear_endpoint_halt_in.number == self._endpoint_number) + + with m.If(clear_endpoint_halt): + # ... reset the expected data toggle. + m.d.usb += expected_data_toggle.eq(0) + return m diff --git a/luna/gateware/usb/usb2/request.py b/luna/gateware/usb/usb2/request.py index 31af19d69..e509649d3 100644 --- a/luna/gateware/usb/usb2/request.py +++ b/luna/gateware/usb/usb2/request.py @@ -11,6 +11,7 @@ from amaranth import Signal, Module, Elaboratable, Cat from amaranth.lib.coding import Encoder +from amaranth.lib.data import Struct from amaranth.hdl.rec import Record, DIR_FANOUT from . import USBSpeed @@ -21,6 +22,12 @@ from ..request import SetupPacket +class ClearEndpointHaltInterface(Struct): + enable: 1 + direction: 1 + number: 4 + + class RequestHandlerInterface: """ Record representing a connection between a control endpoint and a request handler. @@ -73,6 +80,8 @@ def __init__(self): self.config_changed = Signal() self.new_config = Signal(8) + self.clear_endpoint_halt = Signal(ClearEndpointHaltInterface) + self.rx = USBOutStreamInterface() self.rx_expected = Signal() self.rx_ready_for_response = Signal() @@ -378,16 +387,17 @@ def elaborate(self, platform): def _connect_interface_outputs(interface): m.d.comb += [ - shared.tx .stream_eq(interface.tx), + shared.tx .stream_eq(interface.tx), - shared.tx_data_pid .eq(interface.tx_data_pid), + shared.tx_data_pid .eq(interface.tx_data_pid), - shared.handshakes_out .eq(interface.handshakes_out), + shared.handshakes_out .eq(interface.handshakes_out), - shared.address_changed .eq(interface.address_changed), - shared.new_address .eq(interface.new_address), - shared.config_changed .eq(interface.config_changed), - shared.new_config .eq(interface.new_config), + shared.address_changed .eq(interface.address_changed), + shared.new_address .eq(interface.new_address), + shared.config_changed .eq(interface.config_changed), + shared.new_config .eq(interface.new_config), + shared.clear_endpoint_halt.eq(interface.clear_endpoint_halt), ] # The encoder provides the index of the single interface that claims the diff --git a/luna/gateware/usb/usb2/transfer.py b/luna/gateware/usb/usb2/transfer.py index 81feef54e..41d6c7262 100644 --- a/luna/gateware/usb/usb2/transfer.py +++ b/luna/gateware/usb/usb2/transfer.py @@ -115,7 +115,7 @@ def elaborate(self, platform): # Handle our PID-sequence reset. # Note that we store the _inverse_ of our data PID, as we'll toggle our DATA PID - # before sending. + # before sending. However, if it has already been toggled then this is overridden below. with m.If(self.reset_sequence): m.d.usb += self.data_pid.eq(~self.start_with_data1) @@ -268,6 +268,10 @@ def elaborate(self, platform): m.d.usb += self.data_pid[0].eq(~self.data_pid[0]), m.next = "WAIT_FOR_DATA" + # If we get a clear halt request while in this state, reset to initial PID. + with m.Elif(self.reset_sequence): + m.d.usb += self.data_pid.eq(self.start_with_data1) + # Otherwise, once we get an IN token, move to sending a packet. with m.Elif(in_token_received): diff --git a/pyproject.toml b/pyproject.toml index 07d6b523a..e6e712387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "pyusb>1.1.1", "pyvcd>=0.2.4", "amaranth~=0.4.0", - "usb-protocol~=0.9", + "usb-protocol~=0.9.1", ] [project.optional-dependencies]