/*
 * usb.c
 * -----
 * 23/4/23
 * USB midi interface
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <kernel.h>
#include <swis.h>
#include "modhdr.h"
#include "module.h"
#include "usb.h"


#ifndef __packed
#define __packed  __attribute__ ((packed)) // Structures without gaps
#endif

// generic descriptor, they all have these 2 fields
typedef struct __packed
{
  unsigned char  bLength;
  unsigned char  bDescriptorType;
} USB_Descriptor;

typedef struct __packed
{
  unsigned char  bLength;
  unsigned char  bDescriptorType;
  unsigned short bString[STR_LEN];
} USB_StringDescriptor;

typedef struct __packed sDeviceDescriptor
{
  unsigned char  bLength;
  unsigned char  bDescriptorType;
  unsigned short bcdUSBversion;
  unsigned char  bDeviceClass;
  unsigned char  bDeviceSubClass;
  unsigned char  bDeviceProtocol;
  unsigned char  bMaxPacketSize;
  unsigned short idVendor;
  unsigned short idProduct;
  unsigned short bcdDevice;
  unsigned char  iManufacturer;
  unsigned char  iProduct;
  unsigned char  iSerialNumber;
  unsigned char  bNumConfigurations;
}
  tDeviceDescriptor;

typedef struct __packed sInterfaceDescriptor
{
  unsigned char  bLength;
  unsigned char  bDescriptorType;
  unsigned char  bInterfaceNumber;
  unsigned char  bAlternateSetting;
  unsigned char  bNumEndpoints;
  unsigned char  bInterfaceClass;
  unsigned char  bInterfaceSubClass;
  unsigned char  bInterfaceProtocol;
  unsigned char  iInterface;
}
  tInterfaceDescriptor;

typedef struct __packed sEndpointDescriptor
{
  unsigned char  bLength;
  unsigned char  bDescriptorType;
  unsigned char  bEndpointAddress;
  unsigned char  bmAttributes;
  unsigned short wMaxPacketSize;
  unsigned char  bInterval;
  unsigned char  bRefresh;
  unsigned char  bSyncAddress;
}
  tEndpointDescriptor;

typedef struct USBServiceCall
{
  unsigned short sclen;       // sum length of the block including the appended descriptors
  unsigned short descoff;     // offset to the descriptors
  char           devname[NAME_LEN]; // device name as appears in DeviceFS e.g.USB37  NULL terminated
  unsigned char  bus;         // bus number, 0-255
  unsigned char  devaddr;     // usb address,1-127
  unsigned char  hostaddr;    // usb address of upstream port 0-127
  unsigned char  hostport;    // port on host address
  unsigned char  speed;       // device speed
  unsigned char  spare1;      // spare, set NULL
  unsigned char  spare2;      // spare, set NULL
  unsigned char  spare3;      // spare, set NULL
  tDeviceDescriptor ddesc;    // device descriptor
                              // followed immediately by zero or more
                              // descriptors (with no word alignment gaps)
} USBServiceCall;

typedef struct USBServiceAnswer USBServiceAnswer;
struct USBServiceAnswer
{
  USBServiceAnswer  *link;    // pointer to next, or NULL for no more
  USBServiceCall     svc;     // data as per 'Attach' service call
};

usb_t usb;


/*
 * usb_string
 * ----------
 * Given a device string eg. "USB12" and string index,
 * returns a pointer to the associated ascii string.
 */
static char *usb_string(char *device, int i)
{
  _kernel_swi_regs r;
  static char str[STR_LEN];

  str[0] = 0;

  if ((i < 1) || (i >= 0x100))
    return str; // "fail1"

  USB_StringDescriptor lang;
  r.r[0] = DEVFS_USB_CTRL;
  r.r[1] = (int)device;
  r.r[2] = 0;
  r.r[3] = 0x80 | (6<<8) | (0x300<<16); // request type | request | value
  r.r[4] = 0 | (sizeof(USB_StringDescriptor)<<16); // index | length
  r.r[5] = (int)&lang;
  r.r[6] = 0;
  if(_kernel_swi(DeviceFS_CallDevice, &r, &r))
    return str; // "fail2"

  USB_Descriptor size;
  r.r[3] += i << 16; // add index to value
  r.r[4] = lang.bString[0] | (sizeof(USB_Descriptor)<<16); // index | length
  r.r[5] = (int)&size;
  if(_kernel_swi(DeviceFS_CallDevice, &r, &r))
    return str; // "fail3"

  USB_StringDescriptor usd;
  r.r[4] = lang.bString[0] | ((size.bLength)<<16); // index | length
  r.r[5] = (int)&usd;
  if(_kernel_swi(DeviceFS_CallDevice, &r, &r))
    return str; // "fail4"

  // extract ascii text from the unicode string descriptor
  char *s = str;
  unsigned short int *uni = usd.bString;
  unsigned char len = (usd.bLength - 2) / 2; // convert descriptor length to string length
  while((len-- > 0) && ((s - str) < (STR_LEN-1)) && (*uni < 128) && (*uni > 31))
    *s++ = *uni++;
  *s = 0;

  return str;
}


/*
 * usb_check_descriptors
 * ---------------------
 * Checks the descriptors of a single device for a midi interface and if found,
 * sets the port data accordingly.
 * Returns 1 if port found, else returns 0.
 */
int usb_check_descriptors(USBServiceCall *svc, usb_port_t *u)
{
  char *device = svc->devname;
  USB_Descriptor *desc = (USB_Descriptor *)((int)&svc->sclen + svc->descoff);
  tDeviceDescriptor *devi = NULL;
  int ends = 0, midi = 0;

  // search descriptors of device
  while(desc->bLength != 0)
  {
    switch(desc->bDescriptorType)
    {
      case 1: // device descriptor
        devi = (tDeviceDescriptor *)desc;
        break;

      case 4: // interface descriptor
        if(devi)
        {
          tInterfaceDescriptor *iface = (tInterfaceDescriptor *)desc;

          // This is where checks could go for known devices with vendor specific protocols.
          // e.g. the original M-Audio MidiSport 1x1 (that I happen to have)
          if((iface->bInterfaceClass == 1) && (iface->bInterfaceSubClass == 3)) // MIDI Streaming
          {
            midi = 1; // found a posible candidate
            u->open = 0;
            strcpy(u->device, device);
            u->vendor = devi->idVendor;
            u->product_id = devi->idProduct;
            u->version = devi->bcdDevice;
            u->iMan = devi->iManufacturer;
            u->iPro = devi->iProduct;
            u->iSer = devi->iSerialNumber;
            u->interface = iface->bInterfaceNumber;
            ends = 0;
          }
        }
        break;

      case 5: // endpoint descriptor
        if((ends < MAX_ENDS) && midi)
        {
          tEndpointDescriptor *endp = (tEndpointDescriptor *)desc;
          // always store input endpoint first so we know which is which
          int e = (endp->bEndpointAddress & 0x80) ? ends & ~1 : ends | 1;
          u->endpoint[e].addr = endp->bEndpointAddress;
          u->endpoint[e].attr = endp->bmAttributes;
          u->endpoint[e].size = endp->wMaxPacketSize;
          ends++;
        }
        break;
    }
    if(midi)
      u->ends = ends;
    desc = (USB_Descriptor *)((int)desc + desc->bLength); // point to next descriptor
  }

  return (ends > 0);
}


/*
 * usb_find
 * --------
 * Searches for suitable usb midi devices.
 * Returns the number of midi devices found.
 */
int usb_find(void)
{
  _kernel_swi_regs regs;
  USBServiceAnswer *list, *item;
  int ports = 0;

  regs.r[0] = 1;
  regs.r[1] = 0xd2;
  regs.r[2] = 0;
  _kernel_swi(OS_ServiceCall, &regs, &regs);
  item = list = (USBServiceAnswer *)regs.r[2];

  // search devices
  while(item != 0)
  {
    ports += usb_check_descriptors(&item->svc, &usb.port[ports]);
    item = item->link; // point to next device
  }

  // free memory blocks
  while(list != 0)
  {
    item = list->link;
    regs.r[0] = 7;
    regs.r[2] = (int)list;
    _kernel_swi(OS_Module, &regs, &regs);
    list = item;
  }

  usb.ports = ports;
  return ports;
}


/*
 * usb_open
 * --------
 * Opens input and output streams
 */
_kernel_oserror *usb_open(int port)
{
  usb_port_t *u = &usb.port[port];
  _kernel_swi_regs regs;
  _kernel_oserror* err = NULL;
  char str[64];
  int i;

  // fetch string descriptions
  if(u->iMan != 0)
    strcpy(u->manufacturer, usb_string(u->device, u->iMan));
  else
    *u->manufacturer = 0;
  if(u->iPro != 0)
    strcpy(u->product, usb_string(u->device, u->iPro));
  else
    *u->product = 0;
  if(u->iSer != 0)
    strcpy(u->serial_no, usb_string(u->device, u->iSer));
  else
    *u->serial_no = 0;

  for(i=0; i<2; i++)
  {
    handle_t *h = &u->endpoint[i].handle;
    // Open the endpoints, input will be first, then output
    // note. The interface needs specifying for composite devices where the midi interface
    //       is not the first. (I have an usb audio & midi unit where midi is interface 3)
    sprintf(str, "devices#interface%d;endpoint%d;size%d;noblock;nopad:%s",
            u->interface, u->endpoint[i].addr, u->endpoint[i].size + 1, u->device);

    regs.r[0] = (i == 0) ? 0x43 : 0x83; // input, existing, no path : output, create, no path
    regs.r[1] = (int)str;
    if((err = _kernel_swi(OS_Find, &regs, &regs)) != NULL)
      return err;

    h->fileswitch = regs.r[0]; // just a number

    regs.r[2] = regs.r[0];
    regs.r[0] = DEVFS_HANDLES2;
    regs.r[1] = (int)u->device;
    if((err = _kernel_swi(DeviceFS_CallDevice, &regs, &regs)))
      return err;

    h->buffer    = regs.r[3]; // just a number
    // these 3 are pointers to data structures
    h->devicefs  = regs.r[4];
    h->usbstream = regs.r[5];
    h->driver    = regs.r[6]; // device driver handle (same for all endpoints)

    regs.r[0] = regs.r[3];
    if((err = _kernel_swi(Buffer_InternalInfo, &regs, &regs)))
      return err;
    h->buff_id = regs.r[0];
    mod.buffer.service = regs.r[1]; // same for all buffers
    mod.buffer.pw = regs.r[2];      // same for all buffers
  }
  u->open = 1;
  usb_wakeup_rx(port);

  return err;
}


/*
 * usb_close
 * ---------
 * Closes input and output streams
 */
_kernel_oserror *usb_close(int port)
{
  usb_port_t *u = &usb.port[port];

  if(!u->open)
    return NULL;

  int i;
  for(i=0; i<2; i++)
  {
    struct buffer_s *b = &mod.buffer;
    _kernel_swi_regs regs;
    _kernel_oserror* err = NULL;

    regs.r[0] = BUF_PURGE;
    regs.r[1] = u->endpoint[i].handle.buff_id;
    if(b->service)
      module_call(b->service, &regs, b->pw);

    regs.r[0] = 0;
    regs.r[1] = u->endpoint[i].handle.fileswitch;
    if((err = _kernel_swi(OS_Find, &regs, &regs)) != NULL)
      return err;
  }
  u->open = 0;

  return NULL;
}


/*
 * usb_remove
 * ----------
 * Clears all the port data
 */
void usb_remove(int port)
{
  memset(&usb.port[port], 0, sizeof(usb_port_t));
}


/*
 * usb_wakeup_rx
 * -------------
 * Requests the remote device to send a packet
 */
void usb_wakeup_rx(int port)
{
  _kernel_swi_regs r;
  handle_t *h = &usb.port[port].endpoint[IN].handle;

  r.r[0] = DEVFS_WAKEUP_RX;
  r.r[1] = h->driver;
  r.r[2] = h->usbstream;
  r.r[3] = 1<<30;
  _kernel_swi(DeviceFS_CallDevice, &r, &r);
}

// The following 4 functions handle devices being connected and removed.

/*
 * usb_check_device
 * ----------------
 * Called from the service call handler.
 * Returns non zero if the device string matches one of ours.
 */
int usb_check_device(char *name)
{
  int i;

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "check device\n");

  for(i=0; i<MAX_PORTS; i++)
    if(usb.port[i].open)
      if(strcmp(name, usb.port[i].device) == 0)
        return 1;

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "no name match\n");

  return 0;
}


/*
 * usb_dead_device
 * ---------------
 * Called from the service call callback in response to usb_check_device.
 * Given the device driver handle, closes, removes, and deregisters the port.
 * Returns non zero if successful.
 */
int usb_dead_device(int driver_handle)
{
  _kernel_swi_regs regs;
  int i;

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "dead device\n");

  for(i=0; i<MAX_PORTS; i++)
    if(usb.port[i].open)
      if(driver_handle == usb.port[i].endpoint[0].handle.driver)
      {
        usb_close(i);
        usb_remove(i);
        regs.r[1] = mod.driver_number[i];
        _kernel_swi(MIDISupport_RemoveDriver, &regs, &regs);
        mod.registered[i] = FALSE;
        return 1;
      }

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "no ports to kill\n");

  return 0;
}


/*
 * usb_new_device
 * --------------
 * Called from the service call handler.
 * Returns non zero if the new device has a valid midi port. Stores all the details in 'temp'
 */
usb_port_t temp;
int usb_new_device(int svc)
{
  int i = usb_check_descriptors((USBServiceCall *)svc, &temp);

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "new device %d\n", i);

  return i;
}


/*
 * usb_open_device
 * ---------------
 * Called from the service call callback in response to usb_new_device.
 * Returns the port number (index) if the new device has a valid midi port and has
 * been opened, else returns -1.
 */
int usb_open_device(void)
{
  int i;

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "open device\n");

  // MIDI Streaming port found, load into first empty slot
  for(i=0; i<MAX_PORTS; i++)
  {
    if(!usb.port[i].open)
    {
      usb.port[i] = temp;
      break;
    }
  }
  if(i < MAX_PORTS)
    if(!usb_open(i))
      return i; // new midi port opened

  if(mod.debug & (1<<DBG_MOD))
    if(mod.log)fprintf(mod.log, "no devices opened\n");

  return -1; // no spare slots, or failed to open
}


/*
 * usb_report
 * ----------
 * Displays connected port info
 */
void usb_report(int full)
{
  usb_port_t *u;
  int i, j;

  // display the available midi ports
  printf("\n");
  for(i=0; i<MAX_PORTS; i++)
  {
    u = &usb.port[i];
    if(full) // detailed report
    {
      if(u->open)
      {
        printf("Device: %s: (USBPort%d)\n", u->device, i);
        printf("  VID=%X   PID=%X   VERS=%X\n", u->vendor, u->product_id, u->version);
        printf("  Manufacturer: %s\n", u->manufacturer);
        printf("  Product: %s\n", u->product);
        printf("  Serial number: %s\n", u->serial_no);
        printf("  Interface %d\n", u->interface);
        for(j=0; j<u->ends; j++)
        {
          printf("  Endpoint %d %s\n", j + 1, (u->endpoint[j].addr & 0x80) ? "Input" : "Output");
          printf("    address: %d\n", u->endpoint[j].addr & 0x7f);
          const char * const attr_str[] = {"Control","Isochronous","Bulk","Interrupt"};
          const char *a;
          if(u->endpoint[j].attr < 4)
            a = attr_str[u->endpoint[j].attr];
          else
            a = "";
          printf("    attributes: %d %s\n", u->endpoint[j].attr, a);
          printf("    size: %d\n", u->endpoint[j].size);
/*
          handle_t *h = &u->endpoint[j].handle;

          printf("    Handles:\n");
          printf("      fileswitch: %08X\n", h->fileswitch);
          printf("      buffer:     %08X\n", h->buffer);
          printf("      devicefs:   %08X\n", h->devicefs);
          printf("      usbstream:  %08X\n", h->usbstream);
          printf("      driver:     %08X\n", h->driver);

          unsigned int *data;
          int k;

          printf("    DeviceFS handle->\n    ");
          data = (unsigned int *)h->devicefs;
          for(k=0; k<16; k++)
            printf("%08X%s", data[k], ((k & 7) != 7) ? " " : "\n    ");

          printf("    USB Stream handle->\n    ");
          data = (unsigned int *)h->usbstream;
          for(k=0; k<16; k++)
            printf("%08X%s", data[k], ((k & 7) != 7) ? " " : "\n    ");

          printf("    Drivers handle->\n    ");
          data = (unsigned int *)h->driver;
          for(k=0; k<16; k++)
            printf("%08X%s", data[k], ((k & 7) != 7) ? " " : "\n    ");
*/
        }
        printf("\n");
      }
    }
    else // short report
    {
      char *str = "(no product name)";
      if(!u->open)
        str = "(unconnected)";
      else if(u->product != NULL)
        if(*u->product != 0)
          str = u->product;
      printf("USBPort%d : %s\n", i, str);
    }
  }
}



